""" Document enrichment module for the RAG solution. This module handles loading documents from the data directory, processing them with appropriate loaders, splitting them into chunks, and storing them in the vector database with proper metadata. """ import hashlib import os import sqlite3 from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional from llama_index.core import Document, SimpleDirectoryReader from llama_index.core.node_parser import CodeSplitter, SentenceSplitter from loguru import logger from tqdm import tqdm # Import the new configuration module from config import get_embedding_model # Removed unused import from vector_storage import get_vector_store_and_index SUPPORTED_ENRICHMENT_EXTENSIONS = { ".csv", ".doc", ".docx", ".epub", ".htm", ".html", ".json", ".jsonl", ".md", ".odt", ".pdf", ".ppt", ".pptx", ".rtf", ".rst", ".tsv", ".txt", ".xls", ".xlsx", ".xml", } def get_supported_enrichment_extensions() -> set[str]: """Return the file extensions currently supported by enrichment.""" return set(SUPPORTED_ENRICHMENT_EXTENSIONS) class DocumentTracker: """Class to handle tracking of processed documents to avoid re-processing.""" def __init__(self, db_path: str = "document_tracking.db"): self.db_path = db_path self._init_db() def _init_db(self): """Initialize the SQLite database for document tracking.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create table for tracking processed documents cursor.execute(""" CREATE TABLE IF NOT EXISTS processed_documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT UNIQUE NOT NULL, filepath TEXT NOT NULL, checksum TEXT NOT NULL, processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, metadata_json TEXT ) """) conn.commit() conn.close() logger.info(f"Document tracker initialized with database: {self.db_path}") def is_document_processed(self, filepath: str) -> bool: """Check if a document has already been processed.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Calculate checksum of the file checksum = self._calculate_checksum(filepath) cursor.execute( "SELECT COUNT(*) FROM processed_documents WHERE filepath = ? AND checksum = ?", (filepath, checksum), ) count = cursor.fetchone()[0] conn.close() return count > 0 def mark_document_processed(self, filepath: str, metadata: Dict[str, Any] = None): """Mark a document as processed in the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() checksum = self._calculate_checksum(filepath) filename = Path(filepath).name try: cursor.execute( """ INSERT OR REPLACE INTO processed_documents (filename, filepath, checksum, processed_at, metadata_json) VALUES (?, ?, ?, CURRENT_TIMESTAMP, ?) """, (filename, filepath, checksum, str(metadata) if metadata else None), ) conn.commit() logger.info(f"Document marked as processed: {filepath}") except sqlite3.Error as e: logger.error(f"Error marking document as processed: {e}") finally: conn.close() def _calculate_checksum(self, filepath: str) -> str: """Calculate MD5 checksum of a file.""" hash_md5 = hashlib.md5() with open(filepath, "rb") as f: # Read file in chunks to handle large files efficiently for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def get_text_splitter(file_extension: str): """Get appropriate text splitter based on file type.""" from llama_index.core.node_parser import ( CodeSplitter, MarkdownElementNodeParser, SentenceSplitter, TokenTextSplitter, ) # For code files, use CodeSplitter if file_extension.lower() in [ ".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".cs", ".go", ".rs", ".php", ".html", ".css", ".md", ".rst", ]: return CodeSplitter(language="python", max_chars=1000) # For PDF files, use a parser that can handle multi-page documents elif file_extension.lower() == ".pdf": return SentenceSplitter( chunk_size=512, # Smaller chunks for dense PDF content chunk_overlap=100, ) # For presentation files (PowerPoint), use smaller chunks elif file_extension.lower() == ".pptx": return SentenceSplitter( chunk_size=256, # Slides typically have less text chunk_overlap=50, ) # For spreadsheets, use smaller chunks elif file_extension.lower() == ".xlsx": return SentenceSplitter(chunk_size=256, chunk_overlap=50) # For text-heavy documents like Word, use medium-sized chunks elif file_extension.lower() in [".docx", ".odt"]: return SentenceSplitter(chunk_size=768, chunk_overlap=150) # For plain text files, use larger chunks elif file_extension.lower() == ".txt": return SentenceSplitter(chunk_size=1024, chunk_overlap=200) # For image files, we'll handle them differently (metadata extraction) elif file_extension.lower() in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".svg"]: # Images will be handled by multimodal models, return a simple splitter return SentenceSplitter(chunk_size=512, chunk_overlap=100) # For other files, use a standard SentenceSplitter else: return SentenceSplitter(chunk_size=768, chunk_overlap=150) def ensure_proper_encoding(text): """ Helper function to ensure proper encoding of text, especially for non-ASCII characters like Cyrillic. Args: text: Text that may need encoding correction Returns: Properly encoded text string """ if text is None: return "unknown" if isinstance(text, bytes): # Decode bytes to string with proper encoding try: return text.decode("utf-8") except UnicodeDecodeError: # If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text try: return text.decode("cp1251") # Windows Cyrillic encoding except UnicodeDecodeError: try: return text.decode("koi8-r") # Russian encoding except UnicodeDecodeError: # If all else fails, decode with errors='replace' return text.decode("utf-8", errors="replace") elif isinstance(text, str): # Ensure the string is properly encoded try: # Try to encode and decode to ensure it's valid UTF-8 return text.encode("utf-8").decode("utf-8") except UnicodeEncodeError: # If there are encoding issues, try to fix them return text.encode("utf-8", errors="replace").decode( "utf-8", errors="replace" ) else: # Convert other types to string and ensure proper encoding text_str = str(text) try: return text_str.encode("utf-8").decode("utf-8") except UnicodeEncodeError: return text_str.encode("utf-8", errors="replace").decode( "utf-8", errors="replace" ) def process_documents_from_data_folder( data_path: str = "../../../data", recursive: bool = True ): """ Process all documents from the data folder using appropriate loaders and store in vector DB. Args: data_path: Path to the data folder relative to current directory recursive: Whether to process subdirectories recursively """ logger.info(f"Starting document enrichment from: {data_path}") # Initialize document tracker tracker = DocumentTracker() # Initialize vector storage vector_store, index = get_vector_store_and_index() # Get the absolute path to the data directory # The data_path is relative to the current working directory data_abs_path = Path(data_path) # If the path is relative, resolve it from the current working directory if not data_abs_path.is_absolute(): data_abs_path = Path.cwd() / data_abs_path logger.info(f"Looking for documents in: {data_abs_path.absolute()}") if not data_abs_path.exists(): logger.error(f"Data directory does not exist: {data_abs_path.absolute()}") return # Find all supported files in the data directory supported_extensions = get_supported_enrichment_extensions() # Walk through the directory structure all_files = [] if recursive: for root, dirs, files in os.walk(data_abs_path): for file in files: file_ext = Path(file).suffix.lower() if file_ext in supported_extensions: all_files.append(os.path.join(root, file)) else: for file in data_abs_path.iterdir(): if file.is_file(): file_ext = file.suffix.lower() if file_ext in supported_extensions: all_files.append(str(file)) logger.info( f"Found {len(all_files)} supported files to process (extensions: {', '.join(sorted(supported_extensions))})" ) processed_count = 0 skipped_count = 0 error_count = 0 # Initialize progress bar pbar = tqdm(total=len(all_files), desc="Processing documents", unit="file") for file_path in all_files: logger.info( f"Processing file: {file_path} ({processed_count + skipped_count + 1}/{len(all_files)})" ) try: result = process_document_file(file_path, tracker=tracker, index=index) if result["status"] == "processed": processed_count += 1 elif result["status"] == "skipped": skipped_count += 1 else: error_count += 1 pbar.set_postfix( {"Processed": processed_count, "Skipped": skipped_count, "Errors": error_count} ) except Exception as e: logger.error(f"Error processing file {file_path}: {str(e)}") error_count += 1 pbar.set_postfix( {"Processed": processed_count, "Skipped": skipped_count, "Errors": error_count} ) # Update progress bar regardless of success or failure pbar.update(1) pbar.close() logger.info( f"Document enrichment completed. Processed: {processed_count}, Skipped: {skipped_count}, Errors: {error_count}" ) def process_document_file( file_path: str, tracker: Optional[DocumentTracker] = None, index=None, ) -> Dict[str, Any]: """ Process a single document file and store its chunks in the vector index. Returns a dict with status and counters. Status is one of: `processed`, `skipped`, `error`. """ file_ext = Path(file_path).suffix.lower() if file_ext not in get_supported_enrichment_extensions(): logger.info(f"Skipping unsupported extension for file: {file_path}") return {"status": "skipped", "reason": "unsupported_extension", "nodes": 0} tracker = tracker or DocumentTracker() if tracker.is_document_processed(file_path): logger.info(f"Skipping already processed file: {file_path}") return {"status": "skipped", "reason": "already_processed", "nodes": 0} if index is None: _, index = get_vector_store_and_index() try: def file_metadata_func(file_path_str): filename = ensure_proper_encoding(Path(file_path_str).name) return {"filename": filename} reader = SimpleDirectoryReader( input_files=[file_path], file_metadata=file_metadata_func ) documents = reader.load_data() total_nodes_inserted = 0 for doc in documents: current_file_ext = Path(file_path).suffix encoded_file_path = ensure_proper_encoding(file_path) doc.metadata["file_path"] = encoded_file_path doc.metadata["processed_at"] = datetime.now().isoformat() if current_file_ext.lower() == ".pdf": doc.metadata["page_label"] = ensure_proper_encoding( doc.metadata.get("page_label", "unknown") ) doc.metadata["file_type"] = "pdf" elif current_file_ext.lower() in [".docx", ".odt", ".doc", ".rtf"]: doc.metadata["section"] = ensure_proper_encoding( doc.metadata.get("section", "unknown") ) doc.metadata["file_type"] = "document" elif current_file_ext.lower() in [".pptx", ".ppt"]: doc.metadata["slide_id"] = ensure_proper_encoding( doc.metadata.get("slide_id", "unknown") ) doc.metadata["file_type"] = "presentation" elif current_file_ext.lower() in [".xlsx", ".xls", ".csv", ".tsv"]: doc.metadata["sheet_name"] = ensure_proper_encoding( doc.metadata.get("sheet_name", "unknown") ) doc.metadata["file_type"] = "spreadsheet" splitter = get_text_splitter(current_file_ext) nodes = splitter.get_nodes_from_documents([doc]) nodes_with_enhanced_metadata = [] for i, node in enumerate(nodes): node.metadata["original_doc_id"] = ensure_proper_encoding(doc.doc_id) node.metadata["chunk_number"] = i node.metadata["total_chunks"] = len(nodes) node.metadata["file_path"] = encoded_file_path node.text = ensure_proper_encoding(node.text) nodes_with_enhanced_metadata.append(node) if nodes_with_enhanced_metadata: index.insert_nodes(nodes_with_enhanced_metadata) total_nodes_inserted += len(nodes_with_enhanced_metadata) logger.info(f"Processed {len(nodes)} nodes from {encoded_file_path}") tracker.mark_document_processed( file_path, {"documents_count": len(documents), "nodes_count": total_nodes_inserted}, ) return {"status": "processed", "nodes": total_nodes_inserted} except Exception as e: logger.error(f"Error processing file {file_path}: {e}") return {"status": "error", "reason": str(e), "nodes": 0} def enrich_documents(): """Main function to run the document enrichment process.""" logger.info("Starting document enrichment process") process_documents_from_data_folder() logger.info("Document enrichment process completed") if __name__ == "__main__": # Example usage logger.info("Running document enrichment...") enrich_documents()