Working demo.html with connection to the api endpoint
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Agent module for the RAG solution with Ollama-powered chat agent."""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
@@ -16,6 +16,41 @@ from retrieval import create_retriever
|
||||
from vector_storage import initialize_vector_store
|
||||
|
||||
|
||||
def get_llm_model_info(llm_model: str = None) -> Tuple[str, str, str, str, str]:
|
||||
"""
|
||||
Get LLM model information based on environment configuration.
|
||||
|
||||
Args:
|
||||
llm_model: Name of the model to use (defaults to environment variable based on strategy)
|
||||
|
||||
Returns:
|
||||
Tuple containing (strategy, model_name, base_url_or_api_base, api_key, model_type)
|
||||
"""
|
||||
# Determine which model strategy to use
|
||||
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
||||
|
||||
if chat_model_strategy == "openai":
|
||||
# Use OpenAI-compatible API
|
||||
openai_chat_url = os.getenv("OPENAI_CHAT_URL")
|
||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
||||
|
||||
if not openai_chat_url or not openai_chat_key:
|
||||
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
|
||||
|
||||
return chat_model_strategy, llm_model, openai_chat_url, openai_chat_key, "ChatOpenAI"
|
||||
else: # Default to ollama
|
||||
# Use Ollama
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
return chat_model_strategy, llm_model, "http://localhost:11434", "", "ChatOllama"
|
||||
|
||||
|
||||
class DocumentRetrievalTool(BaseTool):
|
||||
"""Tool for retrieving documents from the vector store based on a query."""
|
||||
|
||||
@@ -80,44 +115,28 @@ def create_chat_agent(
|
||||
"""
|
||||
logger.info("Creating chat agent with document retrieval capabilities")
|
||||
|
||||
# Determine which model strategy to use
|
||||
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
||||
|
||||
if chat_model_strategy == "openai":
|
||||
# Use OpenAI-compatible API
|
||||
openai_chat_url = os.getenv("OPENAI_CHAT_URL")
|
||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY")
|
||||
|
||||
if not openai_chat_url or not openai_chat_key:
|
||||
raise ValueError("OPENAI_CHAT_URL and OPENAI_CHAT_KEY must be set when using OpenAI strategy")
|
||||
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo") # Default to a common model
|
||||
# Get model information using the utility function
|
||||
strategy, model_name, base_url_or_api_base, api_key, model_type = get_llm_model_info(llm_model)
|
||||
|
||||
if strategy == "openai":
|
||||
# Initialize the OpenAI-compatible chat model
|
||||
llm = ChatOpenAI(
|
||||
model=llm_model,
|
||||
openai_api_base=openai_chat_url,
|
||||
openai_api_key=openai_chat_key,
|
||||
model=model_name,
|
||||
openai_api_base=base_url_or_api_base,
|
||||
openai_api_key=api_key,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
logger.info(f"Using OpenAI-compatible model: {llm_model} via {openai_chat_url}")
|
||||
logger.info(f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}")
|
||||
else: # Default to ollama
|
||||
# Use Ollama
|
||||
# Get the model name from environment if not provided
|
||||
if llm_model is None:
|
||||
llm_model = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
|
||||
# Initialize the Ollama chat model
|
||||
llm = ChatOllama(
|
||||
model=llm_model,
|
||||
base_url="http://localhost:11434", # Default Ollama URL
|
||||
model=model_name,
|
||||
base_url=base_url_or_api_base, # Default Ollama URL
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
logger.info(f"Using Ollama model: {llm_model}")
|
||||
logger.info(f"Using Ollama model: {model_name}")
|
||||
|
||||
# Create the document retrieval tool
|
||||
retrieval_tool = DocumentRetrievalTool()
|
||||
@@ -224,14 +243,13 @@ def run_chat_loop(
|
||||
"""
|
||||
logger.info("Starting interactive chat loop")
|
||||
|
||||
# Determine which model strategy is being used and inform the user
|
||||
chat_model_strategy = os.getenv("CHAT_MODEL_STRATEGY", "ollama").lower()
|
||||
if chat_model_strategy == "openai":
|
||||
model_info = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
|
||||
print(f"Chat Agent initialized with OpenAI-compatible model: {model_info}")
|
||||
# Get model information using the utility function
|
||||
strategy, model_name, _, _, _ = get_llm_model_info(llm_model)
|
||||
|
||||
if strategy == "openai":
|
||||
print(f"Chat Agent initialized with OpenAI-compatible model: {model_name}")
|
||||
else:
|
||||
model_info = os.getenv("OLLAMA_CHAT_MODEL", "llama3.1")
|
||||
print(f"Chat Agent initialized with Ollama model: {model_info}")
|
||||
print(f"Chat Agent initialized with Ollama model: {model_name}")
|
||||
|
||||
print("Type 'quit' or 'exit' to end the conversation.\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user