Files

313 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Retrieval module for the RAG solution using LlamaIndex and Qdrant.
This module provides functionality to retrieve relevant documents
from the vector storage based on a query text.
"""
import os
from typing import List, Dict, Any
from llama_index.core import VectorStoreIndex, Settings
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from loguru import logger
from pathlib import Path
from vector_storage import get_vector_store_and_index
# Import the new configuration module
from config import setup_global_models
def initialize_retriever(
collection_name: str = "documents_llamaindex",
similarity_top_k: int = 5,
host: str = "localhost",
port: int = 6333
) -> RetrieverQueryEngine:
"""
Initialize the retriever query engine with the vector store.
Args:
collection_name: Name of the Qdrant collection
similarity_top_k: Number of top similar documents to retrieve
host: Qdrant host address
port: Qdrant REST API port
Returns:
RetrieverQueryEngine configured with the vector store
"""
logger.info(f"Initializing retriever for collection: {collection_name}")
try:
# Set up the global models to prevent defaulting to OpenAI
setup_global_models()
# Get the vector store and index from the existing configuration
vector_store, index = get_vector_store_and_index()
# Create a retriever from the index
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k
)
# Create the query engine
query_engine = RetrieverQueryEngine(
retriever=retriever
)
logger.info("Retriever initialized successfully")
return query_engine
except Exception as e:
logger.error(f"Failed to initialize retriever: {str(e)}")
raise
def retrieve_documents(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve documents from the vector storage based on the query text.
Args:
query: The query text to search for
top_k: Number of top similar documents to retrieve
Returns:
List of dictionaries containing document content and metadata
"""
logger.info(f"Retrieving documents for query: '{query[:50]}...' (top_k={top_k})")
try:
# Initialize the query engine
query_engine = initialize_retriever(similarity_top_k=top_k)
# Perform the query
response = query_engine.query(query)
# Extract documents and their metadata
results = []
# If response is a single text response, we need to get the source nodes
if hasattr(response, 'source_nodes'):
for node in response.source_nodes:
doc_info = {
"content": node.text,
"metadata": node.metadata,
"score": node.score if hasattr(node, 'score') else None
}
results.append(doc_info)
else:
# If the response doesn't have source nodes, try to extract text content
results.append({
"content": str(response),
"metadata": {},
"score": None
})
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
return results
except Exception as e:
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
raise
def retrieve_documents_with_query_engine(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Alternative method to retrieve documents using a direct query engine approach.
Args:
query: The query text to search for
top_k: Number of top similar documents to retrieve
Returns:
List of dictionaries containing document content and metadata
"""
logger.info(f"Retrieving documents with direct query engine for query: '{query[:50]}...' (top_k={top_k})")
try:
# Set up the global models to prevent defaulting to OpenAI
setup_global_models()
# Get the vector store and index from the existing configuration
vector_store, index = get_vector_store_and_index()
# Create a retriever from the index
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k
)
# Create the query engine
query_engine = RetrieverQueryEngine(
retriever=retriever
)
# Set the global models again right before the query to ensure they're used
setup_global_models()
# Perform the query
response = query_engine.query(query)
# Extract documents and their metadata
results = []
# Process source nodes to extract content and metadata
if hasattr(response, 'source_nodes'):
for node in response.source_nodes:
# Extract node information
# Get all available metadata from the node
node_metadata = node.metadata or {}
# The actual text content is in node.text
content = node.text or ""
# Ensure proper encoding for content
if isinstance(content, bytes):
content = content.decode('utf-8', errors='replace')
elif not isinstance(content, str):
content = str(content)
# Apply the encoding fix to clean up any garbled characters
content = _ensure_proper_encoding(content)
# Create a comprehensive metadata dictionary with proper encoding
doc_info = {
"content": content,
"metadata": {
"filename": _ensure_proper_encoding(node_metadata.get("filename", "unknown")),
"file_path": _ensure_proper_encoding(node_metadata.get("file_path", "unknown")),
"page_label": _ensure_proper_encoding(node_metadata.get("page_label",
node_metadata.get("page", "unknown"))),
"section": _ensure_proper_encoding(node_metadata.get("section", "unknown")),
"paragraph": _ensure_proper_encoding(node_metadata.get("paragraph", "unknown")),
"chunk_number": _ensure_proper_encoding(node_metadata.get("chunk_number", "unknown")),
"total_chunks": _ensure_proper_encoding(node_metadata.get("total_chunks", "unknown")),
"file_type": _ensure_proper_encoding(node_metadata.get("file_type", "unknown")),
"original_doc_id": _ensure_proper_encoding(node_metadata.get("original_doc_id", "unknown")),
"slide_id": _ensure_proper_encoding(node_metadata.get("slide_id",
node_metadata.get("slide_id", "unknown"))),
"sheet_name": _ensure_proper_encoding(node_metadata.get("sheet_name",
node_metadata.get("sheet_name", "unknown"))),
"processed_at": _ensure_proper_encoding(node_metadata.get("processed_at", "unknown")),
# Include any additional metadata that might be present
**{_ensure_proper_encoding(k): _ensure_proper_encoding(v) for k, v in node_metadata.items()
if k not in ["filename", "file_path", "page_label", "page",
"section", "paragraph", "chunk_number",
"total_chunks", "file_type", "original_doc_id",
"slide_id", "sheet_name", "processed_at"]}
},
"score": getattr(node, 'score', None)
}
results.append(doc_info)
else:
# Fallback if no source nodes are available
content = str(response)
if isinstance(content, bytes):
content = content.decode('utf-8', errors='replace')
results.append({
"content": content,
"metadata": {},
"score": None
})
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
return results
except Exception as e:
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
raise
def _ensure_proper_encoding(text):
"""
Helper function to ensure proper encoding of text, especially for non-ASCII characters like Cyrillic.
Args:
text: Text that may need encoding correction
Returns:
Properly encoded text string
"""
if text is None:
return "unknown"
if isinstance(text, bytes):
# Decode bytes to string with proper encoding
try:
return text.decode('utf-8')
except UnicodeDecodeError:
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
try:
return text.decode('cp1251') # Windows Cyrillic encoding
except UnicodeDecodeError:
try:
return text.decode('koi8-r') # Russian encoding
except UnicodeDecodeError:
# If all else fails, decode with errors='replace'
return text.decode('utf-8', errors='replace')
elif isinstance(text, str):
# Ensure the string is properly encoded
try:
# Try to encode and decode to ensure it's valid UTF-8
return text.encode('utf-8').decode('utf-8')
except UnicodeEncodeError:
# If there are encoding issues, try to fix them
return text.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
else:
# Convert other types to string and ensure proper encoding
text_str = str(text)
try:
return text_str.encode('utf-8').decode('utf-8')
except UnicodeEncodeError:
return text_str.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
if __name__ == "__main__":
# Example usage
from loguru import logger
import sys
# Create logs directory if it doesn't exist
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
# Remove default logger to customize it
logger.remove()
# Add file handler with rotation
logger.add(
"logs/dev.log",
rotation="10 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}"
)
# Add stdout handler
logger.add(
sys.stdout,
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
colorize=True
)
logger.info("Testing retrieval functionality...")
try:
# Test query
test_query = "What is this document about?"
results = retrieve_documents_with_query_engine(test_query, top_k=3)
print(f"Found {len(results)} results for query: '{test_query}'")
for i, result in enumerate(results):
print(f"\nResult {i+1}:")
print(f"Content preview: {result['content'][:200]}...")
print(f"Metadata: {result['metadata']}")
print(f"Score: {result['score']}")
except Exception as e:
logger.error(f"Error in test run: {e}")
print(f"Error: {e}")