more sophisticated chat like retrieval for llamaindex
This commit is contained in:
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from chat_engine import chat_with_documents
|
||||
from retrieval import initialize_retriever
|
||||
|
||||
load_dotenv()
|
||||
@@ -55,6 +56,10 @@ app.add_middleware(
|
||||
class TestQueryRequest(BaseModel):
|
||||
query: str = Field(..., min_length=1, description="User query text")
|
||||
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
|
||||
mode: str = Field(
|
||||
"agent",
|
||||
description="agent (Phase 12 default) or retrieval (fallback/debug)",
|
||||
)
|
||||
|
||||
|
||||
class SourceItem(BaseModel):
|
||||
@@ -67,6 +72,7 @@ class TestQueryResponse(BaseModel):
|
||||
query: str
|
||||
response: str
|
||||
sources: list[SourceItem]
|
||||
mode: str | None = None
|
||||
error: bool
|
||||
success: bool
|
||||
|
||||
@@ -85,9 +91,43 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
|
||||
|
||||
logger.info(f"Received /api/test-query request (top_k={payload.top_k})")
|
||||
logger.info(
|
||||
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
|
||||
)
|
||||
|
||||
try:
|
||||
if payload.mode.lower() in {"agent", "chat"}:
|
||||
chat_result = chat_with_documents(
|
||||
query=query,
|
||||
top_k=payload.top_k,
|
||||
use_agent=True,
|
||||
)
|
||||
sources = [
|
||||
SourceItem(
|
||||
content=str(src.get("content_snippet", "")),
|
||||
score=src.get("score"),
|
||||
metadata=src.get("metadata", {}) or {},
|
||||
)
|
||||
for src in chat_result.get("sources", [])
|
||||
]
|
||||
logger.info(
|
||||
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
|
||||
)
|
||||
return TestQueryResponse(
|
||||
query=query,
|
||||
response=str(chat_result.get("answer", "")),
|
||||
sources=sources,
|
||||
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
|
||||
error=False,
|
||||
success=True,
|
||||
)
|
||||
|
||||
if payload.mode.lower() not in {"retrieval", "debug"}:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
|
||||
)
|
||||
|
||||
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
|
||||
result = query_engine.query(query)
|
||||
|
||||
@@ -103,11 +143,14 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
|
||||
)
|
||||
|
||||
response_text = str(result)
|
||||
logger.info(f"/api/test-query completed successfully (sources={len(sources)})")
|
||||
logger.info(
|
||||
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
|
||||
)
|
||||
return TestQueryResponse(
|
||||
query=query,
|
||||
response=response_text,
|
||||
sources=sources,
|
||||
mode="retrieval",
|
||||
error=False,
|
||||
success=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user