110 lines
2.9 KiB
Python
110 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
HTTP API server for querying the vector storage via the existing retrieval pipeline.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from loguru import logger
|
|
|
|
from retrieval import initialize_retriever
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def setup_logging() -> None:
|
|
"""Configure loguru to stdout and rotating file logs."""
|
|
logs_dir = Path("logs")
|
|
logs_dir.mkdir(exist_ok=True)
|
|
|
|
logger.remove()
|
|
logger.add(
|
|
"logs/dev.log",
|
|
rotation="10 MB",
|
|
retention="10 days",
|
|
level="INFO",
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
|
|
)
|
|
logger.add(
|
|
sys.stdout,
|
|
level="INFO",
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
|
colorize=True,
|
|
)
|
|
|
|
|
|
setup_logging()
|
|
|
|
app = FastAPI(title="LlamaIndex RAG API", version="1.0.0")
|
|
|
|
|
|
class TestQueryRequest(BaseModel):
|
|
query: str = Field(..., min_length=1, description="User query text")
|
|
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
|
|
|
|
|
|
class SourceItem(BaseModel):
|
|
content: str
|
|
score: float | None = None
|
|
metadata: dict = Field(default_factory=dict)
|
|
|
|
|
|
class TestQueryResponse(BaseModel):
|
|
query: str
|
|
response: str
|
|
sources: list[SourceItem]
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/api/test-query", response_model=TestQueryResponse)
|
|
def test_query(payload: TestQueryRequest) -> TestQueryResponse:
|
|
"""
|
|
Query the vector store using the existing retrieval/query engine.
|
|
"""
|
|
query = payload.query.strip()
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
|
|
|
|
logger.info(f"Received /api/test-query request (top_k={payload.top_k})")
|
|
|
|
try:
|
|
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
|
|
result = query_engine.query(query)
|
|
|
|
sources: list[SourceItem] = []
|
|
if hasattr(result, "source_nodes"):
|
|
for node in result.source_nodes:
|
|
sources.append(
|
|
SourceItem(
|
|
content=str(getattr(node, "text", "")),
|
|
score=getattr(node, "score", None),
|
|
metadata=getattr(node, "metadata", {}) or {},
|
|
)
|
|
)
|
|
|
|
response_text = str(result)
|
|
logger.info(
|
|
f"/api/test-query completed successfully (sources={len(sources)})"
|
|
)
|
|
return TestQueryResponse(query=query, response=response_text, sources=sources)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"/api/test-query failed: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to process query")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)
|