openai compatible integration done
This commit is contained in:
@@ -8,6 +8,7 @@ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.agents import AgentFinish
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from loguru import logger
|
||||
|
||||
@@ -69,36 +70,64 @@ def create_chat_agent(
|
||||
) -> Any:
|
||||
"""
|
||||
Create a chat agent with document retrieval capabilities.
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use (defaults to OLLAMA_CHAT_MODEL env var)
|
||||
|
||||
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
||||
|
||||
Returns:
|
||||
Configured chat agent
|
||||
"""
|
||||
logger.info("Creating chat agent with document retrieval capabilities")
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
# Initialize the Ollama chat model
|
||||
llm = ChatOllama(
|
||||
model=llm_model,
|
||||
base_url="http://localhost:11434", # Default Ollama URL
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
|
||||
# Determine which model strategy to use
|
||||
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
||||
|
||||
if chat_model_strategy == "openai":
|
||||
# Use OpenAI-compatible API
|
||||
openai_chat_url = os.getenv("OPENAI_CHAT_URL")
|
||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
||||
|
||||
if not openai_chat_url or not openai_chat_key:
|
||||
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
|
||||
|
||||
# Initialize the OpenAI-compatible chat model
|
||||
llm = ChatOpenAI(
|
||||
model=llm_model,
|
||||
openai_api_base=openai_chat_url,
|
||||
openai_api_key=openai_chat_key,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
logger.info(f"Using OpenAI-compatible model: {llm_model} via {openai_chat_url}")
|
||||
else: # Default to ollama
|
||||
# Use Ollama
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
# Initialize the Ollama chat model
|
||||
llm = ChatOllama(
|
||||
model=llm_model,
|
||||
base_url="http://localhost:11434", # Default Ollama URL
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
logger.info(f"Using Ollama model: {llm_model}")
|
||||
|
||||
# Create the document retrieval tool
|
||||
retrieval_tool = DocumentRetrievalTool()
|
||||
|
||||
|
||||
# Create the agent with the LLM and tools
|
||||
tools = [retrieval_tool]
|
||||
agent = create_react_agent(llm, tools)
|
||||
|
||||
|
||||
logger.info("Chat agent created successfully")
|
||||
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -110,47 +139,47 @@ def chat_with_agent(
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Chat with the agent and get a response based on the query and document retrieval.
|
||||
|
||||
|
||||
Args:
|
||||
query: The user's query
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use
|
||||
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
||||
history: Conversation history (list of messages)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing the agent's response and metadata
|
||||
"""
|
||||
logger.info(f"Starting chat with query: {query}")
|
||||
|
||||
|
||||
# Create the agent
|
||||
agent = create_chat_agent(collection_name, llm_model)
|
||||
|
||||
|
||||
# Prepare the input for the agent
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
|
||||
# Add the user's query to the history
|
||||
history.append(HumanMessage(content=query))
|
||||
|
||||
|
||||
# Prepare the input for the agent executor
|
||||
agent_input = {
|
||||
"messages": history
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Invoke the agent
|
||||
result = agent.invoke(agent_input)
|
||||
|
||||
|
||||
# Extract the agent's response
|
||||
messages = result.get("messages", [])
|
||||
ai_message = None
|
||||
|
||||
|
||||
# Find the AI message in the results
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
ai_message = msg
|
||||
break
|
||||
|
||||
|
||||
if ai_message is None:
|
||||
# If no AI message was found, return the last message content
|
||||
if messages:
|
||||
@@ -160,7 +189,7 @@ def chat_with_agent(
|
||||
response_content = "I couldn't generate a response to your query."
|
||||
else:
|
||||
response_content = ai_message.content
|
||||
|
||||
|
||||
# Create the response dictionary
|
||||
response = {
|
||||
"response": response_content,
|
||||
@@ -168,10 +197,10 @@ def chat_with_agent(
|
||||
"history": messages, # Return updated history
|
||||
"success": True
|
||||
}
|
||||
|
||||
|
||||
logger.info("Chat completed successfully")
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during chat: {str(e)}")
|
||||
return {
|
||||
@@ -188,29 +217,39 @@ def run_chat_loop(
|
||||
):
|
||||
"""
|
||||
Run an interactive chat loop with the agent.
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
llm_model: Name of the Ollama model to use
|
||||
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
||||
"""
|
||||
logger.info("Starting interactive chat loop")
|
||||
print("Chat Agent initialized. Type 'quit' or 'exit' to end the conversation.\n")
|
||||
|
||||
|
||||
# Determine which model strategy is being used and inform the user
|
||||
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
||||
if chat_model_strategy == "openai":
|
||||
model_info = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
|
||||
print(f"Chat Agent initialized with OpenAI-compatible model: {model_info}")
|
||||
else:
|
||||
model_info = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
print(f"Chat Agent initialized with Ollama model: {model_info}")
|
||||
|
||||
print("Type 'quit' or 'exit' to end the conversation.\n")
|
||||
|
||||
history = []
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
|
||||
# Check for exit commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("Ending chat session. Goodbye!")
|
||||
break
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Get response from the agent
|
||||
response_data = chat_with_agent(
|
||||
query=user_input,
|
||||
@@ -218,13 +257,13 @@ def run_chat_loop(
|
||||
llm_model=llm_model,
|
||||
history=history
|
||||
)
|
||||
|
||||
|
||||
# Update history with the new messages
|
||||
history = response_data.get("history", [])
|
||||
|
||||
|
||||
# Print the agent's response
|
||||
print(f"Agent: {response_data.get('response', 'No response generated')}\n")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nEnding chat session. Goodbye!")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user