Files
rag-solution/services/rag/llamaindex/config.py

145 lines
4.2 KiB
Python

"""
Configuration module for managing model strategies in the RAG solution.
This module provides functions to get appropriate model configurations
based on environment variables for both embeddings and chat models.
"""
import os
from dotenv import load_dotenv
from loguru import logger
# Load environment variables from .env file
load_dotenv()
def get_embedding_model():
"""
Get the appropriate embedding model based on the EMBEDDING_STRATEGY environment variable.
Returns:
An embedding model instance based on the selected strategy
"""
strategy = os.getenv("EMBEDDING_STRATEGY", "ollama").lower()
if strategy == "ollama":
from llama_index.embeddings.ollama import OllamaEmbedding
ollama_embed_model = os.getenv("OLLAMA_EMBEDDING_MODEL", "qwen3-embedding:4b")
ollama_base_url = "http://localhost:11434"
logger.info(f"Initializing Ollama embedding model: {ollama_embed_model}")
embed_model = OllamaEmbedding(
model_name=ollama_embed_model, base_url=ollama_base_url
)
return embed_model
elif strategy == "openai":
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
openai_base_url = os.getenv(
"OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1"
)
openai_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY", "dummy_key_for_template")
openai_embed_model = os.getenv(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
)
# Set the API key in environment for OpenAI
os.environ["OPENAI_API_KEY"] = openai_api_key
logger.info(f"Initializing OpenAI embedding model: {openai_embed_model}")
embed_model = OpenAILikeEmbedding(
model_name=openai_embed_model,
api_base=openai_base_url,
api_key=openai_api_key,
)
return embed_model
else:
raise ValueError(
f"Unsupported EMBEDDING_STRATEGY: {strategy}. Supported values are 'ollama' and 'openai'"
)
def get_llm_model():
"""
Get the appropriate LLM model based on the CHAT_STRATEGY environment variable.
Returns:
An LLM model instance based on the selected strategy
"""
strategy = os.getenv("CHAT_STRATEGY", "ollama").lower()
if strategy == "ollama":
from llama_index.llms.ollama import Ollama
ollama_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "nemotron-mini:4b")
ollama_base_url = "http://localhost:11434"
logger.info(f"Initializing Ollama chat model: {ollama_chat_model}")
llm = Ollama(
model=ollama_chat_model,
base_url=ollama_base_url,
request_timeout=120.0, # Increase timeout for longer responses
)
return llm
elif strategy == "openai":
from llama_index.llms.openai import OpenAI
openai_chat_url = os.getenv("OPENAI_CHAT_URL", "https://api.openai.com/v1")
openai_chat_key = os.getenv("OPENAI_CHAT_KEY", "dummy_key_for_template")
openai_chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
# Set the API key in environment for OpenAI
os.environ["OPENAI_API_KEY"] = openai_chat_key
logger.info(f"Initializing OpenAI chat model: {openai_chat_model}")
llm = OpenAI(model=openai_chat_model, api_base=openai_chat_url)
return llm
else:
raise ValueError(
f"Unsupported CHAT_STRATEGY: {strategy}. Supported values are 'ollama' and 'openai'"
)
def get_model_configurations():
"""
Get both embedding and LLM model configurations based on environment variables.
Returns:
A tuple of (embedding_model, llm_model)
"""
embed_model = get_embedding_model()
llm_model = get_llm_model()
return embed_model, llm_model
def setup_global_models():
"""
Set up the global models in LlamaIndex Settings to prevent defaulting to OpenAI.
"""
from llama_index.core import Settings
embed_model, llm_model = get_model_configurations()
# Set as the global embedding model
Settings.embed_model = embed_model
# Set as the global LLM
Settings.llm = llm_model
logger.info("Global models configured successfully based on environment variables")