Files
rag-solution/services/rag/llamaindex/server.py

108 lines
2.9 KiB
Python

#!/usr/bin/env python3
"""
HTTP API server for querying the vector storage via the existing retrieval pipeline.
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from loguru import logger
from pydantic import BaseModel, Field
from retrieval import initialize_retriever
load_dotenv()
def setup_logging() -> None:
"""Configure loguru to stdout and rotating file logs."""
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
logger.remove()
logger.add(
"logs/dev.log",
rotation="10 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
)
logger.add(
sys.stdout,
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
colorize=True,
)
setup_logging()
app = FastAPI(title="LlamaIndex RAG API", version="1.0.0")
class TestQueryRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text")
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
class SourceItem(BaseModel):
content: str
score: float | None = None
metadata: dict = Field(default_factory=dict)
class TestQueryResponse(BaseModel):
query: str
response: str
sources: list[SourceItem]
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.post("/api/test-query", response_model=TestQueryResponse)
def test_query(payload: TestQueryRequest) -> TestQueryResponse:
"""
Query the vector store using the existing retrieval/query engine.
"""
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
logger.info(f"Received /api/test-query request (top_k={payload.top_k})")
try:
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
result = query_engine.query(query)
sources: list[SourceItem] = []
if hasattr(result, "source_nodes"):
for node in result.source_nodes:
sources.append(
SourceItem(
content=str(getattr(node, "text", "")),
score=getattr(node, "score", None),
metadata=getattr(node, "metadata", {}) or {},
)
)
response_text = str(result)
logger.info(f"/api/test-query completed successfully (sources={len(sources)})")
return TestQueryResponse(query=query, response=response_text, sources=sources)
except HTTPException:
raise
except Exception as e:
logger.error(f"/api/test-query failed: {e}")
raise HTTPException(status_code=500, detail="Failed to process query")
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8334, reload=False)