Files
rag-solution/run_rag_batch_eval.py
2026-03-11 22:30:02 +03:00

354 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import math
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import requests
LANGCHAIN_URL = "http://localhost:8331/api/test-query"
LLAMAINDEX_URL = "http://localhost:8334/api/test-query"
INPUT_MD = Path("/Users/idchlife/www/work/rag-solution/DOCUMENTS_TO_TEST.md")
OUTPUT_MD = Path("/Users/idchlife/www/work/rag-solution/RAG_EVALUATION.md")
STOPWORDS_RU = {
"что",
"кто",
"как",
"какой",
"какая",
"какие",
"ли",
"в",
"на",
"по",
"и",
"или",
"для",
"из",
"с",
"о",
"об",
"а",
"не",
"к",
"до",
"от",
"это",
"есть",
"если",
"какому",
"каком",
"году",
"материалах",
"базы",
"найди",
}
@dataclass
class QuestionItem:
section: str
question: str
langchain_answer: str = ""
llamaindex_answer: str = ""
langchain_score: float = 0.0
llamaindex_score: float = 0.0
winner: str = "Tie"
rationale: str = ""
@dataclass
class DocumentItem:
header: str
path: str
sections: list[tuple[str, list[QuestionItem]]] = field(default_factory=list)
def split_documents(md_text: str) -> tuple[list[str], list[str]]:
lines = md_text.splitlines()
header: list[str] = []
docs: list[list[str]] = []
current: list[str] | None = None
for line in lines:
if line.startswith("## "):
if current is not None:
docs.append(current)
current = [line]
else:
if current is None:
header.append(line)
else:
current.append(line)
if current is not None:
docs.append(current)
return header, ["\n".join(d) for d in docs]
def parse_document_block(block: str) -> DocumentItem:
lines = block.splitlines()
header = lines[0].strip()
m = re.search(r"`([^`]+)`", header)
doc_path = m.group(1) if m else ""
sections: list[tuple[str, list[QuestionItem]]] = []
current_section = ""
current_questions: list[QuestionItem] = []
for line in lines[1:]:
if line.startswith("### "):
if current_section:
sections.append((current_section, current_questions))
current_section = line[4:].strip()
current_questions = []
elif line.startswith("- "):
q = line[2:].strip()
if q:
current_questions.append(
QuestionItem(section=current_section, question=q)
)
if current_section:
sections.append((current_section, current_questions))
return DocumentItem(header=header, path=doc_path, sections=sections)
def tokenize(text: str) -> list[str]:
tokens = re.findall(r"[A-Za-zА-Яа-я0-9_]+", text.lower())
return [t for t in tokens if len(t) > 2 and t not in STOPWORDS_RU]
def score_answer(question: str, answer: str) -> tuple[float, dict[str, float]]:
answer = (answer or "").strip()
if not answer:
return 0.0, {"len": 0.0, "overlap": 0.0, "specificity": 0.0, "structure": 0.0}
q_tokens = set(tokenize(question))
a_tokens = tokenize(answer)
a_token_set = set(a_tokens)
overlap = (len(q_tokens & a_token_set) / max(1, len(q_tokens))) if q_tokens else 0.0
length_score = min(1.0, len(answer) / 500.0)
if len(answer) > 2800:
length_score *= 0.85
numbers = len(re.findall(r"\b\d+(?:[.,]\d+)?\b", answer))
cyr_names = len(re.findall(r"[А-ЯЁ][а-яё]{2,}(?:\s+[А-ЯЁ][а-яё]{2,}){0,2}", answer))
specificity = min(1.0, (numbers * 0.08) + (cyr_names * 0.05))
bullet_like = 1.0 if re.search(r"(^|\n)\s*(?:\d+\.|-)\s+", answer) else 0.0
sentence_count = len(re.findall(r"[.!?]", answer))
structure = min(1.0, bullet_like * 0.5 + min(0.5, sentence_count / 6.0))
refusal_penalty = 0.0
if re.search(
r"\b(ошибк|error|не удалось|failed|исключени|exception)\b", answer.lower()
):
refusal_penalty = 0.6
total = (
(0.38 * overlap)
+ (0.26 * length_score)
+ (0.20 * specificity)
+ (0.16 * structure)
- refusal_penalty
)
total = max(0.0, min(1.0, total))
return total, {
"len": length_score,
"overlap": overlap,
"specificity": specificity,
"structure": structure,
}
def compare_answers(
question: str, lc_answer: str, li_answer: str
) -> tuple[str, float, float, str]:
lc_score, lc_parts = score_answer(question, lc_answer)
li_score, li_parts = score_answer(question, li_answer)
diff = lc_score - li_score
if abs(diff) < 0.04:
winner = "Tie"
elif diff > 0:
winner = "LangChain"
else:
winner = "LlamaIndex"
rationale = (
f"LC(overlap={lc_parts['overlap']:.2f}, len={lc_parts['len']:.2f}, spec={lc_parts['specificity']:.2f}, "
f"struct={lc_parts['structure']:.2f}) vs "
f"LI(overlap={li_parts['overlap']:.2f}, len={li_parts['len']:.2f}, spec={li_parts['specificity']:.2f}, "
f"struct={li_parts['structure']:.2f})"
)
return winner, lc_score, li_score, rationale
def call_langchain(query: str, timeout: int) -> str:
payload = {"query": query}
r = requests.post(LANGCHAIN_URL, json=payload, timeout=timeout)
r.raise_for_status()
data = r.json()
return str(data.get("response", "")).strip()
def call_llamaindex(query: str, timeout: int) -> str:
payload = {"query": query, "mode": "agent"}
r = requests.post(LLAMAINDEX_URL, json=payload, timeout=timeout)
r.raise_for_status()
data = r.json()
return str(data.get("response", "")).strip()
def truncate(text: str, max_len: int = 1400) -> str:
text = (text or "").strip()
if len(text) <= max_len:
return text
return text[:max_len] + "... [truncated]"
def format_batch_summary(
batch_docs: list[DocumentItem],
batch_idx: int,
docs_in_batch: int,
) -> str:
wins = {"LangChain": 0, "LlamaIndex": 0, "Tie": 0}
scores_lc: list[float] = []
scores_li: list[float] = []
questions = 0
for doc in batch_docs:
for _, qs in doc.sections:
for q in qs:
questions += 1
wins[q.winner] += 1
scores_lc.append(q.langchain_score)
scores_li.append(q.llamaindex_score)
avg_lc = sum(scores_lc) / max(1, len(scores_lc))
avg_li = sum(scores_li) / max(1, len(scores_li))
lines = [
f"## Batch {batch_idx} Summary",
"",
f"- Documents processed in this batch: {docs_in_batch}",
f"- Questions processed in this batch: {questions}",
f"- LangChain wins: {wins['LangChain']}",
f"- LlamaIndex wins: {wins['LlamaIndex']}",
f"- Ties: {wins['Tie']}",
f"- Average score LangChain: {avg_lc:.3f}",
f"- Average score LlamaIndex: {avg_li:.3f}",
(
f"- Final ranking for this batch: "
f"{'LangChain' if avg_lc > avg_li + 0.01 else 'LlamaIndex' if avg_li > avg_lc + 0.01 else 'Tie'}"
),
"",
"_Scoring note: relative heuristic rubric (query overlap, informativeness, specificity, structure), "
"used only for side-by-side ranking in this batch._",
"",
]
return "\n".join(lines)
def render_document_with_results(doc: DocumentItem, with_results: bool) -> str:
lines = [doc.header, ""]
for section_name, questions in doc.sections:
lines.append(f"### {section_name}")
for q in questions:
lines.append(f"- {q.question}")
if with_results:
lines.append("")
lines.append(" - `LangChain Answer`:")
lines.append(f" {truncate(q.langchain_answer)}")
lines.append(" - `LlamaIndex Answer`:")
lines.append(f" {truncate(q.llamaindex_answer)}")
lines.append(
f" - `Result`: winner={q.winner}, "
f"score_langchain={q.langchain_score:.3f}, score_llamaindex={q.llamaindex_score:.3f}"
)
lines.append(f" - `Rationale`: {q.rationale}")
lines.append("")
if not with_results:
lines.append("_Batch 1 status: not processed yet._")
lines.append("")
return "\n".join(lines)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--batch-docs", type=int, default=10)
parser.add_argument("--batch-index", type=int, default=1)
parser.add_argument("--timeout", type=int, default=120)
args = parser.parse_args()
raw = INPUT_MD.read_text(encoding="utf-8")
header_lines, doc_blocks = split_documents(raw)
docs = [parse_document_block(b) for b in doc_blocks]
start = (args.batch_index - 1) * args.batch_docs
end = start + args.batch_docs
if start >= len(docs):
raise RuntimeError("Batch start is beyond available documents")
batch_docs = docs[start:end]
total_questions = sum(len(qs) for d in batch_docs for _, qs in d.sections)
q_index = 0
for doc in batch_docs:
for _, questions in doc.sections:
for q in questions:
q_index += 1
print(f"[{q_index:03d}/{total_questions}] {q.question}")
try:
t0 = time.time()
q.langchain_answer = call_langchain(
q.question, timeout=args.timeout
)
print(
f" -> LangChain OK in {time.time() - t0:.1f}s "
f"(chars={len(q.langchain_answer)})"
)
except Exception as e:
q.langchain_answer = f"ERROR: {e}"
print(f" -> LangChain ERROR: {e}")
try:
t0 = time.time()
q.llamaindex_answer = call_llamaindex(
q.question, timeout=args.timeout
)
print(
f" -> LlamaIndex OK in {time.time() - t0:.1f}s "
f"(chars={len(q.llamaindex_answer)})"
)
except Exception as e:
q.llamaindex_answer = f"ERROR: {e}"
print(f" -> LlamaIndex ERROR: {e}")
winner, lc_score, li_score, rationale = compare_answers(
q.question, q.langchain_answer, q.llamaindex_answer
)
q.winner = winner
q.langchain_score = lc_score
q.llamaindex_score = li_score
q.rationale = rationale
output_parts: list[str] = []
output_parts.extend(header_lines)
output_parts.append("")
output_parts.append(
format_batch_summary(batch_docs, args.batch_index, len(batch_docs))
)
for i, doc in enumerate(docs):
in_batch = start <= i < end
output_parts.append(render_document_with_results(doc, with_results=in_batch))
OUTPUT_MD.write_text("\n".join(output_parts).rstrip() + "\n", encoding="utf-8")
print(f"Written: {OUTPUT_MD}")
return 0
if __name__ == "__main__":
raise SystemExit(main())