"""Web server for the RAG solution with LangServe integration.""" import json import os from contextlib import asynccontextmanager from typing import Any, Dict from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from loguru import logger from pydantic import BaseModel from agent import chat_with_agent class QueryRequest(BaseModel): """Request model for the query endpoint.""" query: str collection_name: str = "documents_langchain" # llm_model: str = None class QueryResponse(BaseModel): """Response model for the query endpoint.""" response: str query: str success: bool error: str = None @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan event handler for startup and shutdown.""" # Startup logger.info("Starting RAG server...") yield # Shutdown logger.info("Shutting down RAG server...") # Create FastAPI app app = FastAPI( title="RAG Solution API", description="API for Retrieval-Augmented Generation solution with Langchain", version="1.0.0", lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, configure this properly allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/test-query", response_model=QueryResponse) async def test_query(request: QueryRequest) -> QueryResponse: """ POST endpoint to query the RAG agent. Accepts a JSON payload with a "query" field and returns the agent's response. """ logger.info(f"Received query: {request.query}") try: # Call the existing chat_with_agent function from agent.py response_data = chat_with_agent( query=request.query, collection_name=request.collection_name, llm_model=request.llm_model, ) logger.info("Query processed successfully") return QueryResponse( response=response_data.get("response", ""), query=request.query, success=response_data.get("success", False), ) except Exception as e: logger.error(f"Error processing query: {str(e)}") error_msg = f"Error processing query: {str(e)}" return QueryResponse( response="I encountered an error while processing your request.", query=request.query, success=False, error=error_msg, ) @app.get("/") async def root(): """Root endpoint for health check.""" return {"message": "RAG Solution API is running", "status": "healthy"} if __name__ == "__main__": import uvicorn # Configure logging to output to both file and stdout as specified in requirements logs_dir = os.path.join(os.getcwd(), "logs") os.makedirs(logs_dir, exist_ok=True) logger.add("logs/dev.log", rotation="10 MB", retention="10 days") # Run the server uvicorn.run( "server:app", host="0.0.0.0", port=8000, reload=True, # Enable auto-reload during development )