Files
rag-solution/services/rag/llamaindex/retrieval.py

313 lines
12 KiB
Python

"""
Retrieval module for the RAG solution using LlamaIndex and Qdrant.
This module provides functionality to retrieve relevant documents
from the vector storage based on a query text.
"""
import os
from typing import List, Dict, Any
from llama_index.core import VectorStoreIndex, Settings
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from loguru import logger
from pathlib import Path
from vector_storage import get_vector_store_and_index
# Import the new configuration module
from config import setup_global_models
def initialize_retriever(
collection_name: str = "documents_llamaindex",
similarity_top_k: int = 5,
host: str = "localhost",
port: int = 6333
) -> RetrieverQueryEngine:
"""
Initialize the retriever query engine with the vector store.
Args:
collection_name: Name of the Qdrant collection
similarity_top_k: Number of top similar documents to retrieve
host: Qdrant host address
port: Qdrant REST API port
Returns:
RetrieverQueryEngine configured with the vector store
"""
logger.info(f"Initializing retriever for collection: {collection_name}")
try:
# Set up the global models to prevent defaulting to OpenAI
setup_global_models()
# Get the vector store and index from the existing configuration
vector_store, index = get_vector_store_and_index()
# Create a retriever from the index
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k
)
# Create the query engine
query_engine = RetrieverQueryEngine(
retriever=retriever
)
logger.info("Retriever initialized successfully")
return query_engine
except Exception as e:
logger.error(f"Failed to initialize retriever: {str(e)}")
raise
def retrieve_documents(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve documents from the vector storage based on the query text.
Args:
query: The query text to search for
top_k: Number of top similar documents to retrieve
Returns:
List of dictionaries containing document content and metadata
"""
logger.info(f"Retrieving documents for query: '{query[:50]}...' (top_k={top_k})")
try:
# Initialize the query engine
query_engine = initialize_retriever(similarity_top_k=top_k)
# Perform the query
response = query_engine.query(query)
# Extract documents and their metadata
results = []
# If response is a single text response, we need to get the source nodes
if hasattr(response, 'source_nodes'):
for node in response.source_nodes:
doc_info = {
"content": node.text,
"metadata": node.metadata,
"score": node.score if hasattr(node, 'score') else None
}
results.append(doc_info)
else:
# If the response doesn't have source nodes, try to extract text content
results.append({
"content": str(response),
"metadata": {},
"score": None
})
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
return results
except Exception as e:
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
raise
def retrieve_documents_with_query_engine(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Alternative method to retrieve documents using a direct query engine approach.
Args:
query: The query text to search for
top_k: Number of top similar documents to retrieve
Returns:
List of dictionaries containing document content and metadata
"""
logger.info(f"Retrieving documents with direct query engine for query: '{query[:50]}...' (top_k={top_k})")
try:
# Set up the global models to prevent defaulting to OpenAI
setup_global_models()
# Get the vector store and index from the existing configuration
vector_store, index = get_vector_store_and_index()
# Create a retriever from the index
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k
)
# Create the query engine
query_engine = RetrieverQueryEngine(
retriever=retriever
)
# Set the global models again right before the query to ensure they're used
setup_global_models()
# Perform the query
response = query_engine.query(query)
# Extract documents and their metadata
results = []
# Process source nodes to extract content and metadata
if hasattr(response, 'source_nodes'):
for node in response.source_nodes:
# Extract node information
# Get all available metadata from the node
node_metadata = node.metadata or {}
# The actual text content is in node.text
content = node.text or ""
# Ensure proper encoding for content
if isinstance(content, bytes):
content = content.decode('utf-8', errors='replace')
elif not isinstance(content, str):
content = str(content)
# Apply the encoding fix to clean up any garbled characters
content = _ensure_proper_encoding(content)
# Create a comprehensive metadata dictionary with proper encoding
doc_info = {
"content": content,
"metadata": {
"filename": _ensure_proper_encoding(node_metadata.get("filename", "unknown")),
"file_path": _ensure_proper_encoding(node_metadata.get("file_path", "unknown")),
"page_label": _ensure_proper_encoding(node_metadata.get("page_label",
node_metadata.get("page", "unknown"))),
"section": _ensure_proper_encoding(node_metadata.get("section", "unknown")),
"paragraph": _ensure_proper_encoding(node_metadata.get("paragraph", "unknown")),
"chunk_number": _ensure_proper_encoding(node_metadata.get("chunk_number", "unknown")),
"total_chunks": _ensure_proper_encoding(node_metadata.get("total_chunks", "unknown")),
"file_type": _ensure_proper_encoding(node_metadata.get("file_type", "unknown")),
"original_doc_id": _ensure_proper_encoding(node_metadata.get("original_doc_id", "unknown")),
"slide_id": _ensure_proper_encoding(node_metadata.get("slide_id",
node_metadata.get("slide_id", "unknown"))),
"sheet_name": _ensure_proper_encoding(node_metadata.get("sheet_name",
node_metadata.get("sheet_name", "unknown"))),
"processed_at": _ensure_proper_encoding(node_metadata.get("processed_at", "unknown")),
# Include any additional metadata that might be present
**{_ensure_proper_encoding(k): _ensure_proper_encoding(v) for k, v in node_metadata.items()
if k not in ["filename", "file_path", "page_label", "page",
"section", "paragraph", "chunk_number",
"total_chunks", "file_type", "original_doc_id",
"slide_id", "sheet_name", "processed_at"]}
},
"score": getattr(node, 'score', None)
}
results.append(doc_info)
else:
# Fallback if no source nodes are available
content = str(response)
if isinstance(content, bytes):
content = content.decode('utf-8', errors='replace')
results.append({
"content": content,
"metadata": {},
"score": None
})
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
return results
except Exception as e:
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
raise
def _ensure_proper_encoding(text):
"""
Helper function to ensure proper encoding of text, especially for non-ASCII characters like Cyrillic.
Args:
text: Text that may need encoding correction
Returns:
Properly encoded text string
"""
if text is None:
return "unknown"
if isinstance(text, bytes):
# Decode bytes to string with proper encoding
try:
return text.decode('utf-8')
except UnicodeDecodeError:
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
try:
return text.decode('cp1251') # Windows Cyrillic encoding
except UnicodeDecodeError:
try:
return text.decode('koi8-r') # Russian encoding
except UnicodeDecodeError:
# If all else fails, decode with errors='replace'
return text.decode('utf-8', errors='replace')
elif isinstance(text, str):
# Ensure the string is properly encoded
try:
# Try to encode and decode to ensure it's valid UTF-8
return text.encode('utf-8').decode('utf-8')
except UnicodeEncodeError:
# If there are encoding issues, try to fix them
return text.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
else:
# Convert other types to string and ensure proper encoding
text_str = str(text)
try:
return text_str.encode('utf-8').decode('utf-8')
except UnicodeEncodeError:
return text_str.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
if __name__ == "__main__":
# Example usage
from loguru import logger
import sys
# Create logs directory if it doesn't exist
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
# Remove default logger to customize it
logger.remove()
# Add file handler with rotation
logger.add(
"logs/dev.log",
rotation="10 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}"
)
# Add stdout handler
logger.add(
sys.stdout,
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
colorize=True
)
logger.info("Testing retrieval functionality...")
try:
# Test query
test_query = "What is this document about?"
results = retrieve_documents_with_query_engine(test_query, top_k=3)
print(f"Found {len(results)} results for query: '{test_query}'")
for i, result in enumerate(results):
print(f"\nResult {i+1}:")
print(f"Content preview: {result['content'][:200]}...")
print(f"Metadata: {result['metadata']}")
print(f"Score: {result['score']}")
except Exception as e:
logger.error(f"Error in test run: {e}")
print(f"Error: {e}")