Files
rag-solution/services/rag/llamaindex/prefect/01_yadisk_predefined_enrich.py

269 lines
9.2 KiB
Python

"""
Prefect flow for enriching documents from a predefined YaDisk file list.
Flow steps:
1. Load file paths from ../../../yadisk_files.json
2. Filter them by supported enrichment extensions
3. Download each file from YaDisk asynchronously
4. Enrich each downloaded file into vector storage
5. Remove downloaded temporary files after processing
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import os
import sys
import tempfile
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from loguru import logger
from tqdm import tqdm
ROOT_DIR = Path(__file__).resolve().parents[1]
load_dotenv(ROOT_DIR / ".env")
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
import yadisk
from enrichment import get_supported_enrichment_extensions, process_document_file
from prefect import flow, task
DEFAULT_YADISK_LIST_PATH = (ROOT_DIR / "../../../yadisk_files.json").resolve()
def setup_prefect_flow_logging() -> None:
"""Configure loguru handlers for flow execution."""
logs_dir = ROOT_DIR / "logs"
logs_dir.mkdir(exist_ok=True)
logger.remove()
logger.add(
str(logs_dir / "dev.log"),
rotation="10 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
)
logger.add(
sys.stdout,
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
colorize=True,
)
def _normalize_yadisk_paths(payload: Any) -> list[str]:
"""Extract a list of file paths from several common JSON shapes."""
if isinstance(payload, list):
return [str(item) for item in payload if isinstance(item, (str, Path))]
if isinstance(payload, dict):
for key in ("paths", "files", "items"):
value = payload.get(key)
if isinstance(value, list):
normalized: list[str] = []
for item in value:
if isinstance(item, str):
normalized.append(item)
elif isinstance(item, dict):
for item_key in ("path", "remote_path", "file_path", "name"):
if item_key in item and item[item_key]:
normalized.append(str(item[item_key]))
break
return normalized
raise ValueError(
"Unsupported yadisk_files.json structure. Expected list or dict with paths/files/items."
)
def _make_temp_local_path(base_dir: Path, remote_path: str) -> Path:
"""Create a deterministic temp file path for a remote YaDisk path."""
remote_name = Path(remote_path).name or "downloaded_file"
suffix = Path(remote_name).suffix
stem = Path(remote_name).stem or "file"
digest = hashlib.md5(remote_path.encode("utf-8")).hexdigest()[:10]
safe_stem = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem)
return base_dir / f"{safe_stem}_{digest}{suffix}"
@task(name="load_yadisk_paths")
def load_yadisk_paths(json_file_path: str) -> list[str]:
"""Load remote file paths from JSON file."""
path = Path(json_file_path)
if not path.exists():
raise FileNotFoundError(f"YaDisk paths JSON file not found: {path}")
with path.open("r", encoding="utf-8") as f:
payload = json.load(f)
paths = _normalize_yadisk_paths(payload)
logger.info(f"Loaded {len(paths)} paths from {path}")
return paths
@task(name="filter_supported_yadisk_paths")
def filter_supported_yadisk_paths(paths: list[str]) -> list[str]:
"""Keep only paths supported by enrichment extension filters."""
supported_extensions = get_supported_enrichment_extensions()
filtered = [p for p in paths if Path(str(p)).suffix.lower() in supported_extensions]
logger.info(
f"Filtered YaDisk paths: {len(filtered)}/{len(paths)} supported "
f"(extensions: {', '.join(sorted(supported_extensions))})"
)
return filtered
async def _download_and_enrich_one(
client: yadisk.AsyncClient,
remote_path: str,
temp_dir: Path,
semaphore: asyncio.Semaphore,
stats: dict[str, int],
pbar: tqdm,
pbar_lock: asyncio.Lock,
) -> None:
"""Download one YaDisk file, enrich it, remove it, and update stats."""
local_path = _make_temp_local_path(temp_dir, remote_path)
result_status = "error"
async with semaphore:
try:
local_path.parent.mkdir(parents=True, exist_ok=True)
await client.download(remote_path, str(local_path))
# Run sync enrichment in a worker thread. Exceptions are swallowed below.
enrich_result = await asyncio.to_thread(
process_document_file, str(local_path)
)
result_status = str(enrich_result.get("status", "error"))
logger.info(
f"YaDisk processed: remote={remote_path}, local={local_path}, status={result_status}"
)
except Exception as e:
# Explicitly swallow errors as requested.
logger.error(f"YaDisk processing error for {remote_path}: {e}")
result_status = "error"
finally:
try:
if local_path.exists():
local_path.unlink()
except Exception as cleanup_error:
logger.warning(
f"Failed to remove temp file {local_path}: {cleanup_error}"
)
async with pbar_lock:
if result_status == "processed":
stats["processed"] += 1
elif result_status == "skipped":
stats["skipped"] += 1
else:
stats["errors"] += 1
stats["completed"] += 1
pbar.update(1)
pbar.set_postfix(
processed=stats["processed"],
skipped=stats["skipped"],
errors=stats["errors"],
)
@flow(name="yadisk_predefined_enrich")
async def yadisk_predefined_enrich_flow(
yadisk_json_path: str = str(DEFAULT_YADISK_LIST_PATH),
concurrency: int = 4,
) -> dict[str, int]:
"""
Download and enrich YaDisk files listed in the JSON file using async YaDisk client.
"""
setup_prefect_flow_logging()
prefect_api_url = os.getenv("PREFECT_API_URL", "").strip()
yadisk_token = os.getenv("YADISK_TOKEN", "").strip()
if not prefect_api_url:
logger.warning("PREFECT_API_URL is not set in environment/.env")
else:
# Prefect reads this env var for API connectivity.
os.environ["PREFECT_API_URL"] = prefect_api_url
logger.info(f"Using Prefect API URL: {prefect_api_url}")
if not yadisk_token:
raise ValueError("YADISK_TOKEN is required in .env to access Yandex Disk")
all_paths = load_yadisk_paths(yadisk_json_path)
supported_paths = filter_supported_yadisk_paths(all_paths)
stats = {
"total": len(supported_paths),
"completed": 0,
"processed": 0,
"skipped": 0,
"errors": 0,
}
if not supported_paths:
logger.info("No supported YaDisk paths to process")
return stats
concurrency = max(1, int(concurrency))
logger.info(
f"Starting async YaDisk enrichment for {len(supported_paths)} files with concurrency={concurrency}"
)
semaphore = asyncio.Semaphore(concurrency)
pbar_lock = asyncio.Lock()
with tempfile.TemporaryDirectory(prefix="yadisk_enrich_") as temp_dir_str:
temp_dir = Path(temp_dir_str)
pbar = tqdm(total=len(supported_paths), desc="YaDisk enrich", unit="file")
try:
async with yadisk.AsyncClient(token=yadisk_token) as client:
try:
is_token_valid = await client.check_token()
logger.info(f"YaDisk token validation result: {is_token_valid}")
except Exception as token_check_error:
# Token check issues should not block processing attempts.
logger.warning(f"YaDisk token check failed: {token_check_error}")
tasks = [
asyncio.create_task(
_download_and_enrich_one(
client=client,
remote_path=remote_path,
temp_dir=temp_dir,
semaphore=semaphore,
stats=stats,
pbar=pbar,
pbar_lock=pbar_lock,
)
)
for remote_path in supported_paths
]
# Worker function swallows per-file errors, but keep gather resilient too.
await asyncio.gather(*tasks, return_exceptions=True)
finally:
pbar.close()
logger.info(
"YaDisk enrichment flow finished. "
f"Total={stats['total']}, Completed={stats['completed']}, "
f"Processed={stats['processed']}, Skipped={stats['skipped']}, Errors={stats['errors']}"
)
return stats
if __name__ == "__main__":
print("SERVING PREFECT FLOW FOR YANDEX DISK ENRICHMENT OF PREDEFINED PATHS")
yadisk_predefined_enrich_flow.serve()