Files
rag-solution/services/rag/langchain/enrichment.py

373 lines
13 KiB
Python

"""Document enrichment module for loading documents into vector storage."""
import hashlib
import os
from pathlib import Path
from typing import Iterator, List, Tuple
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Dynamically import other loaders to handle optional dependencies
try:
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
except ImportError:
UnstructuredWordDocumentLoader = None
try:
from langchain_community.document_loaders import UnstructuredPowerPointLoader
except ImportError:
UnstructuredPowerPointLoader = None
try:
from langchain_community.document_loaders import UnstructuredExcelLoader
except ImportError:
UnstructuredExcelLoader = None
try:
from langchain_community.document_loaders import UnstructuredImageLoader
except ImportError:
UnstructuredImageLoader = None
try:
from langchain_community.document_loaders import UnstructuredODTLoader
except ImportError:
UnstructuredODTLoader = None
from loguru import logger
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from helpers import (
LocalFilesystemAdaptiveCollection,
YandexDiskAdaptiveCollection,
_AdaptiveCollection,
_AdaptiveFile,
extract_russian_event_names,
extract_years_from_text,
)
# Load environment variables
load_dotenv()
# Define the path to the data directory
DATA_DIR = Path("../../../data").resolve()
DB_PATH = Path("document_tracking.db").resolve()
ENRICHMENT_SOURCE = os.getenv("ENRICHMENT_SOURCE", "local").lower()
ENRICHMENT_LOCAL_PATH = os.getenv("ENRICHMENT_LOCAL_PATH")
ENRICHMENT_YADISK_PATH = os.getenv("ENRICHMENT_YADISK_PATH")
YADISK_TOKEN = os.getenv("YADISK_TOKEN")
SUPPORTED_EXTENSIONS = {
".pdf",
".docx",
".doc",
".pptx",
".xlsx",
".xls",
".jpg",
".jpeg",
".png",
".gif",
".bmp",
".tiff",
".webp",
".odt",
}
def try_guess_source(extension: str) -> str:
if extension in [".xlsx", "xls"]:
return "таблица"
elif extension in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
return "изображение"
elif extension in [".pptx"]:
return "презентация"
else:
return "документ"
Base = declarative_base()
class ProcessedDocument(Base):
"""Database model for tracking processed documents."""
__tablename__ = "processed_documents"
id = Column(Integer, primary_key=True)
file_path = Column(String, unique=True, nullable=False)
file_hash = Column(String, nullable=False)
class DocumentEnricher:
"""Class responsible for enriching documents and loading them to vector storage."""
def __init__(self, vector_store):
self.vector_store = vector_store
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
# Initialize database for tracking processed documents
self._init_db()
def _init_db(self):
"""Initialize the SQLite database for tracking processed documents."""
self.engine = create_engine(f"sqlite:///{DB_PATH}")
Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def _get_file_hash(self, file_path: str) -> str:
"""Calculate SHA256 hash of a file."""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
# Read file in chunks to handle large files
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def _is_document_hash_processed(self, file_hash: str) -> bool:
"""Check if a document hash has already been processed."""
existing = (
self.session.query(ProcessedDocument).filter_by(file_hash=file_hash).first()
)
return existing is not None
def _mark_document_processed(self, file_identifier: str, file_hash: str):
"""Mark a document as processed in the database."""
doc_record = ProcessedDocument(file_path=file_identifier, file_hash=file_hash)
self.session.add(doc_record)
self.session.commit()
def _get_loader_for_extension(self, file_path: str):
"""Get the appropriate loader for a given file extension."""
ext = Path(file_path).suffix.lower()
if ext == ".pdf":
return PyPDFLoader(file_path)
elif ext in [".docx", ".doc"]:
if UnstructuredWordDocumentLoader is None:
logger.warning(
f"UnstructuredWordDocumentLoader not available for {file_path}. Skipping."
)
return None
return UnstructuredWordDocumentLoader(
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
elif ext == ".pptx":
if UnstructuredPowerPointLoader is None:
logger.warning(
f"UnstructuredPowerPointLoader not available for {file_path}. Skipping."
)
return None
return UnstructuredPowerPointLoader(
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
elif ext in [".xlsx", ".xls"]:
if UnstructuredExcelLoader is None:
logger.warning(
f"UnstructuredExcelLoader not available for {file_path}. Skipping."
)
return None
return UnstructuredExcelLoader(
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
if UnstructuredImageLoader is None:
logger.warning(
f"UnstructuredImageLoader not available for {file_path}. Skipping."
)
return None
# Use OCR strategy for images to extract text
return UnstructuredImageLoader(
file_path, **{"strategy": "ocr_only", "languages": ["rus"]}
)
elif ext == ".odt":
if UnstructuredODTLoader is None:
logger.warning(
f"UnstructuredODTLoader not available for {file_path}. Skipping."
)
return None
return UnstructuredODTLoader(
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
else:
return None
def _load_one_adaptive_file(
self, adaptive_file: _AdaptiveFile
) -> Tuple[List[Document], str | None]:
"""Load and split one adaptive file by using its local working callback."""
loaded_docs: List[Document] = []
file_hash: str | None = None
source_identifier = try_guess_source(adaptive_file.extension)
extension = adaptive_file.extension.lower()
def process_local_file(local_file_path: str):
nonlocal loaded_docs, file_hash
file_hash = self._get_file_hash(local_file_path)
if self._is_document_hash_processed(file_hash):
logger.info(
f"Skipping already processed document hash for: {source_identifier}"
)
return
loader = self._get_loader_for_extension(local_file_path)
if loader is None:
logger.warning(f"No loader available for file: {source_identifier}")
return
docs = loader.load()
for doc in docs:
doc.metadata["source"] = source_identifier
doc.metadata["filename"] = adaptive_file.filename
doc.metadata["file_path"] = source_identifier
doc.metadata["file_size"] = os.path.getsize(local_file_path)
doc.metadata["file_extension"] = extension
if "page" in doc.metadata:
doc.metadata["page_number"] = doc.metadata["page"]
split_docs = self.text_splitter.split_documents(docs)
for chunk in split_docs:
years = extract_years_from_text(chunk.page_content)
events = extract_russian_event_names(chunk.page_content)
chunk.metadata["years"] = years
chunk.metadata["events"] = events
loaded_docs = split_docs
adaptive_file.work_with_file_locally(process_local_file)
return loaded_docs, file_hash
def load_and_split_documents(
self, adaptive_collection: _AdaptiveCollection, recursive: bool = True
) -> Iterator[Tuple[List[Document], List[Tuple[str, str]]]]:
"""Load documents from adaptive collection and split them appropriately."""
docs_chunk: List[Document] = []
processed_file_records: dict[str, str] = {}
for adaptive_file in adaptive_collection.iterate(recursive=recursive):
if len(processed_file_records) >= 2:
yield docs_chunk, list(processed_file_records.items())
docs_chunk = []
processed_file_records = {}
if adaptive_file.extension.lower() not in SUPPORTED_EXTENSIONS:
logger.debug(
f"Skipping unsupported file extension for {adaptive_file.filename}: {adaptive_file.extension}"
)
continue
logger.info(f"Processing document: {adaptive_file.filename}")
try:
split_docs, file_hash = self._load_one_adaptive_file(adaptive_file)
if split_docs:
docs_chunk.extend(split_docs)
if file_hash:
processed_file_records[adaptive_file.filename] = file_hash
except Exception as e:
logger.error(f"Error processing {adaptive_file.filename}: {str(e)}")
continue
def enrich_and_store(self, adaptive_collection: _AdaptiveCollection):
"""Load, enrich, and store documents in the vector store."""
logger.info("Starting enrichment process...")
# Load and split documents
for documents, processed_file_records in self.load_and_split_documents(
adaptive_collection
):
if not documents:
logger.info("No new documents to process.")
return
logger.info(
f"Loaded and split {len(documents)} document chunks, adding to vector store..."
)
logger.debug(
f"Documents len: {len(documents)}, processed_file_records len: {len(processed_file_records)}"
)
# Add documents to vector store
try:
self.vector_store.add_documents(documents)
# Only mark documents as processed after successful insertion to vector store
for file_identifier, file_hash in processed_file_records:
self._mark_document_processed(file_identifier, file_hash)
logger.info(
f"Successfully added {len(documents)} document chunks to vector store and marked {len(processed_file_records)} files as processed."
)
except Exception as e:
logger.error(f"Error adding documents to vector store: {str(e)}")
raise
def get_enrichment_adaptive_collection() -> _AdaptiveCollection:
"""Create adaptive collection based on environment source configuration."""
source = ENRICHMENT_SOURCE
if source == "local":
local_path = ENRICHMENT_LOCAL_PATH
if local_path is None:
raise RuntimeError(
"Enrichment strategy is local, but no ENRICHMENT_LOCAL_PATH is defined!"
)
logger.info(f"Using local adaptive collection from path: {local_path}")
return LocalFilesystemAdaptiveCollection(local_path)
if source == "yadisk":
if not YADISK_TOKEN:
raise ValueError("YADISK_TOKEN must be set when ENRICHMENT_SOURCE=yadisk")
if not ENRICHMENT_YADISK_PATH:
raise ValueError(
"ENRICHMENT_YADISK_PATH must be set when ENRICHMENT_SOURCE=yadisk"
)
logger.info(
f"Using Yandex Disk adaptive collection from path: {ENRICHMENT_YADISK_PATH}"
)
return YandexDiskAdaptiveCollection(
token=YADISK_TOKEN,
base_dir=ENRICHMENT_YADISK_PATH,
)
raise ValueError(
f"Unsupported ENRICHMENT_SOURCE='{source}'. Allowed values: local, yadisk"
)
def run_enrichment_process(vector_store):
"""Run the full enrichment process."""
logger.info("Starting document enrichment process")
adaptive_collection = get_enrichment_adaptive_collection()
# Initialize the document enricher
enricher = DocumentEnricher(vector_store)
# Run the enrichment process
enricher.enrich_and_store(adaptive_collection)
logger.info("Document enrichment process completed!")
if __name__ == "__main__":
# Example usage
from vector_storage import initialize_vector_store
# Initialize vector store
vector_store = initialize_vector_store()
# Run enrichment process
run_enrichment_process(vector_store)