proper usage of embedding models if defined in .env
This commit is contained in:
@@ -1,16 +1,17 @@
|
|||||||
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_core.tools import BaseTool, tool
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
|
||||||
from langchain_core.agents import AgentFinish
|
from langchain_core.agents import AgentFinish
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langgraph.prebuilt import create_react_agent
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from retrieval import create_retriever
|
from retrieval import create_retriever
|
||||||
@@ -39,20 +40,36 @@ def get_llm_model_info(llm_model: str = None) -> Tuple[str, str, str, str, str]:
|
|||||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
||||||
|
|
||||||
if not openai_chat_url or not openai_chat_key:
|
if not openai_chat_url or not openai_chat_key:
|
||||||
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
|
raise ValueError(
|
||||||
|
"OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy"
|
||||||
|
)
|
||||||
|
|
||||||
# Get the model name from environment if not provided
|
# Get the model name from environment if not provided
|
||||||
if llm_model is None:
|
if llm_model is None:
|
||||||
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
|
llm_model = os.getenv(
|
||||||
|
"OPENAI_CHAT_MODEL", "PREDEFINED_EXTERNAL_MODEL"
|
||||||
|
) # Default to a common model
|
||||||
|
|
||||||
return chat_model_strategy, llm_model, openai_chat_url, openai_chat_key, "ChatOpenAI"
|
return (
|
||||||
|
chat_model_strategy,
|
||||||
|
llm_model,
|
||||||
|
openai_chat_url,
|
||||||
|
openai_chat_key,
|
||||||
|
"ChatOpenAI",
|
||||||
|
)
|
||||||
else: # Default to ollama
|
else: # Default to ollama
|
||||||
# Use Ollama
|
# Use Ollama
|
||||||
# Get the model name from environment if not provided
|
# Get the model name from environment if not provided
|
||||||
if llm_model is None:
|
if llm_model is None:
|
||||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||||
|
|
||||||
return chat_model_strategy, llm_model, "http://localhost:11434", "", "ChatOllama"
|
return (
|
||||||
|
chat_model_strategy,
|
||||||
|
llm_model,
|
||||||
|
"http://localhost:11434",
|
||||||
|
"",
|
||||||
|
"ChatOllama",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentRetrievalTool(BaseTool):
|
class DocumentRetrievalTool(BaseTool):
|
||||||
@@ -84,7 +101,7 @@ class DocumentRetrievalTool(BaseTool):
|
|||||||
metadata = doc.metadata
|
metadata = doc.metadata
|
||||||
|
|
||||||
formatted_doc = (
|
formatted_doc = (
|
||||||
f"Document {i+1}:\n"
|
f"Document {i + 1}:\n"
|
||||||
f"Source: {metadata.get('source', 'Unknown')}\n"
|
f"Source: {metadata.get('source', 'Unknown')}\n"
|
||||||
f"Filename: {metadata.get('filename', 'Unknown')}\n"
|
f"Filename: {metadata.get('filename', 'Unknown')}\n"
|
||||||
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
|
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
|
||||||
@@ -104,8 +121,7 @@ class DocumentRetrievalTool(BaseTool):
|
|||||||
|
|
||||||
|
|
||||||
def create_chat_agent(
|
def create_chat_agent(
|
||||||
collection_name: str = "documents_langchain",
|
collection_name: str = "documents_langchain", llm_model: str = None
|
||||||
llm_model: str = None
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a chat agent with document retrieval capabilities.
|
Create a chat agent with document retrieval capabilities.
|
||||||
@@ -120,7 +136,9 @@ def create_chat_agent(
|
|||||||
logger.info("Creating chat agent with document retrieval capabilities")
|
logger.info("Creating chat agent with document retrieval capabilities")
|
||||||
|
|
||||||
# Get model information using the utility function
|
# Get model information using the utility function
|
||||||
strategy, model_name, base_url_or_api_base, api_key, model_type = get_llm_model_info(llm_model)
|
strategy, model_name, base_url_or_api_base, api_key, model_type = (
|
||||||
|
get_llm_model_info(llm_model)
|
||||||
|
)
|
||||||
|
|
||||||
if strategy == "openai":
|
if strategy == "openai":
|
||||||
# Initialize the OpenAI-compatible chat model
|
# Initialize the OpenAI-compatible chat model
|
||||||
@@ -131,7 +149,9 @@ def create_chat_agent(
|
|||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}")
|
logger.info(
|
||||||
|
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
|
||||||
|
)
|
||||||
else: # Default to ollama
|
else: # Default to ollama
|
||||||
# Initialize the Ollama chat model
|
# Initialize the Ollama chat model
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
@@ -158,7 +178,7 @@ def chat_with_agent(
|
|||||||
query: str,
|
query: str,
|
||||||
collection_name: str = "documents_langchain",
|
collection_name: str = "documents_langchain",
|
||||||
llm_model: str = None,
|
llm_model: str = None,
|
||||||
history: List[BaseMessage] = None
|
history: List[BaseMessage] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Chat with the agent and get a response based on the query and document retrieval.
|
Chat with the agent and get a response based on the query and document retrieval.
|
||||||
@@ -185,9 +205,7 @@ def chat_with_agent(
|
|||||||
history.append(HumanMessage(content=query))
|
history.append(HumanMessage(content=query))
|
||||||
|
|
||||||
# Prepare the input for the agent executor
|
# Prepare the input for the agent executor
|
||||||
agent_input = {
|
agent_input = {"messages": history}
|
||||||
"messages": history
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Invoke the agent
|
# Invoke the agent
|
||||||
@@ -207,7 +225,7 @@ def chat_with_agent(
|
|||||||
# If no AI message was found, return the last message content
|
# If no AI message was found, return the last message content
|
||||||
if messages:
|
if messages:
|
||||||
last_msg = messages[-1]
|
last_msg = messages[-1]
|
||||||
response_content = getattr(last_msg, 'content', str(last_msg))
|
response_content = getattr(last_msg, "content", str(last_msg))
|
||||||
else:
|
else:
|
||||||
response_content = "I couldn't generate a response to your query."
|
response_content = "I couldn't generate a response to your query."
|
||||||
else:
|
else:
|
||||||
@@ -218,7 +236,7 @@ def chat_with_agent(
|
|||||||
"response": response_content,
|
"response": response_content,
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": messages, # Return updated history
|
"history": messages, # Return updated history
|
||||||
"success": True
|
"success": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Chat completed successfully")
|
logger.info("Chat completed successfully")
|
||||||
@@ -230,14 +248,11 @@ def chat_with_agent(
|
|||||||
"response": f"I encountered an error while processing your request: {str(e)}",
|
"response": f"I encountered an error while processing your request: {str(e)}",
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
"success": False
|
"success": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_chat_loop(
|
def run_chat_loop(collection_name: str = "documents_langchain", llm_model: str = None):
|
||||||
collection_name: str = "documents_langchain",
|
|
||||||
llm_model: str = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Run an interactive chat loop with the agent.
|
Run an interactive chat loop with the agent.
|
||||||
|
|
||||||
@@ -265,7 +280,7 @@ def run_chat_loop(
|
|||||||
user_input = input("You: ").strip()
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
# Check for exit commands
|
# Check for exit commands
|
||||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
print("Ending chat session. Goodbye!")
|
print("Ending chat session. Goodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -277,7 +292,7 @@ def run_chat_loop(
|
|||||||
query=user_input,
|
query=user_input,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
history=history
|
history=history,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update history with the new messages
|
# Update history with the new messages
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_qdrant import QdrantVectorStore
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_ollama import OllamaEmbeddings
|
from langchain_ollama import OllamaEmbeddings
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
from langchain_qdrant import QdrantVectorStore
|
||||||
|
from loguru import logger
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
@@ -43,7 +44,9 @@ def initialize_vector_store(
|
|||||||
if EMBEDDING_STRATEGY == "openai":
|
if EMBEDDING_STRATEGY == "openai":
|
||||||
# Validate required OpenAI embedding variables
|
# Validate required OpenAI embedding variables
|
||||||
if not OPENAI_EMBEDDING_API_KEY or not OPENAI_EMBEDDING_BASE_URL:
|
if not OPENAI_EMBEDDING_API_KEY or not OPENAI_EMBEDDING_BASE_URL:
|
||||||
raise ValueError("OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy")
|
raise ValueError(
|
||||||
|
"OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize OpenAI embeddings
|
# Initialize OpenAI embeddings
|
||||||
embeddings = OpenAIEmbeddings(
|
embeddings = OpenAIEmbeddings(
|
||||||
@@ -51,6 +54,10 @@ def initialize_vector_store(
|
|||||||
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
|
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
|
||||||
openai_api_key=OPENAI_EMBEDDING_API_KEY,
|
openai_api_key=OPENAI_EMBEDDING_API_KEY,
|
||||||
)
|
)
|
||||||
|
elif EMBEDDING_STRATEGY == "none":
|
||||||
|
embeddings = None
|
||||||
|
|
||||||
|
logger.warning("Embedding strategy for vector storage is NONE! FYI")
|
||||||
else: # Default to ollama
|
else: # Default to ollama
|
||||||
# Initialize Ollama embeddings
|
# Initialize Ollama embeddings
|
||||||
embeddings = OllamaEmbeddings(
|
embeddings = OllamaEmbeddings(
|
||||||
@@ -118,7 +125,9 @@ def add_documents_to_vector_store(
|
|||||||
vector_store.add_documents(batch)
|
vector_store.add_documents(batch)
|
||||||
|
|
||||||
|
|
||||||
def search_vector_store(vector_store: QdrantVectorStore, query: str, top_k: int = 5) -> list:
|
def search_vector_store(
|
||||||
|
vector_store: QdrantVectorStore, query: str, top_k: int = 5
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Search the vector store for similar documents.
|
Search the vector store for similar documents.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user