Files

174 lines
4.8 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
HTTP API server for querying the vector storage via the existing retrieval pipeline.
"""
import sys
2026-02-26 15:15:02 +03:00
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
2026-02-26 16:33:03 +03:00
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
2026-02-26 15:15:02 +03:00
from pydantic import BaseModel, Field
from chat_engine import chat_with_documents
from retrieval import initialize_retriever
load_dotenv()
def setup_logging() -> None:
"""Configure loguru to stdout and rotating file logs."""
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
logger.remove()
logger.add(
"logs/dev.log",
rotation="10 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
)
logger.add(
sys.stdout,
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
colorize=True,
)
setup_logging()
app = FastAPI(title="LlamaIndex RAG API", version="1.0.0")
2026-03-11 22:30:02 +03:00
origins = [
"*",
]
2026-02-26 16:33:03 +03:00
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
2026-03-11 22:30:02 +03:00
allow_origins=origins, # In production, configure this properly
2026-02-26 16:33:03 +03:00
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
2026-03-11 22:30:02 +03:00
allow_private_network=True,
2026-02-26 16:33:03 +03:00
)
class TestQueryRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text")
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
mode: str = Field(
"agent",
description="agent (Phase 12 default) or retrieval (fallback/debug)",
)
class SourceItem(BaseModel):
content: str
score: float | None = None
metadata: dict = Field(default_factory=dict)
class TestQueryResponse(BaseModel):
query: str
response: str
sources: list[SourceItem]
mode: str | None = None
2026-02-26 17:24:15 +03:00
error: bool
2026-02-26 17:29:54 +03:00
success: bool
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.post("/api/test-query", response_model=TestQueryResponse)
def test_query(payload: TestQueryRequest) -> TestQueryResponse:
"""
Query the vector store using the existing retrieval/query engine.
"""
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
logger.info(
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
)
try:
if payload.mode.lower() in {"agent", "chat"}:
chat_result = chat_with_documents(
query=query,
top_k=payload.top_k,
use_agent=True,
)
sources = [
SourceItem(
content=str(src.get("content_snippet", "")),
score=src.get("score"),
metadata=src.get("metadata", {}) or {},
)
for src in chat_result.get("sources", [])
]
logger.info(
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=str(chat_result.get("answer", "")),
sources=sources,
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
error=False,
success=True,
)
if payload.mode.lower() not in {"retrieval", "debug"}:
raise HTTPException(
status_code=400,
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
)
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
result = query_engine.query(query)
sources: list[SourceItem] = []
if hasattr(result, "source_nodes"):
for node in result.source_nodes:
sources.append(
SourceItem(
content=str(getattr(node, "text", "")),
score=getattr(node, "score", None),
metadata=getattr(node, "metadata", {}) or {},
)
)
response_text = str(result)
logger.info(
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
)
2026-02-26 17:24:15 +03:00
return TestQueryResponse(
2026-02-26 17:29:54 +03:00
query=query,
response=response_text,
sources=sources,
mode="retrieval",
2026-02-26 17:29:54 +03:00
error=False,
success=True,
2026-02-26 17:24:15 +03:00
)
except HTTPException:
raise
except Exception as e:
logger.error(f"/api/test-query failed: {e}")
raise HTTPException(status_code=500, detail="Failed to process query")
if __name__ == "__main__":
import uvicorn
2026-02-26 15:15:02 +03:00
uvicorn.run("server:app", host="0.0.0.0", port=8334, reload=False)