313 lines
12 KiB
Python
313 lines
12 KiB
Python
"""
|
|
Retrieval module for the RAG solution using LlamaIndex and Qdrant.
|
|
|
|
This module provides functionality to retrieve relevant documents
|
|
from the vector storage based on a query text.
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Dict, Any
|
|
from llama_index.core import VectorStoreIndex, Settings
|
|
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
|
from llama_index.core.retrievers import VectorIndexRetriever
|
|
from loguru import logger
|
|
from pathlib import Path
|
|
|
|
from vector_storage import get_vector_store_and_index
|
|
|
|
# Import the new configuration module
|
|
from config import setup_global_models
|
|
|
|
|
|
def initialize_retriever(
|
|
collection_name: str = "documents_llamaindex",
|
|
similarity_top_k: int = 5,
|
|
host: str = "localhost",
|
|
port: int = 6333
|
|
) -> RetrieverQueryEngine:
|
|
"""
|
|
Initialize the retriever query engine with the vector store.
|
|
|
|
Args:
|
|
collection_name: Name of the Qdrant collection
|
|
similarity_top_k: Number of top similar documents to retrieve
|
|
host: Qdrant host address
|
|
port: Qdrant REST API port
|
|
|
|
Returns:
|
|
RetrieverQueryEngine configured with the vector store
|
|
"""
|
|
logger.info(f"Initializing retriever for collection: {collection_name}")
|
|
|
|
try:
|
|
# Set up the global models to prevent defaulting to OpenAI
|
|
setup_global_models()
|
|
|
|
# Get the vector store and index from the existing configuration
|
|
vector_store, index = get_vector_store_and_index()
|
|
|
|
# Create a retriever from the index
|
|
retriever = VectorIndexRetriever(
|
|
index=index,
|
|
similarity_top_k=similarity_top_k
|
|
)
|
|
|
|
# Create the query engine
|
|
query_engine = RetrieverQueryEngine(
|
|
retriever=retriever
|
|
)
|
|
|
|
logger.info("Retriever initialized successfully")
|
|
return query_engine
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize retriever: {str(e)}")
|
|
raise
|
|
|
|
|
|
def retrieve_documents(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
Retrieve documents from the vector storage based on the query text.
|
|
|
|
Args:
|
|
query: The query text to search for
|
|
top_k: Number of top similar documents to retrieve
|
|
|
|
Returns:
|
|
List of dictionaries containing document content and metadata
|
|
"""
|
|
logger.info(f"Retrieving documents for query: '{query[:50]}...' (top_k={top_k})")
|
|
|
|
try:
|
|
# Initialize the query engine
|
|
query_engine = initialize_retriever(similarity_top_k=top_k)
|
|
|
|
# Perform the query
|
|
response = query_engine.query(query)
|
|
|
|
# Extract documents and their metadata
|
|
results = []
|
|
|
|
# If response is a single text response, we need to get the source nodes
|
|
if hasattr(response, 'source_nodes'):
|
|
for node in response.source_nodes:
|
|
doc_info = {
|
|
"content": node.text,
|
|
"metadata": node.metadata,
|
|
"score": node.score if hasattr(node, 'score') else None
|
|
}
|
|
results.append(doc_info)
|
|
else:
|
|
# If the response doesn't have source nodes, try to extract text content
|
|
results.append({
|
|
"content": str(response),
|
|
"metadata": {},
|
|
"score": None
|
|
})
|
|
|
|
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
|
|
raise
|
|
|
|
|
|
def retrieve_documents_with_query_engine(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
Alternative method to retrieve documents using a direct query engine approach.
|
|
|
|
Args:
|
|
query: The query text to search for
|
|
top_k: Number of top similar documents to retrieve
|
|
|
|
Returns:
|
|
List of dictionaries containing document content and metadata
|
|
"""
|
|
logger.info(f"Retrieving documents with direct query engine for query: '{query[:50]}...' (top_k={top_k})")
|
|
|
|
try:
|
|
# Set up the global models to prevent defaulting to OpenAI
|
|
setup_global_models()
|
|
|
|
# Get the vector store and index from the existing configuration
|
|
vector_store, index = get_vector_store_and_index()
|
|
|
|
# Create a retriever from the index
|
|
retriever = VectorIndexRetriever(
|
|
index=index,
|
|
similarity_top_k=top_k
|
|
)
|
|
|
|
# Create the query engine
|
|
query_engine = RetrieverQueryEngine(
|
|
retriever=retriever
|
|
)
|
|
|
|
# Set the global models again right before the query to ensure they're used
|
|
setup_global_models()
|
|
|
|
# Perform the query
|
|
response = query_engine.query(query)
|
|
|
|
# Extract documents and their metadata
|
|
results = []
|
|
|
|
# Process source nodes to extract content and metadata
|
|
if hasattr(response, 'source_nodes'):
|
|
for node in response.source_nodes:
|
|
# Extract node information
|
|
# Get all available metadata from the node
|
|
node_metadata = node.metadata or {}
|
|
|
|
# The actual text content is in node.text
|
|
content = node.text or ""
|
|
|
|
# Ensure proper encoding for content
|
|
if isinstance(content, bytes):
|
|
content = content.decode('utf-8', errors='replace')
|
|
elif not isinstance(content, str):
|
|
content = str(content)
|
|
|
|
# Apply the encoding fix to clean up any garbled characters
|
|
content = _ensure_proper_encoding(content)
|
|
|
|
# Create a comprehensive metadata dictionary with proper encoding
|
|
doc_info = {
|
|
"content": content,
|
|
"metadata": {
|
|
"filename": _ensure_proper_encoding(node_metadata.get("filename", "unknown")),
|
|
"file_path": _ensure_proper_encoding(node_metadata.get("file_path", "unknown")),
|
|
"page_label": _ensure_proper_encoding(node_metadata.get("page_label",
|
|
node_metadata.get("page", "unknown"))),
|
|
"section": _ensure_proper_encoding(node_metadata.get("section", "unknown")),
|
|
"paragraph": _ensure_proper_encoding(node_metadata.get("paragraph", "unknown")),
|
|
"chunk_number": _ensure_proper_encoding(node_metadata.get("chunk_number", "unknown")),
|
|
"total_chunks": _ensure_proper_encoding(node_metadata.get("total_chunks", "unknown")),
|
|
"file_type": _ensure_proper_encoding(node_metadata.get("file_type", "unknown")),
|
|
"original_doc_id": _ensure_proper_encoding(node_metadata.get("original_doc_id", "unknown")),
|
|
"slide_id": _ensure_proper_encoding(node_metadata.get("slide_id",
|
|
node_metadata.get("slide_id", "unknown"))),
|
|
"sheet_name": _ensure_proper_encoding(node_metadata.get("sheet_name",
|
|
node_metadata.get("sheet_name", "unknown"))),
|
|
"processed_at": _ensure_proper_encoding(node_metadata.get("processed_at", "unknown")),
|
|
# Include any additional metadata that might be present
|
|
**{_ensure_proper_encoding(k): _ensure_proper_encoding(v) for k, v in node_metadata.items()
|
|
if k not in ["filename", "file_path", "page_label", "page",
|
|
"section", "paragraph", "chunk_number",
|
|
"total_chunks", "file_type", "original_doc_id",
|
|
"slide_id", "sheet_name", "processed_at"]}
|
|
},
|
|
"score": getattr(node, 'score', None)
|
|
}
|
|
results.append(doc_info)
|
|
else:
|
|
# Fallback if no source nodes are available
|
|
content = str(response)
|
|
if isinstance(content, bytes):
|
|
content = content.decode('utf-8', errors='replace')
|
|
results.append({
|
|
"content": content,
|
|
"metadata": {},
|
|
"score": None
|
|
})
|
|
|
|
logger.info(f"Retrieved {len(results)} documents for query: '{query[:30]}...'")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving documents for query '{query[:30]}...': {str(e)}")
|
|
raise
|
|
|
|
|
|
def _ensure_proper_encoding(text):
|
|
"""
|
|
Helper function to ensure proper encoding of text, especially for non-ASCII characters like Cyrillic.
|
|
|
|
Args:
|
|
text: Text that may need encoding correction
|
|
|
|
Returns:
|
|
Properly encoded text string
|
|
"""
|
|
if text is None:
|
|
return "unknown"
|
|
|
|
if isinstance(text, bytes):
|
|
# Decode bytes to string with proper encoding
|
|
try:
|
|
return text.decode('utf-8')
|
|
except UnicodeDecodeError:
|
|
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
|
|
try:
|
|
return text.decode('cp1251') # Windows Cyrillic encoding
|
|
except UnicodeDecodeError:
|
|
try:
|
|
return text.decode('koi8-r') # Russian encoding
|
|
except UnicodeDecodeError:
|
|
# If all else fails, decode with errors='replace'
|
|
return text.decode('utf-8', errors='replace')
|
|
elif isinstance(text, str):
|
|
# Ensure the string is properly encoded
|
|
try:
|
|
# Try to encode and decode to ensure it's valid UTF-8
|
|
return text.encode('utf-8').decode('utf-8')
|
|
except UnicodeEncodeError:
|
|
# If there are encoding issues, try to fix them
|
|
return text.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
|
else:
|
|
# Convert other types to string and ensure proper encoding
|
|
text_str = str(text)
|
|
try:
|
|
return text_str.encode('utf-8').decode('utf-8')
|
|
except UnicodeEncodeError:
|
|
return text_str.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
from loguru import logger
|
|
import sys
|
|
|
|
# Create logs directory if it doesn't exist
|
|
logs_dir = Path("logs")
|
|
logs_dir.mkdir(exist_ok=True)
|
|
|
|
# Remove default logger to customize it
|
|
logger.remove()
|
|
|
|
# Add file handler with rotation
|
|
logger.add(
|
|
"logs/dev.log",
|
|
rotation="10 MB",
|
|
retention="10 days",
|
|
level="INFO",
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}"
|
|
)
|
|
|
|
# Add stdout handler
|
|
logger.add(
|
|
sys.stdout,
|
|
level="INFO",
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
|
colorize=True
|
|
)
|
|
|
|
logger.info("Testing retrieval functionality...")
|
|
|
|
try:
|
|
# Test query
|
|
test_query = "What is this document about?"
|
|
results = retrieve_documents_with_query_engine(test_query, top_k=3)
|
|
|
|
print(f"Found {len(results)} results for query: '{test_query}'")
|
|
for i, result in enumerate(results):
|
|
print(f"\nResult {i+1}:")
|
|
print(f"Content preview: {result['content'][:200]}...")
|
|
print(f"Metadata: {result['metadata']}")
|
|
print(f"Score: {result['score']}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in test run: {e}")
|
|
print(f"Error: {e}") |