""" Agent-like chat orchestration for grounded QA over documents using LlamaIndex. This module separates: 1) retrieval of source nodes 2) answer synthesis with an explicit grounded prompt 3) response formatting with sources/metadata """ from __future__ import annotations import asyncio import json import re from dataclasses import dataclass from typing import Any, Iterable, List from llama_index.core import PromptTemplate from llama_index.core.agent import AgentWorkflow from llama_index.core.base.llms.types import ChatMessage, MessageRole from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.schema import NodeWithScore from llama_index.core.tools import FunctionTool from loguru import logger from config import get_llm_model, setup_global_models from vector_storage import get_vector_store_and_index GROUNDED_SYNTHESIS_PROMPT = PromptTemplate( """You are a grounded QA assistant for a document knowledge base. Rules: - Answer ONLY from the provided context snippets. - If the context is insufficient, say directly that the information is not available in the retrieved sources. - Prefer exact dates/years and cite filenames/pages when possible. - Avoid generic claims that are not supported by the snippets. - If multiple sources disagree, mention the conflict briefly. User question: {query} Optional draft from tool-using agent (may be incomplete): {agent_draft} Context snippets (JSON): {context_json} Return a concise answer with source mentions in plain text. """ ) @dataclass class RetrievalSnippet: content: str score: float | None metadata: dict[str, Any] def to_api_dict(self) -> dict[str, Any]: metadata = self.metadata or {} content_preview = self.content.strip().replace("\n", " ") if len(content_preview) > 400: content_preview = content_preview[:400] + "..." return { "content_snippet": content_preview, "score": self.score, "metadata": { "filename": metadata.get("filename", "unknown"), "file_path": metadata.get("file_path", "unknown"), "page_label": metadata.get("page_label", metadata.get("page", "unknown")), "chunk_number": metadata.get("chunk_number", "unknown"), "total_chunks": metadata.get("total_chunks", "unknown"), "file_type": metadata.get("file_type", "unknown"), "processed_at": metadata.get("processed_at", "unknown"), }, } def _normalize_text(value: Any) -> str: if value is None: return "" if isinstance(value, bytes): return value.decode("utf-8", errors="replace") return str(value) def _extract_years(query: str) -> list[int]: years = [] for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query): try: years.append(int(match)) except ValueError: continue return sorted(set(years)) def _extract_keywords(query: str) -> list[str]: words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower()) stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"} keywords = [w for w in words if w not in stop] return list(dict.fromkeys(keywords))[:6] def _node_text(node_with_score: NodeWithScore) -> str: node = getattr(node_with_score, "node", None) if node is None: return _normalize_text(getattr(node_with_score, "text", "")) return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", ""))) def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]: node = getattr(node_with_score, "node", None) if node is None: return dict(getattr(node_with_score, "metadata", {}) or {}) return dict(getattr(node, "metadata", {}) or {}) def _similarity_key(text: str) -> str: text = re.sub(r"\s+", " ", text.strip().lower()) return text[:250] def _is_low_information_chunk(text: str) -> bool: compact = " ".join(text.split()) if len(compact) < 20: return True alpha_chars = sum(ch.isalpha() for ch in compact) if alpha_chars < 8: return True # Repetitive headers/footers often contain too few unique tokens. tokens = [t for t in re.split(r"\W+", compact.lower()) if t] return len(tokens) >= 3 and len(set(tokens)) <= 2 def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]: """ Post-process retrieved nodes: - drop empty / near-empty chunks - optionally drop low-information chunks - deduplicate near-identical chunks """ filtered: list[NodeWithScore] = [] seen = set() for nws in nodes: text = _node_text(nws) if not text or not text.strip(): continue if _is_low_information_chunk(text): continue meta = _node_metadata(nws) dedup_key = ( meta.get("file_path", ""), meta.get("page_label", meta.get("page", "")), meta.get("chunk_number", ""), _similarity_key(text), ) if dedup_key in seen: continue seen.add(dedup_key) filtered.append(nws) return filtered def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]: """ Retrieve source nodes with light metadata-aware query expansion (years/keywords). """ setup_global_models() _, index = get_vector_store_and_index() retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k)) queries = [query] years = _extract_years(query) keywords = _extract_keywords(query) queries.extend(str(year) for year in years) queries.extend(keywords[:3]) collected: list[NodeWithScore] = [] for q in list(dict.fromkeys([q for q in queries if q and q.strip()])): try: logger.info(f"Retrieving nodes for query variant: {q}") collected.extend(retriever.retrieve(q)) except Exception as e: logger.warning(f"Retrieval variant failed for '{q}': {e}") processed = post_process_nodes(collected) processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0))) return processed[:top_k] def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]: """Return structured snippets for tools/API responses.""" snippets: list[dict[str, Any]] = [] for nws in nodes: snippet = RetrievalSnippet( content=_node_text(nws), score=getattr(nws, "score", None), metadata=_node_metadata(nws), ) snippets.append(snippet.to_api_dict()) return snippets def retrieval_tool_search(query: str, top_k: int = 5) -> str: """ Tool wrapper for document retrieval returning structured JSON snippets. """ nodes = retrieve_source_nodes(query=query, top_k=top_k) snippets = build_structured_snippets(nodes) payload = { "query": query, "count": len(snippets), "snippets": snippets, } return json.dumps(payload, ensure_ascii=False) def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str: """ Answer synthesis from retrieved sources using an explicit grounded prompt. """ llm = get_llm_model() context_json = json.dumps(sources, ensure_ascii=False, indent=2) prompt = GROUNDED_SYNTHESIS_PROMPT.format( query=query, agent_draft=agent_draft or "(none)", context_json=context_json, ) logger.info("Synthesizing grounded answer from retrieved sources") # Prefer chat API for chat-capable models; fallback to completion if unavailable. try: if hasattr(llm, "chat"): chat_response = llm.chat( [ ChatMessage(role=MessageRole.SYSTEM, content="You answer with grounded citations only."), ChatMessage(role=MessageRole.USER, content=prompt), ] ) return _normalize_text(getattr(chat_response, "message", chat_response).content) except Exception as e: logger.warning(f"LLM chat synthesis failed, falling back to completion: {e}") response = llm.complete(prompt) return _normalize_text(getattr(response, "text", response)) def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]: """ Response formatting with answer + structured sources. """ return { "query": query, "answer": final_answer, "sources": sources, "mode": mode, } def _extract_agent_result_text(result: Any) -> str: if result is None: return "" if hasattr(result, "response"): return _normalize_text(getattr(result, "response")) return _normalize_text(result) async def _run_agent_workflow_async(query: str, top_k: int) -> str: """ Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text. """ setup_global_models() llm = get_llm_model() tool = FunctionTool.from_defaults( fn=retrieval_tool_search, name="document_search", description=( "Search documents and return structured snippets as JSON with fields: " "filename, file_path, page_label/page, chunk_number, content_snippet, score. " "Use this before answering factual questions about documents." ), ) system_prompt = ( "You are a QA agent over a document store. Use the document_search tool when factual " "information may come from documents. If tool output is insufficient, say so." ) agent = AgentWorkflow.from_tools_or_functions( [tool], llm=llm, system_prompt=system_prompt, verbose=False, ) handler = agent.run(user_msg=query, max_iterations=4) result = await handler return _extract_agent_result_text(result) def run_agent_workflow(query: str, top_k: int = 5) -> str: """ Synchronous wrapper around the async LlamaIndex agent workflow. """ try: return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k)) except RuntimeError: # Fallback if already in an event loop; skip agent workflow in that case. logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis") return "" except Exception as e: logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}") return "" def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]: """ Full chat orchestration entrypoint: - optionally run agent workflow (tool-calling) - retrieve + post-process sources - synthesize grounded answer - format response """ logger.info(f"Starting chat orchestration for query: {query[:80]}") agent_draft = "" mode = "retrieval+synthesis" if use_agent: agent_draft = run_agent_workflow(query=query, top_k=top_k) mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis" nodes = retrieve_source_nodes(query=query, top_k=top_k) sources = build_structured_snippets(nodes) final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft) return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)