Working chat with AI agent with retrieving data
This commit is contained in:
@@ -36,6 +36,6 @@ Chosen data folder: relatve ./../../../data - from the current folder
|
||||
|
||||
# Phase 6 (chat feature, as agent, for usage in the cli)
|
||||
|
||||
- [ ] Create file `agent.py`, which will incorporate into itself agent, powered by the chat model. It should use integration with ollama, model specified in .env in property: OLLAMA_CHAT_MODEL
|
||||
- [ ] Integrate this agent with the existing solution for retrieving, with retrieval.py
|
||||
- [ ] Integrate this agent with the cli, as command to start chatting with the agent. If there is a built-in solution for console communication with the agent, initiate this on cli command.
|
||||
- [x] Create file `agent.py`, which will incorporate into itself agent, powered by the chat model. It should use integration with ollama, model specified in .env in property: OLLAMA_CHAT_MODEL
|
||||
- [x] Integrate this agent with the existing solution for retrieving, with retrieval.py
|
||||
- [x] Integrate this agent with the cli, as command to start chatting with the agent. If there is a built-in solution for console communication with the agent, initiate this on cli command.
|
||||
|
||||
@@ -75,9 +75,9 @@ The project is organized into 6 development phases as outlined in `PLANNING.md`:
|
||||
- [x] Implement metadata retrieval (filename, page, section, etc.)
|
||||
|
||||
### Phase 6: Chat Agent
|
||||
- [ ] Create `agent.py` with Ollama-powered chat agent
|
||||
- [ ] Integrate with retrieval functionality
|
||||
- [ ] Add CLI command for chat interaction
|
||||
- [x] Create `agent.py` with Ollama-powered chat agent
|
||||
- [x] Integrate with retrieval functionality
|
||||
- [x] Add CLI command for chat interaction
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
@@ -169,6 +169,18 @@ The project is in early development phase. The virtual environment is set up and
|
||||
- Retrieval returns documents with metadata including source, filename, page number, file extension, etc.
|
||||
- Used QdrantVectorStore from langchain-qdrant package for compatibility with newer LangChain versions
|
||||
|
||||
### Phase 6 Implementation Notes
|
||||
- Created `agent.py` module with Ollama-powered chat agent using LangGraph
|
||||
- Integrated the agent with retrieval functionality to provide context-aware responses
|
||||
- Added CLI command `chat` for interactive conversation with the RAG agent
|
||||
- Agent uses document retrieval tool to fetch relevant information based on user queries
|
||||
- Implemented proper error handling and conversation history management
|
||||
|
||||
### Issue Fix Notes
|
||||
- Fixed DocumentRetrievalTool class to properly declare and initialize the retriever field
|
||||
- Resolved Pydantic field declaration issue that caused "object has no field" error
|
||||
- Ensured proper initialization sequence for the retriever within the tool class
|
||||
|
||||
### Troubleshooting Notes
|
||||
- If encountering "No module named 'unstructured_inference'" error, install unstructured-inference
|
||||
- If seeing OCR-related errors, ensure tesseract is installed at the system level and unstructured-pytesseract is available
|
||||
|
||||
242
services/rag/langchain/agent.py
Normal file
242
services/rag/langchain/agent.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.agents import AgentFinish
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from loguru import logger
|
||||
|
||||
from retrieval import create_retriever
|
||||
from vector_storage import initialize_vector_store
|
||||
|
||||
|
||||
class DocumentRetrievalTool(BaseTool):
|
||||
"""Tool for retrieving documents from the vector store based on a query."""
|
||||
|
||||
name: str = "document_retrieval"
|
||||
description: str = "Retrieve documents from the vector store based on a query. Input should be a search query string."
|
||||
# Add retriever as a field
|
||||
retriever: object = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Initialize the retriever before calling super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.retriever = create_retriever()
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
"""Execute the document retrieval."""
|
||||
try:
|
||||
# Use the retriever to get relevant documents
|
||||
results = self.retriever.invoke(query)
|
||||
|
||||
if not results:
|
||||
return "No relevant documents found for the query."
|
||||
|
||||
# Format the results to return to the agent
|
||||
formatted_results = []
|
||||
for i, doc in enumerate(results):
|
||||
content_preview = doc.page_content[:500] # Limit content preview
|
||||
metadata = doc.metadata
|
||||
|
||||
formatted_doc = (
|
||||
f"Document {i+1}:\n"
|
||||
f"Source: {metadata.get('source', 'Unknown')}\n"
|
||||
f"Filename: {metadata.get('filename', 'Unknown')}\n"
|
||||
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
|
||||
f"Content: {content_preview}...\n\n"
|
||||
)
|
||||
formatted_results.append(formatted_doc)
|
||||
|
||||
return "".join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during document retrieval: {str(e)}")
|
||||
return f"Error during document retrieval: {str(e)}"
|
||||
|
||||
async def _arun(self, query: str):
|
||||
"""Async version of the document retrieval."""
|
||||
return self._run(query)
|
||||
|
||||
|
||||
def create_chat_agent(
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a chat agent with document retrieval capabilities.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use (defaults to OLLAMA_CHAT_MODEL env var)
|
||||
|
||||
Returns:
|
||||
Configured chat agent
|
||||
"""
|
||||
logger.info("Creating chat agent with document retrieval capabilities")
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
# Initialize the Ollama chat model
|
||||
llm = ChatOllama(
|
||||
model=llm_model,
|
||||
base_url="http://localhost:11434", # Default Ollama URL
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Create the document retrieval tool
|
||||
retrieval_tool = DocumentRetrievalTool()
|
||||
|
||||
# Create the agent with the LLM and tools
|
||||
tools = [retrieval_tool]
|
||||
agent = create_react_agent(llm, tools)
|
||||
|
||||
logger.info("Chat agent created successfully")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def chat_with_agent(
|
||||
query: str,
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None,
|
||||
history: List[BaseMessage] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Chat with the agent and get a response based on the query and document retrieval.
|
||||
|
||||
Args:
|
||||
query: The user's query
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use
|
||||
history: Conversation history (list of messages)
|
||||
|
||||
Returns:
|
||||
Dictionary containing the agent's response and metadata
|
||||
"""
|
||||
logger.info(f"Starting chat with query: {query}")
|
||||
|
||||
# Create the agent
|
||||
agent = create_chat_agent(collection_name, llm_model)
|
||||
|
||||
# Prepare the input for the agent
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
# Add the user's query to the history
|
||||
history.append(HumanMessage(content=query))
|
||||
|
||||
# Prepare the input for the agent executor
|
||||
agent_input = {
|
||||
"messages": history
|
||||
}
|
||||
|
||||
try:
|
||||
# Invoke the agent
|
||||
result = agent.invoke(agent_input)
|
||||
|
||||
# Extract the agent's response
|
||||
messages = result.get("messages", [])
|
||||
ai_message = None
|
||||
|
||||
# Find the AI message in the results
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
ai_message = msg
|
||||
break
|
||||
|
||||
if ai_message is None:
|
||||
# If no AI message was found, return the last message content
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
response_content = getattr(last_msg, 'content', str(last_msg))
|
||||
else:
|
||||
response_content = "I couldn't generate a response to your query."
|
||||
else:
|
||||
response_content = ai_message.content
|
||||
|
||||
# Create the response dictionary
|
||||
response = {
|
||||
"response": response_content,
|
||||
"query": query,
|
||||
"history": messages, # Return updated history
|
||||
"success": True
|
||||
}
|
||||
|
||||
logger.info("Chat completed successfully")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during chat: {str(e)}")
|
||||
return {
|
||||
"response": f"I encountered an error while processing your request: {str(e)}",
|
||||
"query": query,
|
||||
"history": history,
|
||||
"success": False
|
||||
}
|
||||
|
||||
|
||||
def run_chat_loop(
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None
|
||||
):
|
||||
"""
|
||||
Run an interactive chat loop with the agent.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use
|
||||
"""
|
||||
logger.info("Starting interactive chat loop")
|
||||
print("Chat Agent initialized. Type 'quit' or 'exit' to end the conversation.\n")
|
||||
|
||||
history = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
# Check for exit commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("Ending chat session. Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Get response from the agent
|
||||
response_data = chat_with_agent(
|
||||
query=user_input,
|
||||
collection_name=collection_name,
|
||||
llm_model=llm_model,
|
||||
history=history
|
||||
)
|
||||
|
||||
# Update history with the new messages
|
||||
history = response_data.get("history", [])
|
||||
|
||||
# Print the agent's response
|
||||
print(f"Agent: {response_data.get('response', 'No response generated')}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nEnding chat session. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat loop: {str(e)}")
|
||||
print(f"An error occurred: {str(e)}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
print("Initializing chat agent...")
|
||||
|
||||
# Run the interactive chat loop
|
||||
run_chat_loop()
|
||||
@@ -113,5 +113,43 @@ def retrieve(query, collection_name, top_k):
|
||||
click.echo(f"Error: {str(e)}")
|
||||
|
||||
|
||||
@cli.command(
|
||||
name="chat",
|
||||
help="Start an interactive chat session with the RAG agent",
|
||||
)
|
||||
@click.option(
|
||||
"--collection-name",
|
||||
default="documents_langchain",
|
||||
help="Name of the vector store collection",
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
default=None,
|
||||
help="Name of the Ollama model to use for chat",
|
||||
)
|
||||
def chat(collection_name, model):
|
||||
"""Start an interactive chat session with the RAG agent"""
|
||||
logger.info("Starting chat session with RAG agent")
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependencies and only when needed
|
||||
from agent import run_chat_loop
|
||||
|
||||
click.echo("Initializing chat agent...")
|
||||
click.echo("Type 'quit' or 'exit' to end the conversation.\n")
|
||||
|
||||
# Run the interactive chat loop
|
||||
run_chat_loop(
|
||||
collection_name=collection_name,
|
||||
llm_model=model
|
||||
)
|
||||
|
||||
logger.info("Chat session ended")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during chat session: {str(e)}")
|
||||
click.echo(f"Error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
Reference in New Issue
Block a user