322 lines
11 KiB
Python
322 lines
11 KiB
Python
|
|
"""
|
|||
|
|
Agent-like chat orchestration for grounded QA over documents using LlamaIndex.
|
|||
|
|
|
|||
|
|
This module separates:
|
|||
|
|
1) retrieval of source nodes
|
|||
|
|
2) answer synthesis with an explicit grounded prompt
|
|||
|
|
3) response formatting with sources/metadata
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import re
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, Iterable, List
|
|||
|
|
|
|||
|
|
from llama_index.core import PromptTemplate
|
|||
|
|
from llama_index.core.agent import AgentWorkflow
|
|||
|
|
from llama_index.core.retrievers import VectorIndexRetriever
|
|||
|
|
from llama_index.core.schema import NodeWithScore
|
|||
|
|
from llama_index.core.tools import FunctionTool
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from config import get_llm_model, setup_global_models
|
|||
|
|
from vector_storage import get_vector_store_and_index
|
|||
|
|
|
|||
|
|
|
|||
|
|
GROUNDED_SYNTHESIS_PROMPT = PromptTemplate(
|
|||
|
|
"""You are a grounded QA assistant for a document knowledge base.
|
|||
|
|
|
|||
|
|
Rules:
|
|||
|
|
- Answer ONLY from the provided context snippets.
|
|||
|
|
- If the context is insufficient, say directly that the information is not available in the retrieved sources.
|
|||
|
|
- Prefer exact dates/years and cite filenames/pages when possible.
|
|||
|
|
- Avoid generic claims that are not supported by the snippets.
|
|||
|
|
- If multiple sources disagree, mention the conflict briefly.
|
|||
|
|
|
|||
|
|
User question:
|
|||
|
|
{query}
|
|||
|
|
|
|||
|
|
Optional draft from tool-using agent (may be incomplete):
|
|||
|
|
{agent_draft}
|
|||
|
|
|
|||
|
|
Context snippets (JSON):
|
|||
|
|
{context_json}
|
|||
|
|
|
|||
|
|
Return a concise answer with source mentions in plain text.
|
|||
|
|
"""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RetrievalSnippet:
|
|||
|
|
content: str
|
|||
|
|
score: float | None
|
|||
|
|
metadata: dict[str, Any]
|
|||
|
|
|
|||
|
|
def to_api_dict(self) -> dict[str, Any]:
|
|||
|
|
metadata = self.metadata or {}
|
|||
|
|
content_preview = self.content.strip().replace("\n", " ")
|
|||
|
|
if len(content_preview) > 400:
|
|||
|
|
content_preview = content_preview[:400] + "..."
|
|||
|
|
return {
|
|||
|
|
"content_snippet": content_preview,
|
|||
|
|
"score": self.score,
|
|||
|
|
"metadata": {
|
|||
|
|
"filename": metadata.get("filename", "unknown"),
|
|||
|
|
"file_path": metadata.get("file_path", "unknown"),
|
|||
|
|
"page_label": metadata.get("page_label", metadata.get("page", "unknown")),
|
|||
|
|
"chunk_number": metadata.get("chunk_number", "unknown"),
|
|||
|
|
"total_chunks": metadata.get("total_chunks", "unknown"),
|
|||
|
|
"file_type": metadata.get("file_type", "unknown"),
|
|||
|
|
"processed_at": metadata.get("processed_at", "unknown"),
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _normalize_text(value: Any) -> str:
|
|||
|
|
if value is None:
|
|||
|
|
return ""
|
|||
|
|
if isinstance(value, bytes):
|
|||
|
|
return value.decode("utf-8", errors="replace")
|
|||
|
|
return str(value)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_years(query: str) -> list[int]:
|
|||
|
|
years = []
|
|||
|
|
for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query):
|
|||
|
|
try:
|
|||
|
|
years.append(int(match))
|
|||
|
|
except ValueError:
|
|||
|
|
continue
|
|||
|
|
return sorted(set(years))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_keywords(query: str) -> list[str]:
|
|||
|
|
words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower())
|
|||
|
|
stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"}
|
|||
|
|
keywords = [w for w in words if w not in stop]
|
|||
|
|
return list(dict.fromkeys(keywords))[:6]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _node_text(node_with_score: NodeWithScore) -> str:
|
|||
|
|
node = getattr(node_with_score, "node", None)
|
|||
|
|
if node is None:
|
|||
|
|
return _normalize_text(getattr(node_with_score, "text", ""))
|
|||
|
|
return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", "")))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]:
|
|||
|
|
node = getattr(node_with_score, "node", None)
|
|||
|
|
if node is None:
|
|||
|
|
return dict(getattr(node_with_score, "metadata", {}) or {})
|
|||
|
|
return dict(getattr(node, "metadata", {}) or {})
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _similarity_key(text: str) -> str:
|
|||
|
|
text = re.sub(r"\s+", " ", text.strip().lower())
|
|||
|
|
return text[:250]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _is_low_information_chunk(text: str) -> bool:
|
|||
|
|
compact = " ".join(text.split())
|
|||
|
|
if len(compact) < 20:
|
|||
|
|
return True
|
|||
|
|
alpha_chars = sum(ch.isalpha() for ch in compact)
|
|||
|
|
if alpha_chars < 8:
|
|||
|
|
return True
|
|||
|
|
# Repetitive headers/footers often contain too few unique tokens.
|
|||
|
|
tokens = [t for t in re.split(r"\W+", compact.lower()) if t]
|
|||
|
|
return len(tokens) >= 3 and len(set(tokens)) <= 2
|
|||
|
|
|
|||
|
|
|
|||
|
|
def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]:
|
|||
|
|
"""
|
|||
|
|
Post-process retrieved nodes:
|
|||
|
|
- drop empty / near-empty chunks
|
|||
|
|
- optionally drop low-information chunks
|
|||
|
|
- deduplicate near-identical chunks
|
|||
|
|
"""
|
|||
|
|
filtered: list[NodeWithScore] = []
|
|||
|
|
seen = set()
|
|||
|
|
|
|||
|
|
for nws in nodes:
|
|||
|
|
text = _node_text(nws)
|
|||
|
|
if not text or not text.strip():
|
|||
|
|
continue
|
|||
|
|
if _is_low_information_chunk(text):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
meta = _node_metadata(nws)
|
|||
|
|
dedup_key = (
|
|||
|
|
meta.get("file_path", ""),
|
|||
|
|
meta.get("page_label", meta.get("page", "")),
|
|||
|
|
meta.get("chunk_number", ""),
|
|||
|
|
_similarity_key(text),
|
|||
|
|
)
|
|||
|
|
if dedup_key in seen:
|
|||
|
|
continue
|
|||
|
|
seen.add(dedup_key)
|
|||
|
|
filtered.append(nws)
|
|||
|
|
|
|||
|
|
return filtered
|
|||
|
|
|
|||
|
|
|
|||
|
|
def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]:
|
|||
|
|
"""
|
|||
|
|
Retrieve source nodes with light metadata-aware query expansion (years/keywords).
|
|||
|
|
"""
|
|||
|
|
setup_global_models()
|
|||
|
|
_, index = get_vector_store_and_index()
|
|||
|
|
retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k))
|
|||
|
|
|
|||
|
|
queries = [query]
|
|||
|
|
years = _extract_years(query)
|
|||
|
|
keywords = _extract_keywords(query)
|
|||
|
|
queries.extend(str(year) for year in years)
|
|||
|
|
queries.extend(keywords[:3])
|
|||
|
|
|
|||
|
|
collected: list[NodeWithScore] = []
|
|||
|
|
for q in list(dict.fromkeys([q for q in queries if q and q.strip()])):
|
|||
|
|
try:
|
|||
|
|
logger.info(f"Retrieving nodes for query variant: {q}")
|
|||
|
|
collected.extend(retriever.retrieve(q))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"Retrieval variant failed for '{q}': {e}")
|
|||
|
|
|
|||
|
|
processed = post_process_nodes(collected)
|
|||
|
|
processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0)))
|
|||
|
|
return processed[:top_k]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]:
|
|||
|
|
"""Return structured snippets for tools/API responses."""
|
|||
|
|
snippets: list[dict[str, Any]] = []
|
|||
|
|
for nws in nodes:
|
|||
|
|
snippet = RetrievalSnippet(
|
|||
|
|
content=_node_text(nws),
|
|||
|
|
score=getattr(nws, "score", None),
|
|||
|
|
metadata=_node_metadata(nws),
|
|||
|
|
)
|
|||
|
|
snippets.append(snippet.to_api_dict())
|
|||
|
|
return snippets
|
|||
|
|
|
|||
|
|
|
|||
|
|
def retrieval_tool_search(query: str, top_k: int = 5) -> str:
|
|||
|
|
"""
|
|||
|
|
Tool wrapper for document retrieval returning structured JSON snippets.
|
|||
|
|
"""
|
|||
|
|
nodes = retrieve_source_nodes(query=query, top_k=top_k)
|
|||
|
|
snippets = build_structured_snippets(nodes)
|
|||
|
|
payload = {
|
|||
|
|
"query": query,
|
|||
|
|
"count": len(snippets),
|
|||
|
|
"snippets": snippets,
|
|||
|
|
}
|
|||
|
|
return json.dumps(payload, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str:
|
|||
|
|
"""
|
|||
|
|
Answer synthesis from retrieved sources using an explicit grounded prompt.
|
|||
|
|
"""
|
|||
|
|
llm = get_llm_model()
|
|||
|
|
context_json = json.dumps(sources, ensure_ascii=False, indent=2)
|
|||
|
|
prompt = GROUNDED_SYNTHESIS_PROMPT.format(
|
|||
|
|
query=query,
|
|||
|
|
agent_draft=agent_draft or "(none)",
|
|||
|
|
context_json=context_json,
|
|||
|
|
)
|
|||
|
|
logger.info("Synthesizing grounded answer from retrieved sources")
|
|||
|
|
response = llm.complete(prompt)
|
|||
|
|
return _normalize_text(getattr(response, "text", response))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
Response formatting with answer + structured sources.
|
|||
|
|
"""
|
|||
|
|
return {
|
|||
|
|
"query": query,
|
|||
|
|
"answer": final_answer,
|
|||
|
|
"sources": sources,
|
|||
|
|
"mode": mode,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_agent_result_text(result: Any) -> str:
|
|||
|
|
if result is None:
|
|||
|
|
return ""
|
|||
|
|
if hasattr(result, "response"):
|
|||
|
|
return _normalize_text(getattr(result, "response"))
|
|||
|
|
return _normalize_text(result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _run_agent_workflow_async(query: str, top_k: int) -> str:
|
|||
|
|
"""
|
|||
|
|
Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text.
|
|||
|
|
"""
|
|||
|
|
setup_global_models()
|
|||
|
|
llm = get_llm_model()
|
|||
|
|
tool = FunctionTool.from_defaults(
|
|||
|
|
fn=retrieval_tool_search,
|
|||
|
|
name="document_search",
|
|||
|
|
description=(
|
|||
|
|
"Search documents and return structured snippets as JSON with fields: "
|
|||
|
|
"filename, file_path, page_label/page, chunk_number, content_snippet, score. "
|
|||
|
|
"Use this before answering factual questions about documents."
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
system_prompt = (
|
|||
|
|
"You are a QA agent over a document store. Use the document_search tool when factual "
|
|||
|
|
"information may come from documents. If tool output is insufficient, say so."
|
|||
|
|
)
|
|||
|
|
agent = AgentWorkflow.from_tools_or_functions(
|
|||
|
|
[tool],
|
|||
|
|
llm=llm,
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
verbose=False,
|
|||
|
|
)
|
|||
|
|
handler = agent.run(user_msg=query, max_iterations=4)
|
|||
|
|
result = await handler
|
|||
|
|
return _extract_agent_result_text(result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_agent_workflow(query: str, top_k: int = 5) -> str:
|
|||
|
|
"""
|
|||
|
|
Synchronous wrapper around the async LlamaIndex agent workflow.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k))
|
|||
|
|
except RuntimeError:
|
|||
|
|
# Fallback if already in an event loop; skip agent workflow in that case.
|
|||
|
|
logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis")
|
|||
|
|
return ""
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
Full chat orchestration entrypoint:
|
|||
|
|
- optionally run agent workflow (tool-calling)
|
|||
|
|
- retrieve + post-process sources
|
|||
|
|
- synthesize grounded answer
|
|||
|
|
- format response
|
|||
|
|
"""
|
|||
|
|
logger.info(f"Starting chat orchestration for query: {query[:80]}")
|
|||
|
|
agent_draft = ""
|
|||
|
|
mode = "retrieval+synthesis"
|
|||
|
|
if use_agent:
|
|||
|
|
agent_draft = run_agent_workflow(query=query, top_k=top_k)
|
|||
|
|
mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis"
|
|||
|
|
|
|||
|
|
nodes = retrieve_source_nodes(query=query, top_k=top_k)
|
|||
|
|
sources = build_structured_snippets(nodes)
|
|||
|
|
final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft)
|
|||
|
|
return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)
|