Compare commits

..

11 Commits

18 changed files with 4029 additions and 11 deletions

6
.gitignore vendored
View File

@@ -1,2 +1,8 @@
data-unpacked-archives
data-broken-archives
.env
tmp/
__pycache__
venv
services/rag/.DS_Store
EVALUATION_RESULT.json

2314
DOCUMENTS_TO_TEST.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,377 @@
#!/usr/bin/env python3
from __future__ import annotations
import hashlib
import json
import os
import random
import re
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
import requests
try:
from langchain_community.document_loaders import PyPDFLoader, TextLoader
except Exception: # pragma: no cover
PyPDFLoader = None
TextLoader = None
try:
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
except Exception: # pragma: no cover
UnstructuredWordDocumentLoader = None
try:
from langchain_community.document_loaders import UnstructuredPowerPointLoader
except Exception: # pragma: no cover
UnstructuredPowerPointLoader = None
try:
from langchain_community.document_loaders import UnstructuredExcelLoader
except Exception: # pragma: no cover
UnstructuredExcelLoader = None
try:
from langchain_community.document_loaders import UnstructuredODTLoader
except Exception: # pragma: no cover
UnstructuredODTLoader = None
ROOT = Path(__file__).resolve().parent
LANGCHAIN_DIR = ROOT / "services" / "rag" / "langchain"
LLAMAINDEX_DIR = ROOT / "services" / "rag" / "llamaindex"
YADISK_JSON = ROOT / "yadisk_files.json"
OUTPUT_MD = ROOT / "DOCUMENTS_TO_TEST.md"
def safe_stem_from_remote(remote_path: str) -> str:
stem = Path(Path(remote_path).name).stem or "file"
return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem)
def llama_prefect_filename(remote_path: str) -> str:
remote_name = Path(remote_path).name or "downloaded_file"
suffix = Path(remote_name).suffix
digest = hashlib.md5(remote_path.encode("utf-8")).hexdigest()[:10]
return f"{safe_stem_from_remote(remote_path)}_{digest}{suffix}"
def get_loader(local_path: str):
ext = Path(local_path).suffix.lower()
if ext == ".pdf" and PyPDFLoader is not None:
return PyPDFLoader(local_path)
if ext in {".doc", ".docx"} and UnstructuredWordDocumentLoader is not None:
return UnstructuredWordDocumentLoader(
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
if ext == ".pptx" and UnstructuredPowerPointLoader is not None:
return UnstructuredPowerPointLoader(
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
if ext in {".xls", ".xlsx"} and UnstructuredExcelLoader is not None:
return UnstructuredExcelLoader(
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
if ext == ".odt" and UnstructuredODTLoader is not None:
return UnstructuredODTLoader(
local_path, **{"strategy": "hi_res", "languages": ["rus"]}
)
if ext in {".txt", ".md"} and TextLoader is not None:
return TextLoader(local_path, encoding="utf-8")
return None
def supported_loader_extensions() -> set[str]:
exts = set()
if PyPDFLoader is not None:
exts.add(".pdf")
if UnstructuredWordDocumentLoader is not None:
exts.update({".doc", ".docx"})
if UnstructuredPowerPointLoader is not None:
exts.add(".pptx")
if UnstructuredExcelLoader is not None:
exts.update({".xls", ".xlsx"})
if UnstructuredODTLoader is not None:
exts.add(".odt")
if TextLoader is not None:
exts.update({".txt", ".md"})
return exts
def collect_langchain_paths(client: QdrantClient) -> set[str]:
paths: set[str] = set()
offset = None
while True:
points, offset = client.scroll(
collection_name="documents_langchain",
offset=offset,
limit=1000,
with_payload=True,
with_vectors=False,
)
if not points:
break
for p in points:
payload = p.payload or {}
md = payload.get("metadata") or {}
fp = md.get("file_path") or md.get("source")
if isinstance(fp, str) and fp:
paths.add(fp)
if offset is None:
break
return paths
def collect_llama_filenames(client: QdrantClient) -> set[str]:
names: set[str] = set()
offset = None
while True:
points, offset = client.scroll(
collection_name="documents_llamaindex",
offset=offset,
limit=1000,
with_payload=True,
with_vectors=False,
)
if not points:
break
for p in points:
payload = p.payload or {}
name = payload.get("filename")
if isinstance(name, str) and name:
names.add(name)
if offset is None:
break
return names
def first_unique(matches: list[str], fallback: str) -> str:
for m in matches:
m = m.strip()
if m:
return m
return fallback
def build_questions(remote_path: str, text: str) -> dict[str, list[str]]:
text = " ".join((text or "").split())
text_preview = text[:15000]
years = sorted(
{
int(m)
for m in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", text_preview)
if 1900 <= int(m) <= 2199
}
)
dates = re.findall(
r"\b(?:\d{1,2}[./-]\d{1,2}[./-]\d{2,4}|\d{4}[./-]\d{1,2}[./-]\d{1,2})\b",
text_preview,
)
numbers = re.findall(r"\b\d{2,}\b", text_preview)
quoted = re.findall(r"[\"«]([^\"»\n]{4,120})[\"»]", text_preview)
org_like = re.findall(
r"\b(?:ООО|АО|ПАО|ФГУП|Минтранс|Министерств[ао]|Правительств[ао]|Форум)\b[^\n,.]{0,80}",
text_preview,
flags=re.IGNORECASE,
)
year_q = (
f"В каком году в документе описывается ключевое событие ({years[0]}) и как это подтверждается контекстом?"
if years
else "Есть ли в документе указание на год события? Если да, какой именно год упомянут?"
)
date_q = (
f"Какая дата ({dates[0]}) встречается в документе и к какому событию/разделу она относится?"
if dates
else "Какие календарные даты или периоды (если есть) упомянуты в документе?"
)
num_q = (
f"Какое числовое значение ({numbers[0]}) встречается в документе и в каком контексте оно используется?"
if numbers
else "Есть ли в документе количественные показатели (суммы, проценты, номера, объемы) и что они обозначают?"
)
entity = first_unique(quoted, first_unique(org_like, Path(remote_path).name))
topic_hint = Path(remote_path).stem.replace("_", " ").replace("-", " ")
topic_hint = " ".join(topic_hint.split())[:120]
entity_q = f"Что в документе говорится про «{entity}»?"
return {
"Entity/Fact Recall (Response Relevance)": [
f"Что известно про «{entity}» в материалах базы?",
f"В контексте темы «{topic_hint}» кто выступает ключевым участником и какова его роль?",
],
"Numerical & Temporal Precision": [
year_q.replace("в документе", "в материалах").replace("документе", "материалах"),
date_q.replace("в документе", "в материалах").replace("документе", "материалах"),
num_q.replace("в документе", "в материалах").replace("документе", "материалах"),
],
"Context Precision (Evidence-anchored)": [
f"Найди в базе фрагмент, который лучше всего подтверждает тезис по теме «{topic_hint}», и объясни его релевантность.",
f"Есть ли в базе схожие по теме «{topic_hint}», но нерелевантные фрагменты, которые можно ошибочно выбрать?",
],
"Faithfulness / Non-hallucination": [
f"Какая информация по теме «{topic_hint}» отсутствует в найденном контексте и не должна быть додумана?",
f"Если прямого ответа по теме «{topic_hint}» в материалах нет, как корректно ответить без галлюцинаций?",
],
"Reasoning & Synthesis": [
f"Сформулируй краткий вывод по теме «{topic_hint}» в 2-3 пунктах, опираясь на несколько найденных фрагментов.",
f"Какие ограничения, риски или условия по теме «{topic_hint}» упоминаются в материалах, и как они влияют на вывод?",
],
}
def extract_document_text(docs: list[Any]) -> str:
chunks: list[str] = []
for doc in docs:
content = getattr(doc, "page_content", None)
if content is None:
content = getattr(doc, "text", None)
if isinstance(content, str) and content.strip():
chunks.append(content.strip())
if len(" ".join(chunks)) > 25000:
break
return "\n".join(chunks)[:25000]
def download_yadisk_file(remote_path: str, token: str, local_path: str) -> None:
headers = {"Authorization": f"OAuth {token}"}
response = requests.get(
"https://cloud-api.yandex.net/v1/disk/resources/download",
headers=headers,
params={"path": remote_path},
timeout=30,
)
response.raise_for_status()
href = response.json()["href"]
file_response = requests.get(href, timeout=180)
file_response.raise_for_status()
with open(local_path, "wb") as f:
f.write(file_response.content)
def fetch_text_from_yadisk(remote_path: str, token: str) -> str:
suffix = Path(remote_path).suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
local_path = tmp.name
try:
download_yadisk_file(remote_path, token, local_path)
loader = get_loader(local_path)
if loader is None:
return ""
docs = loader.load()
return extract_document_text(docs)
finally:
if os.path.exists(local_path):
os.unlink(local_path)
def main() -> int:
load_dotenv(LANGCHAIN_DIR / ".env")
qdrant_host = os.getenv("QDRANT_HOST")
qdrant_rest_port = int(os.getenv("QDRANT_REST_PORT", "6333"))
yadisk_token = os.getenv("YADISK_TOKEN", "").strip()
if not qdrant_host:
raise RuntimeError("QDRANT_HOST is missing in langchain .env")
if not yadisk_token:
raise RuntimeError("YADISK_TOKEN is missing in langchain .env")
with YADISK_JSON.open("r", encoding="utf-8") as f:
raw_paths = json.load(f)
if not isinstance(raw_paths, list):
raise RuntimeError("yadisk_files.json must be a JSON list of paths")
all_paths = [str(x) for x in raw_paths if isinstance(x, str)]
allowed_ext = supported_loader_extensions()
filtered_by_ext = [
p for p in all_paths if Path(p).suffix.lower() in allowed_ext and p.startswith("disk:/")
]
client = QdrantClient(host=qdrant_host, port=qdrant_rest_port, timeout=60)
langchain_paths = collect_langchain_paths(client)
llama_filenames = collect_llama_filenames(client)
candidates = []
for path in filtered_by_ext:
if path not in langchain_paths:
continue
if llama_prefect_filename(path) not in llama_filenames:
continue
candidates.append(path)
random.seed(42)
random.shuffle(candidates)
if len(candidates) < 100:
raise RuntimeError(
f"Only {len(candidates)} candidate documents found in both collections; need 100"
)
rows: list[dict[str, Any]] = []
attempts = 0
for remote_path in candidates:
if len(rows) >= 100:
break
attempts += 1
idx = len(rows) + 1
print(f"[TRY {attempts:03d}] loading {remote_path}")
try:
text = fetch_text_from_yadisk(remote_path, yadisk_token)
except Exception as e:
print(f" -> skip (download/read error): {e}")
continue
if not text.strip():
print(" -> skip (empty extracted text)")
continue
rows.append(
{
"index": idx,
"path": remote_path,
"questions": build_questions(remote_path, text),
}
)
print(f"[OK {idx:03d}/100] prepared questions for {remote_path}")
if len(rows) < 100:
raise RuntimeError(
f"Only {len(rows)} documents were successfully downloaded/read and turned into questions"
)
lines: list[str] = []
lines.append("# DOCUMENTS_TO_TEST")
lines.append("")
lines.append("This dataset contains 100 YaDisk documents that were verified as present in both:")
lines.append("- `documents_langchain` (Qdrant)")
lines.append("- `documents_llamaindex` (Qdrant)")
lines.append("")
lines.append("Question sections are aligned with common RAG evaluation themes (retrieval + generation):")
lines.append("- Response relevance / entity-fact recall")
lines.append("- Numerical and temporal precision")
lines.append("- Context precision")
lines.append("- Faithfulness / non-hallucination")
lines.append("- Reasoning / synthesis")
lines.append("")
lines.append(
"_References used for evaluation themes: RAGAS metrics and NVIDIA RAG pipeline evaluation docs._"
)
lines.append("")
for row in rows:
lines.append(f"## {row['index']:03d}. `{row['path']}`")
lines.append("")
for section, qs in row["questions"].items():
lines.append(f"### {section}")
for q in qs:
lines.append(f"- {q}")
lines.append("")
OUTPUT_MD.write_text("\n".join(lines), encoding="utf-8")
print(f"Written: {OUTPUT_MD}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

732
rag_evaluation.py Normal file
View File

@@ -0,0 +1,732 @@
#!/usr/bin/env python3
"""
RAG evaluation script (file-batch mode).
Key behavior:
- Step = one document file (all its questions), not one question.
- Pre-download/caching in ./tmp/rag-evaluation (skip if already downloaded).
- Sequential API calls only (LangChain then LlamaIndex).
- Pairwise answer evaluation (both systems in one judge prompt).
- JSON output with append/overwrite support for batch runs and re-runs.
"""
from __future__ import annotations
import argparse
import datetime as dt
import json
import os
import re
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
try:
import requests
except ImportError as e: # pragma: no cover
raise SystemExit(
"Missing dependency: requests. Run with your project venv "
"(for example services/rag/langchain/venv/bin/python rag_evaluation.py ...)"
) from e
from dotenv import load_dotenv
load_dotenv()
# =============================================================================
# Configuration
# =============================================================================
LANGCHAIN_URL = os.getenv("LANGCHAIN_URL", "http://localhost:8331/api/test-query")
LLAMAINDEX_URL = os.getenv("LLAMAINDEX_URL", "http://localhost:8334/api/test-query")
# OpenAI-compatible evaluator endpoint. You can point this at OpenAI-compatible providers.
OPENAI_CHAT_URL = os.getenv(
"OPENAI_CHAT_URL", "https://foundation-models.api.cloud.ru/v1"
)
OPENAI_CHAT_KEY = os.getenv("OPENAI_CHAT_KEY", "")
OPENAI_CHAT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "MiniMaxAI/MiniMax-M2")
YADISK_TOKEN = os.getenv("YADISK_TOKEN", "")
BASE_DIR = Path(__file__).resolve().parent
INPUT_MD = BASE_DIR / "DOCUMENTS_TO_TEST.md"
OUTPUT_JSON = BASE_DIR / "EVALUATION_RESULT.json"
TMP_DIR = BASE_DIR / "tmp" / "rag-evaluation"
RAG_TIMEOUT = int(os.getenv("RAG_TIMEOUT", "120"))
EVAL_TIMEOUT = int(os.getenv("EVAL_TIMEOUT", "90"))
YADISK_META_TIMEOUT = int(os.getenv("YADISK_META_TIMEOUT", "30"))
YADISK_DOWNLOAD_TIMEOUT = int(os.getenv("YADISK_DOWNLOAD_TIMEOUT", "180"))
# =============================================================================
# Data structures
# =============================================================================
@dataclass
class QuestionResult:
section: str
question: str
langchain_answer: str = ""
llamaindex_answer: str = ""
langchain_score: float = 0.0
llamaindex_score: float = 0.0
winner: str = "Tie"
rationale: str = ""
evaluator_model: str = ""
evaluated_at: str = ""
@dataclass
class DocumentEvaluation:
index: int
path: str
cache_file: str = ""
cache_status: str = ""
questions: list[QuestionResult] = field(default_factory=list)
started_at: str = ""
finished_at: str = ""
# =============================================================================
# Markdown parsing
# =============================================================================
def split_documents(md_text: str) -> tuple[list[str], list[str]]:
lines = md_text.splitlines()
header: list[str] = []
docs: list[list[str]] = []
current: list[str] | None = None
for line in lines:
if line.startswith("## "):
if current is not None:
docs.append(current)
current = [line]
else:
if current is None:
header.append(line)
else:
current.append(line)
if current is not None:
docs.append(current)
return header, ["\n".join(d) for d in docs]
def parse_document_block(idx: int, block: str) -> tuple[str, list[QuestionResult]]:
lines = block.splitlines()
header = lines[0].strip()
m = re.search(r"`([^`]+)`", header)
doc_path = m.group(1) if m else ""
section = ""
questions: list[QuestionResult] = []
for line in lines[1:]:
if line.startswith("### "):
section = line[4:].strip()
elif line.startswith("- "):
q = line[2:].strip()
if q:
questions.append(QuestionResult(section=section, question=q))
return doc_path, questions
def parse_all_docs(md_path: Path) -> list[tuple[int, str, list[QuestionResult]]]:
raw = md_path.read_text(encoding="utf-8")
_, blocks = split_documents(raw)
parsed: list[tuple[int, str, list[QuestionResult]]] = []
for i, block in enumerate(blocks, start=1):
path, questions = parse_document_block(i, block)
parsed.append((i, path, questions))
return parsed
# =============================================================================
# Caching / Yandex Disk
# =============================================================================
def cache_file_name(remote_path: str) -> str:
# Deterministic local cache filename
digest = re.sub(r"[^a-z0-9]", "", str(abs(hash(remote_path))))[:12]
suffix = Path(remote_path).suffix or ".bin"
return f"{digest}{suffix}"
def download_yadisk_to_cache(remote_path: str, token: str, cache_path: Path) -> str:
"""
Download file into cache path if missing.
Returns status: "cached_existing" | "downloaded" | "error:..."
"""
cache_path.parent.mkdir(parents=True, exist_ok=True)
if cache_path.exists() and cache_path.stat().st_size > 0:
return "cached_existing"
if not token:
return "error:missing_yadisk_token"
headers = {"Authorization": f"OAuth {token}"}
try:
r = requests.get(
"https://cloud-api.yandex.net/v1/disk/resources/download",
headers=headers,
params={"path": remote_path},
timeout=YADISK_META_TIMEOUT,
)
r.raise_for_status()
href = r.json()["href"]
f = requests.get(href, timeout=YADISK_DOWNLOAD_TIMEOUT)
f.raise_for_status()
cache_path.write_bytes(f.content)
if cache_path.stat().st_size == 0:
return "error:empty_download"
return "downloaded"
except Exception as e: # noqa: BLE001
return f"error:{e}"
# =============================================================================
# File text extraction (for evaluator context)
# =============================================================================
def extract_text_from_file(path: Path) -> str:
ext = path.suffix.lower()
if ext in {".txt", ".md", ".csv", ".json", ".xml", ".html", ".htm"}:
return path.read_text(encoding="utf-8", errors="ignore")
if ext in {".docx", ".doc"}:
try:
from docx import Document # type: ignore
doc = Document(str(path))
return "\n".join(p.text for p in doc.paragraphs)
except Exception as e: # noqa: BLE001
return f"[DOC parse error: {e}]"
if ext == ".pdf":
try:
import PyPDF2 # type: ignore
out: list[str] = []
with path.open("rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
out.append(page.extract_text() or "")
return "\n".join(out)
except Exception as e: # noqa: BLE001
return f"[PDF parse error: {e}]"
if ext in {".xlsx", ".xls"}:
try:
from openpyxl import load_workbook # type: ignore
wb = load_workbook(str(path), read_only=True)
out: list[str] = []
for ws in wb.worksheets:
for row in ws.iter_rows(values_only=True):
out.append("\t".join("" if c is None else str(c) for c in row))
if len(out) > 5000:
break
if len(out) > 5000:
break
return "\n".join(out)
except Exception as e: # noqa: BLE001
return f"[XLS parse error: {e}]"
# fallback
try:
return path.read_text(encoding="utf-8", errors="ignore")
except Exception:
return f"[Binary file: {path.name}]"
# =============================================================================
# RAG API calls (sequential)
# =============================================================================
def call_rag(url: str, query: str, timeout: int) -> str:
payload = {"query": query}
try:
r = requests.post(url, json=payload, timeout=timeout)
r.raise_for_status()
data = r.json()
text = data.get("response", "")
if text is None:
return ""
return str(text).strip()
except Exception as e: # noqa: BLE001
return f"ERROR: {e}"
def call_langchain(query: str, timeout: int) -> str:
return call_rag(LANGCHAIN_URL, query, timeout)
def call_llamaindex(query: str, timeout: int) -> str:
payload = {"query": query, "mode": "agent"}
try:
r = requests.post(LLAMAINDEX_URL, json=payload, timeout=timeout)
r.raise_for_status()
data = r.json()
text = data.get("response", "")
if text is None:
return ""
return str(text).strip()
except Exception as e: # noqa: BLE001
return f"ERROR: {e}"
# =============================================================================
# Evaluator
# =============================================================================
def _rule_score(answer: str) -> float:
if not answer or not answer.strip():
return 0.0
if answer.startswith("ERROR:"):
return -1.0
score = 0.3
if len(answer) > 120:
score += 0.2
if re.search(r"\d", answer):
score += 0.1
if re.search(r"[.!?]", answer):
score += 0.1
if re.search(r"(не найден|недостаточно|нет информации)", answer.lower()):
score += 0.05
return min(1.0, score)
SECTION_CRITERIA: dict[str, str] = {
"Entity/Fact Recall (Response Relevance)": "Оцени точность извлечения сущностей/фактов и релевантность вопросу.",
"Numerical & Temporal Precision": "Оцени точность чисел, дат, периодов и временных связей.",
"Context Precision (Evidence-anchored)": "Оцени, насколько ответ опирается на релевантный контекст без лишнего.",
"Faithfulness / Non-hallucination": "Оцени отсутствие галлюцинаций и корректное поведение при отсутствии фактов.",
"Reasoning & Synthesis": "Оцени качество синтеза фактов и логичность итогового вывода.",
}
def build_pair_eval_prompt(
question: str,
section: str,
langchain_answer: str,
llamaindex_answer: str,
document_text: str,
) -> str:
criteria = SECTION_CRITERIA.get(
section, "Оцени релевантность, точность и полезность."
)
context = document_text[:9000]
return f"""Ты судья качества RAG-ответов. Сравни два ответа на один вопрос.
Вопрос:
{question}
Секция оценки:
{section}
Критерий:
{criteria}
Ответ A (LangChain):
{langchain_answer}
Ответ B (LlamaIndex):
{llamaindex_answer}
Опорный контекст документа:
{context}
Верни ТОЛЬКО JSON:
{{
"langchain_score": <float от -1.0 до 1.0>,
"llamaindex_score": <float от -1.0 до 1.0>,
"winner": "LangChain|LlamaIndex|Tie",
"rationale": "<кратко по сути>"
}}
Правила:
- Технические ошибки/таймауты должны получать -1.0.
- Пустой ответ без ошибки = 0.0.
- Галлюцинации сильно штрафуются.
- Если разница незначительная, выбирай Tie.
"""
def evaluate_pair_with_llm(
question: str,
section: str,
langchain_answer: str,
llamaindex_answer: str,
document_text: str,
) -> tuple[float, float, str, str]:
# Deterministic short-circuit for technical failures
if langchain_answer.startswith("ERROR:") and llamaindex_answer.startswith("ERROR:"):
return -1.0, -1.0, "Tie", "Обе системы вернули техническую ошибку."
if langchain_answer.startswith("ERROR:"):
return (
-1.0,
_rule_score(llamaindex_answer),
"LlamaIndex",
"LangChain технически не ответил.",
)
if llamaindex_answer.startswith("ERROR:"):
return (
_rule_score(langchain_answer),
-1.0,
"LangChain",
"LlamaIndex технически не ответил.",
)
if not OPENAI_CHAT_KEY:
# fallback heuristic
lc = _rule_score(langchain_answer)
li = _rule_score(llamaindex_answer)
if abs(lc - li) < 0.05:
return lc, li, "Tie", "Эвристическая оценка без LLM (ключ не задан)."
return (
(lc, li, "LangChain", "Эвристическая оценка без LLM.")
if lc > li
else (
lc,
li,
"LlamaIndex",
"Эвристическая оценка без LLM.",
)
)
prompt = build_pair_eval_prompt(
question=question,
section=section,
langchain_answer=langchain_answer,
llamaindex_answer=llamaindex_answer,
document_text=document_text,
)
headers = {
"Authorization": f"Bearer {OPENAI_CHAT_KEY}",
"Content-Type": "application/json",
}
payload = {
"model": OPENAI_CHAT_MODEL,
"messages": [
{
"role": "system",
"content": "Ты строгий судья качества RAG. Отвечай только JSON.",
},
{"role": "user", "content": prompt},
],
"temperature": 0.0,
"max_tokens": 400,
}
try:
r = requests.post(
f"{OPENAI_CHAT_URL.rstrip('/')}/chat/completions",
headers=headers,
json=payload,
timeout=EVAL_TIMEOUT,
)
r.raise_for_status()
data = r.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
m = re.search(r"\{.*\}", content, re.DOTALL)
raw = m.group(0) if m else content
parsed = json.loads(raw)
lc = float(parsed.get("langchain_score", 0.0))
li = float(parsed.get("llamaindex_score", 0.0))
winner = str(parsed.get("winner", "Tie"))
rationale = str(parsed.get("rationale", ""))
if winner not in {"LangChain", "LlamaIndex", "Tie"}:
winner = "Tie"
return lc, li, winner, rationale
except Exception as e: # noqa: BLE001
lc = _rule_score(langchain_answer)
li = _rule_score(llamaindex_answer)
if abs(lc - li) < 0.05:
return lc, li, "Tie", f"Fallback heuristic; LLM eval error: {e}"
return (
(lc, li, "LangChain", f"Fallback heuristic; LLM eval error: {e}")
if lc > li
else (
lc,
li,
"LlamaIndex",
f"Fallback heuristic; LLM eval error: {e}",
)
)
# =============================================================================
# JSON storage
# =============================================================================
def now_iso() -> str:
return dt.datetime.now(dt.timezone.utc).isoformat()
def default_json_payload(
all_docs: list[tuple[int, str, list[QuestionResult]]],
) -> dict[str, Any]:
return {
"meta": {
"created_at": now_iso(),
"updated_at": now_iso(),
"input_file": str(INPUT_MD),
"langchain_url": LANGCHAIN_URL,
"llamaindex_url": LLAMAINDEX_URL,
"evaluator_model": OPENAI_CHAT_MODEL,
"notes": [
"step = one file (all file questions)",
"sequential API calls only",
"cache dir: ./tmp/rag-evaluation",
],
},
"documents": [
{
"index": idx,
"path": path,
"cache_file": "",
"cache_status": "not_processed",
"started_at": "",
"finished_at": "",
"questions": [asdict(q) for q in questions],
}
for idx, path, questions in all_docs
],
"batches": [],
}
def load_or_init_json(
all_docs: list[tuple[int, str, list[QuestionResult]]],
output_json: Path,
mode: str,
) -> dict[str, Any]:
if mode == "overwrite" or not output_json.exists():
return default_json_payload(all_docs)
try:
data = json.loads(output_json.read_text(encoding="utf-8"))
if "documents" not in data:
return default_json_payload(all_docs)
return data
except Exception:
return default_json_payload(all_docs)
def upsert_document_result(store: dict[str, Any], result: DocumentEvaluation) -> None:
docs = store.setdefault("documents", [])
for i, doc in enumerate(docs):
if doc.get("path") == result.path:
docs[i] = {
"index": result.index,
"path": result.path,
"cache_file": result.cache_file,
"cache_status": result.cache_status,
"started_at": result.started_at,
"finished_at": result.finished_at,
"questions": [asdict(q) for q in result.questions],
}
return
docs.append(
{
"index": result.index,
"path": result.path,
"cache_file": result.cache_file,
"cache_status": result.cache_status,
"started_at": result.started_at,
"finished_at": result.finished_at,
"questions": [asdict(q) for q in result.questions],
}
)
def update_batch_stats(store: dict[str, Any], batch_meta: dict[str, Any]) -> None:
store.setdefault("batches", []).append(batch_meta)
store.setdefault("meta", {})["updated_at"] = now_iso()
def atomic_write_json(path: Path, payload: dict[str, Any]) -> None:
"""Atomically write JSON to avoid partial/corrupted files on interruption."""
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(path)
def compute_batch_summary(results: list[DocumentEvaluation]) -> dict[str, Any]:
wins = {"LangChain": 0, "LlamaIndex": 0, "Tie": 0}
scores_lc: list[float] = []
scores_li: list[float] = []
q_total = 0
for d in results:
for q in d.questions:
q_total += 1
wins[q.winner] = wins.get(q.winner, 0) + 1
scores_lc.append(q.langchain_score)
scores_li.append(q.llamaindex_score)
avg_lc = sum(scores_lc) / max(1, len(scores_lc))
avg_li = sum(scores_li) / max(1, len(scores_li))
if avg_lc > avg_li + 0.01:
ranking = "LangChain"
elif avg_li > avg_lc + 0.01:
ranking = "LlamaIndex"
else:
ranking = "Tie"
return {
"documents_processed": len(results),
"questions_processed": q_total,
"wins": wins,
"avg_langchain": round(avg_lc, 4),
"avg_llamaindex": round(avg_li, 4),
"ranking": ranking,
}
# =============================================================================
# Main flow
# =============================================================================
def run_evaluation(doc_from: int, doc_to: int, mode: str) -> None:
all_docs = parse_all_docs(INPUT_MD)
total_docs = len(all_docs)
doc_from = max(1, doc_from)
doc_to = min(total_docs, doc_to)
if doc_from > doc_to:
raise ValueError(f"Invalid doc range: {doc_from}:{doc_to}")
store = load_or_init_json(all_docs, OUTPUT_JSON, mode)
TMP_DIR.mkdir(parents=True, exist_ok=True)
selected = [d for d in all_docs if doc_from <= d[0] <= doc_to]
print(
f"Total docs: {total_docs}. Processing docs {doc_from}:{doc_to} ({len(selected)} steps)."
)
print(f"Cache dir: {TMP_DIR}")
print(f"Output JSON: {OUTPUT_JSON}")
batch_results: list[DocumentEvaluation] = []
batch_started = now_iso()
for step, (idx, doc_path, questions) in enumerate(selected, start=1):
print(f"\n[STEP {step}/{len(selected)}] File #{idx}: {doc_path}")
started = now_iso()
cache_name = cache_file_name(doc_path)
cache_path = TMP_DIR / cache_name
cache_status = download_yadisk_to_cache(doc_path, YADISK_TOKEN, cache_path)
print(f" -> cache: {cache_status} ({cache_path})")
doc_text = ""
if cache_status.startswith("error:"):
doc_text = f"[CACHE_ERROR] {cache_status}"
else:
doc_text = extract_text_from_file(cache_path)
print(f" -> extracted text length: {len(doc_text)}")
evaluated_questions: list[QuestionResult] = []
for qn, q in enumerate(questions, start=1):
qr = QuestionResult(section=q.section, question=q.question)
print(f" [{qn}/{len(questions)}] {q.question[:90]}")
t0 = time.time()
qr.langchain_answer = call_langchain(q.question, timeout=RAG_TIMEOUT)
print(f" LangChain: {time.time() - t0:.1f}s")
t0 = time.time()
qr.llamaindex_answer = call_llamaindex(q.question, timeout=RAG_TIMEOUT)
print(f" LlamaIndex: {time.time() - t0:.1f}s")
lc, li, winner, rationale = evaluate_pair_with_llm(
question=q.question,
section=q.section,
langchain_answer=qr.langchain_answer,
llamaindex_answer=qr.llamaindex_answer,
document_text=doc_text,
)
qr.langchain_score = lc
qr.llamaindex_score = li
qr.winner = winner
qr.rationale = rationale
qr.evaluator_model = OPENAI_CHAT_MODEL
qr.evaluated_at = now_iso()
evaluated_questions.append(qr)
doc_result = DocumentEvaluation(
index=idx,
path=doc_path,
cache_file=str(cache_path),
cache_status=cache_status,
questions=evaluated_questions,
started_at=started,
finished_at=now_iso(),
)
upsert_document_result(store, doc_result)
batch_results.append(doc_result)
# Save incremental progress after each file/step
atomic_write_json(OUTPUT_JSON, store)
print(" -> step saved")
summary = compute_batch_summary(batch_results)
batch_meta = {
"started_at": batch_started,
"finished_at": now_iso(),
"range": f"{doc_from}:{doc_to}",
"summary": summary,
"mode": mode,
}
update_batch_stats(store, batch_meta)
atomic_write_json(OUTPUT_JSON, store)
print("\nBatch complete.")
print(json.dumps(summary, ensure_ascii=False, indent=2))
print(f"Saved to: {OUTPUT_JSON}")
def parse_range(value: str) -> tuple[int, int]:
m = re.fullmatch(r"(\d+):(\d+)", value.strip())
if not m:
raise argparse.ArgumentTypeError(
"Range must be in format from:to (example: 1:10)"
)
a, b = int(m.group(1)), int(m.group(2))
if a <= 0 or b <= 0:
raise argparse.ArgumentTypeError("Range values must be positive")
return a, b
def main() -> int:
parser = argparse.ArgumentParser(
description="RAG evaluation in file-batch mode (JSON output)"
)
parser.add_argument(
"doc_range",
type=parse_range,
help="Document range in format from:to (step = one file). Example: 1:10",
)
parser.add_argument(
"--mode",
choices=["append", "overwrite"],
default="append",
help="append: upsert evaluated docs into existing JSON; overwrite: rebuild JSON from input docs",
)
args = parser.parse_args()
doc_from, doc_to = args.doc_range
if "MiniMax" in OPENAI_CHAT_MODEL or "MiniMax" in OPENAI_CHAT_URL:
print(
"NOTE: evaluator model is MiniMax. It works, but for stricter judging quality, "
"gpt-4.1-mini/gpt-4.1 (if available on your endpoint) is usually stronger."
)
run_evaluation(doc_from=doc_from, doc_to=doc_to, mode=args.mode)
return 0
if __name__ == "__main__":
raise SystemExit(main())

8
requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
certifi==2026.2.25
charset-normalizer==3.4.5
dotenv==0.9.9
idna==3.11
python-dotenv==1.2.2
requests==2.32.5
urllib3==2.6.3
yadisk==3.4.0

BIN
services/rag/.DS_Store vendored

Binary file not shown.

View File

@@ -216,3 +216,6 @@ __marimo__/
.streamlit/secrets.toml
document_tracking.db
.env.test
yadisk_imported_paths.csv
yadisk_imported_paths.json

View File

@@ -125,3 +125,7 @@ During this Phase we create asynchronous process of enrichment, utilizing async/
- [x] Make tabbed UI with top level tabs. First tab exists and is selected. Each tab should have copy of demo ui, meaning the chat window with ability to specify the api url
- [x] At the end of the tabs there should be button with plus sign, which will add new tab. Tabs to be called by numbers.
- [x] There should predefined 3 tabs opened. First one should have predefined api url "https://rag.langchain.overwatch.su/api/test-query", second "https://rag.llamaindex.overwatch.su/api/test-query", third "https://rag.haystack.overwatch.su/api/test-query"
# Phase 17 (creating json with list of documents that are supported for import)
- [x] Make cli command that takes json file with list of paths, filters them to only those that are being imported into the vector storage (can be checked in enrichment), then this file should be saved in the current folder as "yadisk_imported_paths.json" and in "yadisk_imported_paths.csv" file. In case of CSV - it should be formatted as csv of course.

View File

@@ -20,6 +20,9 @@ from vector_storage import initialize_vector_store
# Load environment variables
load_dotenv()
CHAT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("CHAT_REQUEST_TIMEOUT_SECONDS", "45"))
CHAT_MAX_RETRIES = int(os.getenv("CHAT_MAX_RETRIES", "0"))
def get_llm_model_info(
llm_model: Optional[str] = None,
@@ -149,10 +152,12 @@ def create_chat_agent(
openai_api_base=base_url_or_api_base,
openai_api_key=api_key,
temperature=0.1,
request_timeout=CHAT_REQUEST_TIMEOUT_SECONDS,
max_retries=CHAT_MAX_RETRIES,
)
logger.info(
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}"
f"Using OpenAI-compatible model: {model_name} via {base_url_or_api_base}, timeout={CHAT_REQUEST_TIMEOUT_SECONDS}s, retries={CHAT_MAX_RETRIES}"
)
else: # Default to ollama
# Initialize the Ollama chat model
@@ -160,9 +165,13 @@ def create_chat_agent(
model=model_name,
base_url=base_url_or_api_base, # Default Ollama URL
temperature=0.1,
sync_client_kwargs={"timeout": CHAT_REQUEST_TIMEOUT_SECONDS},
async_client_kwargs={"timeout": CHAT_REQUEST_TIMEOUT_SECONDS},
)
logger.info(f"Using Ollama model: {model_name}")
logger.info(
f"Using Ollama model: {model_name}, timeout={CHAT_REQUEST_TIMEOUT_SECONDS}s"
)
# Create the document retrieval tool
retrieval_tool = DocumentRetrievalTool()

View File

@@ -1,4 +1,6 @@
import os
import csv
import json
from pathlib import Path
import click
@@ -126,5 +128,60 @@ def chat(collection_name, model):
click.echo(f"Error: {str(e)}")
@cli.command(
name="export-supported-paths",
help="Filter JSON paths by enrichment-supported extensions and export JSON/CSV",
)
@click.argument("input_json", type=click.Path(exists=True, dir_okay=False, path_type=Path))
def export_supported_paths(input_json: Path):
"""Export supported document paths into yadisk_imported_paths.json and yadisk_imported_paths.csv."""
logger.info(f"Filtering supported paths from input file: {input_json}")
try:
from enrichment import SUPPORTED_EXTENSIONS
with input_json.open("r", encoding="utf-8") as source_file:
raw_data = json.load(source_file)
if not isinstance(raw_data, list):
raise ValueError("Input JSON must contain an array of file paths")
filtered_paths = []
seen_paths = set()
for item in raw_data:
path_str = str(item).strip()
if not path_str:
continue
if path_str in seen_paths:
continue
extension = Path(path_str).suffix.lower()
if extension in SUPPORTED_EXTENSIONS:
filtered_paths.append(path_str)
seen_paths.add(path_str)
output_json = Path.cwd() / "yadisk_imported_paths.json"
output_csv = Path.cwd() / "yadisk_imported_paths.csv"
with output_json.open("w", encoding="utf-8") as output_json_file:
json.dump(filtered_paths, output_json_file, ensure_ascii=False, indent=2)
with output_csv.open("w", encoding="utf-8", newline="") as output_csv_file:
writer = csv.writer(output_csv_file)
writer.writerow(["path"])
for path_item in filtered_paths:
writer.writerow([path_item])
click.echo(
f"Export complete: {len(filtered_paths)} supported paths saved to {output_json.name} and {output_csv.name}"
)
logger.info(
f"Exported {len(filtered_paths)} supported paths to {output_json} and {output_csv}"
)
except Exception as error:
logger.error(f"Failed to export supported paths: {error}")
click.echo(f"Error: {error}")
if __name__ == "__main__":
cli()

View File

@@ -25,6 +25,10 @@ OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")
OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002")
OPENAI_EMBEDDING_BASE_URL = os.getenv("OPENAI_EMBEDDING_BASE_URL")
OPENAI_EMBEDDING_API_KEY = os.getenv("OPENAI_EMBEDDING_API_KEY")
EMBEDDING_REQUEST_TIMEOUT_SECONDS = float(
os.getenv("EMBEDDING_REQUEST_TIMEOUT_SECONDS", "30")
)
EMBEDDING_MAX_RETRIES = int(os.getenv("EMBEDDING_MAX_RETRIES", "0"))
def initialize_vector_store(
@@ -53,6 +57,8 @@ def initialize_vector_store(
model=OPENAI_EMBEDDING_MODEL,
openai_api_base=OPENAI_EMBEDDING_BASE_URL,
openai_api_key=OPENAI_EMBEDDING_API_KEY,
request_timeout=EMBEDDING_REQUEST_TIMEOUT_SECONDS,
max_retries=EMBEDDING_MAX_RETRIES,
)
elif EMBEDDING_STRATEGY == "none":
embeddings = None
@@ -63,6 +69,8 @@ def initialize_vector_store(
embeddings = OllamaEmbeddings(
model=OLLAMA_EMBEDDING_MODEL,
base_url="http://localhost:11434", # Default Ollama URL
sync_client_kwargs={"timeout": EMBEDDING_REQUEST_TIMEOUT_SECONDS},
async_client_kwargs={"timeout": EMBEDDING_REQUEST_TIMEOUT_SECONDS},
)
# Check if collection exists and create if needed

View File

@@ -14,6 +14,10 @@ QDRANT_GRPC_PORT=6334
# OpenAI Configuration (for reference - uncomment and configure when using OpenAI strategy)
# OPENAI_CHAT_URL=https://api.openai.com/v1
# OPENAI_CHAT_KEY=your_openai_api_key_here
# OPENAI_CHAT_TEMPERATURE=0.1
# OPENAI_CHAT_MAX_TOKENS=1024
# OPENAI_CHAT_REASONING_EFFORT=low
# OPENAI_CHAT_IS_FUNCTION_CALLING_MODEL=false
# OPENAI_EMBEDDING_MODEL=text-embedding-3-small
# OPENAI_EMBEDDING_BASE_URL=https://api.openai.com/v1
# OPENAI_EMBEDDING_API_KEY=your_openai_api_key_here

View File

@@ -69,3 +69,34 @@ Chosen data folder: relatve ./../../../data - from the current folder
- [x] Create file `server.py`, with web framework fastapi, for example
- [x] Add POST endpoint "/api/test-query" which will use agent, and retrieve response for query, sent in JSON format, field "query"
# Phase 12 (upgrade from simple retrieval to agent-like chat in LlamaIndex)
- [x] Revisit Phase 5 assumption ("simple retrieval only") and explicitly allow agent/chat orchestration in LlamaIndex for QA over documents.
- [x] Create new module for chat orchestration (for example `agent.py` or `chat_engine.py`) that separates:
1) retrieval of source nodes
2) answer synthesis with explicit prompt
3) response formatting with sources/metadata
- [x] Implement a LlamaIndex-based chat feature (agent-like behavior) using framework-native primitives (chat engine / agent workflow / tool-calling approach supported by installed version), so the model can iteratively query retrieval tools when needed.
- [x] Add a retrieval tool wrapper for document search that returns structured snippets (`filename`, `file_path`, `page_label/page`, `chunk_number`, content preview, score) instead of raw text only.
- [x] Add a grounded answer prompt/template for the LlamaIndex chat path with rules:
- answer only from retrieved context
- if information is missing, say so directly
- prefer exact dates/years and quote filenames/pages where possible
- avoid generic claims not supported by sources
- [x] Add response mode that returns both:
- final answer text
- list of retrieved sources (content snippet + metadata + score)
- [x] Add post-processing for retrieved nodes before synthesis:
- deduplicate near-identical chunks
- drop empty / near-empty chunks
- optionally filter low-information chunks (headers/footers)
- [x] Add optional metadata-aware retrieval improvements (years/events/keywords) parity with LangChain approach (folder near current folder), if feasible in the chosen LlamaIndex primitives.
- [x] Update `server.py` endpoint to use the new agent-like chat path (keep simple retrieval endpoint available as fallback or debug mode).
# Phase 13 (make wrapper around OpenAI llamaindex class or install library to help use any openai compatible models)
Inside config.py file object is being created for using as Chat Model. Unfortunately, it allows only "supported" models as value for "model" argument in constructor.
- [x] Search online for library or plugin for llamainde, that fixes the OpenAI class behaviour or provides replacement, that will allow any models. If found, install it in project venv, update then requirements.txt and replace usage of OpenAI to new one in the code.
- [x] Fallback option: create wrapper of OpenAI class, that will inherit and replace methods/features that check for "registered" models. Use the replacement then in the code.

View File

@@ -0,0 +1,335 @@
"""
Agent-like chat orchestration for grounded QA over documents using LlamaIndex.
This module separates:
1) retrieval of source nodes
2) answer synthesis with an explicit grounded prompt
3) response formatting with sources/metadata
"""
from __future__ import annotations
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Any, Iterable, List
from llama_index.core import PromptTemplate
from llama_index.core.agent import AgentWorkflow
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import FunctionTool
from loguru import logger
from config import get_llm_model, setup_global_models
from vector_storage import get_vector_store_and_index
GROUNDED_SYNTHESIS_PROMPT = PromptTemplate(
"""You are a grounded QA assistant for a document knowledge base.
Rules:
- Answer ONLY from the provided context snippets.
- If the context is insufficient, say directly that the information is not available in the retrieved sources.
- Prefer exact dates/years and cite filenames/pages when possible.
- Avoid generic claims that are not supported by the snippets.
- If multiple sources disagree, mention the conflict briefly.
User question:
{query}
Optional draft from tool-using agent (may be incomplete):
{agent_draft}
Context snippets (JSON):
{context_json}
Return a concise answer with source mentions in plain text.
"""
)
@dataclass
class RetrievalSnippet:
content: str
score: float | None
metadata: dict[str, Any]
def to_api_dict(self) -> dict[str, Any]:
metadata = self.metadata or {}
content_preview = self.content.strip().replace("\n", " ")
if len(content_preview) > 400:
content_preview = content_preview[:400] + "..."
return {
"content_snippet": content_preview,
"score": self.score,
"metadata": {
"filename": metadata.get("filename", "unknown"),
"file_path": metadata.get("file_path", "unknown"),
"page_label": metadata.get("page_label", metadata.get("page", "unknown")),
"chunk_number": metadata.get("chunk_number", "unknown"),
"total_chunks": metadata.get("total_chunks", "unknown"),
"file_type": metadata.get("file_type", "unknown"),
"processed_at": metadata.get("processed_at", "unknown"),
},
}
def _normalize_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _extract_years(query: str) -> list[int]:
years = []
for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query):
try:
years.append(int(match))
except ValueError:
continue
return sorted(set(years))
def _extract_keywords(query: str) -> list[str]:
words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower())
stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"}
keywords = [w for w in words if w not in stop]
return list(dict.fromkeys(keywords))[:6]
def _node_text(node_with_score: NodeWithScore) -> str:
node = getattr(node_with_score, "node", None)
if node is None:
return _normalize_text(getattr(node_with_score, "text", ""))
return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", "")))
def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]:
node = getattr(node_with_score, "node", None)
if node is None:
return dict(getattr(node_with_score, "metadata", {}) or {})
return dict(getattr(node, "metadata", {}) or {})
def _similarity_key(text: str) -> str:
text = re.sub(r"\s+", " ", text.strip().lower())
return text[:250]
def _is_low_information_chunk(text: str) -> bool:
compact = " ".join(text.split())
if len(compact) < 20:
return True
alpha_chars = sum(ch.isalpha() for ch in compact)
if alpha_chars < 8:
return True
# Repetitive headers/footers often contain too few unique tokens.
tokens = [t for t in re.split(r"\W+", compact.lower()) if t]
return len(tokens) >= 3 and len(set(tokens)) <= 2
def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]:
"""
Post-process retrieved nodes:
- drop empty / near-empty chunks
- optionally drop low-information chunks
- deduplicate near-identical chunks
"""
filtered: list[NodeWithScore] = []
seen = set()
for nws in nodes:
text = _node_text(nws)
if not text or not text.strip():
continue
if _is_low_information_chunk(text):
continue
meta = _node_metadata(nws)
dedup_key = (
meta.get("file_path", ""),
meta.get("page_label", meta.get("page", "")),
meta.get("chunk_number", ""),
_similarity_key(text),
)
if dedup_key in seen:
continue
seen.add(dedup_key)
filtered.append(nws)
return filtered
def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]:
"""
Retrieve source nodes with light metadata-aware query expansion (years/keywords).
"""
setup_global_models()
_, index = get_vector_store_and_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k))
queries = [query]
years = _extract_years(query)
keywords = _extract_keywords(query)
queries.extend(str(year) for year in years)
queries.extend(keywords[:3])
collected: list[NodeWithScore] = []
for q in list(dict.fromkeys([q for q in queries if q and q.strip()])):
try:
logger.info(f"Retrieving nodes for query variant: {q}")
collected.extend(retriever.retrieve(q))
except Exception as e:
logger.warning(f"Retrieval variant failed for '{q}': {e}")
processed = post_process_nodes(collected)
processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0)))
return processed[:top_k]
def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]:
"""Return structured snippets for tools/API responses."""
snippets: list[dict[str, Any]] = []
for nws in nodes:
snippet = RetrievalSnippet(
content=_node_text(nws),
score=getattr(nws, "score", None),
metadata=_node_metadata(nws),
)
snippets.append(snippet.to_api_dict())
return snippets
def retrieval_tool_search(query: str, top_k: int = 5) -> str:
"""
Tool wrapper for document retrieval returning structured JSON snippets.
"""
nodes = retrieve_source_nodes(query=query, top_k=top_k)
snippets = build_structured_snippets(nodes)
payload = {
"query": query,
"count": len(snippets),
"snippets": snippets,
}
return json.dumps(payload, ensure_ascii=False)
def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str:
"""
Answer synthesis from retrieved sources using an explicit grounded prompt.
"""
llm = get_llm_model()
context_json = json.dumps(sources, ensure_ascii=False, indent=2)
prompt = GROUNDED_SYNTHESIS_PROMPT.format(
query=query,
agent_draft=agent_draft or "(none)",
context_json=context_json,
)
logger.info("Synthesizing grounded answer from retrieved sources")
# Prefer chat API for chat-capable models; fallback to completion if unavailable.
try:
if hasattr(llm, "chat"):
chat_response = llm.chat(
[
ChatMessage(role=MessageRole.SYSTEM, content="You answer with grounded citations only."),
ChatMessage(role=MessageRole.USER, content=prompt),
]
)
return _normalize_text(getattr(chat_response, "message", chat_response).content)
except Exception as e:
logger.warning(f"LLM chat synthesis failed, falling back to completion: {e}")
response = llm.complete(prompt)
return _normalize_text(getattr(response, "text", response))
def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]:
"""
Response formatting with answer + structured sources.
"""
return {
"query": query,
"answer": final_answer,
"sources": sources,
"mode": mode,
}
def _extract_agent_result_text(result: Any) -> str:
if result is None:
return ""
if hasattr(result, "response"):
return _normalize_text(getattr(result, "response"))
return _normalize_text(result)
async def _run_agent_workflow_async(query: str, top_k: int) -> str:
"""
Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text.
"""
setup_global_models()
llm = get_llm_model()
tool = FunctionTool.from_defaults(
fn=retrieval_tool_search,
name="document_search",
description=(
"Search documents and return structured snippets as JSON with fields: "
"filename, file_path, page_label/page, chunk_number, content_snippet, score. "
"Use this before answering factual questions about documents."
),
)
system_prompt = (
"You are a QA agent over a document store. Use the document_search tool when factual "
"information may come from documents. If tool output is insufficient, say so."
)
agent = AgentWorkflow.from_tools_or_functions(
[tool],
llm=llm,
system_prompt=system_prompt,
verbose=False,
)
handler = agent.run(user_msg=query, max_iterations=4)
result = await handler
return _extract_agent_result_text(result)
def run_agent_workflow(query: str, top_k: int = 5) -> str:
"""
Synchronous wrapper around the async LlamaIndex agent workflow.
"""
try:
return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k))
except RuntimeError:
# Fallback if already in an event loop; skip agent workflow in that case.
logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis")
return ""
except Exception as e:
logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}")
return ""
def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]:
"""
Full chat orchestration entrypoint:
- optionally run agent workflow (tool-calling)
- retrieve + post-process sources
- synthesize grounded answer
- format response
"""
logger.info(f"Starting chat orchestration for query: {query[:80]}")
agent_draft = ""
mode = "retrieval+synthesis"
if use_agent:
agent_draft = run_agent_workflow(query=query, top_k=top_k)
mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis"
nodes = retrieve_source_nodes(query=query, top_k=top_k)
sources = build_structured_snippets(nodes)
final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft)
return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)

