preparations for demo html page
This commit is contained in:
118
services/rag/langchain/server.py
Normal file
118
services/rag/langchain/server.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Web server for the RAG solution with LangServe integration."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent import chat_with_agent
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""Request model for the query endpoint."""
|
||||
|
||||
query: str
|
||||
collection_name: str = "documents_langchain"
|
||||
# llm_model: str = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
"""Response model for the query endpoint."""
|
||||
|
||||
response: str
|
||||
query: str
|
||||
success: bool
|
||||
error: str = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan event handler for startup and shutdown."""
|
||||
# Startup
|
||||
logger.info("Starting RAG server...")
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down RAG server...")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="RAG Solution API",
|
||||
description="API for Retrieval-Augmented Generation solution with Langchain",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, configure this properly
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/test-query", response_model=QueryResponse)
|
||||
async def test_query(request: QueryRequest) -> QueryResponse:
|
||||
"""
|
||||
POST endpoint to query the RAG agent.
|
||||
|
||||
Accepts a JSON payload with a "query" field and returns the agent's response.
|
||||
"""
|
||||
logger.info(f"Received query: {request.query}")
|
||||
|
||||
try:
|
||||
# Call the existing chat_with_agent function from agent.py
|
||||
response_data = chat_with_agent(
|
||||
query=request.query,
|
||||
collection_name=request.collection_name,
|
||||
llm_model=request.llm_model,
|
||||
)
|
||||
|
||||
logger.info("Query processed successfully")
|
||||
|
||||
return QueryResponse(
|
||||
response=response_data.get("response", ""),
|
||||
query=request.query,
|
||||
success=response_data.get("success", False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {str(e)}")
|
||||
error_msg = f"Error processing query: {str(e)}"
|
||||
|
||||
return QueryResponse(
|
||||
response="I encountered an error while processing your request.",
|
||||
query=request.query,
|
||||
success=False,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint for health check."""
|
||||
return {"message": "RAG Solution API is running", "status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Configure logging to output to both file and stdout as specified in requirements
|
||||
logs_dir = os.path.join(os.getcwd(), "logs")
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
logger.add("logs/dev.log", rotation="10 MB", retention="10 days")
|
||||
|
||||
# Run the server
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True, # Enable auto-reload during development
|
||||
)
|
||||
Reference in New Issue
Block a user