Retrieval and also update on russian language
This commit is contained in:
338
services/rag/llamaindex/retrieval.py
Normal file
338
services/rag/llamaindex/retrieval.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Retrieval module for the RAG solution using LlamaIndex and Qdrant.
|
||||
|
||||
This module provides functionality to retrieve relevant documents
|
||||
from the vector storage based on a query text.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
from llama_index.core import VectorStoreIndex, Settings
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from vector_storage import get_vector_store_and_index
|
||||
|
||||
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
import os
|
||||
|
||||
|
||||
def setup_global_models():
|
||||
"""Set up the global models to prevent defaulting to OpenAI."""
|
||||
# Set up the embedding model
|
||||
ollama_embed_model = os.getenv("OLLAMA_EMBEDDING_MODEL", "qwen3-embedding:4b")
|
||||
ollama_base_url = "http://localhost:11434"
|
||||
|
||||
embed_model = OllamaEmbedding(
|
||||
model_name=ollama_embed_model,
|
||||
base_url=ollama_base_url
|
||||
)
|
||||
|
||||
# Set as the global embedding model
|
||||
Settings.embed_model = embed_model
|
||||
|
||||
# Set up the LLM model
|
||||
ollama_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "nemotron-mini:4b")
|
||||
|
||||
from llama_index.llms.ollama import Ollama
|
||||
llm = Ollama(model=ollama_chat_model, base_url=ollama_base_url)
|
||||
|
||||
# Set as the global LLM
|
||||
Settings.llm = llm
|
||||
|
||||
|
||||
def initialize_retriever(
|
||||
collection_name: str = "documents_llamaindex",
|
||||
similarity_top_k: int = 5,
|
||||
host: str = "localhost",
|
||||
port: int = 6333
|
||||
) -> RetrieverQueryEngine:
|
||||
"""
|
||||
Initialize the retriever query engine with the vector store.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection
|
||||
similarity_top_k: Number of top similar documents to retrieve
|
||||
host: Qdrant host address
|
||||
port: Qdrant REST API port
|
||||
|
||||
Returns:
|
||||
RetrieverQueryEngine configured with the vector store
|
||||
"""
|
||||
logger.info(f"Initializing retriever for collection: {collection_name}")
|
||||
|
||||
try:
|
||||
# Set up the global models to prevent defaulting to OpenAI
|
||||
setup_global_models()
|
||||
|
||||
# Get the vector store and index from the existing configuration
|
||||
vector_store, index = get_vector_store_and_index()
|
||||
|
||||
# Create a retriever from the index
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=similarity_top_k
|
||||
)
|
||||
|
||||
# Create the query engine
|
||||
query_engine = RetrieverQueryEngine(
|
||||
retriever=retriever
|
||||
)
|
||||
|
||||
logger.info("Retriever initialized successfully")
|
||||
return query_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize retriever: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retrieve_documents(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve documents from the vector storage based on the query text.
|
||||
|
||||
Args:
|
||||
query: The query text to search for
|
||||
top_k: Number of top similar documents to retrieve
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document content and metadata
|
||||
"""
|
||||
logger.info(f"Retrieving documents for query: '{query[:50]}...' (top_k={top_k})")
|
||||
|
||||
try:
|
||||
# Initialize the query engine
|
||||
query_engine = initialize_retriever(similarity_top_k=top_k)
|
||||
|
||||
# Perform the query
|
||||
response = query_engine.query(query)
|
||||
|
||||
# Extract documents and their metadata
|
||||
results = []
|
||||
|
||||
# If response is a single text response, we need to get the source nodes
|
||||
if hasattr(response, 'source_nodes'):
|
||||
for node in response.source_nodes:
|
||||
doc_info = {
|
||||
"content": node.text,
|
||||
"metadata": node.metadata,
|
||||
"score": node.score if hasattr(node, 'score') else None
|
||||
}
|
||||
results.append(doc_info)
|
||||
else:
|
||||
# If the response doesn't have source nodes, try to extract text content
|
||||
results.append({
|
||||
"content": str(response),
|
||||
"metadata": {},
|
||||
"score": None
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retrieve_documents_with_query_engine(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Alternative method to retrieve documents using a direct query engine approach.
|
||||
|
||||
Args:
|
||||
query: The query text to search for
|
||||
top_k: Number of top similar documents to retrieve
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document content and metadata
|
||||
"""
|
||||
logger.info(f"Retrieving documents with direct query engine for query: '{query[:50]}...' (top_k={top_k})")
|
||||
|
||||
try:
|
||||
# Set up the global models to prevent defaulting to OpenAI
|
||||
setup_global_models()
|
||||
|
||||
# Get the vector store and index from the existing configuration
|
||||
vector_store, index = get_vector_store_and_index()
|
||||
|
||||
# Create a retriever from the index
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k
|
||||
)
|
||||
|
||||
# Create the query engine
|
||||
query_engine = RetrieverQueryEngine(
|
||||
retriever=retriever
|
||||
)
|
||||
|
||||
# Set the global models again right before the query to ensure they're used
|
||||
setup_global_models()
|
||||
|
||||
# Perform the query
|
||||
response = query_engine.query(query)
|
||||
|
||||
# Extract documents and their metadata
|
||||
results = []
|
||||
|
||||
# Process source nodes to extract content and metadata
|
||||
if hasattr(response, 'source_nodes'):
|
||||
for node in response.source_nodes:
|
||||
# Extract node information
|
||||
# Get all available metadata from the node
|
||||
node_metadata = node.metadata or {}
|
||||
|
||||
# The actual text content is in node.text
|
||||
content = node.text or ""
|
||||
|
||||
# Ensure proper encoding for content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode('utf-8', errors='replace')
|
||||
elif not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
# Apply the encoding fix to clean up any garbled characters
|
||||
content = _ensure_proper_encoding(content)
|
||||
|
||||
# Create a comprehensive metadata dictionary with proper encoding
|
||||
doc_info = {
|
||||
"content": content,
|
||||
"metadata": {
|
||||
"filename": _ensure_proper_encoding(node_metadata.get("filename", "unknown")),
|
||||
"file_path": _ensure_proper_encoding(node_metadata.get("file_path", "unknown")),
|
||||
"page_label": _ensure_proper_encoding(node_metadata.get("page_label",
|
||||
node_metadata.get("page", "unknown"))),
|
||||
"section": _ensure_proper_encoding(node_metadata.get("section", "unknown")),
|
||||
"paragraph": _ensure_proper_encoding(node_metadata.get("paragraph", "unknown")),
|
||||
"chunk_number": _ensure_proper_encoding(node_metadata.get("chunk_number", "unknown")),
|
||||
"total_chunks": _ensure_proper_encoding(node_metadata.get("total_chunks", "unknown")),
|
||||
"file_type": _ensure_proper_encoding(node_metadata.get("file_type", "unknown")),
|
||||
"original_doc_id": _ensure_proper_encoding(node_metadata.get("original_doc_id", "unknown")),
|
||||
"slide_id": _ensure_proper_encoding(node_metadata.get("slide_id",
|
||||
node_metadata.get("slide_id", "unknown"))),
|
||||
"sheet_name": _ensure_proper_encoding(node_metadata.get("sheet_name",
|
||||
node_metadata.get("sheet_name", "unknown"))),
|
||||
"processed_at": _ensure_proper_encoding(node_metadata.get("processed_at", "unknown")),
|
||||
# Include any additional metadata that might be present
|
||||
**{_ensure_proper_encoding(k): _ensure_proper_encoding(v) for k, v in node_metadata.items()
|
||||
if k not in ["filename", "file_path", "page_label", "page",
|
||||
"section", "paragraph", "chunk_number",
|
||||
"total_chunks", "file_type", "original_doc_id",
|
||||
"slide_id", "sheet_name", "processed_at"]}
|
||||
},
|
||||
"score": getattr(node, 'score', None)
|
||||
}
|
||||
results.append(doc_info)
|
||||
else:
|
||||
# Fallback if no source nodes are available
|
||||
content = str(response)
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode('utf-8', errors='replace')
|
||||
results.append({
|
||||
"content": content,
|
||||
"metadata": {},
|
||||
"score": None
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _ensure_proper_encoding(text):
|
||||
"""
|
||||
Helper function to ensure proper encoding of text, especially for non-ASCII characters like Cyrillic.
|
||||
|
||||
Args:
|
||||
text: Text that may need encoding correction
|
||||
|
||||
Returns:
|
||||
Properly encoded text string
|
||||
"""
|
||||
if text is None:
|
||||
return "unknown"
|
||||
|
||||
if isinstance(text, bytes):
|
||||
# Decode bytes to string with proper encoding
|
||||
try:
|
||||
return text.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
|
||||
try:
|
||||
return text.decode('cp1251') # Windows Cyrillic encoding
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
return text.decode('koi8-r') # Russian encoding
|
||||
except UnicodeDecodeError:
|
||||
# If all else fails, decode with errors='replace'
|
||||
return text.decode('utf-8', errors='replace')
|
||||
elif isinstance(text, str):
|
||||
# Ensure the string is properly encoded
|
||||
try:
|
||||
# Try to encode and decode to ensure it's valid UTF-8
|
||||
return text.encode('utf-8').decode('utf-8')
|
||||
except UnicodeEncodeError:
|
||||
# If there are encoding issues, try to fix them
|
||||
return text.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
||||
else:
|
||||
# Convert other types to string and ensure proper encoding
|
||||
text_str = str(text)
|
||||
try:
|
||||
return text_str.encode('utf-8').decode('utf-8')
|
||||
except UnicodeEncodeError:
|
||||
return text_str.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
logs_dir = Path("logs")
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Remove default logger to customize it
|
||||
logger.remove()
|
||||
|
||||
# Add file handler with rotation
|
||||
logger.add(
|
||||
"logs/dev.log",
|
||||
rotation="10 MB",
|
||||
retention="10 days",
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}"
|
||||
)
|
||||
|
||||
# Add stdout handler
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
||||
colorize=True
|
||||
)
|
||||
|
||||
logger.info("Testing retrieval functionality...")
|
||||
|
||||
try:
|
||||
# Test query
|
||||
test_query = "What is this document about?"
|
||||
results = retrieve_documents_with_query_engine(test_query, top_k=3)
|
||||
|
||||
print(f"Found {len(results)} results for query: '{test_query}'")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\nResult {i+1}:")
|
||||
print(f"Content preview: {result['content'][:200]}...")
|
||||
print(f"Metadata: {result['metadata']}")
|
||||
print(f"Score: {result['score']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test run: {e}")
|
||||
print(f"Error: {e}")
|
||||
Reference in New Issue
Block a user