proper usage of embedding models if defined in .env

This commit is contained in:
2026-02-05 01:07:25 +03:00
parent 31d198afb8
commit effbc7d00f
2 changed files with 57 additions and 33 deletions

View File

@@ -1,16 +1,17 @@
"""Agent module for the RAG solution with Ollama-powered chat agent."""
import os
from typing import List, Dict, Any, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from dotenv import load_dotenv
from langchain_core.tools import BaseTool, tool
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.agents import AgentFinish
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool, tool
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.prebuilt import create_react_agent
from loguru import logger
from retrieval import create_retriever
@@ -39,20 +40,36 @@ def get_llm_model_info(llm_model: str = None) -> Tuple[str, str, str, str, str]:
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
if not openai_chat_url or not openai_chat_key:
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
raise ValueError(
"OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy"
)
# Get the model name from environment if not provided
if llm_model is None:
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
llm_model = os.getenv(
"OPENAI_CHAT_MODEL", "PREDEFINED_EXTERNAL_MODEL"
) # Default to a common model
return chat_model_strategy, llm_model, openai_chat_url, openai_chat_key, "ChatOpenAI"
return (
chat_model_strategy,
llm_model,
openai_chat_url,
openai_chat_key,
"ChatOpenAI",
)
else: # Default to ollama
# Use Ollama
# Get the model name from environment if not provided
if llm_model is None:
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
return chat_model_strategy, llm_model, "http://localhost:11434", "", "ChatOllama"
return (
chat_model_strategy,
llm_model,
"http://localhost:11434",
"",
"ChatOllama",
)
class DocumentRetrievalTool(BaseTool):
@@ -84,7 +101,7 @@ class DocumentRetrievalTool(BaseTool):
metadata = doc.metadata
formatted_doc = (
f"Document {i+1}:\n"
f"Document {i + 1}:\n"
f"Source: {metadata.get('source', 'Unknown')}\n"
f"Filename: {metadata.get('filename', 'Unknown')}\n"
f"Page: {metadata.get('page_number', metadata.get('page', 'N/A'))}\n"
@@ -104,8 +121,7 @@ class DocumentRetrievalTool(BaseTool):
def create_chat_agent(
collection_name: str = "documents_langchain",
llm_model: str = None
collection_name: str = "documents_langchain", llm_model: str = None
) -> Any:
"""
Create a chat agent with document retrieval capabilities.
@@ -120,7 +136,9 @@ def create_chat_agent(
logger.info("Creating chat agent with document retrieval capabilities")
# Get model information using the utility function
strategy, model_name, base_url_or_api_base, api_key, model_type = get_llm_model_info(llm_model)
strategy, model_name, base_url_or_api_base, api_key, model_type = (
get_llm_model_info(llm_model)
)
if strategy == "openai":
# Initialize the OpenAI-compatible chat model
@@ -131,7 +149,9 @@ def create_chat_agent(
temperature=0.1,
)
logger.info(f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}")
logger.info(
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
)
else: # Default to ollama
# Initialize the Ollama chat model
llm = ChatOllama(
@@ -158,7 +178,7 @@ def chat_with_agent(
query: str,
collection_name: str = "documents_langchain",
llm_model: str = None,
history: List[BaseMessage] = None
history: List[BaseMessage] = None,
) -> Dict[str, Any]:
"""
Chat with the agent and get a response based on the query and document retrieval.
@@ -185,9 +205,7 @@ def chat_with_agent(
history.append(HumanMessage(content=query))
# Prepare the input for the agent executor
agent_input = {
"messages": history
}
agent_input = {"messages": history}
try:
# Invoke the agent
@@ -207,7 +225,7 @@ def chat_with_agent(
# If no AI message was found, return the last message content
if messages:
last_msg = messages[-1]
response_content = getattr(last_msg, 'content', str(last_msg))
response_content = getattr(last_msg, "content", str(last_msg))
else:
response_content = "I couldn't generate a response to your query."
else:
@@ -218,7 +236,7 @@ def chat_with_agent(
"response": response_content,
"query": query,
"history": messages, # Return updated history
"success": True
"success": True,
}
logger.info("Chat completed successfully")
@@ -230,14 +248,11 @@ def chat_with_agent(
"response": f"I encountered an error while processing your request: {str(e)}",
"query": query,
"history": history,
"success": False
"success": False,
}
def run_chat_loop(
collection_name: str = "documents_langchain",
llm_model: str = None
):
def run_chat_loop(collection_name: str = "documents_langchain", llm_model: str = None):
"""
Run an interactive chat loop with the agent.
@@ -265,7 +280,7 @@ def run_chat_loop(
user_input = input("You: ").strip()
# Check for exit commands
if user_input.lower() in ['quit', 'exit', 'q']:
if user_input.lower() in ["quit", "exit", "q"]:
print("Ending chat session. Goodbye!")
break
@@ -277,7 +292,7 @@ def run_chat_loop(
query=user_input,
collection_name=collection_name,
llm_model=llm_model,
history=history
history=history,
)
# Update history with the new messages
@@ -298,6 +313,6 @@ def run_chat_loop(
if __name__ == "__main__":
# Example usage
print("Initializing chat agent...")
# Run the interactive chat loop
run_chat_loop()
run_chat_loop()

View File

@@ -4,10 +4,11 @@ import os
from typing import Optional
from dotenv import load_dotenv
from langchain_qdrant import QdrantVectorStore
from langchain_core.documents import Document
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from loguru import logger
from qdrant_client import QdrantClient
# Load environment variables
@@ -43,7 +44,9 @@ def initialize_vector_store(
if EMBEDDING_STRATEGY == "openai":
# Validate required OpenAI embedding variables
if not OPENAI_EMBEDDING_API_KEY or not OPENAI_EMBEDDING_BASE_URL:
raise ValueError("OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy")
raise ValueError(
"OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy"
)
# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(
@@ -51,6 +54,10 @@ def initialize_vector_store(
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
openai_api_key=OPENAI_EMBEDDING_API_KEY,
)
elif EMBEDDING_STRATEGY == "none":
embeddings = None
logger.warning("Embedding strategy for vector storage is NONE! FYI")
else: # Default to ollama
# Initialize Ollama embeddings
embeddings = OllamaEmbeddings(
@@ -118,7 +125,9 @@ def add_documents_to_vector_store(
vector_store.add_documents(batch)
def search_vector_store(vector_store: QdrantVectorStore, query: str, top_k: int = 5) -> list:
def search_vector_store(
vector_store: QdrantVectorStore, query: str, top_k: int = 5
) -> list:
"""
Search the vector store for similar documents.