View File

@@ -93,18 +93,43 @@ def get_llm_model():
return llm
elif strategy == "openai":
from llama_index.llms.openai import OpenAI
from helpers.openai_compatible_llm import OpenAICompatibleLLM
openai_chat_url = os.getenv("OPENAI_CHAT_URL", "https://api.openai.com/v1")
openai_chat_key = os.getenv("OPENAI_CHAT_KEY", "dummy_key_for_template")
openai_chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
openai_chat_temperature = float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.1"))
openai_chat_max_tokens_env = os.getenv("OPENAI_CHAT_MAX_TOKENS", "").strip()
openai_chat_max_tokens = (
int(openai_chat_max_tokens_env) if openai_chat_max_tokens_env else 1024
)
openai_reasoning_effort = (
os.getenv("OPENAI_CHAT_REASONING_EFFORT", "").strip() or None
)
openai_is_fc_model = (
os.getenv("OPENAI_CHAT_IS_FUNCTION_CALLING_MODEL", "false").lower()
== "true"
)
# Set the API key in environment for OpenAI
os.environ["OPENAI_API_KEY"] = openai_chat_key
logger.info(f"Initializing OpenAI chat model: {openai_chat_model}")
logger.info(
f"Initializing OpenAI-compatible chat model: {openai_chat_model} "
f"(base={openai_chat_url}, max_tokens={openai_chat_max_tokens}, "
f"reasoning_effort={openai_reasoning_effort}, function_calling={openai_is_fc_model})"
)
llm = OpenAI(model=openai_chat_model, api_base=openai_chat_url)
llm = OpenAICompatibleLLM(
model=openai_chat_model,
api_base=openai_chat_url,
api_key=openai_chat_key,
temperature=openai_chat_temperature,
max_tokens=openai_chat_max_tokens,
reasoning_effort=openai_reasoning_effort,
timeout=120.0,
is_function_calling_model=openai_is_fc_model,
)
return llm

