170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
"""
|
|
Configuration module for managing model strategies in the RAG solution.
|
|
|
|
This module provides functions to get appropriate model configurations
|
|
based on environment variables for both embeddings and chat models.
|
|
"""
|
|
|
|
import os
|
|
|
|
from dotenv import load_dotenv
|
|
from loguru import logger
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
def get_embedding_model():
|
|
"""
|
|
Get the appropriate embedding model based on the EMBEDDING_STRATEGY environment variable.
|
|
|
|
Returns:
|
|
An embedding model instance based on the selected strategy
|
|
"""
|
|
strategy = os.getenv("EMBEDDING_STRATEGY", "ollama").lower()
|
|
|
|
if strategy == "ollama":
|
|
from llama_index.embeddings.ollama import OllamaEmbedding
|
|
|
|
ollama_embed_model = os.getenv("OLLAMA_EMBEDDING_MODEL", "qwen3-embedding:4b")
|
|
ollama_base_url = "http://localhost:11434"
|
|
|
|
logger.info(f"Initializing Ollama embedding model: {ollama_embed_model}")
|
|
|
|
embed_model = OllamaEmbedding(
|
|
model_name=ollama_embed_model, base_url=ollama_base_url
|
|
)
|
|
|
|
return embed_model
|
|
|
|
elif strategy == "openai":
|
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
|
|
openai_base_url = os.getenv(
|
|
"OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1"
|
|
)
|
|
openai_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY", "dummy_key_for_template")
|
|
openai_embed_model = os.getenv(
|
|
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
|
|
)
|
|
|
|
# Set the API key in environment for OpenAI
|
|
os.environ["OPENAI_API_KEY"] = openai_api_key
|
|
|
|
logger.info(f"Initializing OpenAI embedding model: {openai_embed_model}")
|
|
|
|
embed_model = OpenAILikeEmbedding(
|
|
model_name=openai_embed_model,
|
|
api_base=openai_base_url,
|
|
api_key=openai_api_key,
|
|
)
|
|
|
|
return embed_model
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported EMBEDDING_STRATEGY: {strategy}. Supported values are 'ollama' and 'openai'"
|
|
)
|
|
|
|
|
|
def get_llm_model():
|
|
"""
|
|
Get the appropriate LLM model based on the CHAT_STRATEGY environment variable.
|
|
|
|
Returns:
|
|
An LLM model instance based on the selected strategy
|
|
"""
|
|
strategy = os.getenv("CHAT_STRATEGY", "ollama").lower()
|
|
|
|
if strategy == "ollama":
|
|
from llama_index.llms.ollama import Ollama
|
|
|
|
ollama_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "nemotron-mini:4b")
|
|
ollama_base_url = "http://localhost:11434"
|
|
|
|
logger.info(f"Initializing Ollama chat model: {ollama_chat_model}")
|
|
|
|
llm = Ollama(
|
|
model=ollama_chat_model,
|
|
base_url=ollama_base_url,
|
|
request_timeout=120.0, # Increase timeout for longer responses
|
|
)
|
|
|
|
return llm
|
|
|
|
elif strategy == "openai":
|
|
from helpers.openai_compatible_llm import OpenAICompatibleLLM
|
|
|
|
openai_chat_url = os.getenv("OPENAI_CHAT_URL", "https://api.openai.com/v1")
|
|
openai_chat_key = os.getenv("OPENAI_CHAT_KEY", "dummy_key_for_template")
|
|
openai_chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
|
|
openai_chat_temperature = float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.1"))
|
|
openai_chat_max_tokens_env = os.getenv("OPENAI_CHAT_MAX_TOKENS", "").strip()
|
|
openai_chat_max_tokens = (
|
|
int(openai_chat_max_tokens_env) if openai_chat_max_tokens_env else 1024
|
|
)
|
|
openai_reasoning_effort = (
|
|
os.getenv("OPENAI_CHAT_REASONING_EFFORT", "").strip() or None
|
|
)
|
|
openai_is_fc_model = (
|
|
os.getenv("OPENAI_CHAT_IS_FUNCTION_CALLING_MODEL", "false").lower()
|
|
== "true"
|
|
)
|
|
|
|
# Set the API key in environment for OpenAI
|
|
os.environ["OPENAI_API_KEY"] = openai_chat_key
|
|
|
|
logger.info(
|
|
f"Initializing OpenAI-compatible chat model: {openai_chat_model} "
|
|
f"(base={openai_chat_url}, max_tokens={openai_chat_max_tokens}, "
|
|
f"reasoning_effort={openai_reasoning_effort}, function_calling={openai_is_fc_model})"
|
|
)
|
|
|
|
llm = OpenAICompatibleLLM(
|
|
model=openai_chat_model,
|
|
api_base=openai_chat_url,
|
|
api_key=openai_chat_key,
|
|
temperature=openai_chat_temperature,
|
|
max_tokens=openai_chat_max_tokens,
|
|
reasoning_effort=openai_reasoning_effort,
|
|
timeout=120.0,
|
|
is_function_calling_model=openai_is_fc_model,
|
|
)
|
|
|
|
return llm
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported CHAT_STRATEGY: {strategy}. Supported values are 'ollama' and 'openai'"
|
|
)
|
|
|
|
|
|
def get_model_configurations():
|
|
"""
|
|
Get both embedding and LLM model configurations based on environment variables.
|
|
|
|
Returns:
|
|
A tuple of (embedding_model, llm_model)
|
|
"""
|
|
embed_model = get_embedding_model()
|
|
llm_model = get_llm_model()
|
|
|
|
return embed_model, llm_model
|
|
|
|
|
|
def setup_global_models():
|
|
"""
|
|
Set up the global models in LlamaIndex Settings to prevent defaulting to OpenAI.
|
|
"""
|
|
from llama_index.core import Settings
|
|
|
|
embed_model, llm_model = get_model_configurations()
|
|
|
|
# Set as the global embedding model
|
|
Settings.embed_model = embed_model
|
|
|
|
# Set as the global LLM
|
|
Settings.llm = llm_model
|
|
|
|
logger.info("Global models configured successfully based on environment variables")
|