2026-02-04 00:02:53 +03:00
|
|
|
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
|
|
|
|
|
|
|
|
|
import os
|
2026-02-05 01:07:25 +03:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
2026-02-05 00:08:59 +03:00
|
|
|
from dotenv import load_dotenv
|
2026-02-04 00:02:53 +03:00
|
|
|
from langchain_core.agents import AgentFinish
|
2026-02-05 01:07:25 +03:00
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
|
from langchain_core.tools import BaseTool, tool
|
2026-02-04 00:02:53 +03:00
|
|
|
from langchain_ollama import ChatOllama
|
2026-02-04 22:30:57 +03:00
|
|
|
from langchain_openai import ChatOpenAI
|
2026-02-05 01:07:25 +03:00
|
|
|
from langgraph.prebuilt import create_react_agent
|
2026-02-04 00:02:53 +03:00
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
from retrieval import create_retriever
|
|
|
|
|
from vector_storage import initialize_vector_store
|
|
|
|
|
|
2026-02-05 00:08:59 +03:00
|
|
|
# Load environment variables
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
|
2026-02-10 13:20:19 +03:00
|
|
|
def get_llm_model_info(
|
|
|
|
|
llm_model: Optional[str] = None,
|
|
|
|
|
) -> Tuple[str, str, str, str, str]:
|
2026-02-04 23:13:00 +03:00
|
|
|
"""
|
|
|
|
|
Get LLM model information based on environment configuration.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple containing (strategy, model_name, base_url_or_api_base, api_key, model_type)
|
|
|
|
|
"""
|
|
|
|
|
# Determine which model strategy to use
|
|
|
|
|
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
|
|
|
|
|
|
|
|
|
if chat_model_strategy == "openai":
|
|
|
|
|
# Use OpenAI-compatible API
|
|
|
|
|
openai_chat_url = os.getenv("OPENAI_CHAT_URL")
|
|
|
|
|
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
|
|
|
|
|
|
|
|
|
if not openai_chat_url or not openai_chat_key:
|
2026-02-05 01:07:25 +03:00
|
|
|
raise ValueError(
|
|
|
|
|
"OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy"
|
|
|
|
|
)
|
2026-02-04 23:13:00 +03:00
|
|
|
|
|
|
|
|
# Get the model name from environment if not provided
|
|
|
|
|
if llm_model is None:
|
2026-02-05 01:07:25 +03:00
|
|
|
llm_model = os.getenv(
|
|
|
|
|
"OPENAI_CHAT_MODEL", "PREDEFINED_EXTERNAL_MODEL"
|
|
|
|
|
) # Default to a common model
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
chat_model_strategy,
|
|
|
|
|
llm_model,
|
|
|
|
|
openai_chat_url,
|
|
|
|
|
openai_chat_key,
|
|
|
|
|
"ChatOpenAI",
|
|
|
|
|
)
|
2026-02-04 23:13:00 +03:00
|
|
|
else: # Default to ollama
|
|
|
|
|
# Use Ollama
|
|
|
|
|
# Get the model name from environment if not provided
|
|
|
|
|
if llm_model is None:
|
|
|
|
|
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
|
|
|
|
|
2026-02-05 01:07:25 +03:00
|
|
|
return (
|
|
|
|
|
chat_model_strategy,
|
|
|
|
|
llm_model,
|
|
|
|
|
"http://localhost:11434",
|
|
|
|
|
"",
|
|
|
|
|
"ChatOllama",
|
|
|
|
|
)
|
2026-02-04 23:13:00 +03:00
|
|
|
|
|
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
class DocumentRetrievalTool(BaseTool):
|
|
|
|
|
"""Tool for retrieving documents from the vector store based on a query."""
|
|
|
|
|
|
|
|
|
|
name: str = "document_retrieval"
|
|
|
|
|
description: str = "Retrieve documents from the vector store based on a query. Input should be a search query string."
|
|
|
|
|
# Add retriever as a field
|
|
|
|
|
retriever: object = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
# Initialize the retriever before calling super().__init__()
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.retriever = create_retriever()
|
|
|
|
|
|
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
|
"""Execute the document retrieval."""
|
|
|
|
|
try:
|
|
|
|
|
# Use the retriever to get relevant documents
|
|
|
|
|
results = self.retriever.invoke(query)
|
|
|
|
|
|
|
|
|
|
if not results:
|
|
|
|
|
return "No relevant documents found for the query."
|
|
|
|
|
|
|
|
|
|
# Format the results to return to the agent
|
|
|
|
|
formatted_results = []
|
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
|
content_preview = doc.page_content[:500] # Limit content preview
|
|
|
|
|
metadata = doc.metadata
|
|
|
|
|
|
|
|
|
|
formatted_doc = (
|
2026-02-05 01:07:25 +03:00
|
|
|
f"Document {i + 1}:\n"
|
2026-02-04 00:02:53 +03:00
|
|
|
f"Source: {metadata.get('source', 'Unknown')}\n"
|
|
|
|
|
f"Filename: {metadata.get('filename', 'Unknown')}\n"
|
|
|
|
|
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
|
|
|
|
|
f"Content: {content_preview}...\n\n"
|
|
|
|
|
)
|
|
|
|
|
formatted_results.append(formatted_doc)
|
|
|
|
|
|
|
|
|
|
return "".join(formatted_results)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error during document retrieval: {str(e)}")
|
|
|
|
|
return f"Error during document retrieval: {str(e)}"
|
|
|
|
|
|
|
|
|
|
async def _arun(self, query: str):
|
|
|
|
|
"""Async version of the document retrieval."""
|
|
|
|
|
return self._run(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_chat_agent(
|
2026-02-10 13:20:19 +03:00
|
|
|
collection_name: str = "documents_langchain", llm_model: Optional[str] = None
|
2026-02-04 00:02:53 +03:00
|
|
|
) -> Any:
|
|
|
|
|
"""
|
|
|
|
|
Create a chat agent with document retrieval capabilities.
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
Args:
|
|
|
|
|
collection_name: Name of the Qdrant collection to use
|
2026-02-04 22:30:57 +03:00
|
|
|
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
|
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
Returns:
|
|
|
|
|
Configured chat agent
|
|
|
|
|
"""
|
|
|
|
|
logger.info("Creating chat agent with document retrieval capabilities")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 23:13:00 +03:00
|
|
|
# Get model information using the utility function
|
2026-02-05 01:07:25 +03:00
|
|
|
strategy, model_name, base_url_or_api_base, api_key, model_type = (
|
|
|
|
|
get_llm_model_info(llm_model)
|
|
|
|
|
)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 23:13:00 +03:00
|
|
|
if strategy == "openai":
|
2026-02-04 22:30:57 +03:00
|
|
|
# Initialize the OpenAI-compatible chat model
|
|
|
|
|
llm = ChatOpenAI(
|
2026-02-04 23:13:00 +03:00
|
|
|
model=model_name,
|
|
|
|
|
openai_api_base=base_url_or_api_base,
|
|
|
|
|
openai_api_key=api_key,
|
2026-02-04 22:30:57 +03:00
|
|
|
temperature=0.1,
|
|
|
|
|
)
|
|
|
|
|
|
2026-02-05 01:07:25 +03:00
|
|
|
logger.info(
|
|
|
|
|
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
|
|
|
|
|
)
|
2026-02-04 22:30:57 +03:00
|
|
|
else: # Default to ollama
|
|
|
|
|
# Initialize the Ollama chat model
|
|
|
|
|
llm = ChatOllama(
|
2026-02-04 23:13:00 +03:00
|
|
|
model=model_name,
|
|
|
|
|
base_url=base_url_or_api_base, # Default Ollama URL
|
2026-02-04 22:30:57 +03:00
|
|
|
temperature=0.1,
|
|
|
|
|
)
|
|
|
|
|
|
2026-02-04 23:13:00 +03:00
|
|
|
logger.info(f"Using Ollama model: {model_name}")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Create the document retrieval tool
|
|
|
|
|
retrieval_tool = DocumentRetrievalTool()
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Create the agent with the LLM and tools
|
|
|
|
|
tools = [retrieval_tool]
|
|
|
|
|
agent = create_react_agent(llm, tools)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
logger.info("Chat agent created successfully")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
return agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_with_agent(
|
|
|
|
|
query: str,
|
|
|
|
|
collection_name: str = "documents_langchain",
|
2026-02-10 13:20:19 +03:00
|
|
|
llm_model: Optional[str] = None,
|
2026-02-05 01:07:25 +03:00
|
|
|
history: List[BaseMessage] = None,
|
2026-02-04 00:02:53 +03:00
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Chat with the agent and get a response based on the query and document retrieval.
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
Args:
|
|
|
|
|
query: The user's query
|
|
|
|
|
collection_name: Name of the Qdrant collection to use
|
2026-02-04 22:30:57 +03:00
|
|
|
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
2026-02-04 00:02:53 +03:00
|
|
|
history: Conversation history (list of messages)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
Returns:
|
|
|
|
|
Dictionary containing the agent's response and metadata
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Starting chat with query: {query}")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Create the agent
|
|
|
|
|
agent = create_chat_agent(collection_name, llm_model)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Prepare the input for the agent
|
|
|
|
|
if history is None:
|
|
|
|
|
history = []
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Add the user's query to the history
|
|
|
|
|
history.append(HumanMessage(content=query))
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Prepare the input for the agent executor
|
2026-02-05 01:07:25 +03:00
|
|
|
agent_input = {"messages": history}
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
try:
|
|
|
|
|
# Invoke the agent
|
|
|
|
|
result = agent.invoke(agent_input)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Extract the agent's response
|
|
|
|
|
messages = result.get("messages", [])
|
|
|
|
|
ai_message = None
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Find the AI message in the results
|
|
|
|
|
for msg in reversed(messages):
|
|
|
|
|
if isinstance(msg, AIMessage):
|
|
|
|
|
ai_message = msg
|
|
|
|
|
break
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
if ai_message is None:
|
|
|
|
|
# If no AI message was found, return the last message content
|
|
|
|
|
if messages:
|
|
|
|
|
last_msg = messages[-1]
|
2026-02-05 01:07:25 +03:00
|
|
|
response_content = getattr(last_msg, "content", str(last_msg))
|
2026-02-04 00:02:53 +03:00
|
|
|
else:
|
|
|
|
|
response_content = "I couldn't generate a response to your query."
|
|
|
|
|
else:
|
|
|
|
|
response_content = ai_message.content
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Create the response dictionary
|
|
|
|
|
response = {
|
|
|
|
|
"response": response_content,
|
|
|
|
|
"query": query,
|
|
|
|
|
"history": messages, # Return updated history
|
2026-02-05 01:07:25 +03:00
|
|
|
"success": True,
|
2026-02-04 00:02:53 +03:00
|
|
|
}
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
logger.info("Chat completed successfully")
|
|
|
|
|
return response
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error during chat: {str(e)}")
|
|
|
|
|
return {
|
|
|
|
|
"response": f"I encountered an error while processing your request: {str(e)}",
|
|
|
|
|
"query": query,
|
|
|
|
|
"history": history,
|
2026-02-05 01:07:25 +03:00
|
|
|
"success": False,
|
2026-02-04 00:02:53 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2026-02-05 01:07:25 +03:00
|
|
|
def run_chat_loop(collection_name: str = "documents_langchain", llm_model: str = None):
|
2026-02-04 00:02:53 +03:00
|
|
|
"""
|
|
|
|
|
Run an interactive chat loop with the agent.
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
Args:
|
|
|
|
|
collection_name: Name of the Qdrant collection to use
|
2026-02-04 22:30:57 +03:00
|
|
|
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
2026-02-04 00:02:53 +03:00
|
|
|
"""
|
|
|
|
|
logger.info("Starting interactive chat loop")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 23:13:00 +03:00
|
|
|
# Get model information using the utility function
|
|
|
|
|
strategy, model_name, _, _, _ = get_llm_model_info(llm_model)
|
|
|
|
|
|
|
|
|
|
if strategy == "openai":
|
|
|
|
|
print(f"Chat Agent initialized with OpenAI-compatible model: {model_name}")
|
2026-02-04 22:30:57 +03:00
|
|
|
else:
|
2026-02-04 23:13:00 +03:00
|
|
|
print(f"Chat Agent initialized with Ollama model: {model_name}")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
|
|
|
|
print("Type 'quit' or 'exit' to end the conversation.\n")
|
|
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
history = []
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
# Get user input
|
|
|
|
|
user_input = input("You: ").strip()
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Check for exit commands
|
2026-02-05 01:07:25 +03:00
|
|
|
if user_input.lower() in ["quit", "exit", "q"]:
|
2026-02-04 00:02:53 +03:00
|
|
|
print("Ending chat session. Goodbye!")
|
|
|
|
|
break
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
if not user_input:
|
|
|
|
|
continue
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Get response from the agent
|
|
|
|
|
response_data = chat_with_agent(
|
|
|
|
|
query=user_input,
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
llm_model=llm_model,
|
2026-02-05 01:07:25 +03:00
|
|
|
history=history,
|
2026-02-04 00:02:53 +03:00
|
|
|
)
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Update history with the new messages
|
|
|
|
|
history = response_data.get("history", [])
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Print the agent's response
|
|
|
|
|
print(f"Agent: {response_data.get('response', 'No response generated')}\n")
|
2026-02-04 22:30:57 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
print("\nEnding chat session. Goodbye!")
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in chat loop: {str(e)}")
|
|
|
|
|
print(f"An error occurred: {str(e)}")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# Example usage
|
|
|
|
|
print("Initializing chat agent...")
|
2026-02-05 01:07:25 +03:00
|
|
|
|
2026-02-04 00:02:53 +03:00
|
|
|
# Run the interactive chat loop
|
2026-02-05 01:07:25 +03:00
|
|
|
run_chat_loop()
|