View File

@@ -0,0 +1,40 @@
"""
OpenAI-compatible LLM wrapper for LlamaIndex chat models.
This wrapper is used as a fallback/replacement for strict OpenAI model validation paths.
It relies on LlamaIndex `OpenAILike`, which supports arbitrary model names for
OpenAI-compatible endpoints.
"""
from llama_index.llms.openai_like import OpenAILike
class OpenAICompatibleLLM(OpenAILike):
"""
Thin wrapper over OpenAILike with chat-friendly defaults.
"""
def __init__(
self,
model: str,
api_base: str,
api_key: str,
temperature: float = 0.1,
timeout: float = 120.0,
max_tokens: int | None = None,
reasoning_effort: str | None = None,
is_function_calling_model: bool = False,
):
super().__init__(
model=model,
api_base=api_base,
api_key=api_key,
temperature=temperature,
timeout=timeout,
max_tokens=max_tokens,
reasoning_effort=reasoning_effort,
# Explicitly avoid "registered model only" assumptions.
is_chat_model=True,
is_function_calling_model=is_function_calling_model,
should_use_structured_outputs=False,
)

View File

@@ -74,6 +74,7 @@ llama-index-indices-managed-llama-cloud==0.9.4
llama-index-instrumentation==0.4.2
llama-index-llms-ollama==0.9.1
llama-index-llms-openai==0.6.17
llama-index-llms-openai-like==0.6.0
llama-index-readers-file==0.5.6
llama-index-readers-llama-parse==0.5.1
llama-index-vector-stores-qdrant==0.9.1

