Files
rag-solution/services/rag/llamaindex/chat_engine.py

322 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent-like chat orchestration for grounded QA over documents using LlamaIndex.
This module separates:
1) retrieval of source nodes
2) answer synthesis with an explicit grounded prompt
3) response formatting with sources/metadata
"""
from __future__ import annotations
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Any, Iterable, List
from llama_index.core import PromptTemplate
from llama_index.core.agent import AgentWorkflow
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import FunctionTool
from loguru import logger
from config import get_llm_model, setup_global_models
from vector_storage import get_vector_store_and_index
GROUNDED_SYNTHESIS_PROMPT = PromptTemplate(
"""You are a grounded QA assistant for a document knowledge base.
Rules:
- Answer ONLY from the provided context snippets.
- If the context is insufficient, say directly that the information is not available in the retrieved sources.
- Prefer exact dates/years and cite filenames/pages when possible.
- Avoid generic claims that are not supported by the snippets.
- If multiple sources disagree, mention the conflict briefly.
User question:
{query}
Optional draft from tool-using agent (may be incomplete):
{agent_draft}
Context snippets (JSON):
{context_json}
Return a concise answer with source mentions in plain text.
"""
)
@dataclass
class RetrievalSnippet:
content: str
score: float | None
metadata: dict[str, Any]
def to_api_dict(self) -> dict[str, Any]:
metadata = self.metadata or {}
content_preview = self.content.strip().replace("\n", " ")
if len(content_preview) > 400:
content_preview = content_preview[:400] + "..."
return {
"content_snippet": content_preview,
"score": self.score,
"metadata": {
"filename": metadata.get("filename", "unknown"),
"file_path": metadata.get("file_path", "unknown"),
"page_label": metadata.get("page_label", metadata.get("page", "unknown")),
"chunk_number": metadata.get("chunk_number", "unknown"),
"total_chunks": metadata.get("total_chunks", "unknown"),
"file_type": metadata.get("file_type", "unknown"),
"processed_at": metadata.get("processed_at", "unknown"),
},
}
def _normalize_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _extract_years(query: str) -> list[int]:
years = []
for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query):
try:
years.append(int(match))
except ValueError:
continue
return sorted(set(years))
def _extract_keywords(query: str) -> list[str]:
words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower())
stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"}
keywords = [w for w in words if w not in stop]
return list(dict.fromkeys(keywords))[:6]
def _node_text(node_with_score: NodeWithScore) -> str:
node = getattr(node_with_score, "node", None)
if node is None:
return _normalize_text(getattr(node_with_score, "text", ""))
return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", "")))
def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]:
node = getattr(node_with_score, "node", None)
if node is None:
return dict(getattr(node_with_score, "metadata", {}) or {})
return dict(getattr(node, "metadata", {}) or {})
def _similarity_key(text: str) -> str:
text = re.sub(r"\s+", " ", text.strip().lower())
return text[:250]
def _is_low_information_chunk(text: str) -> bool:
compact = " ".join(text.split())
if len(compact) < 20:
return True
alpha_chars = sum(ch.isalpha() for ch in compact)
if alpha_chars < 8:
return True
# Repetitive headers/footers often contain too few unique tokens.
tokens = [t for t in re.split(r"\W+", compact.lower()) if t]
return len(tokens) >= 3 and len(set(tokens)) <= 2
def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]:
"""
Post-process retrieved nodes:
- drop empty / near-empty chunks
- optionally drop low-information chunks
- deduplicate near-identical chunks
"""
filtered: list[NodeWithScore] = []
seen = set()
for nws in nodes:
text = _node_text(nws)
if not text or not text.strip():
continue
if _is_low_information_chunk(text):
continue
meta = _node_metadata(nws)
dedup_key = (
meta.get("file_path", ""),
meta.get("page_label", meta.get("page", "")),
meta.get("chunk_number", ""),
_similarity_key(text),
)
if dedup_key in seen:
continue
seen.add(dedup_key)
filtered.append(nws)
return filtered
def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]:
"""
Retrieve source nodes with light metadata-aware query expansion (years/keywords).
"""
setup_global_models()
_, index = get_vector_store_and_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k))
queries = [query]
years = _extract_years(query)
keywords = _extract_keywords(query)
queries.extend(str(year) for year in years)
queries.extend(keywords[:3])
collected: list[NodeWithScore] = []
for q in list(dict.fromkeys([q for q in queries if q and q.strip()])):
try:
logger.info(f"Retrieving nodes for query variant: {q}")
collected.extend(retriever.retrieve(q))
except Exception as e:
logger.warning(f"Retrieval variant failed for '{q}': {e}")
processed = post_process_nodes(collected)
processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0)))
return processed[:top_k]
def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]:
"""Return structured snippets for tools/API responses."""
snippets: list[dict[str, Any]] = []
for nws in nodes:
snippet = RetrievalSnippet(
content=_node_text(nws),
score=getattr(nws, "score", None),
metadata=_node_metadata(nws),
)
snippets.append(snippet.to_api_dict())
return snippets
def retrieval_tool_search(query: str, top_k: int = 5) -> str:
"""
Tool wrapper for document retrieval returning structured JSON snippets.
"""
nodes = retrieve_source_nodes(query=query, top_k=top_k)
snippets = build_structured_snippets(nodes)
payload = {
"query": query,
"count": len(snippets),
"snippets": snippets,
}
return json.dumps(payload, ensure_ascii=False)
def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str:
"""
Answer synthesis from retrieved sources using an explicit grounded prompt.
"""
llm = get_llm_model()
context_json = json.dumps(sources, ensure_ascii=False, indent=2)
prompt = GROUNDED_SYNTHESIS_PROMPT.format(
query=query,
agent_draft=agent_draft or "(none)",
context_json=context_json,
)
logger.info("Synthesizing grounded answer from retrieved sources")
response = llm.complete(prompt)
return _normalize_text(getattr(response, "text", response))
def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]:
"""
Response formatting with answer + structured sources.
"""
return {
"query": query,
"answer": final_answer,
"sources": sources,
"mode": mode,
}
def _extract_agent_result_text(result: Any) -> str:
if result is None:
return ""
if hasattr(result, "response"):
return _normalize_text(getattr(result, "response"))
return _normalize_text(result)
async def _run_agent_workflow_async(query: str, top_k: int) -> str:
"""
Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text.
"""
setup_global_models()
llm = get_llm_model()
tool = FunctionTool.from_defaults(
fn=retrieval_tool_search,
name="document_search",
description=(
"Search documents and return structured snippets as JSON with fields: "
"filename, file_path, page_label/page, chunk_number, content_snippet, score. "
"Use this before answering factual questions about documents."
),
)
system_prompt = (
"You are a QA agent over a document store. Use the document_search tool when factual "
"information may come from documents. If tool output is insufficient, say so."
)
agent = AgentWorkflow.from_tools_or_functions(
[tool],
llm=llm,
system_prompt=system_prompt,
verbose=False,
)
handler = agent.run(user_msg=query, max_iterations=4)
result = await handler
return _extract_agent_result_text(result)
def run_agent_workflow(query: str, top_k: int = 5) -> str:
"""
Synchronous wrapper around the async LlamaIndex agent workflow.
"""
try:
return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k))
except RuntimeError:
# Fallback if already in an event loop; skip agent workflow in that case.
logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis")
return ""
except Exception as e:
logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}")
return ""
def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]:
"""
Full chat orchestration entrypoint:
- optionally run agent workflow (tool-calling)
- retrieve + post-process sources
- synthesize grounded answer
- format response
"""
logger.info(f"Starting chat orchestration for query: {query[:80]}")
agent_draft = ""
mode = "retrieval+synthesis"
if use_agent:
agent_draft = run_agent_workflow(query=query, top_k=top_k)
mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis"
nodes = retrieve_source_nodes(query=query, top_k=top_k)
sources = build_structured_snippets(nodes)
final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft)
return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)