Files
rag-solution/services/rag/langchain/agent.py

242 lines
7.6 KiB
Python

"""Agent module for the RAG solution with Ollama-powered chat agent."""
import os
from typing import List, Dict, Any, Optional
from langchain_core.tools import BaseTool, tool
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.agents import AgentFinish
from langgraph.prebuilt import create_react_agent
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from loguru import logger
from retrieval import create_retriever
from vector_storage import initialize_vector_store
class DocumentRetrievalTool(BaseTool):
"""Tool for retrieving documents from the vector store based on a query."""
name: str = "document_retrieval"
description: str = "Retrieve documents from the vector store based on a query. Input should be a search query string."
# Add retriever as a field
retriever: object = None
def __init__(self, **kwargs):
# Initialize the retriever before calling super().__init__()
super().__init__(**kwargs)
self.retriever = create_retriever()
def _run(self, query: str) -> str:
"""Execute the document retrieval."""
try:
# Use the retriever to get relevant documents
results = self.retriever.invoke(query)
if not results:
return "No relevant documents found for the query."
# Format the results to return to the agent
formatted_results = []
for i, doc in enumerate(results):
content_preview = doc.page_content[:500] # Limit content preview
metadata = doc.metadata
formatted_doc = (
f"Document {i+1}:\n"
f"Source: {metadata.get('source', 'Unknown')}\n"
f"Filename: {metadata.get('filename', 'Unknown')}\n"
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
f"Content: {content_preview}...\n\n"
)
formatted_results.append(formatted_doc)
return "".join(formatted_results)
except Exception as e:
logger.error(f"Error during document retrieval: {str(e)}")
return f"Error during document retrieval: {str(e)}"
async def _arun(self, query: str):
"""Async version of the document retrieval."""
return self._run(query)
def create_chat_agent(
collection_name: str = "documents_langchain",
llm_model: str = None
) -> Any:
"""
Create a chat agent with document retrieval capabilities.
Args:
collection_name: Name of the Qdrant collection to use
llm_model: Name of the Ollama model to use (defaults to OLLAMA_CHAT_MODEL env var)
Returns:
Configured chat agent
"""
logger.info("Creating chat agent with document retrieval capabilities")
# Get the model name from environment if not provided
if llm_model is None:
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
# Initialize the Ollama chat model
llm = ChatOllama(
model=llm_model,
base_url="http://localhost:11434", # Default Ollama URL
temperature=0.1,
)
# Create the document retrieval tool
retrieval_tool = DocumentRetrievalTool()
# Create the agent with the LLM and tools
tools = [retrieval_tool]
agent = create_react_agent(llm, tools)
logger.info("Chat agent created successfully")
return agent
def chat_with_agent(
query: str,
collection_name: str = "documents_langchain",
llm_model: str = None,
history: List[BaseMessage] = None
) -> Dict[str, Any]:
"""
Chat with the agent and get a response based on the query and document retrieval.
Args:
query: The user's query
collection_name: Name of the Qdrant collection to use
llm_model: Name of the Ollama model to use
history: Conversation history (list of messages)
Returns:
Dictionary containing the agent's response and metadata
"""
logger.info(f"Starting chat with query: {query}")
# Create the agent
agent = create_chat_agent(collection_name, llm_model)
# Prepare the input for the agent
if history is None:
history = []
# Add the user's query to the history
history.append(HumanMessage(content=query))
# Prepare the input for the agent executor
agent_input = {
"messages": history
}
try:
# Invoke the agent
result = agent.invoke(agent_input)
# Extract the agent's response
messages = result.get("messages", [])
ai_message = None
# Find the AI message in the results
for msg in reversed(messages):
if isinstance(msg, AIMessage):
ai_message = msg
break
if ai_message is None:
# If no AI message was found, return the last message content
if messages:
last_msg = messages[-1]
response_content = getattr(last_msg, 'content', str(last_msg))
else:
response_content = "I couldn't generate a response to your query."
else:
response_content = ai_message.content
# Create the response dictionary
response = {
"response": response_content,
"query": query,
"history": messages, # Return updated history
"success": True
}
logger.info("Chat completed successfully")
return response
except Exception as e:
logger.error(f"Error during chat: {str(e)}")
return {
"response": f"I encountered an error while processing your request: {str(e)}",
"query": query,
"history": history,
"success": False
}
def run_chat_loop(
collection_name: str = "documents_langchain",
llm_model: str = None
):
"""
Run an interactive chat loop with the agent.
Args:
collection_name: Name of the Qdrant collection to use
llm_model: Name of the Ollama model to use
"""
logger.info("Starting interactive chat loop")
print("Chat Agent initialized. Type 'quit' or 'exit' to end the conversation.\n")
history = []
while True:
try:
# Get user input
user_input = input("You: ").strip()
# Check for exit commands
if user_input.lower() in ['quit', 'exit', 'q']:
print("Ending chat session. Goodbye!")
break
if not user_input:
continue
# Get response from the agent
response_data = chat_with_agent(
query=user_input,
collection_name=collection_name,
llm_model=llm_model,
history=history
)
# Update history with the new messages
history = response_data.get("history", [])
# Print the agent's response
print(f"Agent: {response_data.get('response', 'No response generated')}\n")
except KeyboardInterrupt:
print("\nEnding chat session. Goodbye!")
break
except Exception as e:
logger.error(f"Error in chat loop: {str(e)}")
print(f"An error occurred: {str(e)}")
break
if __name__ == "__main__":
# Example usage
print("Initializing chat agent...")
# Run the interactive chat loop
run_chat_loop()