proper usage of embedding models if defined in .env
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.agents import AgentFinish
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from loguru import logger
|
||||
|
||||
from retrieval import create_retriever
|
||||
@@ -39,20 +40,36 @@ def get_llm_model_info(llm_model: str = None) -> Tuple[str, str, str, str, str]:
|
||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
||||
|
||||
if not openai_chat_url or not openai_chat_key:
|
||||
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
|
||||
raise ValueError(
|
||||
"OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy"
|
||||
)
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
|
||||
llm_model = os.getenv(
|
||||
"OPENAI_CHAT_MODEL", "PREDEFINED_EXTERNAL_MODEL"
|
||||
) # Default to a common model
|
||||
|
||||
return chat_model_strategy, llm_model, openai_chat_url, openai_chat_key, "ChatOpenAI"
|
||||
return (
|
||||
chat_model_strategy,
|
||||
llm_model,
|
||||
openai_chat_url,
|
||||
openai_chat_key,
|
||||
"ChatOpenAI",
|
||||
)
|
||||
else: # Default to ollama
|
||||
# Use Ollama
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
return chat_model_strategy, llm_model, "http://localhost:11434", "", "ChatOllama"
|
||||
return (
|
||||
chat_model_strategy,
|
||||
llm_model,
|
||||
"http://localhost:11434",
|
||||
"",
|
||||
"ChatOllama",
|
||||
)
|
||||
|
||||
|
||||
class DocumentRetrievalTool(BaseTool):
|
||||
@@ -104,8 +121,7 @@ class DocumentRetrievalTool(BaseTool):
|
||||
|
||||
|
||||
def create_chat_agent(
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None
|
||||
collection_name: str = "documents_langchain", llm_model: str = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a chat agent with document retrieval capabilities.
|
||||
@@ -120,7 +136,9 @@ def create_chat_agent(
|
||||
logger.info("Creating chat agent with document retrieval capabilities")
|
||||
|
||||
# Get model information using the utility function
|
||||
strategy, model_name, base_url_or_api_base, api_key, model_type = get_llm_model_info(llm_model)
|
||||
strategy, model_name, base_url_or_api_base, api_key, model_type = (
|
||||
get_llm_model_info(llm_model)
|
||||
)
|
||||
|
||||
if strategy == "openai":
|
||||
# Initialize the OpenAI-compatible chat model
|
||||
@@ -131,7 +149,9 @@ def create_chat_agent(
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
logger.info(f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}")
|
||||
logger.info(
|
||||
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
|
||||
)
|
||||
else: # Default to ollama
|
||||
# Initialize the Ollama chat model
|
||||
llm = ChatOllama(
|
||||
@@ -158,7 +178,7 @@ def chat_with_agent(
|
||||
query: str,
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None,
|
||||
history: List[BaseMessage] = None
|
||||
history: List[BaseMessage] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Chat with the agent and get a response based on the query and document retrieval.
|
||||
@@ -185,9 +205,7 @@ def chat_with_agent(
|
||||
history.append(HumanMessage(content=query))
|
||||
|
||||
# Prepare the input for the agent executor
|
||||
agent_input = {
|
||||
"messages": history
|
||||
}
|
||||
agent_input = {"messages": history}
|
||||
|
||||
try:
|
||||
# Invoke the agent
|
||||
@@ -207,7 +225,7 @@ def chat_with_agent(
|
||||
# If no AI message was found, return the last message content
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
response_content = getattr(last_msg, 'content', str(last_msg))
|
||||
response_content = getattr(last_msg, "content", str(last_msg))
|
||||
else:
|
||||
response_content = "I couldn't generate a response to your query."
|
||||
else:
|
||||
@@ -218,7 +236,7 @@ def chat_with_agent(
|
||||
"response": response_content,
|
||||
"query": query,
|
||||
"history": messages, # Return updated history
|
||||
"success": True
|
||||
"success": True,
|
||||
}
|
||||
|
||||
logger.info("Chat completed successfully")
|
||||
@@ -230,14 +248,11 @@ def chat_with_agent(
|
||||
"response": f"I encountered an error while processing your request: {str(e)}",
|
||||
"query": query,
|
||||
"history": history,
|
||||
"success": False
|
||||
"success": False,
|
||||
}
|
||||
|
||||
|
||||
def run_chat_loop(
|
||||
collection_name: str = "documents_langchain",
|
||||
llm_model: str = None
|
||||
):
|
||||
def run_chat_loop(collection_name: str = "documents_langchain", llm_model: str = None):
|
||||
"""
|
||||
Run an interactive chat loop with the agent.
|
||||
|
||||
@@ -265,7 +280,7 @@ def run_chat_loop(
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
# Check for exit commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Ending chat session. Goodbye!")
|
||||
break
|
||||
|
||||
@@ -277,7 +292,7 @@ def run_chat_loop(
|
||||
query=user_input,
|
||||
collection_name=collection_name,
|
||||
llm_model=llm_model,
|
||||
history=history
|
||||
history=history,
|
||||
)
|
||||
|
||||
# Update history with the new messages
|
||||
|
||||
@@ -4,10 +4,11 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from langchain_core.documents import Document
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from loguru import logger
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
# Load environment variables
|
||||
@@ -43,7 +44,9 @@ def initialize_vector_store(
|
||||
if EMBEDDING_STRATEGY == "openai":
|
||||
# Validate required OpenAI embedding variables
|
||||
if not OPENAI_EMBEDDING_API_KEY or not OPENAI_EMBEDDING_BASE_URL:
|
||||
raise ValueError("OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy")
|
||||
raise ValueError(
|
||||
"OPENAI_EMBEDDING_API_KEY and OPENAI_EMBEDDING_BASE_URL must be set when using OpenAI embedding strategy"
|
||||
)
|
||||
|
||||
# Initialize OpenAI embeddings
|
||||
embeddings = OpenAIEmbeddings(
|
||||
@@ -51,6 +54,10 @@ def initialize_vector_store(
|
||||
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
|
||||
openai_api_key=OPENAI_EMBEDDING_API_KEY,
|
||||
)
|
||||
elif EMBEDDING_STRATEGY == "none":
|
||||
embeddings = None
|
||||
|
||||
logger.warning("Embedding strategy for vector storage is NONE! FYI")
|
||||
else: # Default to ollama
|
||||
# Initialize Ollama embeddings
|
||||
embeddings = OllamaEmbeddings(
|
||||
@@ -118,7 +125,9 @@ def add_documents_to_vector_store(
|
||||
vector_store.add_documents(batch)
|
||||
|
||||
|
||||
def search_vector_store(vector_store: QdrantVectorStore, query: str, top_k: int = 5) -> list:
|
||||
def search_vector_store(
|
||||
vector_store: QdrantVectorStore, query: str, top_k: int = 5
|
||||
) -> list:
|
||||
"""
|
||||
Search the vector store for similar documents.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user