"""Agent module for the RAG solution with Ollama-powered chat agent.""" import os from typing import List, Dict, Any, Optional from langchain_core.tools import BaseTool, tool from langchain_core.runnables import RunnableConfig from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from langchain_core.agents import AgentFinish from langgraph.prebuilt import create_react_agent from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from loguru import logger from retrieval import create_retriever from vector_storage import initialize_vector_store class DocumentRetrievalTool(BaseTool): """Tool for retrieving documents from the vector store based on a query.""" name: str = "document_retrieval" description: str = "Retrieve documents from the vector store based on a query. Input should be a search query string." # Add retriever as a field retriever: object = None def __init__(self, **kwargs): # Initialize the retriever before calling super().__init__() super().__init__(**kwargs) self.retriever = create_retriever() def _run(self, query: str) -> str: """Execute the document retrieval.""" try: # Use the retriever to get relevant documents results = self.retriever.invoke(query) if not results: return "No relevant documents found for the query." # Format the results to return to the agent formatted_results = [] for i, doc in enumerate(results): content_preview = doc.page_content[:500] # Limit content preview metadata = doc.metadata formatted_doc = ( f"Document {i+1}:\n" f"Source: {metadata.get('source', 'Unknown')}\n" f"Filename: {metadata.get('filename', 'Unknown')}\n" f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n" f"Content: {content_preview}...\n\n" ) formatted_results.append(formatted_doc) return "".join(formatted_results) except Exception as e: logger.error(f"Error during document retrieval: {str(e)}") return f"Error during document retrieval: {str(e)}" async def _arun(self, query: str): """Async version of the document retrieval.""" return self._run(query) def create_chat_agent( collection_name: str = "documents_langchain", llm_model: str = None ) -> Any: """ Create a chat agent with document retrieval capabilities. Args: collection_name: Name of the Qdrant collection to use llm_model: Name of the model to use (defaults to environment variable based on strategy) Returns: Configured chat agent """ logger.info("Creating chat agent with document retrieval capabilities") # Determine which model strategy to use chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower() if chat_model_strategy == "openai": # Use OpenAI-compatible API openai_chat_url = os.getenv("OPENAI_CHAT_URL") openai_chat_key = os.getenv("OPENAI_CHAT_KEY") if not openai_chat_url or not openai_chat_key: raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy") # Get the model name from environment if not provided if llm_model is None: llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model # Initialize the OpenAI-compatible chat model llm = ChatOpenAI( model=llm_model, openai_api_base=openai_chat_url, openai_api_key=openai_chat_key, temperature=0.1, ) logger.info(f"Using OpenAI-compatible model: {llm_model} via {openai_chat_url}") else: # Default to ollama # Use Ollama # Get the model name from environment if not provided if llm_model is None: llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1") # Initialize the Ollama chat model llm = ChatOllama( model=llm_model, base_url="http://localhost:11434", # Default Ollama URL temperature=0.1, ) logger.info(f"Using Ollama model: {llm_model}") # Create the document retrieval tool retrieval_tool = DocumentRetrievalTool() # Create the agent with the LLM and tools tools = [retrieval_tool] agent = create_react_agent(llm, tools) logger.info("Chat agent created successfully") return agent def chat_with_agent( query: str, collection_name: str = "documents_langchain", llm_model: str = None, history: List[BaseMessage] = None ) -> Dict[str, Any]: """ Chat with the agent and get a response based on the query and document retrieval. Args: query: The user's query collection_name: Name of the Qdrant collection to use llm_model: Name of the model to use (defaults to environment variable based on strategy) history: Conversation history (list of messages) Returns: Dictionary containing the agent's response and metadata """ logger.info(f"Starting chat with query: {query}") # Create the agent agent = create_chat_agent(collection_name, llm_model) # Prepare the input for the agent if history is None: history = [] # Add the user's query to the history history.append(HumanMessage(content=query)) # Prepare the input for the agent executor agent_input = { "messages": history } try: # Invoke the agent result = agent.invoke(agent_input) # Extract the agent's response messages = result.get("messages", []) ai_message = None # Find the AI message in the results for msg in reversed(messages): if isinstance(msg, AIMessage): ai_message = msg break if ai_message is None: # If no AI message was found, return the last message content if messages: last_msg = messages[-1] response_content = getattr(last_msg, 'content', str(last_msg)) else: response_content = "I couldn't generate a response to your query." else: response_content = ai_message.content # Create the response dictionary response = { "response": response_content, "query": query, "history": messages, # Return updated history "success": True } logger.info("Chat completed successfully") return response except Exception as e: logger.error(f"Error during chat: {str(e)}") return { "response": f"I encountered an error while processing your request: {str(e)}", "query": query, "history": history, "success": False } def run_chat_loop( collection_name: str = "documents_langchain", llm_model: str = None ): """ Run an interactive chat loop with the agent. Args: collection_name: Name of the Qdrant collection to use llm_model: Name of the model to use (defaults to environment variable based on strategy) """ logger.info("Starting interactive chat loop") # Determine which model strategy is being used and inform the user chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower() if chat_model_strategy == "openai": model_info = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") print(f"Chat Agent initialized with OpenAI-compatible model: {model_info}") else: model_info = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1") print(f"Chat Agent initialized with Ollama model: {model_info}") print("Type 'quit' or 'exit' to end the conversation.\n") history = [] while True: try: # Get user input user_input = input("You: ").strip() # Check for exit commands if user_input.lower() in ['quit', 'exit', 'q']: print("Ending chat session. Goodbye!") break if not user_input: continue # Get response from the agent response_data = chat_with_agent( query=user_input, collection_name=collection_name, llm_model=llm_model, history=history ) # Update history with the new messages history = response_data.get("history", []) # Print the agent's response print(f"Agent: {response_data.get('response', 'No response generated')}\n") except KeyboardInterrupt: print("\nEnding chat session. Goodbye!") break except Exception as e: logger.error(f"Error in chat loop: {str(e)}") print(f"An error occurred: {str(e)}") break if __name__ == "__main__": # Example usage print("Initializing chat agent...") # Run the interactive chat loop run_chat_loop()