- added file_type, this will hold the "таблица", "презентация" and so on types - file source metadata is now taken either from local source or yandex disk.
487 lines
17 KiB
Python
487 lines
17 KiB
Python
"""Document enrichment module for loading documents into vector storage."""
|
|
|
|
import hashlib
|
|
import os
|
|
import queue
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
from dotenv import load_dotenv
|
|
from langchain_community.document_loaders import PyPDFLoader
|
|
from langchain_core.documents import Document
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from loguru import logger
|
|
from sqlalchemy import Column, Integer, String, create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Dynamically import other loaders to handle optional dependencies
|
|
try:
|
|
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
|
|
except ImportError:
|
|
UnstructuredWordDocumentLoader = None
|
|
|
|
try:
|
|
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
|
except ImportError:
|
|
UnstructuredPowerPointLoader = None
|
|
|
|
try:
|
|
from langchain_community.document_loaders import UnstructuredExcelLoader
|
|
except ImportError:
|
|
UnstructuredExcelLoader = None
|
|
|
|
try:
|
|
from langchain_community.document_loaders import UnstructuredImageLoader
|
|
except ImportError:
|
|
UnstructuredImageLoader = None
|
|
|
|
try:
|
|
from langchain_community.document_loaders import UnstructuredODTLoader
|
|
except ImportError:
|
|
UnstructuredODTLoader = None
|
|
|
|
from helpers import (
|
|
LocalFilesystemAdaptiveCollection,
|
|
YandexDiskAdaptiveCollection,
|
|
YandexDiskAdaptiveFile,
|
|
_AdaptiveCollection,
|
|
_AdaptiveFile,
|
|
extract_russian_event_names,
|
|
extract_years_from_text,
|
|
)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Define the path to the data directory
|
|
DATA_DIR = Path("../../../data").resolve()
|
|
DB_PATH = Path("document_tracking.db").resolve()
|
|
ENRICHMENT_SOURCE = os.getenv("ENRICHMENT_SOURCE", "local").lower()
|
|
ENRICHMENT_LOCAL_PATH = os.getenv("ENRICHMENT_LOCAL_PATH")
|
|
ENRICHMENT_YADISK_PATH = os.getenv("ENRICHMENT_YADISK_PATH")
|
|
YADISK_TOKEN = os.getenv("YADISK_TOKEN")
|
|
|
|
ENRICHMENT_PROCESSING_MODE = os.getenv("ENRICHMENT_PROCESSING_MODE", "async").lower()
|
|
ENRICHMENT_ADAPTIVE_FILES_QUEUE_LIMIT = int(
|
|
os.getenv("ENRICHMENT_ADAPTIVE_FILES_QUEUE_LIMIT", "5")
|
|
)
|
|
ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS = int(
|
|
os.getenv("ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS", "4")
|
|
)
|
|
ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS = int(
|
|
os.getenv("ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS", "4")
|
|
)
|
|
|
|
SUPPORTED_EXTENSIONS = {
|
|
".pdf",
|
|
".docx",
|
|
".doc",
|
|
".pptx",
|
|
".xlsx",
|
|
".xls",
|
|
".jpg",
|
|
".jpeg",
|
|
".png",
|
|
".gif",
|
|
".bmp",
|
|
".tiff",
|
|
".webp",
|
|
".odt",
|
|
".txt", # this one is obvious but was unexpected to see in data lol
|
|
}
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class ProcessedDocument(Base):
|
|
"""Database model for tracking processed documents."""
|
|
|
|
__tablename__ = "processed_documents"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
file_path = Column(String, unique=True, nullable=False)
|
|
file_hash = Column(String, nullable=False)
|
|
|
|
|
|
# to guess the filetype in russian language, for searching it
|
|
def try_guess_file_type(extension: str) -> str:
|
|
if extension in [".xlsx", "xls"]:
|
|
return "таблица"
|
|
elif extension in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
|
|
return "изображение"
|
|
elif extension in [".pptx"]:
|
|
return "презентация"
|
|
else:
|
|
return "документ"
|
|
|
|
|
|
def identify_adaptive_file_source(adaptive_file: _AdaptiveFile) -> str:
|
|
if adaptive_file is YandexDiskAdaptiveFile:
|
|
return "Яндекс Диск"
|
|
else:
|
|
return "Локальный Файл"
|
|
|
|
|
|
class DocumentEnricher:
|
|
"""Class responsible for enriching documents and loading them to vector storage."""
|
|
|
|
def __init__(self, vector_store):
|
|
self.vector_store = vector_store
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
length_function=len,
|
|
)
|
|
|
|
# In sync mode we force minimal concurrency values.
|
|
if ENRICHMENT_PROCESSING_MODE == "sync":
|
|
self.adaptive_files_queue_limit = 1
|
|
self.file_process_threads_count = 1
|
|
self.document_upload_threads_count = 1
|
|
else:
|
|
self.adaptive_files_queue_limit = max(
|
|
1, ENRICHMENT_ADAPTIVE_FILES_QUEUE_LIMIT
|
|
)
|
|
self.file_process_threads_count = max(
|
|
1, ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS
|
|
)
|
|
self.document_upload_threads_count = max(
|
|
1, ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS
|
|
)
|
|
|
|
# Phase 13 queues
|
|
self.ADAPTIVE_FILES_QUEUE: queue.Queue = queue.Queue(
|
|
maxsize=self.adaptive_files_queue_limit
|
|
)
|
|
self.PROCESSED_DOCUMENTS_QUEUE: queue.Queue = queue.Queue(
|
|
maxsize=max(1, self.adaptive_files_queue_limit * 2)
|
|
)
|
|
|
|
# Shared state for thread lifecycle
|
|
self.collection_finished = threading.Event()
|
|
self.processing_finished = threading.Event()
|
|
|
|
# Initialize database for tracking processed documents
|
|
self._init_db()
|
|
|
|
def _init_db(self):
|
|
"""Initialize the SQLite database for tracking processed documents."""
|
|
self.engine = create_engine(f"sqlite:///{DB_PATH}")
|
|
Base.metadata.create_all(self.engine)
|
|
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
|
|
def _get_file_hash(self, file_path: str) -> str:
|
|
"""Calculate SHA256 hash of a file."""
|
|
hash_sha256 = hashlib.sha256()
|
|
with open(file_path, "rb") as file_handle:
|
|
for chunk in iter(lambda: file_handle.read(4096), b""):
|
|
hash_sha256.update(chunk)
|
|
return hash_sha256.hexdigest()
|
|
|
|
def _is_document_hash_processed(self, file_hash: str) -> bool:
|
|
"""Check if a document hash has already been processed."""
|
|
session = self.SessionLocal()
|
|
try:
|
|
existing = (
|
|
session.query(ProcessedDocument).filter_by(file_hash=file_hash).first()
|
|
)
|
|
return existing is not None
|
|
finally:
|
|
session.close()
|
|
|
|
def _mark_document_processed(self, file_identifier: str, file_hash: str):
|
|
"""Mark a document as processed in the database."""
|
|
session = self.SessionLocal()
|
|
try:
|
|
existing = (
|
|
session.query(ProcessedDocument)
|
|
.filter_by(file_path=file_identifier)
|
|
.first()
|
|
)
|
|
if existing is not None:
|
|
existing.file_hash = file_hash
|
|
else:
|
|
session.add(
|
|
ProcessedDocument(file_path=file_identifier, file_hash=file_hash)
|
|
)
|
|
session.commit()
|
|
finally:
|
|
session.close()
|
|
|
|
def _get_loader_for_extension(self, file_path: str):
|
|
"""Get the appropriate loader for a given file extension."""
|
|
ext = Path(file_path).suffix.lower()
|
|
|
|
if ext == ".pdf":
|
|
return PyPDFLoader(file_path)
|
|
if ext in [".docx", ".doc"]:
|
|
if UnstructuredWordDocumentLoader is None:
|
|
logger.warning(
|
|
f"UnstructuredWordDocumentLoader not available for {file_path}. Skipping."
|
|
)
|
|
return None
|
|
return UnstructuredWordDocumentLoader(
|
|
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
|
)
|
|
if ext == ".pptx":
|
|
if UnstructuredPowerPointLoader is None:
|
|
logger.warning(
|
|
f"UnstructuredPowerPointLoader not available for {file_path}. Skipping."
|
|
)
|
|
return None
|
|
return UnstructuredPowerPointLoader(
|
|
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
|
)
|
|
if ext in [".xlsx", ".xls"]:
|
|
if UnstructuredExcelLoader is None:
|
|
logger.warning(
|
|
f"UnstructuredExcelLoader not available for {file_path}. Skipping."
|
|
)
|
|
return None
|
|
return UnstructuredExcelLoader(
|
|
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
|
)
|
|
if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
|
|
if UnstructuredImageLoader is None:
|
|
logger.warning(
|
|
f"UnstructuredImageLoader not available for {file_path}. Skipping."
|
|
)
|
|
return None
|
|
return UnstructuredImageLoader(
|
|
file_path, **{"strategy": "ocr_only", "languages": ["rus"]}
|
|
)
|
|
if ext == ".odt":
|
|
if UnstructuredODTLoader is None:
|
|
logger.warning(
|
|
f"UnstructuredODTLoader not available for {file_path}. Skipping."
|
|
)
|
|
return None
|
|
return UnstructuredODTLoader(
|
|
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
|
)
|
|
return None
|
|
|
|
def _load_one_adaptive_file(
|
|
self, adaptive_file: _AdaptiveFile
|
|
) -> Tuple[List[Document], Optional[Tuple[str, str]]]:
|
|
"""Load and split one adaptive file by using its local working callback."""
|
|
loaded_docs: List[Document] = []
|
|
processed_record: Optional[Tuple[str, str]] = None
|
|
source_identifier = identify_adaptive_file_source(adaptive_file)
|
|
extension = adaptive_file.extension.lower()
|
|
file_type = try_guess_file_type(extension)
|
|
|
|
def process_local_file(local_file_path: str):
|
|
nonlocal loaded_docs, processed_record
|
|
|
|
file_hash = self._get_file_hash(local_file_path)
|
|
if self._is_document_hash_processed(file_hash):
|
|
logger.info(
|
|
f"Skipping already processed document hash for: {source_identifier}"
|
|
)
|
|
return
|
|
|
|
loader = self._get_loader_for_extension(local_file_path)
|
|
if loader is None:
|
|
logger.warning(f"No loader available for file: {source_identifier}")
|
|
return
|
|
|
|
docs = loader.load()
|
|
for doc in docs:
|
|
doc.metadata["file_type"] = file_type
|
|
doc.metadata["source"] = source_identifier
|
|
doc.metadata["filename"] = adaptive_file.filename
|
|
doc.metadata["file_path"] = source_identifier
|
|
doc.metadata["file_size"] = os.path.getsize(local_file_path)
|
|
doc.metadata["file_extension"] = extension
|
|
|
|
if "page" in doc.metadata:
|
|
doc.metadata["page_number"] = doc.metadata["page"]
|
|
|
|
split_docs = self.text_splitter.split_documents(docs)
|
|
for chunk in split_docs:
|
|
chunk.metadata["years"] = extract_years_from_text(chunk.page_content)
|
|
chunk.metadata["events"] = extract_russian_event_names(
|
|
chunk.page_content
|
|
)
|
|
|
|
loaded_docs = split_docs
|
|
processed_record = (source_identifier, file_hash)
|
|
|
|
adaptive_file.work_with_file_locally(process_local_file)
|
|
return loaded_docs, processed_record
|
|
|
|
# Phase 13 API: inserts adaptive files into ADAPTIVE_FILES_QUEUE
|
|
def insert_adaptive_files_queue(
|
|
self, adaptive_collection: _AdaptiveCollection, recursive: bool = True
|
|
):
|
|
for adaptive_file in adaptive_collection.iterate(recursive=recursive):
|
|
if adaptive_file.extension.lower() not in SUPPORTED_EXTENSIONS:
|
|
logger.debug(
|
|
f"Skipping unsupported file extension for {adaptive_file.filename}: {adaptive_file.extension}"
|
|
)
|
|
continue
|
|
|
|
self.ADAPTIVE_FILES_QUEUE.put(adaptive_file)
|
|
|
|
self.collection_finished.set()
|
|
|
|
# Phase 13 API: reads adaptive files and writes processed docs into PROCESSED_DOCUMENTS_QUEUE
|
|
def process_adaptive_files_queue(self):
|
|
while True:
|
|
try:
|
|
adaptive_file = self.ADAPTIVE_FILES_QUEUE.get(timeout=0.2)
|
|
except queue.Empty:
|
|
if self.collection_finished.is_set():
|
|
return
|
|
continue
|
|
|
|
try:
|
|
split_docs, processed_record = self._load_one_adaptive_file(
|
|
adaptive_file
|
|
)
|
|
if split_docs:
|
|
self.PROCESSED_DOCUMENTS_QUEUE.put((split_docs, processed_record))
|
|
except Exception as error:
|
|
logger.error(f"Error processing {adaptive_file.filename}: {error}")
|
|
finally:
|
|
self.ADAPTIVE_FILES_QUEUE.task_done()
|
|
|
|
# Phase 13 API: uploads chunked docs and marks file processed
|
|
def upload_processed_documents_from_queue(self):
|
|
while True:
|
|
try:
|
|
payload = self.PROCESSED_DOCUMENTS_QUEUE.get(timeout=0.2)
|
|
except queue.Empty:
|
|
if self.processing_finished.is_set():
|
|
return
|
|
continue
|
|
|
|
try:
|
|
documents, processed_record = payload
|
|
self.vector_store.add_documents(documents)
|
|
|
|
if processed_record is not None:
|
|
self._mark_document_processed(
|
|
processed_record[0], processed_record[1]
|
|
)
|
|
except Exception as error:
|
|
logger.error(f"Error uploading processed documents: {error}")
|
|
raise
|
|
finally:
|
|
self.PROCESSED_DOCUMENTS_QUEUE.task_done()
|
|
|
|
def _run_threaded_pipeline(self, adaptive_collection: _AdaptiveCollection):
|
|
"""Run Phase 13 queue/thread pipeline."""
|
|
process_threads = [
|
|
threading.Thread(
|
|
target=self.process_adaptive_files_queue,
|
|
name=f"adaptive-file-processor-{index}",
|
|
daemon=True,
|
|
)
|
|
for index in range(self.file_process_threads_count)
|
|
]
|
|
upload_threads = [
|
|
threading.Thread(
|
|
target=self.upload_processed_documents_from_queue,
|
|
name=f"document-uploader-{index}",
|
|
daemon=True,
|
|
)
|
|
for index in range(self.document_upload_threads_count)
|
|
]
|
|
|
|
for thread in process_threads:
|
|
thread.start()
|
|
for thread in upload_threads:
|
|
thread.start()
|
|
|
|
# This one intentionally runs on main thread per Phase 13 requirement.
|
|
self.insert_adaptive_files_queue(adaptive_collection, recursive=True)
|
|
|
|
# Wait file queue completion and processing threads end.
|
|
self.ADAPTIVE_FILES_QUEUE.join()
|
|
for thread in process_threads:
|
|
thread.join()
|
|
|
|
# Signal upload workers no more payload is expected.
|
|
self.processing_finished.set()
|
|
|
|
# Wait upload completion and upload threads end.
|
|
self.PROCESSED_DOCUMENTS_QUEUE.join()
|
|
for thread in upload_threads:
|
|
thread.join()
|
|
|
|
def _run_sync_pipeline(self, adaptive_collection: _AdaptiveCollection):
|
|
"""Sequential pipeline for sync mode."""
|
|
logger.info("Running enrichment in sync mode")
|
|
self.insert_adaptive_files_queue(adaptive_collection, recursive=True)
|
|
self.process_adaptive_files_queue()
|
|
self.processing_finished.set()
|
|
self.upload_processed_documents_from_queue()
|
|
|
|
def enrich_and_store(self, adaptive_collection: _AdaptiveCollection):
|
|
"""Load, enrich, and store documents in the vector store."""
|
|
logger.info("Starting enrichment process...")
|
|
|
|
if ENRICHMENT_PROCESSING_MODE == "sync":
|
|
logger.info("Document enrichment process starting in SYNC mode")
|
|
self._run_sync_pipeline(adaptive_collection)
|
|
return
|
|
|
|
logger.info("Document enrichment process starting in ASYNC/THREAD mode")
|
|
self._run_threaded_pipeline(adaptive_collection)
|
|
|
|
|
|
def get_enrichment_adaptive_collection(
|
|
data_dir: str = str(DATA_DIR),
|
|
) -> _AdaptiveCollection:
|
|
"""Create adaptive collection based on environment source configuration."""
|
|
source = ENRICHMENT_SOURCE
|
|
if source == "local":
|
|
local_path = ENRICHMENT_LOCAL_PATH or data_dir
|
|
logger.info(f"Using local adaptive collection from path: {local_path}")
|
|
return LocalFilesystemAdaptiveCollection(local_path)
|
|
|
|
if source == "yadisk":
|
|
if not YADISK_TOKEN:
|
|
raise ValueError("YADISK_TOKEN must be set when ENRICHMENT_SOURCE=yadisk")
|
|
if not ENRICHMENT_YADISK_PATH:
|
|
raise ValueError(
|
|
"ENRICHMENT_YADISK_PATH must be set when ENRICHMENT_SOURCE=yadisk"
|
|
)
|
|
logger.info(
|
|
f"Using Yandex Disk adaptive collection from path: {ENRICHMENT_YADISK_PATH}"
|
|
)
|
|
return YandexDiskAdaptiveCollection(
|
|
token=YADISK_TOKEN,
|
|
base_dir=ENRICHMENT_YADISK_PATH,
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Unsupported ENRICHMENT_SOURCE='{source}'. Allowed values: local, yadisk"
|
|
)
|
|
|
|
|
|
def run_enrichment_process(vector_store, data_dir: str = str(DATA_DIR)):
|
|
"""Run the full enrichment process."""
|
|
logger.info("Starting document enrichment process")
|
|
|
|
adaptive_collection = get_enrichment_adaptive_collection(data_dir=data_dir)
|
|
|
|
# Initialize the document enricher
|
|
enricher = DocumentEnricher(vector_store)
|
|
|
|
# Run the enrichment process
|
|
enricher.enrich_and_store(adaptive_collection)
|
|
|
|
logger.info("Document enrichment process completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from vector_storage import initialize_vector_store
|
|
|
|
vector_store = initialize_vector_store()
|
|
run_enrichment_process(vector_store)
|