Compare commits
17 Commits
93d538ecc6
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a6d8d6fc3 | |||
| 65ce290104 | |||
| 236f44b2c3 | |||
| 6c953a327f | |||
| 5721bad117 | |||
| 969d25209c | |||
| 6b3fa1cfaa | |||
| 468d5fb572 | |||
| 2e0a0718cd | |||
| 0cef887155 | |||
| c6715eb021 | |||
| 9ec3d9281d | |||
| ba1b7abf0a | |||
| 3e29ea70ed | |||
| 2c7ab06b3f | |||
| c29928cc89 | |||
| 77c578c9e6 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,2 +1,8 @@
|
|||||||
data-unpacked-archives
|
data-unpacked-archives
|
||||||
data-broken-archives
|
data-broken-archives
|
||||||
|
.env
|
||||||
|
tmp/
|
||||||
|
__pycache__
|
||||||
|
venv
|
||||||
|
services/rag/.DS_Store
|
||||||
|
EVALUATION_RESULT.json
|
||||||
|
|||||||
2314
DOCUMENTS_TO_TEST.md
Normal file
2314
DOCUMENTS_TO_TEST.md
Normal file
File diff suppressed because it is too large
Load Diff
48
ext_stats.py
Normal file
48
ext_stats.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_ext(path_str: str) -> str:
|
||||||
|
ext = Path(path_str).suffix.lower()
|
||||||
|
return ext if ext else "(no_ext)"
|
||||||
|
|
||||||
|
|
||||||
|
def load_paths(json_path: Path) -> list[str]:
|
||||||
|
with json_path.open("r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("Expected JSON to be a list of file paths")
|
||||||
|
if not all(isinstance(item, str) for item in data):
|
||||||
|
raise ValueError("Expected JSON list to contain only strings")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Count file extensions from a JSON list of paths."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"json_path",
|
||||||
|
nargs="?",
|
||||||
|
default="yadisk_files.json",
|
||||||
|
help="Path to JSON file (default: yadisk_files.json)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
json_path = Path(args.json_path)
|
||||||
|
if not json_path.exists():
|
||||||
|
raise SystemExit(f"JSON file not found: {json_path}")
|
||||||
|
|
||||||
|
paths = load_paths(json_path)
|
||||||
|
counts = Counter(normalize_ext(p) for p in paths)
|
||||||
|
|
||||||
|
for ext, count in counts.most_common():
|
||||||
|
print(f"{ext}\t{count}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
377
generate_documents_to_test.py
Normal file
377
generate_documents_to_test.py
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
|
||||||
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
PyPDFLoader = None
|
||||||
|
TextLoader = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
UnstructuredWordDocumentLoader = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
UnstructuredPowerPointLoader = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
UnstructuredExcelLoader = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import UnstructuredODTLoader
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
UnstructuredODTLoader = None
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parent
|
||||||
|
LANGCHAIN_DIR = ROOT / "services" / "rag" / "langchain"
|
||||||
|
LLAMAINDEX_DIR = ROOT / "services" / "rag" / "llamaindex"
|
||||||
|
YADISK_JSON = ROOT / "yadisk_files.json"
|
||||||
|
OUTPUT_MD = ROOT / "DOCUMENTS_TO_TEST.md"
|
||||||
|
|
||||||
|
|
||||||
|
def safe_stem_from_remote(remote_path: str) -> str:
|
||||||
|
stem = Path(Path(remote_path).name).stem or "file"
|
||||||
|
return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_prefect_filename(remote_path: str) -> str:
|
||||||
|
remote_name = Path(remote_path).name or "downloaded_file"
|
||||||
|
suffix = Path(remote_name).suffix
|
||||||
|
digest = hashlib.md5(remote_path.encode("utf-8")).hexdigest()[:10]
|
||||||
|
return f"{safe_stem_from_remote(remote_path)}_{digest}{suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_loader(local_path: str):
|
||||||
|
ext = Path(local_path).suffix.lower()
|
||||||
|
if ext == ".pdf" and PyPDFLoader is not None:
|
||||||
|
return PyPDFLoader(local_path)
|
||||||
|
if ext in {".doc", ".docx"} and UnstructuredWordDocumentLoader is not None:
|
||||||
|
return UnstructuredWordDocumentLoader(
|
||||||
|
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
||||||
|
)
|
||||||
|
if ext == ".pptx" and UnstructuredPowerPointLoader is not None:
|
||||||
|
return UnstructuredPowerPointLoader(
|
||||||
|
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
||||||
|
)
|
||||||
|
if ext in {".xls", ".xlsx"} and UnstructuredExcelLoader is not None:
|
||||||
|
return UnstructuredExcelLoader(
|
||||||
|
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
||||||
|
)
|
||||||
|
if ext == ".odt" and UnstructuredODTLoader is not None:
|
||||||
|
return UnstructuredODTLoader(
|
||||||
|
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
||||||
|
)
|
||||||
|
if ext in {".txt", ".md"} and TextLoader is not None:
|
||||||
|
return TextLoader(local_path, encoding="utf-8")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def supported_loader_extensions() -> set[str]:
|
||||||
|
exts = set()
|
||||||
|
if PyPDFLoader is not None:
|
||||||
|
exts.add(".pdf")
|
||||||
|
if UnstructuredWordDocumentLoader is not None:
|
||||||
|
exts.update({".doc", ".docx"})
|
||||||
|
if UnstructuredPowerPointLoader is not None:
|
||||||
|
exts.add(".pptx")
|
||||||
|
if UnstructuredExcelLoader is not None:
|
||||||
|
exts.update({".xls", ".xlsx"})
|
||||||
|
if UnstructuredODTLoader is not None:
|
||||||
|
exts.add(".odt")
|
||||||
|
if TextLoader is not None:
|
||||||
|
exts.update({".txt", ".md"})
|
||||||
|
return exts
|
||||||
|
|
||||||
|
|
||||||
|
def collect_langchain_paths(client: QdrantClient) -> set[str]:
|
||||||
|
paths: set[str] = set()
|
||||||
|
offset = None
|
||||||
|
while True:
|
||||||
|
points, offset = client.scroll(
|
||||||
|
collection_name="documents_langchain",
|
||||||
|
offset=offset,
|
||||||
|
limit=1000,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
)
|
||||||
|
if not points:
|
||||||
|
break
|
||||||
|
for p in points:
|
||||||
|
payload = p.payload or {}
|
||||||
|
md = payload.get("metadata") or {}
|
||||||
|
fp = md.get("file_path") or md.get("source")
|
||||||
|
if isinstance(fp, str) and fp:
|
||||||
|
paths.add(fp)
|
||||||
|
if offset is None:
|
||||||
|
break
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def collect_llama_filenames(client: QdrantClient) -> set[str]:
|
||||||
|
names: set[str] = set()
|
||||||
|
offset = None
|
||||||
|
while True:
|
||||||
|
points, offset = client.scroll(
|
||||||
|
collection_name="documents_llamaindex",
|
||||||
|
offset=offset,
|
||||||
|
limit=1000,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
)
|
||||||
|
if not points:
|
||||||
|
break
|
||||||
|
for p in points:
|
||||||
|
payload = p.payload or {}
|
||||||
|
name = payload.get("filename")
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
names.add(name)
|
||||||
|
if offset is None:
|
||||||
|
break
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def first_unique(matches: list[str], fallback: str) -> str:
|
||||||
|
for m in matches:
|
||||||
|
m = m.strip()
|
||||||
|
if m:
|
||||||
|
return m
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def build_questions(remote_path: str, text: str) -> dict[str, list[str]]:
|
||||||
|
text = " ".join((text or "").split())
|
||||||
|
text_preview = text[:15000]
|
||||||
|
years = sorted(
|
||||||
|
{
|
||||||
|
int(m)
|
||||||
|
for m in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", text_preview)
|
||||||
|
if 1900 <= int(m) <= 2199
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dates = re.findall(
|
||||||
|
r"\b(?:\d{1,2}[./-]\d{1,2}[./-]\d{2,4}|\d{4}[./-]\d{1,2}[./-]\d{1,2})\b",
|
||||||
|
text_preview,
|
||||||
|
)
|
||||||
|
numbers = re.findall(r"\b\d{2,}\b", text_preview)
|
||||||
|
quoted = re.findall(r"[\"«]([^\"»\n]{4,120})[\"»]", text_preview)
|
||||||
|
org_like = re.findall(
|
||||||
|
r"\b(?:ООО|АО|ПАО|ФГУП|Минтранс|Министерств[ао]|Правительств[ао]|Форум)\b[^\n,.]{0,80}",
|
||||||
|
text_preview,
|
||||||
|
flags=re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
year_q = (
|
||||||
|
f"В каком году в документе описывается ключевое событие ({years[0]}) и как это подтверждается контекстом?"
|
||||||
|
if years
|
||||||
|
else "Есть ли в документе указание на год события? Если да, какой именно год упомянут?"
|
||||||
|
)
|
||||||
|
date_q = (
|
||||||
|
f"Какая дата ({dates[0]}) встречается в документе и к какому событию/разделу она относится?"
|
||||||
|
if dates
|
||||||
|
else "Какие календарные даты или периоды (если есть) упомянуты в документе?"
|
||||||
|
)
|
||||||
|
num_q = (
|
||||||
|
f"Какое числовое значение ({numbers[0]}) встречается в документе и в каком контексте оно используется?"
|
||||||
|
if numbers
|
||||||
|
else "Есть ли в документе количественные показатели (суммы, проценты, номера, объемы) и что они обозначают?"
|
||||||
|
)
|
||||||
|
entity = first_unique(quoted, first_unique(org_like, Path(remote_path).name))
|
||||||
|
topic_hint = Path(remote_path).stem.replace("_", " ").replace("-", " ")
|
||||||
|
topic_hint = " ".join(topic_hint.split())[:120]
|
||||||
|
entity_q = f"Что в документе говорится про «{entity}»?"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"Entity/Fact Recall (Response Relevance)": [
|
||||||
|
f"Что известно про «{entity}» в материалах базы?",
|
||||||
|
f"В контексте темы «{topic_hint}» кто выступает ключевым участником и какова его роль?",
|
||||||
|
],
|
||||||
|
"Numerical & Temporal Precision": [
|
||||||
|
year_q.replace("в документе", "в материалах").replace("документе", "материалах"),
|
||||||
|
date_q.replace("в документе", "в материалах").replace("документе", "материалах"),
|
||||||
|
num_q.replace("в документе", "в материалах").replace("документе", "материалах"),
|
||||||
|
],
|
||||||
|
"Context Precision (Evidence-anchored)": [
|
||||||
|
f"Найди в базе фрагмент, который лучше всего подтверждает тезис по теме «{topic_hint}», и объясни его релевантность.",
|
||||||
|
f"Есть ли в базе схожие по теме «{topic_hint}», но нерелевантные фрагменты, которые можно ошибочно выбрать?",
|
||||||
|
],
|
||||||
|
"Faithfulness / Non-hallucination": [
|
||||||
|
f"Какая информация по теме «{topic_hint}» отсутствует в найденном контексте и не должна быть додумана?",
|
||||||
|
f"Если прямого ответа по теме «{topic_hint}» в материалах нет, как корректно ответить без галлюцинаций?",
|
||||||
|
],
|
||||||
|
"Reasoning & Synthesis": [
|
||||||
|
f"Сформулируй краткий вывод по теме «{topic_hint}» в 2-3 пунктах, опираясь на несколько найденных фрагментов.",
|
||||||
|
f"Какие ограничения, риски или условия по теме «{topic_hint}» упоминаются в материалах, и как они влияют на вывод?",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_document_text(docs: list[Any]) -> str:
|
||||||
|
chunks: list[str] = []
|
||||||
|
for doc in docs:
|
||||||
|
content = getattr(doc, "page_content", None)
|
||||||
|
if content is None:
|
||||||
|
content = getattr(doc, "text", None)
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
chunks.append(content.strip())
|
||||||
|
if len(" ".join(chunks)) > 25000:
|
||||||
|
break
|
||||||
|
return "\n".join(chunks)[:25000]
|
||||||
|
|
||||||
|
|
||||||
|
def download_yadisk_file(remote_path: str, token: str, local_path: str) -> None:
|
||||||
|
headers = {"Authorization": f"OAuth {token}"}
|
||||||
|
response = requests.get(
|
||||||
|
"https://cloud-api.yandex.net/v1/disk/resources/download",
|
||||||
|
headers=headers,
|
||||||
|
params={"path": remote_path},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
href = response.json()["href"]
|
||||||
|
file_response = requests.get(href, timeout=180)
|
||||||
|
file_response.raise_for_status()
|
||||||
|
with open(local_path, "wb") as f:
|
||||||
|
f.write(file_response.content)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_text_from_yadisk(remote_path: str, token: str) -> str:
|
||||||
|
suffix = Path(remote_path).suffix
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||||
|
local_path = tmp.name
|
||||||
|
try:
|
||||||
|
download_yadisk_file(remote_path, token, local_path)
|
||||||
|
loader = get_loader(local_path)
|
||||||
|
if loader is None:
|
||||||
|
return ""
|
||||||
|
docs = loader.load()
|
||||||
|
return extract_document_text(docs)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
os.unlink(local_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
load_dotenv(LANGCHAIN_DIR / ".env")
|
||||||
|
qdrant_host = os.getenv("QDRANT_HOST")
|
||||||
|
qdrant_rest_port = int(os.getenv("QDRANT_REST_PORT", "6333"))
|
||||||
|
yadisk_token = os.getenv("YADISK_TOKEN", "").strip()
|
||||||
|
if not qdrant_host:
|
||||||
|
raise RuntimeError("QDRANT_HOST is missing in langchain .env")
|
||||||
|
if not yadisk_token:
|
||||||
|
raise RuntimeError("YADISK_TOKEN is missing in langchain .env")
|
||||||
|
|
||||||
|
with YADISK_JSON.open("r", encoding="utf-8") as f:
|
||||||
|
raw_paths = json.load(f)
|
||||||
|
if not isinstance(raw_paths, list):
|
||||||
|
raise RuntimeError("yadisk_files.json must be a JSON list of paths")
|
||||||
|
all_paths = [str(x) for x in raw_paths if isinstance(x, str)]
|
||||||
|
|
||||||
|
allowed_ext = supported_loader_extensions()
|
||||||
|
filtered_by_ext = [
|
||||||
|
p for p in all_paths if Path(p).suffix.lower() in allowed_ext and p.startswith("disk:/")
|
||||||
|
]
|
||||||
|
|
||||||
|
client = QdrantClient(host=qdrant_host, port=qdrant_rest_port, timeout=60)
|
||||||
|
langchain_paths = collect_langchain_paths(client)
|
||||||
|
llama_filenames = collect_llama_filenames(client)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for path in filtered_by_ext:
|
||||||
|
if path not in langchain_paths:
|
||||||
|
continue
|
||||||
|
if llama_prefect_filename(path) not in llama_filenames:
|
||||||
|
continue
|
||||||
|
candidates.append(path)
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(candidates)
|
||||||
|
if len(candidates) < 100:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Only {len(candidates)} candidate documents found in both collections; need 100"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
attempts = 0
|
||||||
|
for remote_path in candidates:
|
||||||
|
if len(rows) >= 100:
|
||||||
|
break
|
||||||
|
attempts += 1
|
||||||
|
idx = len(rows) + 1
|
||||||
|
print(f"[TRY {attempts:03d}] loading {remote_path}")
|
||||||
|
try:
|
||||||
|
text = fetch_text_from_yadisk(remote_path, yadisk_token)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" -> skip (download/read error): {e}")
|
||||||
|
continue
|
||||||
|
if not text.strip():
|
||||||
|
print(" -> skip (empty extracted text)")
|
||||||
|
continue
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"index": idx,
|
||||||
|
"path": remote_path,
|
||||||
|
"questions": build_questions(remote_path, text),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"[OK {idx:03d}/100] prepared questions for {remote_path}")
|
||||||
|
|
||||||
|
if len(rows) < 100:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Only {len(rows)} documents were successfully downloaded/read and turned into questions"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append("# DOCUMENTS_TO_TEST")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("This dataset contains 100 YaDisk documents that were verified as present in both:")
|
||||||
|
lines.append("- `documents_langchain` (Qdrant)")
|
||||||
|
lines.append("- `documents_llamaindex` (Qdrant)")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Question sections are aligned with common RAG evaluation themes (retrieval + generation):")
|
||||||
|
lines.append("- Response relevance / entity-fact recall")
|
||||||
|
lines.append("- Numerical and temporal precision")
|
||||||
|
lines.append("- Context precision")
|
||||||
|
lines.append("- Faithfulness / non-hallucination")
|
||||||
|
lines.append("- Reasoning / synthesis")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(
|
||||||
|
"_References used for evaluation themes: RAGAS metrics and NVIDIA RAG pipeline evaluation docs._"
|
||||||
|
)
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
lines.append(f"## {row['index']:03d}. `{row['path']}`")
|
||||||
|
lines.append("")
|
||||||
|
for section, qs in row["questions"].items():
|
||||||
|
lines.append(f"### {section}")
|
||||||
|
for q in qs:
|
||||||
|
lines.append(f"- {q}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
OUTPUT_MD.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
print(f"Written: {OUTPUT_MD}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
732
rag_evaluation.py
Normal file
732
rag_evaluation.py
Normal file
@@ -0,0 +1,732 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
RAG evaluation script (file-batch mode).
|
||||||
|
|
||||||
|
Key behavior:
|
||||||
|
- Step = one document file (all its questions), not one question.
|
||||||
|
- Pre-download/caching in ./tmp/rag-evaluation (skip if already downloaded).
|
||||||
|
- Sequential API calls only (LangChain then LlamaIndex).
|
||||||
|
- Pairwise answer evaluation (both systems in one judge prompt).
|
||||||
|
- JSON output with append/overwrite support for batch runs and re-runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError as e: # pragma: no cover
|
||||||
|
raise SystemExit(
|
||||||
|
"Missing dependency: requests. Run with your project venv "
|
||||||
|
"(for example services/rag/langchain/venv/bin/python rag_evaluation.py ...)"
|
||||||
|
) from e
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Configuration
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
LANGCHAIN_URL = os.getenv("LANGCHAIN_URL", "http://localhost:8331/api/test-query")
|
||||||
|
LLAMAINDEX_URL = os.getenv("LLAMAINDEX_URL", "http://localhost:8334/api/test-query")
|
||||||
|
|
||||||
|
# OpenAI-compatible evaluator endpoint. You can point this at OpenAI-compatible providers.
|
||||||
|
OPENAI_CHAT_URL = os.getenv(
|
||||||
|
"OPENAI_CHAT_URL", "https://foundation-models.api.cloud.ru/v1"
|
||||||
|
)
|
||||||
|
OPENAI_CHAT_KEY = os.getenv("OPENAI_CHAT_KEY", "")
|
||||||
|
OPENAI_CHAT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "MiniMaxAI/MiniMax-M2")
|
||||||
|
|
||||||
|
YADISK_TOKEN = os.getenv("YADISK_TOKEN", "")
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
|
INPUT_MD = BASE_DIR / "DOCUMENTS_TO_TEST.md"
|
||||||
|
OUTPUT_JSON = BASE_DIR / "EVALUATION_RESULT.json"
|
||||||
|
TMP_DIR = BASE_DIR / "tmp" / "rag-evaluation"
|
||||||
|
|
||||||
|
RAG_TIMEOUT = int(os.getenv("RAG_TIMEOUT", "120"))
|
||||||
|
EVAL_TIMEOUT = int(os.getenv("EVAL_TIMEOUT", "90"))
|
||||||
|
YADISK_META_TIMEOUT = int(os.getenv("YADISK_META_TIMEOUT", "30"))
|
||||||
|
YADISK_DOWNLOAD_TIMEOUT = int(os.getenv("YADISK_DOWNLOAD_TIMEOUT", "180"))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Data structures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuestionResult:
|
||||||
|
section: str
|
||||||
|
question: str
|
||||||
|
langchain_answer: str = ""
|
||||||
|
llamaindex_answer: str = ""
|
||||||
|
langchain_score: float = 0.0
|
||||||
|
llamaindex_score: float = 0.0
|
||||||
|
winner: str = "Tie"
|
||||||
|
rationale: str = ""
|
||||||
|
evaluator_model: str = ""
|
||||||
|
evaluated_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocumentEvaluation:
|
||||||
|
index: int
|
||||||
|
path: str
|
||||||
|
cache_file: str = ""
|
||||||
|
cache_status: str = ""
|
||||||
|
questions: list[QuestionResult] = field(default_factory=list)
|
||||||
|
started_at: str = ""
|
||||||
|
finished_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Markdown parsing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def split_documents(md_text: str) -> tuple[list[str], list[str]]:
|
||||||
|
lines = md_text.splitlines()
|
||||||
|
header: list[str] = []
|
||||||
|
docs: list[list[str]] = []
|
||||||
|
current: list[str] | None = None
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("## "):
|
||||||
|
if current is not None:
|
||||||
|
docs.append(current)
|
||||||
|
current = [line]
|
||||||
|
else:
|
||||||
|
if current is None:
|
||||||
|
header.append(line)
|
||||||
|
else:
|
||||||
|
current.append(line)
|
||||||
|
if current is not None:
|
||||||
|
docs.append(current)
|
||||||
|
return header, ["\n".join(d) for d in docs]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_document_block(idx: int, block: str) -> tuple[str, list[QuestionResult]]:
|
||||||
|
lines = block.splitlines()
|
||||||
|
header = lines[0].strip()
|
||||||
|
m = re.search(r"`([^`]+)`", header)
|
||||||
|
doc_path = m.group(1) if m else ""
|
||||||
|
|
||||||
|
section = ""
|
||||||
|
questions: list[QuestionResult] = []
|
||||||
|
for line in lines[1:]:
|
||||||
|
if line.startswith("### "):
|
||||||
|
section = line[4:].strip()
|
||||||
|
elif line.startswith("- "):
|
||||||
|
q = line[2:].strip()
|
||||||
|
if q:
|
||||||
|
questions.append(QuestionResult(section=section, question=q))
|
||||||
|
return doc_path, questions
|
||||||
|
|
||||||
|
|
||||||
|
def parse_all_docs(md_path: Path) -> list[tuple[int, str, list[QuestionResult]]]:
|
||||||
|
raw = md_path.read_text(encoding="utf-8")
|
||||||
|
_, blocks = split_documents(raw)
|
||||||
|
parsed: list[tuple[int, str, list[QuestionResult]]] = []
|
||||||
|
for i, block in enumerate(blocks, start=1):
|
||||||
|
path, questions = parse_document_block(i, block)
|
||||||
|
parsed.append((i, path, questions))
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Caching / Yandex Disk
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def cache_file_name(remote_path: str) -> str:
|
||||||
|
# Deterministic local cache filename
|
||||||
|
digest = re.sub(r"[^a-z0-9]", "", str(abs(hash(remote_path))))[:12]
|
||||||
|
suffix = Path(remote_path).suffix or ".bin"
|
||||||
|
return f"{digest}{suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
def download_yadisk_to_cache(remote_path: str, token: str, cache_path: Path) -> str:
|
||||||
|
"""
|
||||||
|
Download file into cache path if missing.
|
||||||
|
Returns status: "cached_existing" | "downloaded" | "error:..."
|
||||||
|
"""
|
||||||
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if cache_path.exists() and cache_path.stat().st_size > 0:
|
||||||
|
return "cached_existing"
|
||||||
|
if not token:
|
||||||
|
return "error:missing_yadisk_token"
|
||||||
|
|
||||||
|
headers = {"Authorization": f"OAuth {token}"}
|
||||||
|
try:
|
||||||
|
r = requests.get(
|
||||||
|
"https://cloud-api.yandex.net/v1/disk/resources/download",
|
||||||
|
headers=headers,
|
||||||
|
params={"path": remote_path},
|
||||||
|
timeout=YADISK_META_TIMEOUT,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
href = r.json()["href"]
|
||||||
|
f = requests.get(href, timeout=YADISK_DOWNLOAD_TIMEOUT)
|
||||||
|
f.raise_for_status()
|
||||||
|
cache_path.write_bytes(f.content)
|
||||||
|
if cache_path.stat().st_size == 0:
|
||||||
|
return "error:empty_download"
|
||||||
|
return "downloaded"
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"error:{e}"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# File text extraction (for evaluator context)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_file(path: Path) -> str:
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
if ext in {".txt", ".md", ".csv", ".json", ".xml", ".html", ".htm"}:
|
||||||
|
return path.read_text(encoding="utf-8", errors="ignore")
|
||||||
|
|
||||||
|
if ext in {".docx", ".doc"}:
|
||||||
|
try:
|
||||||
|
from docx import Document # type: ignore
|
||||||
|
|
||||||
|
doc = Document(str(path))
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"[DOC parse error: {e}]"
|
||||||
|
|
||||||
|
if ext == ".pdf":
|
||||||
|
try:
|
||||||
|
import PyPDF2 # type: ignore
|
||||||
|
|
||||||
|
out: list[str] = []
|
||||||
|
with path.open("rb") as f:
|
||||||
|
reader = PyPDF2.PdfReader(f)
|
||||||
|
for page in reader.pages:
|
||||||
|
out.append(page.extract_text() or "")
|
||||||
|
return "\n".join(out)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"[PDF parse error: {e}]"
|
||||||
|
|
||||||
|
if ext in {".xlsx", ".xls"}:
|
||||||
|
try:
|
||||||
|
from openpyxl import load_workbook # type: ignore
|
||||||
|
|
||||||
|
wb = load_workbook(str(path), read_only=True)
|
||||||
|
out: list[str] = []
|
||||||
|
for ws in wb.worksheets:
|
||||||
|
for row in ws.iter_rows(values_only=True):
|
||||||
|
out.append("\t".join("" if c is None else str(c) for c in row))
|
||||||
|
if len(out) > 5000:
|
||||||
|
break
|
||||||
|
if len(out) > 5000:
|
||||||
|
break
|
||||||
|
return "\n".join(out)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"[XLS parse error: {e}]"
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
try:
|
||||||
|
return path.read_text(encoding="utf-8", errors="ignore")
|
||||||
|
except Exception:
|
||||||
|
return f"[Binary file: {path.name}]"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RAG API calls (sequential)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def call_rag(url: str, query: str, timeout: int) -> str:
|
||||||
|
payload = {"query": query}
|
||||||
|
try:
|
||||||
|
r = requests.post(url, json=payload, timeout=timeout)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
text = data.get("response", "")
|
||||||
|
if text is None:
|
||||||
|
return ""
|
||||||
|
return str(text).strip()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"ERROR: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def call_langchain(query: str, timeout: int) -> str:
|
||||||
|
return call_rag(LANGCHAIN_URL, query, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def call_llamaindex(query: str, timeout: int) -> str:
|
||||||
|
payload = {"query": query, "mode": "agent"}
|
||||||
|
try:
|
||||||
|
r = requests.post(LLAMAINDEX_URL, json=payload, timeout=timeout)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
text = data.get("response", "")
|
||||||
|
if text is None:
|
||||||
|
return ""
|
||||||
|
return str(text).strip()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"ERROR: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Evaluator
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _rule_score(answer: str) -> float:
|
||||||
|
if not answer or not answer.strip():
|
||||||
|
return 0.0
|
||||||
|
if answer.startswith("ERROR:"):
|
||||||
|
return -1.0
|
||||||
|
score = 0.3
|
||||||
|
if len(answer) > 120:
|
||||||
|
score += 0.2
|
||||||
|
if re.search(r"\d", answer):
|
||||||
|
score += 0.1
|
||||||
|
if re.search(r"[.!?]", answer):
|
||||||
|
score += 0.1
|
||||||
|
if re.search(r"(не найден|недостаточно|нет информации)", answer.lower()):
|
||||||
|
score += 0.05
|
||||||
|
return min(1.0, score)
|
||||||
|
|
||||||
|
|
||||||
|
SECTION_CRITERIA: dict[str, str] = {
|
||||||
|
"Entity/Fact Recall (Response Relevance)": "Оцени точность извлечения сущностей/фактов и релевантность вопросу.",
|
||||||
|
"Numerical & Temporal Precision": "Оцени точность чисел, дат, периодов и временных связей.",
|
||||||
|
"Context Precision (Evidence-anchored)": "Оцени, насколько ответ опирается на релевантный контекст без лишнего.",
|
||||||
|
"Faithfulness / Non-hallucination": "Оцени отсутствие галлюцинаций и корректное поведение при отсутствии фактов.",
|
||||||
|
"Reasoning & Synthesis": "Оцени качество синтеза фактов и логичность итогового вывода.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_pair_eval_prompt(
|
||||||
|
question: str,
|
||||||
|
section: str,
|
||||||
|
langchain_answer: str,
|
||||||
|
llamaindex_answer: str,
|
||||||
|
document_text: str,
|
||||||
|
) -> str:
|
||||||
|
criteria = SECTION_CRITERIA.get(
|
||||||
|
section, "Оцени релевантность, точность и полезность."
|
||||||
|
)
|
||||||
|
context = document_text[:9000]
|
||||||
|
return f"""Ты судья качества RAG-ответов. Сравни два ответа на один вопрос.
|
||||||
|
|
||||||
|
Вопрос:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Секция оценки:
|
||||||
|
{section}
|
||||||
|
Критерий:
|
||||||
|
{criteria}
|
||||||
|
|
||||||
|
Ответ A (LangChain):
|
||||||
|
{langchain_answer}
|
||||||
|
|
||||||
|
Ответ B (LlamaIndex):
|
||||||
|
{llamaindex_answer}
|
||||||
|
|
||||||
|
Опорный контекст документа:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Верни ТОЛЬКО JSON:
|
||||||
|
{{
|
||||||
|
"langchain_score": <float от -1.0 до 1.0>,
|
||||||
|
"llamaindex_score": <float от -1.0 до 1.0>,
|
||||||
|
"winner": "LangChain|LlamaIndex|Tie",
|
||||||
|
"rationale": "<кратко по сути>"
|
||||||
|
}}
|
||||||
|
|
||||||
|
Правила:
|
||||||
|
- Технические ошибки/таймауты должны получать -1.0.
|
||||||
|
- Пустой ответ без ошибки = 0.0.
|
||||||
|
- Галлюцинации сильно штрафуются.
|
||||||
|
- Если разница незначительная, выбирай Tie.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_pair_with_llm(
|
||||||
|
question: str,
|
||||||
|
section: str,
|
||||||
|
langchain_answer: str,
|
||||||
|
llamaindex_answer: str,
|
||||||
|
document_text: str,
|
||||||
|
) -> tuple[float, float, str, str]:
|
||||||
|
# Deterministic short-circuit for technical failures
|
||||||
|
if langchain_answer.startswith("ERROR:") and llamaindex_answer.startswith("ERROR:"):
|
||||||
|
return -1.0, -1.0, "Tie", "Обе системы вернули техническую ошибку."
|
||||||
|
if langchain_answer.startswith("ERROR:"):
|
||||||
|
return (
|
||||||
|
-1.0,
|
||||||
|
_rule_score(llamaindex_answer),
|
||||||
|
"LlamaIndex",
|
||||||
|
"LangChain технически не ответил.",
|
||||||
|
)
|
||||||
|
if llamaindex_answer.startswith("ERROR:"):
|
||||||
|
return (
|
||||||
|
_rule_score(langchain_answer),
|
||||||
|
-1.0,
|
||||||
|
"LangChain",
|
||||||
|
"LlamaIndex технически не ответил.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENAI_CHAT_KEY:
|
||||||
|
# fallback heuristic
|
||||||
|
lc = _rule_score(langchain_answer)
|
||||||
|
li = _rule_score(llamaindex_answer)
|
||||||
|
if abs(lc - li) < 0.05:
|
||||||
|
return lc, li, "Tie", "Эвристическая оценка без LLM (ключ не задан)."
|
||||||
|
return (
|
||||||
|
(lc, li, "LangChain", "Эвристическая оценка без LLM.")
|
||||||
|
if lc > li
|
||||||
|
else (
|
||||||
|
lc,
|
||||||
|
li,
|
||||||
|
"LlamaIndex",
|
||||||
|
"Эвристическая оценка без LLM.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = build_pair_eval_prompt(
|
||||||
|
question=question,
|
||||||
|
section=section,
|
||||||
|
langchain_answer=langchain_answer,
|
||||||
|
llamaindex_answer=llamaindex_answer,
|
||||||
|
document_text=document_text,
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {OPENAI_CHAT_KEY}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": OPENAI_CHAT_MODEL,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Ты строгий судья качества RAG. Отвечай только JSON.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 400,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
f"{OPENAI_CHAT_URL.rstrip('/')}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=EVAL_TIMEOUT,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
m = re.search(r"\{.*\}", content, re.DOTALL)
|
||||||
|
raw = m.group(0) if m else content
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
lc = float(parsed.get("langchain_score", 0.0))
|
||||||
|
li = float(parsed.get("llamaindex_score", 0.0))
|
||||||
|
winner = str(parsed.get("winner", "Tie"))
|
||||||
|
rationale = str(parsed.get("rationale", ""))
|
||||||
|
if winner not in {"LangChain", "LlamaIndex", "Tie"}:
|
||||||
|
winner = "Tie"
|
||||||
|
return lc, li, winner, rationale
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
lc = _rule_score(langchain_answer)
|
||||||
|
li = _rule_score(llamaindex_answer)
|
||||||
|
if abs(lc - li) < 0.05:
|
||||||
|
return lc, li, "Tie", f"Fallback heuristic; LLM eval error: {e}"
|
||||||
|
return (
|
||||||
|
(lc, li, "LangChain", f"Fallback heuristic; LLM eval error: {e}")
|
||||||
|
if lc > li
|
||||||
|
else (
|
||||||
|
lc,
|
||||||
|
li,
|
||||||
|
"LlamaIndex",
|
||||||
|
f"Fallback heuristic; LLM eval error: {e}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# JSON storage
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def now_iso() -> str:
|
||||||
|
return dt.datetime.now(dt.timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def default_json_payload(
|
||||||
|
all_docs: list[tuple[int, str, list[QuestionResult]]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"meta": {
|
||||||
|
"created_at": now_iso(),
|
||||||
|
"updated_at": now_iso(),
|
||||||
|
"input_file": str(INPUT_MD),
|
||||||
|
"langchain_url": LANGCHAIN_URL,
|
||||||
|
"llamaindex_url": LLAMAINDEX_URL,
|
||||||
|
"evaluator_model": OPENAI_CHAT_MODEL,
|
||||||
|
"notes": [
|
||||||
|
"step = one file (all file questions)",
|
||||||
|
"sequential API calls only",
|
||||||
|
"cache dir: ./tmp/rag-evaluation",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"index": idx,
|
||||||
|
"path": path,
|
||||||
|
"cache_file": "",
|
||||||
|
"cache_status": "not_processed",
|
||||||
|
"started_at": "",
|
||||||
|
"finished_at": "",
|
||||||
|
"questions": [asdict(q) for q in questions],
|
||||||
|
}
|
||||||
|
for idx, path, questions in all_docs
|
||||||
|
],
|
||||||
|
"batches": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_init_json(
|
||||||
|
all_docs: list[tuple[int, str, list[QuestionResult]]],
|
||||||
|
output_json: Path,
|
||||||
|
mode: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if mode == "overwrite" or not output_json.exists():
|
||||||
|
return default_json_payload(all_docs)
|
||||||
|
try:
|
||||||
|
data = json.loads(output_json.read_text(encoding="utf-8"))
|
||||||
|
if "documents" not in data:
|
||||||
|
return default_json_payload(all_docs)
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
return default_json_payload(all_docs)
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_document_result(store: dict[str, Any], result: DocumentEvaluation) -> None:
|
||||||
|
docs = store.setdefault("documents", [])
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
if doc.get("path") == result.path:
|
||||||
|
docs[i] = {
|
||||||
|
"index": result.index,
|
||||||
|
"path": result.path,
|
||||||
|
"cache_file": result.cache_file,
|
||||||
|
"cache_status": result.cache_status,
|
||||||
|
"started_at": result.started_at,
|
||||||
|
"finished_at": result.finished_at,
|
||||||
|
"questions": [asdict(q) for q in result.questions],
|
||||||
|
}
|
||||||
|
return
|
||||||
|
docs.append(
|
||||||
|
{
|
||||||
|
"index": result.index,
|
||||||
|
"path": result.path,
|
||||||
|
"cache_file": result.cache_file,
|
||||||
|
"cache_status": result.cache_status,
|
||||||
|
"started_at": result.started_at,
|
||||||
|
"finished_at": result.finished_at,
|
||||||
|
"questions": [asdict(q) for q in result.questions],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_batch_stats(store: dict[str, Any], batch_meta: dict[str, Any]) -> None:
|
||||||
|
store.setdefault("batches", []).append(batch_meta)
|
||||||
|
store.setdefault("meta", {})["updated_at"] = now_iso()
|
||||||
|
|
||||||
|
|
||||||
|
def atomic_write_json(path: Path, payload: dict[str, Any]) -> None:
|
||||||
|
"""Atomically write JSON to avoid partial/corrupted files on interruption."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
tmp_path.replace(path)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_batch_summary(results: list[DocumentEvaluation]) -> dict[str, Any]:
|
||||||
|
wins = {"LangChain": 0, "LlamaIndex": 0, "Tie": 0}
|
||||||
|
scores_lc: list[float] = []
|
||||||
|
scores_li: list[float] = []
|
||||||
|
q_total = 0
|
||||||
|
for d in results:
|
||||||
|
for q in d.questions:
|
||||||
|
q_total += 1
|
||||||
|
wins[q.winner] = wins.get(q.winner, 0) + 1
|
||||||
|
scores_lc.append(q.langchain_score)
|
||||||
|
scores_li.append(q.llamaindex_score)
|
||||||
|
avg_lc = sum(scores_lc) / max(1, len(scores_lc))
|
||||||
|
avg_li = sum(scores_li) / max(1, len(scores_li))
|
||||||
|
if avg_lc > avg_li + 0.01:
|
||||||
|
ranking = "LangChain"
|
||||||
|
elif avg_li > avg_lc + 0.01:
|
||||||
|
ranking = "LlamaIndex"
|
||||||
|
else:
|
||||||
|
ranking = "Tie"
|
||||||
|
return {
|
||||||
|
"documents_processed": len(results),
|
||||||
|
"questions_processed": q_total,
|
||||||
|
"wins": wins,
|
||||||
|
"avg_langchain": round(avg_lc, 4),
|
||||||
|
"avg_llamaindex": round(avg_li, 4),
|
||||||
|
"ranking": ranking,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Main flow
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def run_evaluation(doc_from: int, doc_to: int, mode: str) -> None:
|
||||||
|
all_docs = parse_all_docs(INPUT_MD)
|
||||||
|
total_docs = len(all_docs)
|
||||||
|
doc_from = max(1, doc_from)
|
||||||
|
doc_to = min(total_docs, doc_to)
|
||||||
|
if doc_from > doc_to:
|
||||||
|
raise ValueError(f"Invalid doc range: {doc_from}:{doc_to}")
|
||||||
|
|
||||||
|
store = load_or_init_json(all_docs, OUTPUT_JSON, mode)
|
||||||
|
|
||||||
|
TMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
selected = [d for d in all_docs if doc_from <= d[0] <= doc_to]
|
||||||
|
print(
|
||||||
|
f"Total docs: {total_docs}. Processing docs {doc_from}:{doc_to} ({len(selected)} steps)."
|
||||||
|
)
|
||||||
|
print(f"Cache dir: {TMP_DIR}")
|
||||||
|
print(f"Output JSON: {OUTPUT_JSON}")
|
||||||
|
|
||||||
|
batch_results: list[DocumentEvaluation] = []
|
||||||
|
batch_started = now_iso()
|
||||||
|
|
||||||
|
for step, (idx, doc_path, questions) in enumerate(selected, start=1):
|
||||||
|
print(f"\n[STEP {step}/{len(selected)}] File #{idx}: {doc_path}")
|
||||||
|
started = now_iso()
|
||||||
|
cache_name = cache_file_name(doc_path)
|
||||||
|
cache_path = TMP_DIR / cache_name
|
||||||
|
cache_status = download_yadisk_to_cache(doc_path, YADISK_TOKEN, cache_path)
|
||||||
|
print(f" -> cache: {cache_status} ({cache_path})")
|
||||||
|
|
||||||
|
doc_text = ""
|
||||||
|
if cache_status.startswith("error:"):
|
||||||
|
doc_text = f"[CACHE_ERROR] {cache_status}"
|
||||||
|
else:
|
||||||
|
doc_text = extract_text_from_file(cache_path)
|
||||||
|
print(f" -> extracted text length: {len(doc_text)}")
|
||||||
|
|
||||||
|
evaluated_questions: list[QuestionResult] = []
|
||||||
|
for qn, q in enumerate(questions, start=1):
|
||||||
|
qr = QuestionResult(section=q.section, question=q.question)
|
||||||
|
print(f" [{qn}/{len(questions)}] {q.question[:90]}")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
qr.langchain_answer = call_langchain(q.question, timeout=RAG_TIMEOUT)
|
||||||
|
print(f" LangChain: {time.time() - t0:.1f}s")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
qr.llamaindex_answer = call_llamaindex(q.question, timeout=RAG_TIMEOUT)
|
||||||
|
print(f" LlamaIndex: {time.time() - t0:.1f}s")
|
||||||
|
|
||||||
|
lc, li, winner, rationale = evaluate_pair_with_llm(
|
||||||
|
question=q.question,
|
||||||
|
section=q.section,
|
||||||
|
langchain_answer=qr.langchain_answer,
|
||||||
|
llamaindex_answer=qr.llamaindex_answer,
|
||||||
|
document_text=doc_text,
|
||||||
|
)
|
||||||
|
qr.langchain_score = lc
|
||||||
|
qr.llamaindex_score = li
|
||||||
|
qr.winner = winner
|
||||||
|
qr.rationale = rationale
|
||||||
|
qr.evaluator_model = OPENAI_CHAT_MODEL
|
||||||
|
qr.evaluated_at = now_iso()
|
||||||
|
evaluated_questions.append(qr)
|
||||||
|
|
||||||
|
doc_result = DocumentEvaluation(
|
||||||
|
index=idx,
|
||||||
|
path=doc_path,
|
||||||
|
cache_file=str(cache_path),
|
||||||
|
cache_status=cache_status,
|
||||||
|
questions=evaluated_questions,
|
||||||
|
started_at=started,
|
||||||
|
finished_at=now_iso(),
|
||||||
|
)
|
||||||
|
upsert_document_result(store, doc_result)
|
||||||
|
batch_results.append(doc_result)
|
||||||
|
|
||||||
|
# Save incremental progress after each file/step
|
||||||
|
atomic_write_json(OUTPUT_JSON, store)
|
||||||
|
print(" -> step saved")
|
||||||
|
|
||||||
|
summary = compute_batch_summary(batch_results)
|
||||||
|
batch_meta = {
|
||||||
|
"started_at": batch_started,
|
||||||
|
"finished_at": now_iso(),
|
||||||
|
"range": f"{doc_from}:{doc_to}",
|
||||||
|
"summary": summary,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
update_batch_stats(store, batch_meta)
|
||||||
|
atomic_write_json(OUTPUT_JSON, store)
|
||||||
|
|
||||||
|
print("\nBatch complete.")
|
||||||
|
print(json.dumps(summary, ensure_ascii=False, indent=2))
|
||||||
|
print(f"Saved to: {OUTPUT_JSON}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_range(value: str) -> tuple[int, int]:
|
||||||
|
m = re.fullmatch(r"(\d+):(\d+)", value.strip())
|
||||||
|
if not m:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
"Range must be in format from:to (example: 1:10)"
|
||||||
|
)
|
||||||
|
a, b = int(m.group(1)), int(m.group(2))
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise argparse.ArgumentTypeError("Range values must be positive")
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="RAG evaluation in file-batch mode (JSON output)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"doc_range",
|
||||||
|
type=parse_range,
|
||||||
|
help="Document range in format from:to (step = one file). Example: 1:10",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["append", "overwrite"],
|
||||||
|
default="append",
|
||||||
|
help="append: upsert evaluated docs into existing JSON; overwrite: rebuild JSON from input docs",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
doc_from, doc_to = args.doc_range
|
||||||
|
|
||||||
|
if "MiniMax" in OPENAI_CHAT_MODEL or "MiniMax" in OPENAI_CHAT_URL:
|
||||||
|
print(
|
||||||
|
"NOTE: evaluator model is MiniMax. It works, but for stricter judging quality, "
|
||||||
|
"gpt-4.1-mini/gpt-4.1 (if available on your endpoint) is usually stronger."
|
||||||
|
)
|
||||||
|
|
||||||
|
run_evaluation(doc_from=doc_from, doc_to=doc_to, mode=args.mode)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
certifi==2026.2.25
|
||||||
|
charset-normalizer==3.4.5
|
||||||
|
dotenv==0.9.9
|
||||||
|
idna==3.11
|
||||||
|
python-dotenv==1.2.2
|
||||||
|
requests==2.32.5
|
||||||
|
urllib3==2.6.3
|
||||||
|
yadisk==3.4.0
|
||||||
BIN
services/.DS_Store
vendored
Normal file
BIN
services/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
services/opensource/.DS_Store
vendored
Normal file
BIN
services/opensource/.DS_Store
vendored
Normal file
Binary file not shown.
1
services/opensource/ragflow
Submodule
1
services/opensource/ragflow
Submodule
Submodule services/opensource/ragflow added at ce71d87867
BIN
services/rag/.DS_Store
vendored
BIN
services/rag/.DS_Store
vendored
Binary file not shown.
BIN
services/rag/langchain/.DS_Store
vendored
BIN
services/rag/langchain/.DS_Store
vendored
Binary file not shown.
@@ -3,6 +3,7 @@ OLLAMA_CHAT_MODEL=MODEL
|
|||||||
OPENAI_CHAT_URL=URL
|
OPENAI_CHAT_URL=URL
|
||||||
OPENAI_CHAT_KEY=KEY
|
OPENAI_CHAT_KEY=KEY
|
||||||
CHAT_MODEL_STRATEGY=ollama
|
CHAT_MODEL_STRATEGY=ollama
|
||||||
|
PREFECT_API_URL=URL
|
||||||
QDRANT_HOST=HOST
|
QDRANT_HOST=HOST
|
||||||
QDRANT_REST_PORT=PORT
|
QDRANT_REST_PORT=PORT
|
||||||
QDRANT_GRPC_PORT=PORT
|
QDRANT_GRPC_PORT=PORT
|
||||||
@@ -14,3 +15,4 @@ ENRICHMENT_PROCESSING_MODE=async/sync
|
|||||||
ENRICHMENT_ADAPTIVE_FILES_QUEUE_LIMIT=5
|
ENRICHMENT_ADAPTIVE_FILES_QUEUE_LIMIT=5
|
||||||
ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS=4
|
ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS=4
|
||||||
ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS=4
|
ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS=4
|
||||||
|
PREFECT_YADISK_ENRICH_CONCURRENCY=8
|
||||||
|
|||||||
3
services/rag/langchain/.gitignore
vendored
3
services/rag/langchain/.gitignore
vendored
@@ -216,3 +216,6 @@ __marimo__/
|
|||||||
.streamlit/secrets.toml
|
.streamlit/secrets.toml
|
||||||
document_tracking.db
|
document_tracking.db
|
||||||
.env.test
|
.env.test
|
||||||
|
|
||||||
|
yadisk_imported_paths.csv
|
||||||
|
yadisk_imported_paths.json
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Use if possible logging, using library `loguru`, for steps. Use logrotation in f
|
|||||||
|
|
||||||
Chosen RAG framework: Langchain
|
Chosen RAG framework: Langchain
|
||||||
Chosen Vector Storage: Qdrant
|
Chosen Vector Storage: Qdrant
|
||||||
Chosen data folder: relatve ./../../../data - from the current folder
|
Chosen data folder: relative ./../../../data - from the current folder
|
||||||
|
|
||||||
# Phase 1 (cli entrypoint)
|
# Phase 1 (cli entrypoint)
|
||||||
|
|
||||||
@@ -85,7 +85,6 @@ During enrichment, we should use adaptive collection from the helpers, for loadi
|
|||||||
- [x] We still will need filetypes that we will need to skip, so while iterating over files we need to check their extension and skip them.
|
- [x] We still will need filetypes that we will need to skip, so while iterating over files we need to check their extension and skip them.
|
||||||
- [x] Adaptive files has filename in them, so it should be used when extracting metadata
|
- [x] Adaptive files has filename in them, so it should be used when extracting metadata
|
||||||
|
|
||||||
|
|
||||||
# Phase 13 (async processing of files)
|
# Phase 13 (async processing of files)
|
||||||
|
|
||||||
During this Phase we create asynchronous process of enrichment, utilizing async/await
|
During this Phase we create asynchronous process of enrichment, utilizing async/await
|
||||||
@@ -101,3 +100,32 @@ During this Phase we create asynchronous process of enrichment, utilizing async/
|
|||||||
- [x] Function process_adaptive_files_queue should be started in number of threads (defined in .env ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS)
|
- [x] Function process_adaptive_files_queue should be started in number of threads (defined in .env ENRICHMENT_ADAPTIVE_FILE_PROCESS_THREADS)
|
||||||
- [x] Function upload_processed_documents_from_queue should be started in number of threads (defined in .env ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS)
|
- [x] Function upload_processed_documents_from_queue should be started in number of threads (defined in .env ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS)
|
||||||
- [x] Program should control threads. Function insert_adaptive_files_queue, after adaptive collection ends, then should wait untill all theads finish. What does finish mean? It means when our insert_adaptive_files_queue function realizes that there is no adaptive files left in collection, it marks shared variable between threads, that collection finished. When our other functions in threads sees that this variable became true - they deplete queue and do not go to the next loop to wait for new items in queue, and just finish. This would eventually finish the program. Each thread finishes, and main program too as usual after processing all of things.
|
- [x] Program should control threads. Function insert_adaptive_files_queue, after adaptive collection ends, then should wait untill all theads finish. What does finish mean? It means when our insert_adaptive_files_queue function realizes that there is no adaptive files left in collection, it marks shared variable between threads, that collection finished. When our other functions in threads sees that this variable became true - they deplete queue and do not go to the next loop to wait for new items in queue, and just finish. This would eventually finish the program. Each thread finishes, and main program too as usual after processing all of things.
|
||||||
|
|
||||||
|
# Phase 14 (integration of Prefect client, for creating flow and tasks on remote Prefect server)
|
||||||
|
|
||||||
|
- [x] Install Prefect client library.
|
||||||
|
- [x] Add .env variable PREFECT_API_URL, that will be used for connecting client to the prefect server
|
||||||
|
- [x] Create prefect client file in `prefect/01_yadisk_analyze.py`. In this file we will work with prefect flows and tasks for this phase.
|
||||||
|
- [x] Create prefect flow called "analyze_yadisk_file_urls"
|
||||||
|
- [x] Create prefect task "iterate_yadisk_folder_and_store_file_paths" that will connect to yandex disk with yadisk library, analyze everything inside folder `Общая` recursively and store file paths in the ./../../../yadisk_files.json, in array of strings.
|
||||||
|
- [x] In our pefect file add function for flow to serve, as per prefect documentation on serving flows
|
||||||
|
- [x] Tests will be done manually by hand, by executing this script and checking prefect dashboard. No automatical tests needed for this phase.
|
||||||
|
|
||||||
|
# Phase 15 (prefect enrichment process for langchain, with predefined values, also removal of non-documet formats)
|
||||||
|
|
||||||
|
- [x] Remove for now formats, extensions for images of any kind, archives of any kind, and add possible text documents, documents formats, like .txt, .xlsx, etc. in enrichment processes/functions.
|
||||||
|
- [x] Create prefect client file in `prefect/02_yadisk_predefined_enrich.py`. This file will firt load file from ./../../../yadisk_files.json into array of paths. After that, array of paths will be filtered, and only supported in enrichment extensions will be left. After that, code will iterate through each path in this filtered array, use yadisk library to download file, process it for enrichment, and the remove it after processing. There should be statistics for this, at runtime, with progressbar that shows how many files processed out of how many left. Also, near the progressbar there should be counter of errors. Yes, if there is an error, it should be swallowed, even if it is inside thred or async function.
|
||||||
|
- [x] For yandex disk integration use library yadisk. In .env file there should be variable YADISK_TOKEN for accessing the needed connection
|
||||||
|
- [x] Code for loading should be reflected upon, and then made it so it would be done in async way, with as much as possible simulatenous tasks. yadisk async integration should be used (async features can be checked here: https://pypi.org/project/yadisk/)
|
||||||
|
- [x] No tests for code should be done at this phase, all tests will be done manually, because loading of documents can take a long time for automated test.
|
||||||
|
|
||||||
|
# Phase 16 (making demo ui scalable)
|
||||||
|
|
||||||
|
- [x] Make demo-ui window containable and reusable part of html + js. This part will be used for creating multi-windowed demo ui.
|
||||||
|
- [x] Make tabbed UI with top level tabs. First tab exists and is selected. Each tab should have copy of demo ui, meaning the chat window with ability to specify the api url
|
||||||
|
- [x] At the end of the tabs there should be button with plus sign, which will add new tab. Tabs to be called by numbers.
|
||||||
|
- [x] There should predefined 3 tabs opened. First one should have predefined api url "https://rag.langchain.overwatch.su/api/test-query", second "https://rag.llamaindex.overwatch.su/api/test-query", third "https://rag.haystack.overwatch.su/api/test-query"
|
||||||
|
|
||||||
|
# Phase 17 (creating json with list of documents that are supported for import)
|
||||||
|
|
||||||
|
- [x] Make cli command that takes json file with list of paths, filters them to only those that are being imported into the vector storage (can be checked in enrichment), then this file should be saved in the current folder as "yadisk_imported_paths.json" and in "yadisk_imported_paths.csv" file. In case of CSV - it should be formatted as csv of course.
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ from vector_storage import initialize_vector_store
|
|||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
CHAT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("CHAT_REQUEST_TIMEOUT_SECONDS", "45"))
|
||||||
|
CHAT_MAX_RETRIES = int(os.getenv("CHAT_MAX_RETRIES", "0"))
|
||||||
|
|
||||||
|
|
||||||
def get_llm_model_info(
|
def get_llm_model_info(
|
||||||
llm_model: Optional[str] = None,
|
llm_model: Optional[str] = None,
|
||||||
@@ -149,10 +152,12 @@ def create_chat_agent(
|
|||||||
openai_api_base=base_url_or_api_base,
|
openai_api_base=base_url_or_api_base,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
|
request_timeout=CHAT_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
max_retries=CHAT_MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
|
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}, timeout={CHAT_REQUEST_TIMEOUT_SECONDS}s, retries={CHAT_MAX_RETRIES}"
|
||||||
)
|
)
|
||||||
else: # Default to ollama
|
else: # Default to ollama
|
||||||
# Initialize the Ollama chat model
|
# Initialize the Ollama chat model
|
||||||
@@ -160,9 +165,13 @@ def create_chat_agent(
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
base_url=base_url_or_api_base, # Default Ollama URL
|
base_url=base_url_or_api_base, # Default Ollama URL
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
|
sync_client_kwargs={"timeout": CHAT_REQUEST_TIMEOUT_SECONDS},
|
||||||
|
async_client_kwargs={"timeout": CHAT_REQUEST_TIMEOUT_SECONDS},
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Using Ollama model: {model_name}")
|
logger.info(
|
||||||
|
f"Using Ollama model: {model_name}, timeout={CHAT_REQUEST_TIMEOUT_SECONDS}s"
|
||||||
|
)
|
||||||
|
|
||||||
# Create the document retrieval tool
|
# Create the document retrieval tool
|
||||||
retrieval_tool = DocumentRetrievalTool()
|
retrieval_tool = DocumentRetrievalTool()
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import click
|
import click
|
||||||
@@ -126,5 +128,60 @@ def chat(collection_name, model):
|
|||||||
click.echo(f"Error: {str(e)}")
|
click.echo(f"Error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(
|
||||||
|
name="export-supported-paths",
|
||||||
|
help="Filter JSON paths by enrichment-supported extensions and export JSON/CSV",
|
||||||
|
)
|
||||||
|
@click.argument("input_json", type=click.Path(exists=True, dir_okay=False, path_type=Path))
|
||||||
|
def export_supported_paths(input_json: Path):
|
||||||
|
"""Export supported document paths into yadisk_imported_paths.json and yadisk_imported_paths.csv."""
|
||||||
|
logger.info(f"Filtering supported paths from input file: {input_json}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from enrichment import SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
with input_json.open("r", encoding="utf-8") as source_file:
|
||||||
|
raw_data = json.load(source_file)
|
||||||
|
|
||||||
|
if not isinstance(raw_data, list):
|
||||||
|
raise ValueError("Input JSON must contain an array of file paths")
|
||||||
|
|
||||||
|
filtered_paths = []
|
||||||
|
seen_paths = set()
|
||||||
|
for item in raw_data:
|
||||||
|
path_str = str(item).strip()
|
||||||
|
if not path_str:
|
||||||
|
continue
|
||||||
|
if path_str in seen_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
extension = Path(path_str).suffix.lower()
|
||||||
|
if extension in SUPPORTED_EXTENSIONS:
|
||||||
|
filtered_paths.append(path_str)
|
||||||
|
seen_paths.add(path_str)
|
||||||
|
|
||||||
|
output_json = Path.cwd() / "yadisk_imported_paths.json"
|
||||||
|
output_csv = Path.cwd() / "yadisk_imported_paths.csv"
|
||||||
|
|
||||||
|
with output_json.open("w", encoding="utf-8") as output_json_file:
|
||||||
|
json.dump(filtered_paths, output_json_file, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
with output_csv.open("w", encoding="utf-8", newline="") as output_csv_file:
|
||||||
|
writer = csv.writer(output_csv_file)
|
||||||
|
writer.writerow(["path"])
|
||||||
|
for path_item in filtered_paths:
|
||||||
|
writer.writerow([path_item])
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
f"Export complete: {len(filtered_paths)} supported paths saved to {output_json.name} and {output_csv.name}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Exported {len(filtered_paths)} supported paths to {output_json} and {output_csv}"
|
||||||
|
)
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"Failed to export supported paths: {error}")
|
||||||
|
click.echo(f"Error: {error}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@@ -3,117 +3,297 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>RAG Solution Chat Interface</title>
|
<title>RAG Multi-Window Demo</title>
|
||||||
<style>
|
<style>
|
||||||
|
:root {
|
||||||
|
--bg: #f1efe8;
|
||||||
|
--paper: #fffdf7;
|
||||||
|
--ink: #1f2937;
|
||||||
|
--muted: #6b7280;
|
||||||
|
--line: #dfd8c9;
|
||||||
|
--accent: #0f766e;
|
||||||
|
--accent-2: #d97706;
|
||||||
|
--bot: #ece8dc;
|
||||||
|
--user: #115e59;
|
||||||
|
--danger-bg: #fde8e8;
|
||||||
|
--danger-ink: #9b1c1c;
|
||||||
|
}
|
||||||
|
|
||||||
* {
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
margin: 0;
|
margin: 0;
|
||||||
padding: 0;
|
padding: 0;
|
||||||
box-sizing: border-box;
|
|
||||||
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body {
|
body {
|
||||||
background-color: #f5f7fa;
|
background:
|
||||||
color: #333;
|
radial-gradient(circle at 15% 15%, #f9d9a8 0%, transparent 35%),
|
||||||
line-height: 1.6;
|
radial-gradient(circle at 85% 20%, #b7e3d8 0%, transparent 40%),
|
||||||
|
linear-gradient(180deg, #f5f0e4 0%, #ede7d8 100%);
|
||||||
|
color: var(--ink);
|
||||||
|
font-family: "Trebuchet MS", "Segoe UI", sans-serif;
|
||||||
|
min-height: 100vh;
|
||||||
}
|
}
|
||||||
|
|
||||||
.container {
|
.app {
|
||||||
max-width: 900px;
|
max-width: 1100px;
|
||||||
margin: 0 auto;
|
margin: 0 auto;
|
||||||
padding: 20px;
|
padding: 20px 14px 24px;
|
||||||
}
|
}
|
||||||
|
|
||||||
header {
|
.shell {
|
||||||
background: linear-gradient(135deg, #6a11cb 0%, #2575fc 100%);
|
border: 1px solid rgba(70, 62, 43, 0.15);
|
||||||
color: white;
|
border-radius: 16px;
|
||||||
padding: 20px;
|
background: rgba(255, 253, 247, 0.92);
|
||||||
border-radius: 10px;
|
box-shadow: 0 18px 45px rgba(47, 41, 30, 0.12);
|
||||||
margin-bottom: 20px;
|
overflow: hidden;
|
||||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
backdrop-filter: blur(6px);
|
||||||
}
|
}
|
||||||
|
|
||||||
h1 {
|
.shell-header {
|
||||||
font-size: 1.8rem;
|
padding: 14px 18px;
|
||||||
margin-bottom: 10px;
|
border-bottom: 1px solid var(--line);
|
||||||
|
background: linear-gradient(180deg, rgba(255,255,255,0.9), rgba(244,239,228,0.85));
|
||||||
}
|
}
|
||||||
|
|
||||||
.api-endpoint-container {
|
.shell-header h1 {
|
||||||
display: flex;
|
font-size: 1.25rem;
|
||||||
gap: 10px;
|
letter-spacing: 0.02em;
|
||||||
margin-top: 15px;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.api-endpoint-container label {
|
.shell-header p {
|
||||||
|
margin-top: 4px;
|
||||||
|
color: var(--muted);
|
||||||
|
font-size: 0.92rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs-bar {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
font-weight: bold;
|
gap: 8px;
|
||||||
|
padding: 10px 12px;
|
||||||
|
border-bottom: 1px solid var(--line);
|
||||||
|
background: rgba(247, 243, 233, 0.9);
|
||||||
|
overflow-x: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
.api-endpoint-container input {
|
.tab-btn {
|
||||||
flex: 1;
|
border: 1px solid var(--line);
|
||||||
min-width: 300px;
|
background: #faf7ee;
|
||||||
|
color: var(--ink);
|
||||||
padding: 8px 12px;
|
padding: 8px 12px;
|
||||||
border: none;
|
border-radius: 999px;
|
||||||
border-radius: 4px;
|
|
||||||
margin-left: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.api-endpoint-container button {
|
|
||||||
background-color: #fff;
|
|
||||||
color: #2575fc;
|
|
||||||
border: none;
|
|
||||||
padding: 8px 15px;
|
|
||||||
border-radius: 4px;
|
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
font-weight: bold;
|
font-weight: 700;
|
||||||
transition: background-color 0.3s;
|
white-space: nowrap;
|
||||||
|
transition: transform 0.15s ease, background-color 0.15s ease, border-color 0.15s ease;
|
||||||
}
|
}
|
||||||
|
|
||||||
.api-endpoint-container button:hover {
|
.tab-btn:hover {
|
||||||
background-color: #e6f0ff;
|
transform: translateY(-1px);
|
||||||
|
border-color: #c9bea6;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-container {
|
.tab-btn.active {
|
||||||
background-color: white;
|
background: var(--accent);
|
||||||
|
border-color: var(--accent);
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tab-btn.add-tab {
|
||||||
|
min-width: 38px;
|
||||||
|
padding-inline: 0;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 1.1rem;
|
||||||
|
background: #fff;
|
||||||
|
color: var(--accent-2);
|
||||||
|
border-color: #e5c792;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-host {
|
||||||
|
padding: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-panel {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-toolbar {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: auto 1fr auto;
|
||||||
|
gap: 10px;
|
||||||
|
align-items: center;
|
||||||
|
background: var(--paper);
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-toolbar label {
|
||||||
|
font-weight: 700;
|
||||||
|
color: #4b5563;
|
||||||
|
font-size: 0.92rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.endpoint-input {
|
||||||
|
width: 100%;
|
||||||
|
border: 1px solid #d8cfbd;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
padding: 10px 12px;
|
||||||
|
font-size: 0.94rem;
|
||||||
|
background: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.endpoint-input:focus,
|
||||||
|
.message-input:focus {
|
||||||
|
outline: 2px solid rgba(15, 118, 110, 0.16);
|
||||||
|
border-color: var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-toolbar button,
|
||||||
|
.send-btn {
|
||||||
|
border: none;
|
||||||
|
border-radius: 10px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-weight: 700;
|
||||||
|
}
|
||||||
|
|
||||||
|
.set-endpoint-btn {
|
||||||
|
padding: 10px 12px;
|
||||||
|
background: #fff;
|
||||||
|
color: var(--accent);
|
||||||
|
border: 1px solid #b9d4cf;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-card {
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 14px;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
height: 60vh;
|
background: #fff;
|
||||||
|
min-height: 62vh;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-header {
|
.chat-header {
|
||||||
background-color: #f8f9fa;
|
display: flex;
|
||||||
padding: 15px;
|
justify-content: space-between;
|
||||||
border-bottom: 1px solid #eaeaea;
|
align-items: center;
|
||||||
font-weight: bold;
|
gap: 12px;
|
||||||
color: #495057;
|
padding: 12px 14px;
|
||||||
|
border-bottom: 1px solid var(--line);
|
||||||
|
background: #fbf8ef;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header-title {
|
||||||
|
font-weight: 800;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header-endpoint {
|
||||||
|
color: var(--muted);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
text-align: right;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-messages {
|
.chat-messages {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
padding: 20px;
|
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
|
padding: 14px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
gap: 15px;
|
gap: 12px;
|
||||||
|
background:
|
||||||
|
linear-gradient(180deg, rgba(255,255,255,0.92), rgba(250,247,238,0.95)),
|
||||||
|
repeating-linear-gradient(
|
||||||
|
0deg,
|
||||||
|
rgba(218, 206, 181, 0.15),
|
||||||
|
rgba(218, 206, 181, 0.15) 1px,
|
||||||
|
transparent 1px,
|
||||||
|
transparent 28px
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
max-width: 80%;
|
max-width: 82%;
|
||||||
padding: 12px 16px;
|
padding: 11px 13px;
|
||||||
border-radius: 18px;
|
border-radius: 14px;
|
||||||
position: relative;
|
line-height: 1.45;
|
||||||
animation: fadeIn 0.3s ease;
|
animation: slideIn 0.2s ease;
|
||||||
|
box-shadow: 0 1px 0 rgba(0, 0, 0, 0.03);
|
||||||
}
|
}
|
||||||
|
|
||||||
@keyframes fadeIn {
|
.message.user-message {
|
||||||
|
align-self: flex-end;
|
||||||
|
background: var(--user);
|
||||||
|
color: #fff;
|
||||||
|
border-bottom-right-radius: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.bot-message {
|
||||||
|
align-self: flex-start;
|
||||||
|
background: var(--bot);
|
||||||
|
color: #3c3f44;
|
||||||
|
border-bottom-left-radius: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.error-message {
|
||||||
|
align-self: flex-start;
|
||||||
|
background: var(--danger-bg);
|
||||||
|
color: var(--danger-ink);
|
||||||
|
border: 1px solid #f3bcbc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.typing-indicator {
|
||||||
|
align-self: flex-start;
|
||||||
|
background: #eef2f7;
|
||||||
|
color: #475569;
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-row {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr auto;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 12px;
|
||||||
|
border-top: 1px solid var(--line);
|
||||||
|
background: #fbf8ef;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-input {
|
||||||
|
border: 1px solid #d8cfbd;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 11px 12px;
|
||||||
|
font-size: 0.95rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.send-btn {
|
||||||
|
background: var(--accent);
|
||||||
|
color: #fff;
|
||||||
|
padding: 0 16px;
|
||||||
|
min-width: 86px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.send-btn:disabled,
|
||||||
|
.set-endpoint-btn:disabled {
|
||||||
|
opacity: 0.55;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.footer-note {
|
||||||
|
padding: 8px 14px 14px;
|
||||||
|
color: var(--muted);
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes slideIn {
|
||||||
from {
|
from {
|
||||||
opacity: 0;
|
opacity: 0;
|
||||||
transform: translateY(10px);
|
transform: translateY(6px);
|
||||||
}
|
}
|
||||||
to {
|
to {
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
@@ -121,281 +301,274 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.user-message {
|
@media (max-width: 720px) {
|
||||||
align-self: flex-end;
|
.panel-toolbar {
|
||||||
background-color: #2575fc;
|
grid-template-columns: 1fr;
|
||||||
color: white;
|
|
||||||
border-bottom-right-radius: 4px;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.bot-message {
|
.chat-header {
|
||||||
align-self: flex-start;
|
|
||||||
background-color: #e9ecef;
|
|
||||||
color: #495057;
|
|
||||||
border-bottom-left-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.error-message {
|
|
||||||
align-self: flex-start;
|
|
||||||
background-color: #f8d7da;
|
|
||||||
color: #721c24;
|
|
||||||
border: 1px solid #f5c6cb;
|
|
||||||
border-radius: 18px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-area {
|
|
||||||
display: flex;
|
|
||||||
padding: 15px;
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
border-top: 1px solid #eaeaea;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-area input {
|
|
||||||
flex: 1;
|
|
||||||
padding: 12px 15px;
|
|
||||||
border: 1px solid #ddd;
|
|
||||||
border-radius: 24px;
|
|
||||||
outline: none;
|
|
||||||
font-size: 1rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-area button {
|
|
||||||
background-color: #2575fc;
|
|
||||||
color: white;
|
|
||||||
border: none;
|
|
||||||
padding: 12px 20px;
|
|
||||||
border-radius: 24px;
|
|
||||||
margin-left: 10px;
|
|
||||||
cursor: pointer;
|
|
||||||
font-weight: bold;
|
|
||||||
transition: background-color 0.3s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-area button:hover {
|
|
||||||
background-color: #1a68e8;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-area button:disabled {
|
|
||||||
background-color: #adb5bd;
|
|
||||||
cursor: not-allowed;
|
|
||||||
}
|
|
||||||
|
|
||||||
.typing-indicator {
|
|
||||||
align-self: flex-start;
|
|
||||||
background-color: #e9ecef;
|
|
||||||
color: #495057;
|
|
||||||
padding: 12px 16px;
|
|
||||||
border-radius: 18px;
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
footer {
|
|
||||||
text-align: center;
|
|
||||||
margin-top: 20px;
|
|
||||||
color: #6c757d;
|
|
||||||
font-size: 0.9rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
@media (max-width: 768px) {
|
|
||||||
.container {
|
|
||||||
padding: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.api-endpoint-container {
|
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
}
|
}
|
||||||
|
|
||||||
.api-endpoint-container input {
|
.chat-header-endpoint {
|
||||||
min-width: auto;
|
text-align: left;
|
||||||
|
white-space: normal;
|
||||||
|
word-break: break-all;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
max-width: 90%;
|
max-width: 92%;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="app">
|
||||||
<header>
|
<div class="shell">
|
||||||
<h1>RAG Solution Chat Interface</h1>
|
<div class="shell-header">
|
||||||
<div class="api-endpoint-container">
|
<h1>RAG Demo Control Room</h1>
|
||||||
<label for="apiEndpoint">API Endpoint:</label>
|
<p>Multiple chat windows with independent API endpoints.</p>
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
id="apiEndpoint"
|
|
||||||
value="http://localhost:8000/api/test-query"
|
|
||||||
placeholder="Enter API endpoint URL"
|
|
||||||
/>
|
|
||||||
<button onclick="setApiEndpoint()">Set Endpoint</button>
|
|
||||||
</div>
|
</div>
|
||||||
</header>
|
|
||||||
|
|
||||||
<div class="chat-container">
|
<div class="tabs-bar" id="tabsBar"></div>
|
||||||
<div class="chat-header">Chat with RAG Agent</div>
|
<div class="panel-host" id="panelHost"></div>
|
||||||
<div class="chat-messages" id="chatMessages">
|
<div class="footer-note">Phase 16 scalable demo UI: reusable chat panel + tabbed windows.</div>
|
||||||
<div class="message bot-message">
|
|
||||||
Hello! I'm your RAG agent. Please enter your API endpoint and start
|
|
||||||
chatting.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="input-area">
|
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
id="userInput"
|
|
||||||
placeholder="Type your message here..."
|
|
||||||
onkeypress="handleKeyPress(event)"
|
|
||||||
/>
|
|
||||||
<button onclick="sendMessage()" id="sendButton">Send</button>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<footer>
|
<template id="chatPanelTemplate">
|
||||||
<p>RAG Solution with LangChain | Chat Interface Demo</p>
|
<div class="chat-panel">
|
||||||
</footer>
|
<div class="panel-toolbar">
|
||||||
|
<label>API Endpoint</label>
|
||||||
|
<input class="endpoint-input" type="text" placeholder="Enter API endpoint URL" />
|
||||||
|
<button class="set-endpoint-btn" type="button">Set Endpoint</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="chat-card">
|
||||||
|
<div class="chat-header">
|
||||||
|
<div class="chat-header-title">Chat with RAG Agent</div>
|
||||||
|
<div class="chat-header-endpoint">Endpoint: not set</div>
|
||||||
|
</div>
|
||||||
|
<div class="chat-messages"></div>
|
||||||
|
<div class="chat-input-row">
|
||||||
|
<input class="message-input" type="text" placeholder="Type your message here..." />
|
||||||
|
<button class="send-btn" type="button">Send</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
// Store the API endpoint
|
const DEFAULT_TAB_ENDPOINTS = [
|
||||||
let apiEndpoint = document.getElementById("apiEndpoint").value;
|
"https://rag.langchain.overwatch.su/api/test-query",
|
||||||
|
"https://rag.llamaindex.overwatch.su/api/test-query",
|
||||||
|
"https://rag.haystack.overwatch.su/api/test-query",
|
||||||
|
];
|
||||||
|
|
||||||
// Set the API endpoint from the input field
|
class ChatPanel {
|
||||||
function setApiEndpoint() {
|
constructor(rootElement, initialEndpoint = "") {
|
||||||
const input = document.getElementById("apiEndpoint");
|
this.root = rootElement;
|
||||||
apiEndpoint = input.value.trim();
|
this.apiEndpoint = initialEndpoint;
|
||||||
|
|
||||||
if (!apiEndpoint) {
|
this.endpointInput = this.root.querySelector(".endpoint-input");
|
||||||
|
this.setEndpointButton = this.root.querySelector(".set-endpoint-btn");
|
||||||
|
this.headerEndpoint = this.root.querySelector(".chat-header-endpoint");
|
||||||
|
this.messagesEl = this.root.querySelector(".chat-messages");
|
||||||
|
this.messageInput = this.root.querySelector(".message-input");
|
||||||
|
this.sendButton = this.root.querySelector(".send-btn");
|
||||||
|
|
||||||
|
this.endpointInput.value = initialEndpoint;
|
||||||
|
this._renderEndpointLabel();
|
||||||
|
this._bindEvents();
|
||||||
|
this.addMessage(
|
||||||
|
initialEndpoint
|
||||||
|
? `Ready. Endpoint preset to: ${initialEndpoint}`
|
||||||
|
: "Hello. Set an API endpoint and start chatting.",
|
||||||
|
"bot-message",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
_bindEvents() {
|
||||||
|
this.setEndpointButton.addEventListener("click", () => this.setApiEndpoint());
|
||||||
|
this.sendButton.addEventListener("click", () => this.sendMessage());
|
||||||
|
this.messageInput.addEventListener("keydown", (event) => {
|
||||||
|
if (event.key === "Enter") {
|
||||||
|
event.preventDefault();
|
||||||
|
this.sendMessage();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
focusInput() {
|
||||||
|
this.messageInput.focus();
|
||||||
|
}
|
||||||
|
|
||||||
|
_renderEndpointLabel() {
|
||||||
|
this.headerEndpoint.textContent = this.apiEndpoint
|
||||||
|
? `Endpoint: ${this.apiEndpoint}`
|
||||||
|
: "Endpoint: not set";
|
||||||
|
}
|
||||||
|
|
||||||
|
setApiEndpoint() {
|
||||||
|
const candidate = this.endpointInput.value.trim();
|
||||||
|
if (!candidate) {
|
||||||
alert("Please enter a valid API endpoint URL");
|
alert("Please enter a valid API endpoint URL");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
this.apiEndpoint = candidate;
|
||||||
// Add notification that endpoint was set
|
this._renderEndpointLabel();
|
||||||
addMessage(`API endpoint set to: ${apiEndpoint}`, "bot-message");
|
this.addMessage(`API endpoint set to: ${candidate}`, "bot-message");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a message to the API
|
addMessage(text, className, extraClass = "") {
|
||||||
async function sendMessage() {
|
const messageDiv = document.createElement("div");
|
||||||
const inputElement = document.getElementById("userInput");
|
messageDiv.className = `message ${className} ${extraClass}`.trim();
|
||||||
const message = inputElement.value.trim();
|
messageDiv.innerHTML = String(text).replace(/\n/g, "<br>");
|
||||||
const sendButton = document.getElementById("sendButton");
|
this.messagesEl.appendChild(messageDiv);
|
||||||
|
this.messagesEl.scrollTop = this.messagesEl.scrollHeight;
|
||||||
if (!message) {
|
return messageDiv;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!apiEndpoint) {
|
removeMessage(node) {
|
||||||
|
if (node && node.parentNode) {
|
||||||
|
node.parentNode.removeChild(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async sendMessage() {
|
||||||
|
const message = this.messageInput.value.trim();
|
||||||
|
if (!message) return;
|
||||||
|
if (!this.apiEndpoint) {
|
||||||
alert("Please set the API endpoint first");
|
alert("Please set the API endpoint first");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable the send button and input during request
|
this.sendButton.disabled = true;
|
||||||
sendButton.disabled = true;
|
this.messageInput.disabled = true;
|
||||||
inputElement.disabled = true;
|
|
||||||
|
|
||||||
|
let typingIndicator = null;
|
||||||
try {
|
try {
|
||||||
// Add user message to chat
|
this.addMessage(message, "user-message");
|
||||||
addMessage(message, "user-message");
|
this.messageInput.value = "";
|
||||||
|
|
||||||
// Clear input
|
typingIndicator = this.addMessage("Thinking...", "typing-indicator", "typing");
|
||||||
inputElement.value = "";
|
|
||||||
|
|
||||||
// Show typing indicator
|
const response = await fetch(this.apiEndpoint, {
|
||||||
const typingIndicator = addMessage(
|
|
||||||
"Thinking...",
|
|
||||||
"typing-indicator",
|
|
||||||
"typing",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Send request to API
|
|
||||||
const response = await fetch(apiEndpoint, {
|
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: { "Content-Type": "application/json" },
|
||||||
"Content-Type": "application/json",
|
body: JSON.stringify({ query: message }),
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
query: message,
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Remove typing indicator
|
this.removeMessage(typingIndicator);
|
||||||
removeMessage(typingIndicator);
|
typingIndicator = null;
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(
|
throw new Error(`API request failed with status ${response.status}`);
|
||||||
`API request failed with status ${response.status}`,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
// Add bot response to chat
|
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
addMessage(data.response, "bot-message");
|
this.addMessage(data.response, "bot-message");
|
||||||
} else {
|
} else {
|
||||||
addMessage(
|
this.addMessage(
|
||||||
`Error: ${data.error || "Unknown error occurred"}`,
|
`Error: ${data.error || "Unknown error occurred"}`,
|
||||||
"error-message",
|
"error-message",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error:", error);
|
if (typingIndicator) this.removeMessage(typingIndicator);
|
||||||
// Remove typing indicator if still present
|
this.addMessage(
|
||||||
const typingElements = document.querySelectorAll(".typing");
|
`Connection error: ${error.message}. Please check the API endpoint and try again. Reload page if needed to reinitialize endpoint state.`,
|
||||||
typingElements.forEach((el) => el.remove());
|
|
||||||
|
|
||||||
// Add error message to chat
|
|
||||||
addMessage(
|
|
||||||
`Connection error: ${error.message}. Please check the API endpoint and try again.`,
|
|
||||||
"error-message",
|
"error-message",
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
// Re-enable the send button and input
|
this.sendButton.disabled = false;
|
||||||
sendButton.disabled = false;
|
this.messageInput.disabled = false;
|
||||||
inputElement.disabled = false;
|
this.messageInput.focus();
|
||||||
inputElement.focus();
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a message to the chat
|
class MultiChatApp {
|
||||||
function addMessage(text, className, id = null) {
|
constructor() {
|
||||||
const chatMessages = document.getElementById("chatMessages");
|
this.tabsBar = document.getElementById("tabsBar");
|
||||||
const messageDiv = document.createElement("div");
|
this.panelHost = document.getElementById("panelHost");
|
||||||
messageDiv.className = `message ${className}`;
|
this.panelTemplate = document.getElementById("chatPanelTemplate");
|
||||||
|
this.tabs = [];
|
||||||
if (id) {
|
this.activeTabId = null;
|
||||||
messageDiv.id = id;
|
this.nextTabNumber = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format text with line breaks
|
init() {
|
||||||
const formattedText = text.replace(/\n/g, "<br>");
|
DEFAULT_TAB_ENDPOINTS.forEach((endpoint) => this.createTab(endpoint));
|
||||||
messageDiv.innerHTML = formattedText;
|
if (this.tabs.length > 0) {
|
||||||
|
this.selectTab(this.tabs[0].id);
|
||||||
chatMessages.appendChild(messageDiv);
|
}
|
||||||
|
this.renderTabs();
|
||||||
// Scroll to bottom
|
|
||||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
|
||||||
|
|
||||||
return messageDiv;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove a message from the chat
|
createTab(initialEndpoint = "") {
|
||||||
function removeMessage(element) {
|
const tab = {
|
||||||
if (element && element.parentNode) {
|
id: `tab-${crypto.randomUUID ? crypto.randomUUID() : `${Date.now()}-${Math.random()}`}`,
|
||||||
element.parentNode.removeChild(element);
|
title: String(this.nextTabNumber++),
|
||||||
}
|
endpoint: initialEndpoint,
|
||||||
}
|
panel: null,
|
||||||
|
panelNode: null,
|
||||||
// Handle Enter key press in the input field
|
|
||||||
function handleKeyPress(event) {
|
|
||||||
if (event.key === "Enter") {
|
|
||||||
sendMessage();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Focus on the input field when the page loads
|
|
||||||
window.onload = function () {
|
|
||||||
document.getElementById("userInput").focus();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const fragment = this.panelTemplate.content.cloneNode(true);
|
||||||
|
const panelNode = fragment.firstElementChild;
|
||||||
|
const panel = new ChatPanel(panelNode, initialEndpoint);
|
||||||
|
tab.panel = panel;
|
||||||
|
tab.panelNode = panelNode;
|
||||||
|
this.tabs.push(tab);
|
||||||
|
this.renderTabs();
|
||||||
|
return tab;
|
||||||
|
}
|
||||||
|
|
||||||
|
selectTab(tabId) {
|
||||||
|
this.activeTabId = tabId;
|
||||||
|
const tab = this.tabs.find((item) => item.id === tabId);
|
||||||
|
if (!tab) return;
|
||||||
|
|
||||||
|
this.panelHost.innerHTML = "";
|
||||||
|
this.panelHost.appendChild(tab.panelNode);
|
||||||
|
tab.panel._renderEndpointLabel();
|
||||||
|
tab.panel.focusInput();
|
||||||
|
this.renderTabs();
|
||||||
|
}
|
||||||
|
|
||||||
|
renderTabs() {
|
||||||
|
this.tabsBar.innerHTML = "";
|
||||||
|
|
||||||
|
this.tabs.forEach((tab) => {
|
||||||
|
const btn = document.createElement("button");
|
||||||
|
btn.type = "button";
|
||||||
|
btn.className = `tab-btn ${tab.id === this.activeTabId ? "active" : ""}`.trim();
|
||||||
|
btn.textContent = tab.title;
|
||||||
|
btn.title = tab.endpoint || `Tab ${tab.title}`;
|
||||||
|
btn.addEventListener("click", () => this.selectTab(tab.id));
|
||||||
|
this.tabsBar.appendChild(btn);
|
||||||
|
});
|
||||||
|
|
||||||
|
const addBtn = document.createElement("button");
|
||||||
|
addBtn.type = "button";
|
||||||
|
addBtn.className = "tab-btn add-tab";
|
||||||
|
addBtn.textContent = "+";
|
||||||
|
addBtn.title = "Add new tab";
|
||||||
|
addBtn.addEventListener("click", () => {
|
||||||
|
const tab = this.createTab("");
|
||||||
|
this.selectTab(tab.id);
|
||||||
|
});
|
||||||
|
this.tabsBar.appendChild(addBtn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener("DOMContentLoaded", () => {
|
||||||
|
const app = new MultiChatApp();
|
||||||
|
app.init();
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -75,21 +75,26 @@ ENRICHMENT_ADAPTIVE_DOCUMENT_UPLOADS_THREADS = int(
|
|||||||
)
|
)
|
||||||
|
|
||||||
SUPPORTED_EXTENSIONS = {
|
SUPPORTED_EXTENSIONS = {
|
||||||
".pdf",
|
".csv",
|
||||||
".docx",
|
|
||||||
".doc",
|
".doc",
|
||||||
".pptx",
|
".docx",
|
||||||
".xlsx",
|
".epub",
|
||||||
".xls",
|
".htm",
|
||||||
".jpg",
|
".html",
|
||||||
".jpeg",
|
".json",
|
||||||
".png",
|
".jsonl",
|
||||||
".gif",
|
".md",
|
||||||
".bmp",
|
|
||||||
".tiff",
|
|
||||||
".webp",
|
|
||||||
".odt",
|
".odt",
|
||||||
".txt", # this one is obvious but was unexpected to see in data lol
|
".pdf",
|
||||||
|
".ppt",
|
||||||
|
".pptx",
|
||||||
|
".rtf",
|
||||||
|
".rst",
|
||||||
|
".tsv",
|
||||||
|
".txt",
|
||||||
|
".xls",
|
||||||
|
".xlsx",
|
||||||
|
".xml",
|
||||||
}
|
}
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
@@ -261,6 +266,8 @@ class DocumentEnricher:
|
|||||||
return UnstructuredODTLoader(
|
return UnstructuredODTLoader(
|
||||||
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
file_path, **{"strategy": "hi_res", "languages": ["rus"]}
|
||||||
)
|
)
|
||||||
|
if ext in [".txt", ".md"]:
|
||||||
|
return TextLoader(file_path, encoding="utf-8")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_one_adaptive_file(
|
def _load_one_adaptive_file(
|
||||||
@@ -273,7 +280,7 @@ class DocumentEnricher:
|
|||||||
extension = adaptive_file.extension.lower()
|
extension = adaptive_file.extension.lower()
|
||||||
file_type = try_guess_file_type(extension)
|
file_type = try_guess_file_type(extension)
|
||||||
|
|
||||||
def process_local_file(local_file_path: str):
|
def process_local_file(original_path: str, local_file_path: str):
|
||||||
nonlocal loaded_docs, processed_record
|
nonlocal loaded_docs, processed_record
|
||||||
|
|
||||||
file_hash = self._get_file_hash(local_file_path)
|
file_hash = self._get_file_hash(local_file_path)
|
||||||
@@ -295,7 +302,7 @@ class DocumentEnricher:
|
|||||||
doc.metadata["file_type"] = file_type
|
doc.metadata["file_type"] = file_type
|
||||||
doc.metadata["source"] = source_identifier
|
doc.metadata["source"] = source_identifier
|
||||||
doc.metadata["filename"] = adaptive_file.filename
|
doc.metadata["filename"] = adaptive_file.filename
|
||||||
doc.metadata["file_path"] = source_identifier
|
doc.metadata["file_path"] = original_path
|
||||||
doc.metadata["file_size"] = os.path.getsize(local_file_path)
|
doc.metadata["file_size"] = os.path.getsize(local_file_path)
|
||||||
doc.metadata["file_extension"] = extension
|
doc.metadata["file_extension"] = extension
|
||||||
|
|
||||||
@@ -310,7 +317,7 @@ class DocumentEnricher:
|
|||||||
)
|
)
|
||||||
|
|
||||||
loaded_docs = split_docs
|
loaded_docs = split_docs
|
||||||
processed_record = (source_identifier, file_hash)
|
processed_record = (original_path, file_hash)
|
||||||
|
|
||||||
adaptive_file.work_with_file_locally(process_local_file)
|
adaptive_file.work_with_file_locally(process_local_file)
|
||||||
return loaded_docs, processed_record
|
return loaded_docs, processed_record
|
||||||
|
|||||||
@@ -123,9 +123,9 @@ class _AdaptiveFile(ABC):
|
|||||||
|
|
||||||
# This method allows to work with file locally, and lambda should be provided for this.
|
# This method allows to work with file locally, and lambda should be provided for this.
|
||||||
# Why separate method? For possible cleanup after work is done. And to download file, if needed
|
# Why separate method? For possible cleanup after work is done. And to download file, if needed
|
||||||
# Lambda: first argument is a local path
|
# Lambda: first argument is an original path, second: local path. In case of just local files, these will be the same
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def work_with_file_locally(self, func: Callable[[str], None]):
|
def work_with_file_locally(self, func: Callable[[str, str], None]):
|
||||||
"""Run callback with a local path to the file."""
|
"""Run callback with a local path to the file."""
|
||||||
|
|
||||||
|
|
||||||
@@ -143,8 +143,8 @@ class LocalFilesystemAdaptiveFile(_AdaptiveFile):
|
|||||||
super().__init__(filename, extension)
|
super().__init__(filename, extension)
|
||||||
self.local_path = local_path
|
self.local_path = local_path
|
||||||
|
|
||||||
def work_with_file_locally(self, func: Callable[[str], None]):
|
def work_with_file_locally(self, func: Callable[[str, str], None]):
|
||||||
func(self.local_path)
|
func(self.local_path, self.local_path)
|
||||||
|
|
||||||
|
|
||||||
class LocalFilesystemAdaptiveCollection(_AdaptiveCollection):
|
class LocalFilesystemAdaptiveCollection(_AdaptiveCollection):
|
||||||
@@ -196,10 +196,10 @@ class YandexDiskAdaptiveFile(_AdaptiveFile):
|
|||||||
temp_file.write(file_response.content)
|
temp_file.write(file_response.content)
|
||||||
return temp_file.name
|
return temp_file.name
|
||||||
|
|
||||||
def work_with_file_locally(self, func: Callable[[str], None]):
|
def work_with_file_locally(self, func: Callable[[str, str], None]):
|
||||||
temp_path = self._download_to_temp_file()
|
temp_path = self._download_to_temp_file()
|
||||||
try:
|
try:
|
||||||
func(temp_path)
|
func(self.remote_path, temp_path)
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|||||||
85
services/rag/langchain/prefect/01_yadisk_analyze.py
Normal file
85
services/rag/langchain/prefect/01_yadisk_analyze.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Prefect flow for analyzing Yandex Disk files and storing file paths."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from prefect.logging import get_run_logger
|
||||||
|
|
||||||
|
from prefect import flow, task
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
PREFECT_API_URL = os.getenv("PREFECT_API_URL")
|
||||||
|
YADISK_TOKEN = os.getenv("YADISK_TOKEN")
|
||||||
|
YADISK_ROOT_PATH = "Общая"
|
||||||
|
OUTPUT_JSON_PATH = (
|
||||||
|
Path(__file__).resolve().parent.parent / "../../../yadisk_files.json"
|
||||||
|
).resolve()
|
||||||
|
|
||||||
|
if PREFECT_API_URL:
|
||||||
|
os.environ["PREFECT_API_URL"] = PREFECT_API_URL
|
||||||
|
|
||||||
|
|
||||||
|
@task(name="iterate_yadisk_folder_and_store_file_paths")
|
||||||
|
def iterate_yadisk_folder_and_store_file_paths() -> List[str]:
|
||||||
|
"""Iterate Yandex Disk recursively from `Общая` and save file paths to JSON."""
|
||||||
|
if not YADISK_TOKEN:
|
||||||
|
raise ValueError("YADISK_TOKEN is required to analyze Yandex Disk")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import yadisk
|
||||||
|
except ImportError as error:
|
||||||
|
raise RuntimeError(
|
||||||
|
"yadisk package is required for this task. Install dependencies in venv first."
|
||||||
|
) from error
|
||||||
|
|
||||||
|
yandex_disk = yadisk.YaDisk(token=YADISK_TOKEN)
|
||||||
|
file_paths: List[str] = []
|
||||||
|
|
||||||
|
logger = get_run_logger()
|
||||||
|
|
||||||
|
def walk_folder(folder_path: str) -> None:
|
||||||
|
for item in yandex_disk.listdir(folder_path):
|
||||||
|
item_type = getattr(item, "type", None)
|
||||||
|
item_path = getattr(item, "path", None)
|
||||||
|
|
||||||
|
if item_path is None and isinstance(item, dict):
|
||||||
|
item_type = item.get("type")
|
||||||
|
item_path = item.get("path")
|
||||||
|
|
||||||
|
if not item_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item_type == "dir":
|
||||||
|
walk_folder(item_path)
|
||||||
|
elif item_type == "file":
|
||||||
|
logger.info(f"Added {len(file_paths)} file into the list")
|
||||||
|
file_paths.append(item_path)
|
||||||
|
|
||||||
|
walk_folder(YADISK_ROOT_PATH)
|
||||||
|
|
||||||
|
OUTPUT_JSON_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(OUTPUT_JSON_PATH, "w", encoding="utf-8") as output_file:
|
||||||
|
json.dump(file_paths, output_file, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return file_paths
|
||||||
|
|
||||||
|
|
||||||
|
@flow(name="analyze_yadisk_file_urls")
|
||||||
|
def analyze_yadisk_file_urls() -> List[str]:
|
||||||
|
"""Run Yandex Disk analysis task and return collected file paths."""
|
||||||
|
return iterate_yadisk_folder_and_store_file_paths()
|
||||||
|
|
||||||
|
|
||||||
|
def serve_analyze_yadisk_file_urls() -> None:
|
||||||
|
"""Serve the flow as a deployment for remote Prefect execution."""
|
||||||
|
analyze_yadisk_file_urls.serve(name="analyze-yadisk-file-urls")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve_analyze_yadisk_file_urls()
|
||||||
216
services/rag/langchain/prefect/02_yadisk_predefined_enrich.py
Normal file
216
services/rag/langchain/prefect/02_yadisk_predefined_enrich.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""Prefect flow to enrich Yandex Disk files from a predefined JSON file list."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from prefect import flow, task
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
PREFECT_API_URL = os.getenv("PREFECT_API_URL")
|
||||||
|
YADISK_TOKEN = os.getenv("YADISK_TOKEN")
|
||||||
|
ENRICH_CONCURRENCY = int(os.getenv("PREFECT_YADISK_ENRICH_CONCURRENCY", "8"))
|
||||||
|
OUTPUT_FILE_LIST = (PROJECT_ROOT / "../../../yadisk_files.json").resolve()
|
||||||
|
|
||||||
|
if PREFECT_API_URL:
|
||||||
|
os.environ["PREFECT_API_URL"] = PREFECT_API_URL
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressTracker:
|
||||||
|
def __init__(self, total: int):
|
||||||
|
self.total = total
|
||||||
|
self.processed = 0
|
||||||
|
self.errors = 0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def mark_done(self, error: bool = False):
|
||||||
|
async with self._lock:
|
||||||
|
self.processed += 1
|
||||||
|
if error:
|
||||||
|
self.errors += 1
|
||||||
|
self._render()
|
||||||
|
|
||||||
|
def _render(self):
|
||||||
|
total = max(self.total, 1)
|
||||||
|
width = 30
|
||||||
|
filled = int(width * self.processed / total)
|
||||||
|
bar = "#" * filled + "-" * (width - filled)
|
||||||
|
left = max(self.total - self.processed, 0)
|
||||||
|
print(
|
||||||
|
f"\r[{bar}] {self.processed}/{self.total} processed | left: {left} | errors: {self.errors}",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if self.processed >= self.total:
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_yadisk_file(async_disk, remote_path: str, local_path: str) -> None:
|
||||||
|
await async_disk.download(remote_path, local_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_local_file_for_enrichment(enricher, vector_store, local_path: str, remote_path: str) -> bool:
|
||||||
|
"""Process one downloaded file and upload chunks into vector store.
|
||||||
|
|
||||||
|
Returns True when file was processed/uploaded, False when skipped.
|
||||||
|
"""
|
||||||
|
extension = Path(remote_path).suffix.lower()
|
||||||
|
file_hash = enricher._get_file_hash(local_path)
|
||||||
|
if enricher._is_document_hash_processed(file_hash):
|
||||||
|
return False
|
||||||
|
|
||||||
|
loader = enricher._get_loader_for_extension(local_path)
|
||||||
|
if loader is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
docs = loader.load()
|
||||||
|
filename = Path(remote_path).name
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata["source"] = remote_path
|
||||||
|
doc.metadata["filename"] = filename
|
||||||
|
doc.metadata["file_path"] = remote_path
|
||||||
|
doc.metadata["file_size"] = os.path.getsize(local_path)
|
||||||
|
doc.metadata["file_extension"] = extension
|
||||||
|
if "page" in doc.metadata:
|
||||||
|
doc.metadata["page_number"] = doc.metadata["page"]
|
||||||
|
|
||||||
|
split_docs = enricher.text_splitter.split_documents(docs)
|
||||||
|
from helpers import extract_russian_event_names, extract_years_from_text
|
||||||
|
|
||||||
|
for chunk in split_docs:
|
||||||
|
chunk.metadata["years"] = extract_years_from_text(chunk.page_content)
|
||||||
|
chunk.metadata["events"] = extract_russian_event_names(chunk.page_content)
|
||||||
|
|
||||||
|
if not split_docs:
|
||||||
|
return False
|
||||||
|
|
||||||
|
vector_store.add_documents(split_docs)
|
||||||
|
enricher._mark_document_processed(remote_path, file_hash)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_remote_file(async_disk, remote_path: str, semaphore: asyncio.Semaphore, tracker: _ProgressTracker, enricher, vector_store):
|
||||||
|
async with semaphore:
|
||||||
|
temp_path = None
|
||||||
|
had_error = False
|
||||||
|
try:
|
||||||
|
suffix = Path(remote_path).suffix
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
||||||
|
temp_path = tmp_file.name
|
||||||
|
|
||||||
|
await _download_yadisk_file(async_disk, remote_path, temp_path)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
_process_local_file_for_enrichment,
|
||||||
|
enricher,
|
||||||
|
vector_store,
|
||||||
|
temp_path,
|
||||||
|
remote_path,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Phase requirement: swallow per-file errors and continue processing.
|
||||||
|
had_error = True
|
||||||
|
finally:
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except OSError:
|
||||||
|
had_error = True
|
||||||
|
await tracker.mark_done(error=had_error)
|
||||||
|
|
||||||
|
|
||||||
|
@task(name="prefilter_yadisk_file_paths")
|
||||||
|
def prefilter_yadisk_file_paths() -> List[str]:
|
||||||
|
"""Load file list JSON and keep only extensions supported by enrichment."""
|
||||||
|
from enrichment import SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
if not OUTPUT_FILE_LIST.exists():
|
||||||
|
raise FileNotFoundError(f"File list not found: {OUTPUT_FILE_LIST}")
|
||||||
|
|
||||||
|
with open(OUTPUT_FILE_LIST, "r", encoding="utf-8") as input_file:
|
||||||
|
raw_paths = json.load(input_file)
|
||||||
|
|
||||||
|
filtered = [
|
||||||
|
path for path in raw_paths if Path(str(path)).suffix.lower() in SUPPORTED_EXTENSIONS
|
||||||
|
]
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
@task(name="enrich_filtered_yadisk_files_async")
|
||||||
|
async def enrich_filtered_yadisk_files_async(filtered_paths: List[str]) -> dict:
|
||||||
|
"""Download/process Yandex Disk files concurrently and enrich LangChain vector store."""
|
||||||
|
if not YADISK_TOKEN:
|
||||||
|
raise ValueError("YADISK_TOKEN is required for Yandex Disk enrichment")
|
||||||
|
|
||||||
|
if not filtered_paths:
|
||||||
|
print("No supported files found for enrichment.")
|
||||||
|
return {"total": 0, "processed": 0, "errors": 0}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import yadisk
|
||||||
|
except ImportError as error:
|
||||||
|
raise RuntimeError("yadisk package is required for this flow") from error
|
||||||
|
|
||||||
|
if not hasattr(yadisk, "AsyncYaDisk"):
|
||||||
|
raise RuntimeError("Installed yadisk package does not expose AsyncYaDisk")
|
||||||
|
|
||||||
|
from enrichment import DocumentEnricher
|
||||||
|
from vector_storage import initialize_vector_store
|
||||||
|
|
||||||
|
vector_store = initialize_vector_store()
|
||||||
|
enricher = DocumentEnricher(vector_store)
|
||||||
|
|
||||||
|
tracker = _ProgressTracker(total=len(filtered_paths))
|
||||||
|
semaphore = asyncio.Semaphore(max(1, ENRICH_CONCURRENCY))
|
||||||
|
|
||||||
|
async with yadisk.AsyncYaDisk(token=YADISK_TOKEN) as async_disk:
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(
|
||||||
|
_process_remote_file(
|
||||||
|
async_disk=async_disk,
|
||||||
|
remote_path=remote_path,
|
||||||
|
semaphore=semaphore,
|
||||||
|
tracker=tracker,
|
||||||
|
enricher=enricher,
|
||||||
|
vector_store=vector_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for remote_path in filtered_paths
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": tracker.total,
|
||||||
|
"processed": tracker.processed,
|
||||||
|
"errors": tracker.errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@flow(name="yadisk_predefined_enrich")
|
||||||
|
async def yadisk_predefined_enrich() -> dict:
|
||||||
|
filtered_paths = prefilter_yadisk_file_paths()
|
||||||
|
return await enrich_filtered_yadisk_files_async(filtered_paths)
|
||||||
|
|
||||||
|
|
||||||
|
def serve_yadisk_predefined_enrich() -> None:
|
||||||
|
yadisk_predefined_enrich.serve(name="yadisk-predefined-enrich")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve_mode = os.getenv("PREFECT_SERVE", "0") == "1"
|
||||||
|
if serve_mode:
|
||||||
|
serve_yadisk_predefined_enrich()
|
||||||
|
else:
|
||||||
|
asyncio.run(yadisk_predefined_enrich())
|
||||||
@@ -55,3 +55,5 @@ unstructured-pytesseract>=0.3.12
|
|||||||
|
|
||||||
# System and utilities
|
# System and utilities
|
||||||
ollama>=0.3.0
|
ollama>=0.3.0
|
||||||
|
prefect>=2.19.0
|
||||||
|
yadisk>=3.4.0
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import json
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")
|
|||||||
OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002")
|
OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||||
OPENAI_EMBEDDING_BASE_URL = os.getenv("OPENAI_EMBEDDING_BASE_URL")
|
OPENAI_EMBEDDING_BASE_URL = os.getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||||
OPENAI_EMBEDDING_API_KEY = os.getenv("OPENAI_EMBEDDING_API_KEY")
|
OPENAI_EMBEDDING_API_KEY = os.getenv("OPENAI_EMBEDDING_API_KEY")
|
||||||
|
EMBEDDING_REQUEST_TIMEOUT_SECONDS = float(
|
||||||
|
os.getenv("EMBEDDING_REQUEST_TIMEOUT_SECONDS", "30")
|
||||||
|
)
|
||||||
|
EMBEDDING_MAX_RETRIES = int(os.getenv("EMBEDDING_MAX_RETRIES", "0"))
|
||||||
|
|
||||||
|
|
||||||
def initialize_vector_store(
|
def initialize_vector_store(
|
||||||
@@ -53,6 +57,8 @@ def initialize_vector_store(
|
|||||||
model=OPENAI_EMBEDDING_MODEL,
|
model=OPENAI_EMBEDDING_MODEL,
|
||||||
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
|
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
|
||||||
openai_api_key=OPENAI_EMBEDDING_API_KEY,
|
openai_api_key=OPENAI_EMBEDDING_API_KEY,
|
||||||
|
request_timeout=EMBEDDING_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
max_retries=EMBEDDING_MAX_RETRIES,
|
||||||
)
|
)
|
||||||
elif EMBEDDING_STRATEGY == "none":
|
elif EMBEDDING_STRATEGY == "none":
|
||||||
embeddings = None
|
embeddings = None
|
||||||
@@ -63,6 +69,8 @@ def initialize_vector_store(
|
|||||||
embeddings = OllamaEmbeddings(
|
embeddings = OllamaEmbeddings(
|
||||||
model=OLLAMA_EMBEDDING_MODEL,
|
model=OLLAMA_EMBEDDING_MODEL,
|
||||||
base_url="http://localhost:11434", # Default Ollama URL
|
base_url="http://localhost:11434", # Default Ollama URL
|
||||||
|
sync_client_kwargs={"timeout": EMBEDDING_REQUEST_TIMEOUT_SECONDS},
|
||||||
|
async_client_kwargs={"timeout": EMBEDDING_REQUEST_TIMEOUT_SECONDS},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if collection exists and create if needed
|
# Check if collection exists and create if needed
|
||||||
|
|||||||
@@ -6,9 +6,25 @@ EMBEDDING_STRATEGY=ollama
|
|||||||
OLLAMA_EMBEDDING_MODEL=MODEL
|
OLLAMA_EMBEDDING_MODEL=MODEL
|
||||||
OLLAMA_CHAT_MODEL=MODEL
|
OLLAMA_CHAT_MODEL=MODEL
|
||||||
|
|
||||||
|
# Qdrant Configuration
|
||||||
|
QDRANT_HOST=localhost
|
||||||
|
QDRANT_REST_PORT=6333
|
||||||
|
QDRANT_GRPC_PORT=6334
|
||||||
|
|
||||||
# OpenAI Configuration (for reference - uncomment and configure when using OpenAI strategy)
|
# OpenAI Configuration (for reference - uncomment and configure when using OpenAI strategy)
|
||||||
# OPENAI_CHAT_URL=https://api.openai.com/v1
|
# OPENAI_CHAT_URL=https://api.openai.com/v1
|
||||||
# OPENAI_CHAT_KEY=your_openai_api_key_here
|
# OPENAI_CHAT_KEY=your_openai_api_key_here
|
||||||
|
# OPENAI_CHAT_TEMPERATURE=0.1
|
||||||
|
# OPENAI_CHAT_MAX_TOKENS=1024
|
||||||
|
# OPENAI_CHAT_REASONING_EFFORT=low
|
||||||
|
# OPENAI_CHAT_IS_FUNCTION_CALLING_MODEL=false
|
||||||
# OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
# OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
||||||
# OPENAI_EMBEDDING_BASE_URL=https://api.openai.com/v1
|
# OPENAI_EMBEDDING_BASE_URL=https://api.openai.com/v1
|
||||||
# OPENAI_EMBEDDING_API_KEY=your_openai_api_key_here
|
# OPENAI_EMBEDDING_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
|
# Yandex Disk + Prefect (Phase 9)
|
||||||
|
YADISK_TOKEN=your_yadisk_token_here
|
||||||
|
PREFECT_API_URL=https://your-prefect-server.example/api
|
||||||
|
QDRANT_HOST=HOST
|
||||||
|
QDRANT_REST_PORT=PORT
|
||||||
|
QDRANT_GRPC_PORT=PORT
|
||||||
|
|||||||
@@ -47,8 +47,56 @@ Chosen data folder: relatve ./../../../data - from the current folder
|
|||||||
|
|
||||||
- [x] Add log of how many files currently being processed in enrichment. We need to see how many total to process and how many processed each time new document being processed. If it's possible, also add progressbar showing percentage and those numbers on top of logs.
|
- [x] Add log of how many files currently being processed in enrichment. We need to see how many total to process and how many processed each time new document being processed. If it's possible, also add progressbar showing percentage and those numbers on top of logs.
|
||||||
|
|
||||||
# Phase 8 (chat feature, as agent, for usage in the cli)
|
# Phase 8 (comment unsupported formats for now)
|
||||||
|
|
||||||
- [ ] Create file `agent.py`, which will incorporate into itself agent, powered by the chat model. It should use integration with openai, env variables are configure
|
- [x] Remove for now formats, extensions for images of any kind, archives of any kind, and add possible text documents, documents formats, like .txt, .xlsx, etc. in enrichment processes/functions.
|
||||||
- [ ] Integrate this agent with the existing solution for retrieving, with retrieval.py, if it's possible in current chosen RAG framework
|
|
||||||
- [ ] Integrate this agent with the cli, as command to start chatting with the agent. If there is a built-in solution for console communication with the agent, initiate this on cli command.
|
# Phase 9 (integration of Prefect client, for creating flow and tasks on remote Prefect server)
|
||||||
|
|
||||||
|
- [x] Install Prefect client library.
|
||||||
|
- [x] Add .env variable PREFECT_API_URL, that will be used for connecting client to the prefect server
|
||||||
|
- [x] Create prefect client file in `prefect/01_yadisk_predefined_enrich.py`. This file will firt load file from ./../../../yadisk_files.json into array of paths. After that, array of paths will be filtered, and only supported in enrichment extensions will be left. After that, code will iterate through each path in this filtered array, use yadisk library to download file, process it for enrichment, and the remove it after processing. There should be statistics for this, at runtime, with progressbar that shows how many files processed out of how many left. Also, near the progressbar there should be counter of errors. Yes, if there is an error, it should be swallowed, even if it is inside thred or async function.
|
||||||
|
- [x] For yandex disk integration use library yadisk. In .env file there should be variable YADISK_TOKEN for accessing the needed connection
|
||||||
|
- [x] Code for loading should be reflected upon, and then made it so it would be done in async way, with as much as possible simulatenous tasks. yadisk async integration should be used (async features can be checked here: https://pypi.org/project/yadisk/)
|
||||||
|
- [x] No tests for code should be done at this phase, all tests will be done manually, because loading of documents can take a long time for automated test.
|
||||||
|
|
||||||
|
# Phase 10 (qdrant connection credentials in .env)
|
||||||
|
|
||||||
|
- [x] Add Qdrant connection variables to the .env file: QDRANT_HOST, QDRANT_REST_PORT, QDRANT_GRPC_PORT
|
||||||
|
- [x] Replace everywhere where Qdran connection used hardcoded values into the usage of Qdrant .env variables
|
||||||
|
|
||||||
|
# Phase 11 (http endpoint to retrieve data from the vector storage by query)
|
||||||
|
|
||||||
|
- [x] Create file `server.py`, with web framework fastapi, for example
|
||||||
|
- [x] Add POST endpoint "/api/test-query" which will use agent, and retrieve response for query, sent in JSON format, field "query"
|
||||||
|
|
||||||
|
# Phase 12 (upgrade from simple retrieval to agent-like chat in LlamaIndex)
|
||||||
|
|
||||||
|
- [x] Revisit Phase 5 assumption ("simple retrieval only") and explicitly allow agent/chat orchestration in LlamaIndex for QA over documents.
|
||||||
|
- [x] Create new module for chat orchestration (for example `agent.py` or `chat_engine.py`) that separates:
|
||||||
|
1) retrieval of source nodes
|
||||||
|
2) answer synthesis with explicit prompt
|
||||||
|
3) response formatting with sources/metadata
|
||||||
|
- [x] Implement a LlamaIndex-based chat feature (agent-like behavior) using framework-native primitives (chat engine / agent workflow / tool-calling approach supported by installed version), so the model can iteratively query retrieval tools when needed.
|
||||||
|
- [x] Add a retrieval tool wrapper for document search that returns structured snippets (`filename`, `file_path`, `page_label/page`, `chunk_number`, content preview, score) instead of raw text only.
|
||||||
|
- [x] Add a grounded answer prompt/template for the LlamaIndex chat path with rules:
|
||||||
|
- answer only from retrieved context
|
||||||
|
- if information is missing, say so directly
|
||||||
|
- prefer exact dates/years and quote filenames/pages where possible
|
||||||
|
- avoid generic claims not supported by sources
|
||||||
|
- [x] Add response mode that returns both:
|
||||||
|
- final answer text
|
||||||
|
- list of retrieved sources (content snippet + metadata + score)
|
||||||
|
- [x] Add post-processing for retrieved nodes before synthesis:
|
||||||
|
- deduplicate near-identical chunks
|
||||||
|
- drop empty / near-empty chunks
|
||||||
|
- optionally filter low-information chunks (headers/footers)
|
||||||
|
- [x] Add optional metadata-aware retrieval improvements (years/events/keywords) parity with LangChain approach (folder near current folder), if feasible in the chosen LlamaIndex primitives.
|
||||||
|
- [x] Update `server.py` endpoint to use the new agent-like chat path (keep simple retrieval endpoint available as fallback or debug mode).
|
||||||
|
|
||||||
|
# Phase 13 (make wrapper around OpenAI llamaindex class or install library to help use any openai compatible models)
|
||||||
|
|
||||||
|
Inside config.py file object is being created for using as Chat Model. Unfortunately, it allows only "supported" models as value for "model" argument in constructor.
|
||||||
|
|
||||||
|
- [x] Search online for library or plugin for llamainde, that fixes the OpenAI class behaviour or provides replacement, that will allow any models. If found, install it in project venv, update then requirements.txt and replace usage of OpenAI to new one in the code.
|
||||||
|
- [x] Fallback option: create wrapper of OpenAI class, that will inherit and replace methods/features that check for "registered" models. Use the replacement then in the code.
|
||||||
|
|||||||
335
services/rag/llamaindex/chat_engine.py
Normal file
335
services/rag/llamaindex/chat_engine.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""
|
||||||
|
Agent-like chat orchestration for grounded QA over documents using LlamaIndex.
|
||||||
|
|
||||||
|
This module separates:
|
||||||
|
1) retrieval of source nodes
|
||||||
|
2) answer synthesis with an explicit grounded prompt
|
||||||
|
3) response formatting with sources/metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Iterable, List
|
||||||
|
|
||||||
|
from llama_index.core import PromptTemplate
|
||||||
|
from llama_index.core.agent import AgentWorkflow
|
||||||
|
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||||
|
from llama_index.core.retrievers import VectorIndexRetriever
|
||||||
|
from llama_index.core.schema import NodeWithScore
|
||||||
|
from llama_index.core.tools import FunctionTool
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from config import get_llm_model, setup_global_models
|
||||||
|
from vector_storage import get_vector_store_and_index
|
||||||
|
|
||||||
|
|
||||||
|
GROUNDED_SYNTHESIS_PROMPT = PromptTemplate(
|
||||||
|
"""You are a grounded QA assistant for a document knowledge base.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Answer ONLY from the provided context snippets.
|
||||||
|
- If the context is insufficient, say directly that the information is not available in the retrieved sources.
|
||||||
|
- Prefer exact dates/years and cite filenames/pages when possible.
|
||||||
|
- Avoid generic claims that are not supported by the snippets.
|
||||||
|
- If multiple sources disagree, mention the conflict briefly.
|
||||||
|
|
||||||
|
User question:
|
||||||
|
{query}
|
||||||
|
|
||||||
|
Optional draft from tool-using agent (may be incomplete):
|
||||||
|
{agent_draft}
|
||||||
|
|
||||||
|
Context snippets (JSON):
|
||||||
|
{context_json}
|
||||||
|
|
||||||
|
Return a concise answer with source mentions in plain text.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalSnippet:
|
||||||
|
content: str
|
||||||
|
score: float | None
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
def to_api_dict(self) -> dict[str, Any]:
|
||||||
|
metadata = self.metadata or {}
|
||||||
|
content_preview = self.content.strip().replace("\n", " ")
|
||||||
|
if len(content_preview) > 400:
|
||||||
|
content_preview = content_preview[:400] + "..."
|
||||||
|
return {
|
||||||
|
"content_snippet": content_preview,
|
||||||
|
"score": self.score,
|
||||||
|
"metadata": {
|
||||||
|
"filename": metadata.get("filename", "unknown"),
|
||||||
|
"file_path": metadata.get("file_path", "unknown"),
|
||||||
|
"page_label": metadata.get("page_label", metadata.get("page", "unknown")),
|
||||||
|
"chunk_number": metadata.get("chunk_number", "unknown"),
|
||||||
|
"total_chunks": metadata.get("total_chunks", "unknown"),
|
||||||
|
"file_type": metadata.get("file_type", "unknown"),
|
||||||
|
"processed_at": metadata.get("processed_at", "unknown"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_text(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value.decode("utf-8", errors="replace")
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_years(query: str) -> list[int]:
|
||||||
|
years = []
|
||||||
|
for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query):
|
||||||
|
try:
|
||||||
|
years.append(int(match))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return sorted(set(years))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_keywords(query: str) -> list[str]:
|
||||||
|
words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower())
|
||||||
|
stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"}
|
||||||
|
keywords = [w for w in words if w not in stop]
|
||||||
|
return list(dict.fromkeys(keywords))[:6]
|
||||||
|
|
||||||
|
|
||||||
|
def _node_text(node_with_score: NodeWithScore) -> str:
|
||||||
|
node = getattr(node_with_score, "node", None)
|
||||||
|
if node is None:
|
||||||
|
return _normalize_text(getattr(node_with_score, "text", ""))
|
||||||
|
return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", "")))
|
||||||
|
|
||||||
|
|
||||||
|
def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]:
|
||||||
|
node = getattr(node_with_score, "node", None)
|
||||||
|
if node is None:
|
||||||
|
return dict(getattr(node_with_score, "metadata", {}) or {})
|
||||||
|
return dict(getattr(node, "metadata", {}) or {})
|
||||||
|
|
||||||
|
|
||||||
|
def _similarity_key(text: str) -> str:
|
||||||
|
text = re.sub(r"\s+", " ", text.strip().lower())
|
||||||
|
return text[:250]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_low_information_chunk(text: str) -> bool:
|
||||||
|
compact = " ".join(text.split())
|
||||||
|
if len(compact) < 20:
|
||||||
|
return True
|
||||||
|
alpha_chars = sum(ch.isalpha() for ch in compact)
|
||||||
|
if alpha_chars < 8:
|
||||||
|
return True
|
||||||
|
# Repetitive headers/footers often contain too few unique tokens.
|
||||||
|
tokens = [t for t in re.split(r"\W+", compact.lower()) if t]
|
||||||
|
return len(tokens) >= 3 and len(set(tokens)) <= 2
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]:
|
||||||
|
"""
|
||||||
|
Post-process retrieved nodes:
|
||||||
|
- drop empty / near-empty chunks
|
||||||
|
- optionally drop low-information chunks
|
||||||
|
- deduplicate near-identical chunks
|
||||||
|
"""
|
||||||
|
filtered: list[NodeWithScore] = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
for nws in nodes:
|
||||||
|
text = _node_text(nws)
|
||||||
|
if not text or not text.strip():
|
||||||
|
continue
|
||||||
|
if _is_low_information_chunk(text):
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta = _node_metadata(nws)
|
||||||
|
dedup_key = (
|
||||||
|
meta.get("file_path", ""),
|
||||||
|
meta.get("page_label", meta.get("page", "")),
|
||||||
|
meta.get("chunk_number", ""),
|
||||||
|
_similarity_key(text),
|
||||||
|
)
|
||||||
|
if dedup_key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(dedup_key)
|
||||||
|
filtered.append(nws)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]:
|
||||||
|
"""
|
||||||
|
Retrieve source nodes with light metadata-aware query expansion (years/keywords).
|
||||||
|
"""
|
||||||
|
setup_global_models()
|
||||||
|
_, index = get_vector_store_and_index()
|
||||||
|
retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k))
|
||||||
|
|
||||||
|
queries = [query]
|
||||||
|
years = _extract_years(query)
|
||||||
|
keywords = _extract_keywords(query)
|
||||||
|
queries.extend(str(year) for year in years)
|
||||||
|
queries.extend(keywords[:3])
|
||||||
|
|
||||||
|
collected: list[NodeWithScore] = []
|
||||||
|
for q in list(dict.fromkeys([q for q in queries if q and q.strip()])):
|
||||||
|
try:
|
||||||
|
logger.info(f"Retrieving nodes for query variant: {q}")
|
||||||
|
collected.extend(retriever.retrieve(q))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Retrieval variant failed for '{q}': {e}")
|
||||||
|
|
||||||
|
processed = post_process_nodes(collected)
|
||||||
|
processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0)))
|
||||||
|
return processed[:top_k]
|
||||||
|
|
||||||
|
|
||||||
|
def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]:
|
||||||
|
"""Return structured snippets for tools/API responses."""
|
||||||
|
snippets: list[dict[str, Any]] = []
|
||||||
|
for nws in nodes:
|
||||||
|
snippet = RetrievalSnippet(
|
||||||
|
content=_node_text(nws),
|
||||||
|
score=getattr(nws, "score", None),
|
||||||
|
metadata=_node_metadata(nws),
|
||||||
|
)
|
||||||
|
snippets.append(snippet.to_api_dict())
|
||||||
|
return snippets
|
||||||
|
|
||||||
|
|
||||||
|
def retrieval_tool_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""
|
||||||
|
Tool wrapper for document retrieval returning structured JSON snippets.
|
||||||
|
"""
|
||||||
|
nodes = retrieve_source_nodes(query=query, top_k=top_k)
|
||||||
|
snippets = build_structured_snippets(nodes)
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"count": len(snippets),
|
||||||
|
"snippets": snippets,
|
||||||
|
}
|
||||||
|
return json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Answer synthesis from retrieved sources using an explicit grounded prompt.
|
||||||
|
"""
|
||||||
|
llm = get_llm_model()
|
||||||
|
context_json = json.dumps(sources, ensure_ascii=False, indent=2)
|
||||||
|
prompt = GROUNDED_SYNTHESIS_PROMPT.format(
|
||||||
|
query=query,
|
||||||
|
agent_draft=agent_draft or "(none)",
|
||||||
|
context_json=context_json,
|
||||||
|
)
|
||||||
|
logger.info("Synthesizing grounded answer from retrieved sources")
|
||||||
|
# Prefer chat API for chat-capable models; fallback to completion if unavailable.
|
||||||
|
try:
|
||||||
|
if hasattr(llm, "chat"):
|
||||||
|
chat_response = llm.chat(
|
||||||
|
[
|
||||||
|
ChatMessage(role=MessageRole.SYSTEM, content="You answer with grounded citations only."),
|
||||||
|
ChatMessage(role=MessageRole.USER, content=prompt),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return _normalize_text(getattr(chat_response, "message", chat_response).content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM chat synthesis failed, falling back to completion: {e}")
|
||||||
|
|
||||||
|
response = llm.complete(prompt)
|
||||||
|
return _normalize_text(getattr(response, "text", response))
|
||||||
|
|
||||||
|
|
||||||
|
def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Response formatting with answer + structured sources.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"answer": final_answer,
|
||||||
|
"sources": sources,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_agent_result_text(result: Any) -> str:
|
||||||
|
if result is None:
|
||||||
|
return ""
|
||||||
|
if hasattr(result, "response"):
|
||||||
|
return _normalize_text(getattr(result, "response"))
|
||||||
|
return _normalize_text(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_workflow_async(query: str, top_k: int) -> str:
|
||||||
|
"""
|
||||||
|
Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text.
|
||||||
|
"""
|
||||||
|
setup_global_models()
|
||||||
|
llm = get_llm_model()
|
||||||
|
tool = FunctionTool.from_defaults(
|
||||||
|
fn=retrieval_tool_search,
|
||||||
|
name="document_search",
|
||||||
|
description=(
|
||||||
|
"Search documents and return structured snippets as JSON with fields: "
|
||||||
|
"filename, file_path, page_label/page, chunk_number, content_snippet, score. "
|
||||||
|
"Use this before answering factual questions about documents."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are a QA agent over a document store. Use the document_search tool when factual "
|
||||||
|
"information may come from documents. If tool output is insufficient, say so."
|
||||||
|
)
|
||||||
|
agent = AgentWorkflow.from_tools_or_functions(
|
||||||
|
[tool],
|
||||||
|
llm=llm,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
handler = agent.run(user_msg=query, max_iterations=4)
|
||||||
|
result = await handler
|
||||||
|
return _extract_agent_result_text(result)
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent_workflow(query: str, top_k: int = 5) -> str:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper around the async LlamaIndex agent workflow.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k))
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback if already in an event loop; skip agent workflow in that case.
|
||||||
|
logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Full chat orchestration entrypoint:
|
||||||
|
- optionally run agent workflow (tool-calling)
|
||||||
|
- retrieve + post-process sources
|
||||||
|
- synthesize grounded answer
|
||||||
|
- format response
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting chat orchestration for query: {query[:80]}")
|
||||||
|
agent_draft = ""
|
||||||
|
mode = "retrieval+synthesis"
|
||||||
|
if use_agent:
|
||||||
|
agent_draft = run_agent_workflow(query=query, top_k=top_k)
|
||||||
|
mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis"
|
||||||
|
|
||||||
|
nodes = retrieve_source_nodes(query=query, top_k=top_k)
|
||||||
|
sources = build_structured_snippets(nodes)
|
||||||
|
final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft)
|
||||||
|
return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)
|
||||||
@@ -93,18 +93,43 @@ def get_llm_model():
|
|||||||
return llm
|
return llm
|
||||||
|
|
||||||
elif strategy == "openai":
|
elif strategy == "openai":
|
||||||
from llama_index.llms.openai import OpenAI
|
from helpers.openai_compatible_llm import OpenAICompatibleLLM
|
||||||
|
|
||||||
openai_chat_url = os.getenv("OPENAI_CHAT_URL", "https://api.openai.com/v1")
|
openai_chat_url = os.getenv("OPENAI_CHAT_URL", "https://api.openai.com/v1")
|
||||||
openai_chat_key = os.getenv("OPENAI_CHAT_KEY", "dummy_key_for_template")
|
openai_chat_key = os.getenv("OPENAI_CHAT_KEY", "dummy_key_for_template")
|
||||||
openai_chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
|
openai_chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
|
||||||
|
openai_chat_temperature = float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.1"))
|
||||||
|
openai_chat_max_tokens_env = os.getenv("OPENAI_CHAT_MAX_TOKENS", "").strip()
|
||||||
|
openai_chat_max_tokens = (
|
||||||
|
int(openai_chat_max_tokens_env) if openai_chat_max_tokens_env else 1024
|
||||||
|
)
|
||||||
|
openai_reasoning_effort = (
|
||||||
|
os.getenv("OPENAI_CHAT_REASONING_EFFORT", "").strip() or None
|
||||||
|
)
|
||||||
|
openai_is_fc_model = (
|
||||||
|
os.getenv("OPENAI_CHAT_IS_FUNCTION_CALLING_MODEL", "false").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
|
||||||
# Set the API key in environment for OpenAI
|
# Set the API key in environment for OpenAI
|
||||||
os.environ["OPENAI_API_KEY"] = openai_chat_key
|
os.environ["OPENAI_API_KEY"] = openai_chat_key
|
||||||
|
|
||||||
logger.info(f"Initializing OpenAI chat model: {openai_chat_model}")
|
logger.info(
|
||||||
|
f"Initializing OpenAI-compatible chat model: {openai_chat_model} "
|
||||||
|
f"(base={openai_chat_url}, max_tokens={openai_chat_max_tokens}, "
|
||||||
|
f"reasoning_effort={openai_reasoning_effort}, function_calling={openai_is_fc_model})"
|
||||||
|
)
|
||||||
|
|
||||||
llm = OpenAI(model=openai_chat_model, api_base=openai_chat_url)
|
llm = OpenAICompatibleLLM(
|
||||||
|
model=openai_chat_model,
|
||||||
|
api_base=openai_chat_url,
|
||||||
|
api_key=openai_chat_key,
|
||||||
|
temperature=openai_chat_temperature,
|
||||||
|
max_tokens=openai_chat_max_tokens,
|
||||||
|
reasoning_effort=openai_reasoning_effort,
|
||||||
|
timeout=120.0,
|
||||||
|
is_function_calling_model=openai_is_fc_model,
|
||||||
|
)
|
||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|||||||
@@ -6,24 +6,53 @@ processing them with appropriate loaders, splitting them into chunks,
|
|||||||
and storing them in the vector database with proper metadata.
|
and storing them in the vector database with proper metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
import os
|
||||||
from typing import List, Dict, Any
|
|
||||||
from datetime import datetime
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_index.core import Document, SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import CodeSplitter, SentenceSplitter
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader, Document
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter, CodeSplitter
|
|
||||||
# Removed unused import
|
|
||||||
|
|
||||||
from vector_storage import get_vector_store_and_index
|
|
||||||
|
|
||||||
# Import the new configuration module
|
# Import the new configuration module
|
||||||
from config import get_embedding_model
|
from config import get_embedding_model
|
||||||
|
|
||||||
|
# Removed unused import
|
||||||
|
from vector_storage import get_vector_store_and_index
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_ENRICHMENT_EXTENSIONS = {
|
||||||
|
".csv",
|
||||||
|
".doc",
|
||||||
|
".docx",
|
||||||
|
".epub",
|
||||||
|
".htm",
|
||||||
|
".html",
|
||||||
|
".json",
|
||||||
|
".jsonl",
|
||||||
|
".md",
|
||||||
|
".odt",
|
||||||
|
".pdf",
|
||||||
|
".ppt",
|
||||||
|
".pptx",
|
||||||
|
".rtf",
|
||||||
|
".rst",
|
||||||
|
".tsv",
|
||||||
|
".txt",
|
||||||
|
".xls",
|
||||||
|
".xlsx",
|
||||||
|
".xml",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_enrichment_extensions() -> set[str]:
|
||||||
|
"""Return the file extensions currently supported by enrichment."""
|
||||||
|
return set(SUPPORTED_ENRICHMENT_EXTENSIONS)
|
||||||
|
|
||||||
|
|
||||||
class DocumentTracker:
|
class DocumentTracker:
|
||||||
"""Class to handle tracking of processed documents to avoid re-processing."""
|
"""Class to handle tracking of processed documents to avoid re-processing."""
|
||||||
@@ -38,7 +67,7 @@ class DocumentTracker:
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Create table for tracking processed documents
|
# Create table for tracking processed documents
|
||||||
cursor.execute('''
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS processed_documents (
|
CREATE TABLE IF NOT EXISTS processed_documents (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
filename TEXT UNIQUE NOT NULL,
|
filename TEXT UNIQUE NOT NULL,
|
||||||
@@ -47,7 +76,7 @@ class DocumentTracker:
|
|||||||
processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
metadata_json TEXT
|
metadata_json TEXT
|
||||||
)
|
)
|
||||||
''')
|
""")
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -63,7 +92,7 @@ class DocumentTracker:
|
|||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT COUNT(*) FROM processed_documents WHERE filepath = ? AND checksum = ?",
|
"SELECT COUNT(*) FROM processed_documents WHERE filepath = ? AND checksum = ?",
|
||||||
(filepath, checksum)
|
(filepath, checksum),
|
||||||
)
|
)
|
||||||
count = cursor.fetchone()[0]
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
@@ -79,11 +108,14 @@ class DocumentTracker:
|
|||||||
filename = Path(filepath).name
|
filename = Path(filepath).name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cursor.execute('''
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT OR REPLACE INTO processed_documents
|
INSERT OR REPLACE INTO processed_documents
|
||||||
(filename, filepath, checksum, processed_at, metadata_json)
|
(filename, filepath, checksum, processed_at, metadata_json)
|
||||||
VALUES (?, ?, ?, CURRENT_TIMESTAMP, ?)
|
VALUES (?, ?, ?, CURRENT_TIMESTAMP, ?)
|
||||||
''', (filename, filepath, checksum, str(metadata) if metadata else None))
|
""",
|
||||||
|
(filename, filepath, checksum, str(metadata) if metadata else None),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info(f"Document marked as processed: {filepath}")
|
logger.info(f"Document marked as processed: {filepath}")
|
||||||
@@ -104,62 +136,67 @@ class DocumentTracker:
|
|||||||
|
|
||||||
def get_text_splitter(file_extension: str):
|
def get_text_splitter(file_extension: str):
|
||||||
"""Get appropriate text splitter based on file type."""
|
"""Get appropriate text splitter based on file type."""
|
||||||
from llama_index.core.node_parser import SentenceSplitter, CodeSplitter, TokenTextSplitter
|
from llama_index.core.node_parser import (
|
||||||
from llama_index.core.node_parser import MarkdownElementNodeParser
|
CodeSplitter,
|
||||||
|
MarkdownElementNodeParser,
|
||||||
|
SentenceSplitter,
|
||||||
|
TokenTextSplitter,
|
||||||
|
)
|
||||||
|
|
||||||
# For code files, use CodeSplitter
|
# For code files, use CodeSplitter
|
||||||
if file_extension.lower() in ['.py', '.js', '.ts', '.java', '.cpp', '.c', '.h', '.cs', '.go', '.rs', '.php', '.html', '.css', '.md', '.rst']:
|
if file_extension.lower() in [
|
||||||
|
".py",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
".java",
|
||||||
|
".cpp",
|
||||||
|
".c",
|
||||||
|
".h",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".rs",
|
||||||
|
".php",
|
||||||
|
".html",
|
||||||
|
".css",
|
||||||
|
".md",
|
||||||
|
".rst",
|
||||||
|
]:
|
||||||
return CodeSplitter(language="python", max_chars=1000)
|
return CodeSplitter(language="python", max_chars=1000)
|
||||||
|
|
||||||
# For PDF files, use a parser that can handle multi-page documents
|
# For PDF files, use a parser that can handle multi-page documents
|
||||||
elif file_extension.lower() == '.pdf':
|
elif file_extension.lower() == ".pdf":
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(
|
||||||
chunk_size=512, # Smaller chunks for dense PDF content
|
chunk_size=512, # Smaller chunks for dense PDF content
|
||||||
chunk_overlap=100
|
chunk_overlap=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For presentation files (PowerPoint), use smaller chunks
|
# For presentation files (PowerPoint), use smaller chunks
|
||||||
elif file_extension.lower() == '.pptx':
|
elif file_extension.lower() == ".pptx":
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(
|
||||||
chunk_size=256, # Slides typically have less text
|
chunk_size=256, # Slides typically have less text
|
||||||
chunk_overlap=50
|
chunk_overlap=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For spreadsheets, use smaller chunks
|
# For spreadsheets, use smaller chunks
|
||||||
elif file_extension.lower() == '.xlsx':
|
elif file_extension.lower() == ".xlsx":
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(chunk_size=256, chunk_overlap=50)
|
||||||
chunk_size=256,
|
|
||||||
chunk_overlap=50
|
|
||||||
)
|
|
||||||
|
|
||||||
# For text-heavy documents like Word, use medium-sized chunks
|
# For text-heavy documents like Word, use medium-sized chunks
|
||||||
elif file_extension.lower() in ['.docx', '.odt']:
|
elif file_extension.lower() in [".docx", ".odt"]:
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(chunk_size=768, chunk_overlap=150)
|
||||||
chunk_size=768,
|
|
||||||
chunk_overlap=150
|
|
||||||
)
|
|
||||||
|
|
||||||
# For plain text files, use larger chunks
|
# For plain text files, use larger chunks
|
||||||
elif file_extension.lower() == '.txt':
|
elif file_extension.lower() == ".txt":
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(chunk_size=1024, chunk_overlap=200)
|
||||||
chunk_size=1024,
|
|
||||||
chunk_overlap=200
|
|
||||||
)
|
|
||||||
|
|
||||||
# For image files, we'll handle them differently (metadata extraction)
|
# For image files, we'll handle them differently (metadata extraction)
|
||||||
elif file_extension.lower() in ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.svg']:
|
elif file_extension.lower() in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".svg"]:
|
||||||
# Images will be handled by multimodal models, return a simple splitter
|
# Images will be handled by multimodal models, return a simple splitter
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(chunk_size=512, chunk_overlap=100)
|
||||||
chunk_size=512,
|
|
||||||
chunk_overlap=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# For other files, use a standard SentenceSplitter
|
# For other files, use a standard SentenceSplitter
|
||||||
else:
|
else:
|
||||||
return SentenceSplitter(
|
return SentenceSplitter(chunk_size=768, chunk_overlap=150)
|
||||||
chunk_size=768,
|
|
||||||
chunk_overlap=150
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_proper_encoding(text):
|
def ensure_proper_encoding(text):
|
||||||
@@ -178,35 +215,41 @@ def ensure_proper_encoding(text):
|
|||||||
if isinstance(text, bytes):
|
if isinstance(text, bytes):
|
||||||
# Decode bytes to string with proper encoding
|
# Decode bytes to string with proper encoding
|
||||||
try:
|
try:
|
||||||
return text.decode('utf-8')
|
return text.decode("utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
|
# If UTF-8 fails, try other encodings commonly used for Russian/Cyrillic text
|
||||||
try:
|
try:
|
||||||
return text.decode('cp1251') # Windows Cyrillic encoding
|
return text.decode("cp1251") # Windows Cyrillic encoding
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
try:
|
try:
|
||||||
return text.decode('koi8-r') # Russian encoding
|
return text.decode("koi8-r") # Russian encoding
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
# If all else fails, decode with errors='replace'
|
# If all else fails, decode with errors='replace'
|
||||||
return text.decode('utf-8', errors='replace')
|
return text.decode("utf-8", errors="replace")
|
||||||
elif isinstance(text, str):
|
elif isinstance(text, str):
|
||||||
# Ensure the string is properly encoded
|
# Ensure the string is properly encoded
|
||||||
try:
|
try:
|
||||||
# Try to encode and decode to ensure it's valid UTF-8
|
# Try to encode and decode to ensure it's valid UTF-8
|
||||||
return text.encode('utf-8').decode('utf-8')
|
return text.encode("utf-8").decode("utf-8")
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
# If there are encoding issues, try to fix them
|
# If there are encoding issues, try to fix them
|
||||||
return text.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
return text.encode("utf-8", errors="replace").decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Convert other types to string and ensure proper encoding
|
# Convert other types to string and ensure proper encoding
|
||||||
text_str = str(text)
|
text_str = str(text)
|
||||||
try:
|
try:
|
||||||
return text_str.encode('utf-8').decode('utf-8')
|
return text_str.encode("utf-8").decode("utf-8")
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
return text_str.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
return text_str.encode("utf-8", errors="replace").decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_documents_from_data_folder(data_path: str = "../../../data", recursive: bool = True):
|
def process_documents_from_data_folder(
|
||||||
|
data_path: str = "../../../data", recursive: bool = True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Process all documents from the data folder using appropriate loaders and store in vector DB.
|
Process all documents from the data folder using appropriate loaders and store in vector DB.
|
||||||
|
|
||||||
@@ -237,11 +280,7 @@ def process_documents_from_data_folder(data_path: str = "../../../data", recursi
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Find all supported files in the data directory
|
# Find all supported files in the data directory
|
||||||
supported_extensions = {
|
supported_extensions = get_supported_enrichment_extensions()
|
||||||
'.pdf', '.docx', '.xlsx', '.pptx', '.odt', '.txt',
|
|
||||||
'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.svg',
|
|
||||||
'.zip', '.rar', '.tar', '.gz'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Walk through the directory structure
|
# Walk through the directory structure
|
||||||
all_files = []
|
all_files = []
|
||||||
@@ -258,111 +297,140 @@ def process_documents_from_data_folder(data_path: str = "../../../data", recursi
|
|||||||
if file_ext in supported_extensions:
|
if file_ext in supported_extensions:
|
||||||
all_files.append(str(file))
|
all_files.append(str(file))
|
||||||
|
|
||||||
logger.info(f"Found {len(all_files)} files to process")
|
logger.info(
|
||||||
|
f"Found {len(all_files)} supported files to process (extensions: {', '.join(sorted(supported_extensions))})"
|
||||||
|
)
|
||||||
|
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
skipped_count = 0
|
skipped_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
# Initialize progress bar
|
# Initialize progress bar
|
||||||
pbar = tqdm(total=len(all_files), desc="Processing documents", unit="file")
|
pbar = tqdm(total=len(all_files), desc="Processing documents", unit="file")
|
||||||
|
|
||||||
for file_path in all_files:
|
for file_path in all_files:
|
||||||
logger.info(f"Processing file: {file_path} ({processed_count + skipped_count + 1}/{len(all_files)})")
|
logger.info(
|
||||||
|
f"Processing file: {file_path} ({processed_count + skipped_count + 1}/{len(all_files)})"
|
||||||
# Check if document has already been processed
|
)
|
||||||
if tracker.is_document_processed(file_path):
|
|
||||||
logger.info(f"Skipping already processed file: {file_path}")
|
|
||||||
skipped_count += 1
|
|
||||||
pbar.set_postfix({"Processed": processed_count, "Skipped": skipped_count})
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the document using SimpleDirectoryReader
|
result = process_document_file(file_path, tracker=tracker, index=index)
|
||||||
# This automatically selects the appropriate reader based on file extension
|
if result["status"] == "processed":
|
||||||
def file_metadata_func(file_path_str):
|
|
||||||
# Apply proper encoding to filename
|
|
||||||
filename = ensure_proper_encoding(Path(file_path_str).name)
|
|
||||||
return {"filename": filename}
|
|
||||||
|
|
||||||
reader = SimpleDirectoryReader(
|
|
||||||
input_files=[file_path],
|
|
||||||
file_metadata=file_metadata_func
|
|
||||||
)
|
|
||||||
documents = reader.load_data()
|
|
||||||
|
|
||||||
# Process each document
|
|
||||||
for doc in documents:
|
|
||||||
# Extract additional metadata based on document type
|
|
||||||
file_ext = Path(file_path).suffix
|
|
||||||
|
|
||||||
# Apply proper encoding to file path
|
|
||||||
encoded_file_path = ensure_proper_encoding(file_path)
|
|
||||||
|
|
||||||
# Add additional metadata
|
|
||||||
doc.metadata["file_path"] = encoded_file_path
|
|
||||||
doc.metadata["processed_at"] = datetime.now().isoformat()
|
|
||||||
|
|
||||||
# Handle document-type-specific metadata
|
|
||||||
if file_ext.lower() == '.pdf':
|
|
||||||
# PDF-specific metadata
|
|
||||||
doc.metadata["page_label"] = ensure_proper_encoding(doc.metadata.get("page_label", "unknown"))
|
|
||||||
doc.metadata["file_type"] = "pdf"
|
|
||||||
|
|
||||||
elif file_ext.lower() in ['.docx', '.odt']:
|
|
||||||
# Word document metadata
|
|
||||||
doc.metadata["section"] = ensure_proper_encoding(doc.metadata.get("section", "unknown"))
|
|
||||||
doc.metadata["file_type"] = "document"
|
|
||||||
|
|
||||||
elif file_ext.lower() == '.pptx':
|
|
||||||
# PowerPoint metadata
|
|
||||||
doc.metadata["slide_id"] = ensure_proper_encoding(doc.metadata.get("slide_id", "unknown"))
|
|
||||||
doc.metadata["file_type"] = "presentation"
|
|
||||||
|
|
||||||
elif file_ext.lower() == '.xlsx':
|
|
||||||
# Excel metadata
|
|
||||||
doc.metadata["sheet_name"] = ensure_proper_encoding(doc.metadata.get("sheet_name", "unknown"))
|
|
||||||
doc.metadata["file_type"] = "spreadsheet"
|
|
||||||
|
|
||||||
# Determine the appropriate text splitter based on file type
|
|
||||||
splitter = get_text_splitter(file_ext)
|
|
||||||
|
|
||||||
# Split the document into nodes
|
|
||||||
nodes = splitter.get_nodes_from_documents([doc])
|
|
||||||
|
|
||||||
# Insert nodes into the vector index
|
|
||||||
nodes_with_enhanced_metadata = []
|
|
||||||
for i, node in enumerate(nodes):
|
|
||||||
# Enhance node metadata with additional information
|
|
||||||
node.metadata["original_doc_id"] = ensure_proper_encoding(doc.doc_id)
|
|
||||||
node.metadata["chunk_number"] = i
|
|
||||||
node.metadata["total_chunks"] = len(nodes)
|
|
||||||
node.metadata["file_path"] = encoded_file_path
|
|
||||||
|
|
||||||
# Ensure the text content is properly encoded
|
|
||||||
node.text = ensure_proper_encoding(node.text)
|
|
||||||
|
|
||||||
nodes_with_enhanced_metadata.append(node)
|
|
||||||
|
|
||||||
# Add all nodes to the index at once
|
|
||||||
if nodes_with_enhanced_metadata:
|
|
||||||
index.insert_nodes(nodes_with_enhanced_metadata)
|
|
||||||
|
|
||||||
logger.info(f"Processed {len(nodes)} nodes from {encoded_file_path}")
|
|
||||||
|
|
||||||
# Mark document as processed only after successful insertion
|
|
||||||
tracker.mark_document_processed(file_path, {"nodes_count": len(documents)})
|
|
||||||
processed_count += 1
|
processed_count += 1
|
||||||
pbar.set_postfix({"Processed": processed_count, "Skipped": skipped_count})
|
elif result["status"] == "skipped":
|
||||||
|
skipped_count += 1
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
pbar.set_postfix(
|
||||||
|
{"Processed": processed_count, "Skipped": skipped_count, "Errors": error_count}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing file {file_path}: {str(e)}")
|
logger.error(f"Error processing file {file_path}: {str(e)}")
|
||||||
|
error_count += 1
|
||||||
|
pbar.set_postfix(
|
||||||
|
{"Processed": processed_count, "Skipped": skipped_count, "Errors": error_count}
|
||||||
|
)
|
||||||
|
|
||||||
# Update progress bar regardless of success or failure
|
# Update progress bar regardless of success or failure
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
logger.info(f"Document enrichment completed. Processed: {processed_count}, Skipped: {skipped_count}")
|
logger.info(
|
||||||
|
f"Document enrichment completed. Processed: {processed_count}, Skipped: {skipped_count}, Errors: {error_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_document_file(
|
||||||
|
file_path: str,
|
||||||
|
tracker: Optional[DocumentTracker] = None,
|
||||||
|
index=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process a single document file and store its chunks in the vector index.
|
||||||
|
|
||||||
|
Returns a dict with status and counters. Status is one of:
|
||||||
|
`processed`, `skipped`, `error`.
|
||||||
|
"""
|
||||||
|
file_ext = Path(file_path).suffix.lower()
|
||||||
|
if file_ext not in get_supported_enrichment_extensions():
|
||||||
|
logger.info(f"Skipping unsupported extension for file: {file_path}")
|
||||||
|
return {"status": "skipped", "reason": "unsupported_extension", "nodes": 0}
|
||||||
|
|
||||||
|
tracker = tracker or DocumentTracker()
|
||||||
|
|
||||||
|
if tracker.is_document_processed(file_path):
|
||||||
|
logger.info(f"Skipping already processed file: {file_path}")
|
||||||
|
return {"status": "skipped", "reason": "already_processed", "nodes": 0}
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
_, index = get_vector_store_and_index()
|
||||||
|
|
||||||
|
try:
|
||||||
|
def file_metadata_func(file_path_str):
|
||||||
|
filename = ensure_proper_encoding(Path(file_path_str).name)
|
||||||
|
return {"filename": filename}
|
||||||
|
|
||||||
|
reader = SimpleDirectoryReader(
|
||||||
|
input_files=[file_path], file_metadata=file_metadata_func
|
||||||
|
)
|
||||||
|
documents = reader.load_data()
|
||||||
|
|
||||||
|
total_nodes_inserted = 0
|
||||||
|
for doc in documents:
|
||||||
|
current_file_ext = Path(file_path).suffix
|
||||||
|
encoded_file_path = ensure_proper_encoding(file_path)
|
||||||
|
|
||||||
|
doc.metadata["file_path"] = encoded_file_path
|
||||||
|
doc.metadata["processed_at"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
if current_file_ext.lower() == ".pdf":
|
||||||
|
doc.metadata["page_label"] = ensure_proper_encoding(
|
||||||
|
doc.metadata.get("page_label", "unknown")
|
||||||
|
)
|
||||||
|
doc.metadata["file_type"] = "pdf"
|
||||||
|
elif current_file_ext.lower() in [".docx", ".odt", ".doc", ".rtf"]:
|
||||||
|
doc.metadata["section"] = ensure_proper_encoding(
|
||||||
|
doc.metadata.get("section", "unknown")
|
||||||
|
)
|
||||||
|
doc.metadata["file_type"] = "document"
|
||||||
|
elif current_file_ext.lower() in [".pptx", ".ppt"]:
|
||||||
|
doc.metadata["slide_id"] = ensure_proper_encoding(
|
||||||
|
doc.metadata.get("slide_id", "unknown")
|
||||||
|
)
|
||||||
|
doc.metadata["file_type"] = "presentation"
|
||||||
|
elif current_file_ext.lower() in [".xlsx", ".xls", ".csv", ".tsv"]:
|
||||||
|
doc.metadata["sheet_name"] = ensure_proper_encoding(
|
||||||
|
doc.metadata.get("sheet_name", "unknown")
|
||||||
|
)
|
||||||
|
doc.metadata["file_type"] = "spreadsheet"
|
||||||
|
|
||||||
|
splitter = get_text_splitter(current_file_ext)
|
||||||
|
nodes = splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
nodes_with_enhanced_metadata = []
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
node.metadata["original_doc_id"] = ensure_proper_encoding(doc.doc_id)
|
||||||
|
node.metadata["chunk_number"] = i
|
||||||
|
node.metadata["total_chunks"] = len(nodes)
|
||||||
|
node.metadata["file_path"] = encoded_file_path
|
||||||
|
node.text = ensure_proper_encoding(node.text)
|
||||||
|
nodes_with_enhanced_metadata.append(node)
|
||||||
|
|
||||||
|
if nodes_with_enhanced_metadata:
|
||||||
|
index.insert_nodes(nodes_with_enhanced_metadata)
|
||||||
|
total_nodes_inserted += len(nodes_with_enhanced_metadata)
|
||||||
|
|
||||||
|
logger.info(f"Processed {len(nodes)} nodes from {encoded_file_path}")
|
||||||
|
|
||||||
|
tracker.mark_document_processed(
|
||||||
|
file_path,
|
||||||
|
{"documents_count": len(documents), "nodes_count": total_nodes_inserted},
|
||||||
|
)
|
||||||
|
return {"status": "processed", "nodes": total_nodes_inserted}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing file {file_path}: {e}")
|
||||||
|
return {"status": "error", "reason": str(e), "nodes": 0}
|
||||||
|
|
||||||
|
|
||||||
def enrich_documents():
|
def enrich_documents():
|
||||||
|
|||||||
40
services/rag/llamaindex/helpers/openai_compatible_llm.py
Normal file
40
services/rag/llamaindex/helpers/openai_compatible_llm.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
OpenAI-compatible LLM wrapper for LlamaIndex chat models.
|
||||||
|
|
||||||
|
This wrapper is used as a fallback/replacement for strict OpenAI model validation paths.
|
||||||
|
It relies on LlamaIndex `OpenAILike`, which supports arbitrary model names for
|
||||||
|
OpenAI-compatible endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleLLM(OpenAILike):
|
||||||
|
"""
|
||||||
|
Thin wrapper over OpenAILike with chat-friendly defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_base: str,
|
||||||
|
api_key: str,
|
||||||
|
temperature: float = 0.1,
|
||||||
|
timeout: float = 120.0,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
is_function_calling_model: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=timeout,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
# Explicitly avoid "registered model only" assumptions.
|
||||||
|
is_chat_model=True,
|
||||||
|
is_function_calling_model=is_function_calling_model,
|
||||||
|
should_use_structured_outputs=False,
|
||||||
|
)
|
||||||
268
services/rag/llamaindex/prefect/01_yadisk_predefined_enrich.py
Normal file
268
services/rag/llamaindex/prefect/01_yadisk_predefined_enrich.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""
|
||||||
|
Prefect flow for enriching documents from a predefined YaDisk file list.
|
||||||
|
|
||||||
|
Flow steps:
|
||||||
|
1. Load file paths from ../../../yadisk_files.json
|
||||||
|
2. Filter them by supported enrichment extensions
|
||||||
|
3. Download each file from YaDisk asynchronously
|
||||||
|
4. Enrich each downloaded file into vector storage
|
||||||
|
5. Remove downloaded temporary files after processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||||
|
load_dotenv(ROOT_DIR / ".env")
|
||||||
|
|
||||||
|
if str(ROOT_DIR) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT_DIR))
|
||||||
|
|
||||||
|
import yadisk
|
||||||
|
|
||||||
|
from enrichment import get_supported_enrichment_extensions, process_document_file
|
||||||
|
from prefect import flow, task
|
||||||
|
|
||||||
|
DEFAULT_YADISK_LIST_PATH = (ROOT_DIR / "../../../yadisk_files.json").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_prefect_flow_logging() -> None:
|
||||||
|
"""Configure loguru handlers for flow execution."""
|
||||||
|
logs_dir = ROOT_DIR / "logs"
|
||||||
|
logs_dir.mkdir(exist_ok=True)
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
str(logs_dir / "dev.log"),
|
||||||
|
rotation="10 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
|
||||||
|
)
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
||||||
|
colorize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_yadisk_paths(payload: Any) -> list[str]:
|
||||||
|
"""Extract a list of file paths from several common JSON shapes."""
|
||||||
|
if isinstance(payload, list):
|
||||||
|
return [str(item) for item in payload if isinstance(item, (str, Path))]
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
for key in ("paths", "files", "items"):
|
||||||
|
value = payload.get(key)
|
||||||
|
if isinstance(value, list):
|
||||||
|
normalized: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
normalized.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
for item_key in ("path", "remote_path", "file_path", "name"):
|
||||||
|
if item_key in item and item[item_key]:
|
||||||
|
normalized.append(str(item[item_key]))
|
||||||
|
break
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported yadisk_files.json structure. Expected list or dict with paths/files/items."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_temp_local_path(base_dir: Path, remote_path: str) -> Path:
|
||||||
|
"""Create a deterministic temp file path for a remote YaDisk path."""
|
||||||
|
remote_name = Path(remote_path).name or "downloaded_file"
|
||||||
|
suffix = Path(remote_name).suffix
|
||||||
|
stem = Path(remote_name).stem or "file"
|
||||||
|
digest = hashlib.md5(remote_path.encode("utf-8")).hexdigest()[:10]
|
||||||
|
safe_stem = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem)
|
||||||
|
return base_dir / f"{safe_stem}_{digest}{suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
@task(name="load_yadisk_paths")
|
||||||
|
def load_yadisk_paths(json_file_path: str) -> list[str]:
|
||||||
|
"""Load remote file paths from JSON file."""
|
||||||
|
path = Path(json_file_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"YaDisk paths JSON file not found: {path}")
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
payload = json.load(f)
|
||||||
|
|
||||||
|
paths = _normalize_yadisk_paths(payload)
|
||||||
|
logger.info(f"Loaded {len(paths)} paths from {path}")
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
@task(name="filter_supported_yadisk_paths")
|
||||||
|
def filter_supported_yadisk_paths(paths: list[str]) -> list[str]:
|
||||||
|
"""Keep only paths supported by enrichment extension filters."""
|
||||||
|
supported_extensions = get_supported_enrichment_extensions()
|
||||||
|
filtered = [p for p in paths if Path(str(p)).suffix.lower() in supported_extensions]
|
||||||
|
logger.info(
|
||||||
|
f"Filtered YaDisk paths: {len(filtered)}/{len(paths)} supported "
|
||||||
|
f"(extensions: {', '.join(sorted(supported_extensions))})"
|
||||||
|
)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_and_enrich_one(
|
||||||
|
client: yadisk.AsyncClient,
|
||||||
|
remote_path: str,
|
||||||
|
temp_dir: Path,
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
|
stats: dict[str, int],
|
||||||
|
pbar: tqdm,
|
||||||
|
pbar_lock: asyncio.Lock,
|
||||||
|
) -> None:
|
||||||
|
"""Download one YaDisk file, enrich it, remove it, and update stats."""
|
||||||
|
local_path = _make_temp_local_path(temp_dir, remote_path)
|
||||||
|
result_status = "error"
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
await client.download(remote_path, str(local_path))
|
||||||
|
|
||||||
|
# Run sync enrichment in a worker thread. Exceptions are swallowed below.
|
||||||
|
enrich_result = await asyncio.to_thread(
|
||||||
|
process_document_file, str(local_path)
|
||||||
|
)
|
||||||
|
result_status = str(enrich_result.get("status", "error"))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"YaDisk processed: remote={remote_path}, local={local_path}, status={result_status}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Explicitly swallow errors as requested.
|
||||||
|
logger.error(f"YaDisk processing error for {remote_path}: {e}")
|
||||||
|
result_status = "error"
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if local_path.exists():
|
||||||
|
local_path.unlink()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to remove temp file {local_path}: {cleanup_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with pbar_lock:
|
||||||
|
if result_status == "processed":
|
||||||
|
stats["processed"] += 1
|
||||||
|
elif result_status == "skipped":
|
||||||
|
stats["skipped"] += 1
|
||||||
|
else:
|
||||||
|
stats["errors"] += 1
|
||||||
|
stats["completed"] += 1
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_postfix(
|
||||||
|
processed=stats["processed"],
|
||||||
|
skipped=stats["skipped"],
|
||||||
|
errors=stats["errors"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@flow(name="yadisk_predefined_enrich")
|
||||||
|
async def yadisk_predefined_enrich_flow(
|
||||||
|
yadisk_json_path: str = str(DEFAULT_YADISK_LIST_PATH),
|
||||||
|
concurrency: int = 4,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Download and enrich YaDisk files listed in the JSON file using async YaDisk client.
|
||||||
|
"""
|
||||||
|
setup_prefect_flow_logging()
|
||||||
|
|
||||||
|
prefect_api_url = os.getenv("PREFECT_API_URL", "").strip()
|
||||||
|
yadisk_token = os.getenv("YADISK_TOKEN", "").strip()
|
||||||
|
|
||||||
|
if not prefect_api_url:
|
||||||
|
logger.warning("PREFECT_API_URL is not set in environment/.env")
|
||||||
|
else:
|
||||||
|
# Prefect reads this env var for API connectivity.
|
||||||
|
os.environ["PREFECT_API_URL"] = prefect_api_url
|
||||||
|
logger.info(f"Using Prefect API URL: {prefect_api_url}")
|
||||||
|
|
||||||
|
if not yadisk_token:
|
||||||
|
raise ValueError("YADISK_TOKEN is required in .env to access Yandex Disk")
|
||||||
|
|
||||||
|
all_paths = load_yadisk_paths(yadisk_json_path)
|
||||||
|
supported_paths = filter_supported_yadisk_paths(all_paths)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total": len(supported_paths),
|
||||||
|
"completed": 0,
|
||||||
|
"processed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
"errors": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not supported_paths:
|
||||||
|
logger.info("No supported YaDisk paths to process")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
concurrency = max(1, int(concurrency))
|
||||||
|
logger.info(
|
||||||
|
f"Starting async YaDisk enrichment for {len(supported_paths)} files with concurrency={concurrency}"
|
||||||
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
pbar_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(prefix="yadisk_enrich_") as temp_dir_str:
|
||||||
|
temp_dir = Path(temp_dir_str)
|
||||||
|
pbar = tqdm(total=len(supported_paths), desc="YaDisk enrich", unit="file")
|
||||||
|
try:
|
||||||
|
async with yadisk.AsyncClient(token=yadisk_token) as client:
|
||||||
|
try:
|
||||||
|
is_token_valid = await client.check_token()
|
||||||
|
logger.info(f"YaDisk token validation result: {is_token_valid}")
|
||||||
|
except Exception as token_check_error:
|
||||||
|
# Token check issues should not block processing attempts.
|
||||||
|
logger.warning(f"YaDisk token check failed: {token_check_error}")
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(
|
||||||
|
_download_and_enrich_one(
|
||||||
|
client=client,
|
||||||
|
remote_path=remote_path,
|
||||||
|
temp_dir=temp_dir,
|
||||||
|
semaphore=semaphore,
|
||||||
|
stats=stats,
|
||||||
|
pbar=pbar,
|
||||||
|
pbar_lock=pbar_lock,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for remote_path in supported_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
# Worker function swallows per-file errors, but keep gather resilient too.
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
finally:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"YaDisk enrichment flow finished. "
|
||||||
|
f"Total={stats['total']}, Completed={stats['completed']}, "
|
||||||
|
f"Processed={stats['processed']}, Skipped={stats['skipped']}, Errors={stats['errors']}"
|
||||||
|
)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("SERVING PREFECT FLOW FOR YANDEX DISK ENRICHMENT OF PREDEFINED PATHS")
|
||||||
|
yadisk_predefined_enrich_flow.serve()
|
||||||
174
services/rag/llamaindex/requirements.txt
Normal file
174
services/rag/llamaindex/requirements.txt
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
aiofiles==25.1.0
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.13.3
|
||||||
|
aiosignal==1.4.0
|
||||||
|
aiosqlite==0.22.1
|
||||||
|
alembic==1.18.4
|
||||||
|
amplitude-analytics==1.2.2
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.12.1
|
||||||
|
apprise==1.9.7
|
||||||
|
asgi-lifespan==2.1.0
|
||||||
|
asyncpg==0.31.0
|
||||||
|
attrs==25.4.0
|
||||||
|
banks==2.3.0
|
||||||
|
beartype==0.22.9
|
||||||
|
beautifulsoup4==4.14.3
|
||||||
|
cachetools==7.0.1
|
||||||
|
certifi==2026.1.4
|
||||||
|
cffi==2.0.0
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
click==8.3.1
|
||||||
|
cloudpickle==3.1.2
|
||||||
|
colorama==0.4.6
|
||||||
|
coolname==3.0.0
|
||||||
|
croniter==6.0.0
|
||||||
|
cryptography==46.0.5
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
dateparser==1.3.0
|
||||||
|
defusedxml==0.7.1
|
||||||
|
Deprecated==1.2.18
|
||||||
|
dirtyjson==1.0.8
|
||||||
|
distro==1.9.0
|
||||||
|
docker==7.1.0
|
||||||
|
docx2txt==0.9
|
||||||
|
et_xmlfile==2.0.0
|
||||||
|
exceptiongroup==1.3.1
|
||||||
|
fakeredis==2.34.0
|
||||||
|
fastapi==0.131.0
|
||||||
|
filetype==1.2.0
|
||||||
|
frozenlist==1.8.0
|
||||||
|
fsspec==2026.1.0
|
||||||
|
graphviz==0.21
|
||||||
|
greenlet==3.3.1
|
||||||
|
griffe==1.15.0
|
||||||
|
grpcio==1.76.0
|
||||||
|
h11==0.16.0
|
||||||
|
h2==4.3.0
|
||||||
|
hpack==4.1.0
|
||||||
|
httpcore==1.0.9
|
||||||
|
httpx==0.28.1
|
||||||
|
humanize==4.15.0
|
||||||
|
hyperframe==6.1.0
|
||||||
|
idna==3.11
|
||||||
|
importlib_metadata==8.7.1
|
||||||
|
iniconfig==2.3.0
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jinja2-humanize-extension==0.4.0
|
||||||
|
jiter==0.13.0
|
||||||
|
joblib==1.5.3
|
||||||
|
jsonpatch==1.33
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
jsonschema==4.26.0
|
||||||
|
jsonschema-specifications==2025.9.1
|
||||||
|
llama-cloud==0.1.35
|
||||||
|
llama-cloud-services==0.6.54
|
||||||
|
llama-index==0.14.13
|
||||||
|
llama-index-cli==0.5.3
|
||||||
|
llama-index-core==0.14.13
|
||||||
|
llama-index-embeddings-ollama==0.8.6
|
||||||
|
llama-index-embeddings-openai==0.5.1
|
||||||
|
llama-index-embeddings-openai-like==0.2.2
|
||||||
|
llama-index-indices-managed-llama-cloud==0.9.4
|
||||||
|
llama-index-instrumentation==0.4.2
|
||||||
|
llama-index-llms-ollama==0.9.1
|
||||||
|
llama-index-llms-openai==0.6.17
|
||||||
|
llama-index-llms-openai-like==0.6.0
|
||||||
|
llama-index-readers-file==0.5.6
|
||||||
|
llama-index-readers-llama-parse==0.5.1
|
||||||
|
llama-index-vector-stores-qdrant==0.9.1
|
||||||
|
llama-index-workflows==2.13.1
|
||||||
|
llama-parse==0.6.54
|
||||||
|
loguru==0.7.3
|
||||||
|
lupa==2.6
|
||||||
|
lxml==6.0.2
|
||||||
|
Mako==1.3.10
|
||||||
|
Markdown==3.10.2
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
marshmallow==3.26.2
|
||||||
|
mdurl==0.1.2
|
||||||
|
multidict==6.7.1
|
||||||
|
mypy_extensions==1.1.0
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
networkx==3.6.1
|
||||||
|
nltk==3.9.2
|
||||||
|
numpy==2.4.2
|
||||||
|
oauthlib==3.3.1
|
||||||
|
ollama==0.6.1
|
||||||
|
openai==2.16.0
|
||||||
|
openpyxl==3.1.5
|
||||||
|
opentelemetry-api==1.39.1
|
||||||
|
orjson==3.11.7
|
||||||
|
packaging==25.0
|
||||||
|
pandas==2.3.3
|
||||||
|
pathspec==1.0.4
|
||||||
|
patool==4.0.4
|
||||||
|
pendulum==3.2.0
|
||||||
|
pillow==12.1.0
|
||||||
|
platformdirs==4.5.1
|
||||||
|
pluggy==1.6.0
|
||||||
|
portalocker==3.2.0
|
||||||
|
prefect==3.6.18
|
||||||
|
prometheus_client==0.24.1
|
||||||
|
propcache==0.4.1
|
||||||
|
protobuf==6.33.5
|
||||||
|
py-key-value-aio==0.4.4
|
||||||
|
pycparser==3.0
|
||||||
|
pydantic==2.12.5
|
||||||
|
pydantic-extra-types==2.11.0
|
||||||
|
pydantic-settings==2.13.1
|
||||||
|
pydantic_core==2.41.5
|
||||||
|
pydocket==0.17.9
|
||||||
|
Pygments==2.19.2
|
||||||
|
pypdf==6.6.2
|
||||||
|
pytest==9.0.2
|
||||||
|
pytest-asyncio==1.3.0
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-dotenv==1.2.1
|
||||||
|
python-json-logger==4.0.0
|
||||||
|
python-pptx==1.0.2
|
||||||
|
python-slugify==8.0.4
|
||||||
|
pytz==2025.2
|
||||||
|
PyYAML==6.0.3
|
||||||
|
qdrant-client==1.16.2
|
||||||
|
readchar==4.2.1
|
||||||
|
redis==7.2.0
|
||||||
|
referencing==0.37.0
|
||||||
|
regex==2026.1.15
|
||||||
|
requests==2.32.5
|
||||||
|
requests-oauthlib==2.0.0
|
||||||
|
rfc3339-validator==0.1.4
|
||||||
|
rich==14.3.3
|
||||||
|
rpds-py==0.30.0
|
||||||
|
ruamel.yaml==0.19.1
|
||||||
|
ruamel.yaml.clib==0.2.15
|
||||||
|
semver==3.0.4
|
||||||
|
shellingham==1.5.4
|
||||||
|
six==1.17.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
soupsieve==2.8.3
|
||||||
|
SQLAlchemy==2.0.46
|
||||||
|
starlette==0.52.1
|
||||||
|
striprtf==0.0.26
|
||||||
|
tenacity==9.1.2
|
||||||
|
text-unidecode==1.3
|
||||||
|
tiktoken==0.12.0
|
||||||
|
toml==0.10.2
|
||||||
|
tqdm==4.67.3
|
||||||
|
typer==0.24.1
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2025.3
|
||||||
|
tzlocal==5.3.1
|
||||||
|
urllib3==2.6.3
|
||||||
|
uvicorn==0.41.0
|
||||||
|
websockets==16.0
|
||||||
|
wrapt==1.17.3
|
||||||
|
xlsxwriter==3.2.9
|
||||||
|
yadisk==3.4.0
|
||||||
|
yarl==1.22.0
|
||||||
|
zipp==3.23.0
|
||||||
@@ -14,7 +14,7 @@ from llama_index.core.retrievers import VectorIndexRetriever
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from vector_storage import get_vector_store_and_index
|
from vector_storage import get_qdrant_connection_config, get_vector_store_and_index
|
||||||
|
|
||||||
# Import the new configuration module
|
# Import the new configuration module
|
||||||
from config import setup_global_models
|
from config import setup_global_models
|
||||||
@@ -23,8 +23,9 @@ from config import setup_global_models
|
|||||||
def initialize_retriever(
|
def initialize_retriever(
|
||||||
collection_name: str = "documents_llamaindex",
|
collection_name: str = "documents_llamaindex",
|
||||||
similarity_top_k: int = 5,
|
similarity_top_k: int = 5,
|
||||||
host: str = "localhost",
|
host: str | None = None,
|
||||||
port: int = 6333
|
port: int | None = None,
|
||||||
|
grpc_port: int | None = None,
|
||||||
) -> RetrieverQueryEngine:
|
) -> RetrieverQueryEngine:
|
||||||
"""
|
"""
|
||||||
Initialize the retriever query engine with the vector store.
|
Initialize the retriever query engine with the vector store.
|
||||||
@@ -32,8 +33,9 @@ def initialize_retriever(
|
|||||||
Args:
|
Args:
|
||||||
collection_name: Name of the Qdrant collection
|
collection_name: Name of the Qdrant collection
|
||||||
similarity_top_k: Number of top similar documents to retrieve
|
similarity_top_k: Number of top similar documents to retrieve
|
||||||
host: Qdrant host address
|
host: Qdrant host address (defaults to QDRANT_HOST from .env)
|
||||||
port: Qdrant REST API port
|
port: Qdrant REST API port (defaults to QDRANT_REST_PORT from .env)
|
||||||
|
grpc_port: Qdrant gRPC API port (defaults to QDRANT_GRPC_PORT from .env)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RetrieverQueryEngine configured with the vector store
|
RetrieverQueryEngine configured with the vector store
|
||||||
@@ -44,8 +46,23 @@ def initialize_retriever(
|
|||||||
# Set up the global models to prevent defaulting to OpenAI
|
# Set up the global models to prevent defaulting to OpenAI
|
||||||
setup_global_models()
|
setup_global_models()
|
||||||
|
|
||||||
|
qdrant_config = get_qdrant_connection_config()
|
||||||
|
resolved_host = host or str(qdrant_config["host"])
|
||||||
|
resolved_port = port or int(qdrant_config["port"])
|
||||||
|
resolved_grpc_port = grpc_port or int(qdrant_config["grpc_port"])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Retriever Qdrant connection: host={resolved_host}, "
|
||||||
|
f"rest_port={resolved_port}, grpc_port={resolved_grpc_port}"
|
||||||
|
)
|
||||||
|
|
||||||
# Get the vector store and index from the existing configuration
|
# Get the vector store and index from the existing configuration
|
||||||
vector_store, index = get_vector_store_and_index()
|
vector_store, index = get_vector_store_and_index(
|
||||||
|
collection_name=collection_name,
|
||||||
|
host=resolved_host,
|
||||||
|
port=resolved_port,
|
||||||
|
grpc_port=resolved_grpc_port,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a retriever from the index
|
# Create a retriever from the index
|
||||||
retriever = VectorIndexRetriever(
|
retriever = VectorIndexRetriever(
|
||||||
|
|||||||
173
services/rag/llamaindex/server.py
Normal file
173
services/rag/llamaindex/server.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
HTTP API server for querying the vector storage via the existing retrieval pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from chat_engine import chat_with_documents
|
||||||
|
from retrieval import initialize_retriever
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> None:
|
||||||
|
"""Configure loguru to stdout and rotating file logs."""
|
||||||
|
logs_dir = Path("logs")
|
||||||
|
logs_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
"logs/dev.log",
|
||||||
|
rotation="10 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
|
||||||
|
)
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
||||||
|
colorize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
app = FastAPI(title="LlamaIndex RAG API", version="1.0.0")
|
||||||
|
|
||||||
|
origins = [
|
||||||
|
"*",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins, # In production, configure this properly
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
allow_private_network=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryRequest(BaseModel):
|
||||||
|
query: str = Field(..., min_length=1, description="User query text")
|
||||||
|
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
|
||||||
|
mode: str = Field(
|
||||||
|
"agent",
|
||||||
|
description="agent (Phase 12 default) or retrieval (fallback/debug)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceItem(BaseModel):
|
||||||
|
content: str
|
||||||
|
score: float | None = None
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryResponse(BaseModel):
|
||||||
|
query: str
|
||||||
|
response: str
|
||||||
|
sources: list[SourceItem]
|
||||||
|
mode: str | None = None
|
||||||
|
error: bool
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/test-query", response_model=TestQueryResponse)
|
||||||
|
def test_query(payload: TestQueryRequest) -> TestQueryResponse:
|
||||||
|
"""
|
||||||
|
Query the vector store using the existing retrieval/query engine.
|
||||||
|
"""
|
||||||
|
query = payload.query.strip()
|
||||||
|
if not query:
|
||||||
|
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if payload.mode.lower() in {"agent", "chat"}:
|
||||||
|
chat_result = chat_with_documents(
|
||||||
|
query=query,
|
||||||
|
top_k=payload.top_k,
|
||||||
|
use_agent=True,
|
||||||
|
)
|
||||||
|
sources = [
|
||||||
|
SourceItem(
|
||||||
|
content=str(src.get("content_snippet", "")),
|
||||||
|
score=src.get("score"),
|
||||||
|
metadata=src.get("metadata", {}) or {},
|
||||||
|
)
|
||||||
|
for src in chat_result.get("sources", [])
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
|
||||||
|
)
|
||||||
|
return TestQueryResponse(
|
||||||
|
query=query,
|
||||||
|
response=str(chat_result.get("answer", "")),
|
||||||
|
sources=sources,
|
||||||
|
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
|
||||||
|
error=False,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload.mode.lower() not in {"retrieval", "debug"}:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
|
||||||
|
result = query_engine.query(query)
|
||||||
|
|
||||||
|
sources: list[SourceItem] = []
|
||||||
|
if hasattr(result, "source_nodes"):
|
||||||
|
for node in result.source_nodes:
|
||||||
|
sources.append(
|
||||||
|
SourceItem(
|
||||||
|
content=str(getattr(node, "text", "")),
|
||||||
|
score=getattr(node, "score", None),
|
||||||
|
metadata=getattr(node, "metadata", {}) or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_text = str(result)
|
||||||
|
logger.info(
|
||||||
|
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
|
||||||
|
)
|
||||||
|
return TestQueryResponse(
|
||||||
|
query=query,
|
||||||
|
response=response_text,
|
||||||
|
sources=sources,
|
||||||
|
mode="retrieval",
|
||||||
|
error=False,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"/api/test-query failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to process query")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run("server:app", host="0.0.0.0", port=8334, reload=False)
|
||||||
@@ -10,6 +10,7 @@ This module provides initialization and configuration for:
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -18,12 +19,26 @@ from qdrant_client import QdrantClient
|
|||||||
# Import the new configuration module
|
# Import the new configuration module
|
||||||
from config import get_embedding_model
|
from config import get_embedding_model
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def get_qdrant_connection_config() -> dict[str, int | str]:
|
||||||
|
"""Load Qdrant connection settings from environment variables."""
|
||||||
|
host = os.getenv("QDRANT_HOST", "localhost")
|
||||||
|
rest_port = int(os.getenv("QDRANT_REST_PORT", "6333"))
|
||||||
|
grpc_port = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
|
||||||
|
return {
|
||||||
|
"host": host,
|
||||||
|
"port": rest_port,
|
||||||
|
"grpc_port": grpc_port,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def initialize_vector_storage(
|
def initialize_vector_storage(
|
||||||
collection_name: str = "documents_llamaindex",
|
collection_name: str = "documents_llamaindex",
|
||||||
host: str = "localhost",
|
host: Optional[str] = None,
|
||||||
port: int = 6333,
|
port: Optional[int] = None,
|
||||||
grpc_port: int = 6334,
|
grpc_port: Optional[int] = None,
|
||||||
) -> tuple[QdrantVectorStore, VectorStoreIndex]:
|
) -> tuple[QdrantVectorStore, VectorStoreIndex]:
|
||||||
"""
|
"""
|
||||||
Initialize Qdrant vector storage with embedding model based on configured strategy.
|
Initialize Qdrant vector storage with embedding model based on configured strategy.
|
||||||
@@ -37,11 +52,19 @@ def initialize_vector_storage(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (QdrantVectorStore, VectorStoreIndex)
|
Tuple of (QdrantVectorStore, VectorStoreIndex)
|
||||||
"""
|
"""
|
||||||
logger.info(f"Initializing vector storage with collection: {collection_name}")
|
qdrant_config = get_qdrant_connection_config()
|
||||||
|
host = host or str(qdrant_config["host"])
|
||||||
|
port = port or int(qdrant_config["port"])
|
||||||
|
grpc_port = grpc_port or int(qdrant_config["grpc_port"])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initializing vector storage with collection: {collection_name} "
|
||||||
|
f"(host={host}, rest_port={port}, grpc_port={grpc_port})"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Qdrant client
|
# Initialize Qdrant client
|
||||||
client = QdrantClient(host=host, port=port)
|
client = QdrantClient(host=host, port=port, grpc_port=grpc_port)
|
||||||
|
|
||||||
# Get the embedding model based on the configured strategy
|
# Get the embedding model based on the configured strategy
|
||||||
embed_model = get_embedding_model()
|
embed_model = get_embedding_model()
|
||||||
@@ -131,14 +154,24 @@ def initialize_vector_storage(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store_and_index() -> tuple[QdrantVectorStore, VectorStoreIndex]:
|
def get_vector_store_and_index(
|
||||||
|
collection_name: str = "documents_llamaindex",
|
||||||
|
host: Optional[str] = None,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
grpc_port: Optional[int] = None,
|
||||||
|
) -> tuple[QdrantVectorStore, VectorStoreIndex]:
|
||||||
"""
|
"""
|
||||||
Convenience function to get the initialized vector store and index.
|
Convenience function to get the initialized vector store and index.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (QdrantVectorStore, VectorStoreIndex)
|
Tuple of (QdrantVectorStore, VectorStoreIndex)
|
||||||
"""
|
"""
|
||||||
return initialize_vector_storage()
|
return initialize_vector_storage(
|
||||||
|
collection_name=collection_name,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
grpc_port=grpc_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
97955
yadisk_files.json
Normal file
97955
yadisk_files.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user