View File

@@ -3,14 +3,16 @@
HTTP API server for querying the vector storage via the existing retrieval pipeline.
"""
from pathlib import Path
import sys
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from chat_engine import chat_with_documents
from retrieval import initialize_retriever
load_dotenv()
@@ -41,10 +43,28 @@ setup_logging()
app = FastAPI(title="LlamaIndex RAG API", version="1.0.0")
origins = [
"*",
]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # In production, configure this properly
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_private_network=True,
)
class TestQueryRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text")
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
mode: str = Field(
"agent",
description="agent (Phase 12 default) or retrieval (fallback/debug)",
)
class SourceItem(BaseModel):
@@ -57,6 +77,9 @@ class TestQueryResponse(BaseModel):
query: str
response: str
sources: list[SourceItem]
mode: str | None = None
error: bool
success: bool
@app.get("/health")
@@ -73,9 +96,43 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
if not query:
raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
logger.info(f"Received /api/test-query request (top_k={payload.top_k})")
logger.info(
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
)
try:
if payload.mode.lower() in {"agent", "chat"}:
chat_result = chat_with_documents(
query=query,
top_k=payload.top_k,
use_agent=True,
)
sources = [
SourceItem(
content=str(src.get("content_snippet", "")),
score=src.get("score"),
metadata=src.get("metadata", {}) or {},
)
for src in chat_result.get("sources", [])
]
logger.info(
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=str(chat_result.get("answer", "")),
sources=sources,
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
error=False,
success=True,
)
if payload.mode.lower() not in {"retrieval", "debug"}:
raise HTTPException(
status_code=400,
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
)
query_engine = initialize_retriever(similarity_top_k=payload.top_k)
result = query_engine.query(query)
@@ -92,9 +149,16 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
response_text = str(result)
logger.info(
f"/api/test-query completed successfully (sources={len(sources)})"
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=response_text,
sources=sources,
mode="retrieval",
error=False,
success=True,
)
return TestQueryResponse(query=query, response=response_text, sources=sources)
except HTTPException:
raise
@@ -106,4 +170,4 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)
uvicorn.run("server:app", host="0.0.0.0", port=8334, reload=False)