more sophisticated chat like retrieval for llamaindex

This commit is contained in:
2026-02-26 19:02:05 +03:00
parent 468d5fb572
commit 6b3fa1cfaa
3 changed files with 390 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from chat_engine import chat_with_documents
from retrieval import initialize_retriever
load_dotenv()
@@ -55,6 +56,10 @@ app.add_middleware(
class TestQueryRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text")
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
mode: str = Field(
"agent",
description="agent (Phase 12 default) or retrieval (fallback/debug)",
)
class SourceItem(BaseModel):
@@ -67,6 +72,7 @@ class TestQueryResponse(BaseModel):
query: str
response: str
sources: list[SourceItem]
mode: str | None = None
error: bool
success: bool
@@ -85,9 +91,43 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
if not query:
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
logger.info(f"Received /api/test-query request (top_k={payload.top_k})")
logger.info(
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
)
try:
if payload.mode.lower() in {"agent", "chat"}:
chat_result = chat_with_documents(
query=query,
top_k=payload.top_k,
use_agent=True,
)
sources = [
SourceItem(
content=str(src.get("content_snippet", "")),
score=src.get("score"),
metadata=src.get("metadata", {}) or {},
)
for src in chat_result.get("sources", [])
]
logger.info(
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=str(chat_result.get("answer", "")),
sources=sources,
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
error=False,
success=True,
)
if payload.mode.lower() not in {"retrieval", "debug"}:
raise HTTPException(
status_code=400,
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
)
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
result = query_engine.query(query)
@@ -103,11 +143,14 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
)
response_text = str(result)
logger.info(f"/api/test-query completed successfully (sources={len(sources)})")
logger.info(
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=response_text,
sources=sources,
mode="retrieval",
error=False,
success=True,
)