Files
rag-solution/services/rag/langchain/prefect/02_yadisk_predefined_enrich.py

217 lines
7.0 KiB
Python
Raw Permalink Normal View History

"""Prefect flow to enrich Yandex Disk files from a predefined JSON file list."""
from __future__ import annotations
import asyncio
import json
import os
import sys
import tempfile
from pathlib import Path
from typing import List
from dotenv import load_dotenv
from prefect import flow, task
load_dotenv()
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
PREFECT_API_URL = os.getenv("PREFECT_API_URL")
YADISK_TOKEN = os.getenv("YADISK_TOKEN")
ENRICH_CONCURRENCY = int(os.getenv("PREFECT_YADISK_ENRICH_CONCURRENCY", "8"))
OUTPUT_FILE_LIST = (PROJECT_ROOT / "../../../yadisk_files.json").resolve()
if PREFECT_API_URL:
os.environ["PREFECT_API_URL"] = PREFECT_API_URL
class _ProgressTracker:
def __init__(self, total: int):
self.total = total
self.processed = 0
self.errors = 0
self._lock = asyncio.Lock()
async def mark_done(self, error: bool = False):
async with self._lock:
self.processed += 1
if error:
self.errors += 1
self._render()
def _render(self):
total = max(self.total, 1)
width = 30
filled = int(width * self.processed / total)
bar = "#" * filled + "-" * (width - filled)
left = max(self.total - self.processed, 0)
print(
f"\r[{bar}] {self.processed}/{self.total} processed | left: {left} | errors: {self.errors}",
end="",
flush=True,
)
if self.processed >= self.total:
print()
async def _download_yadisk_file(async_disk, remote_path: str, local_path: str) -> None:
await async_disk.download(remote_path, local_path)
def _process_local_file_for_enrichment(enricher, vector_store, local_path: str, remote_path: str) -> bool:
"""Process one downloaded file and upload chunks into vector store.
Returns True when file was processed/uploaded, False when skipped.
"""
extension = Path(remote_path).suffix.lower()
file_hash = enricher._get_file_hash(local_path)
if enricher._is_document_hash_processed(file_hash):
return False
loader = enricher._get_loader_for_extension(local_path)
if loader is None:
return False
docs = loader.load()
filename = Path(remote_path).name
for doc in docs:
doc.metadata["source"] = remote_path
doc.metadata["filename"] = filename
doc.metadata["file_path"] = remote_path
doc.metadata["file_size"] = os.path.getsize(local_path)
doc.metadata["file_extension"] = extension
if "page" in doc.metadata:
doc.metadata["page_number"] = doc.metadata["page"]
split_docs = enricher.text_splitter.split_documents(docs)
from helpers import extract_russian_event_names, extract_years_from_text
for chunk in split_docs:
chunk.metadata["years"] = extract_years_from_text(chunk.page_content)
chunk.metadata["events"] = extract_russian_event_names(chunk.page_content)
if not split_docs:
return False
vector_store.add_documents(split_docs)
enricher._mark_document_processed(remote_path, file_hash)
return True
async def _process_remote_file(async_disk, remote_path: str, semaphore: asyncio.Semaphore, tracker: _ProgressTracker, enricher, vector_store):
async with semaphore:
temp_path = None
had_error = False
try:
suffix = Path(remote_path).suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
temp_path = tmp_file.name
await _download_yadisk_file(async_disk, remote_path, temp_path)
await asyncio.to_thread(
_process_local_file_for_enrichment,
enricher,
vector_store,
temp_path,
remote_path,
)
except Exception:
# Phase requirement: swallow per-file errors and continue processing.
had_error = True
finally:
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except OSError:
had_error = True
await tracker.mark_done(error=had_error)
@task(name="prefilter_yadisk_file_paths")
def prefilter_yadisk_file_paths() -> List[str]:
"""Load file list JSON and keep only extensions supported by enrichment."""
from enrichment import SUPPORTED_EXTENSIONS
if not OUTPUT_FILE_LIST.exists():
raise FileNotFoundError(f"File list not found: {OUTPUT_FILE_LIST}")
with open(OUTPUT_FILE_LIST, "r", encoding="utf-8") as input_file:
raw_paths = json.load(input_file)
filtered = [
path for path in raw_paths if Path(str(path)).suffix.lower() in SUPPORTED_EXTENSIONS
]
return filtered
@task(name="enrich_filtered_yadisk_files_async")
async def enrich_filtered_yadisk_files_async(filtered_paths: List[str]) -> dict:
"""Download/process Yandex Disk files concurrently and enrich LangChain vector store."""
if not YADISK_TOKEN:
raise ValueError("YADISK_TOKEN is required for Yandex Disk enrichment")
if not filtered_paths:
print("No supported files found for enrichment.")
return {"total": 0, "processed": 0, "errors": 0}
try:
import yadisk
except ImportError as error:
raise RuntimeError("yadisk package is required for this flow") from error
if not hasattr(yadisk, "AsyncYaDisk"):
raise RuntimeError("Installed yadisk package does not expose AsyncYaDisk")
from enrichment import DocumentEnricher
from vector_storage import initialize_vector_store
vector_store = initialize_vector_store()
enricher = DocumentEnricher(vector_store)
tracker = _ProgressTracker(total=len(filtered_paths))
semaphore = asyncio.Semaphore(max(1, ENRICH_CONCURRENCY))
async with yadisk.AsyncYaDisk(token=YADISK_TOKEN) as async_disk:
tasks = [
asyncio.create_task(
_process_remote_file(
async_disk=async_disk,
remote_path=remote_path,
semaphore=semaphore,
tracker=tracker,
enricher=enricher,
vector_store=vector_store,
)
)
for remote_path in filtered_paths
]
await asyncio.gather(*tasks)
return {
"total": tracker.total,
"processed": tracker.processed,
"errors": tracker.errors,
}
@flow(name="yadisk_predefined_enrich")
async def yadisk_predefined_enrich() -> dict:
filtered_paths = prefilter_yadisk_file_paths()
return await enrich_filtered_yadisk_files_async(filtered_paths)
def serve_yadisk_predefined_enrich() -> None:
yadisk_predefined_enrich.serve(name="yadisk-predefined-enrich")
if __name__ == "__main__":
serve_mode = os.getenv("PREFECT_SERVE", "0") == "1"
if serve_mode:
serve_yadisk_predefined_enrich()
else:
asyncio.run(yadisk_predefined_enrich())