From 6b3fa1cfaab5676b1bead6511780c52fe1ec8c92 Mon Sep 17 00:00:00 2001 From: idchlife Date: Thu, 26 Feb 2026 19:02:05 +0300 Subject: [PATCH] more sophisticated chat like retrieval for llamaindex --- services/rag/llamaindex/PLANNING.md | 24 ++ services/rag/llamaindex/chat_engine.py | 321 +++++++++++++++++++++++++ services/rag/llamaindex/server.py | 47 +++- 3 files changed, 390 insertions(+), 2 deletions(-) create mode 100644 services/rag/llamaindex/chat_engine.py diff --git a/services/rag/llamaindex/PLANNING.md b/services/rag/llamaindex/PLANNING.md index 052ff14..65cad9a 100644 --- a/services/rag/llamaindex/PLANNING.md +++ b/services/rag/llamaindex/PLANNING.md @@ -69,3 +69,27 @@ Chosen data folder: relatve ./../../../data - from the current folder - [x] Create file `server.py`, with web framework fastapi, for example - [x] Add POST endpoint "/api/test-query" which will use agent, and retrieve response for query, sent in JSON format, field "query" + +# Phase 12 (upgrade from simple retrieval to agent-like chat in LlamaIndex) + +- [x] Revisit Phase 5 assumption ("simple retrieval only") and explicitly allow agent/chat orchestration in LlamaIndex for QA over documents. +- [x] Create new module for chat orchestration (for example `agent.py` or `chat_engine.py`) that separates: + 1) retrieval of source nodes + 2) answer synthesis with explicit prompt + 3) response formatting with sources/metadata +- [x] Implement a LlamaIndex-based chat feature (agent-like behavior) using framework-native primitives (chat engine / agent workflow / tool-calling approach supported by installed version), so the model can iteratively query retrieval tools when needed. +- [x] Add a retrieval tool wrapper for document search that returns structured snippets (`filename`, `file_path`, `page_label/page`, `chunk_number`, content preview, score) instead of raw text only. +- [x] Add a grounded answer prompt/template for the LlamaIndex chat path with rules: + - answer only from retrieved context + - if information is missing, say so directly + - prefer exact dates/years and quote filenames/pages where possible + - avoid generic claims not supported by sources +- [x] Add response mode that returns both: + - final answer text + - list of retrieved sources (content snippet + metadata + score) +- [x] Add post-processing for retrieved nodes before synthesis: + - deduplicate near-identical chunks + - drop empty / near-empty chunks + - optionally filter low-information chunks (headers/footers) +- [x] Add optional metadata-aware retrieval improvements (years/events/keywords) parity with LangChain approach (folder near current folder), if feasible in the chosen LlamaIndex primitives. +- [x] Update `server.py` endpoint to use the new agent-like chat path (keep simple retrieval endpoint available as fallback or debug mode). diff --git a/services/rag/llamaindex/chat_engine.py b/services/rag/llamaindex/chat_engine.py new file mode 100644 index 0000000..162ad70 --- /dev/null +++ b/services/rag/llamaindex/chat_engine.py @@ -0,0 +1,321 @@ +""" +Agent-like chat orchestration for grounded QA over documents using LlamaIndex. + +This module separates: +1) retrieval of source nodes +2) answer synthesis with an explicit grounded prompt +3) response formatting with sources/metadata +""" + +from __future__ import annotations + +import asyncio +import json +import re +from dataclasses import dataclass +from typing import Any, Iterable, List + +from llama_index.core import PromptTemplate +from llama_index.core.agent import AgentWorkflow +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import NodeWithScore +from llama_index.core.tools import FunctionTool +from loguru import logger + +from config import get_llm_model, setup_global_models +from vector_storage import get_vector_store_and_index + + +GROUNDED_SYNTHESIS_PROMPT = PromptTemplate( + """You are a grounded QA assistant for a document knowledge base. + +Rules: +- Answer ONLY from the provided context snippets. +- If the context is insufficient, say directly that the information is not available in the retrieved sources. +- Prefer exact dates/years and cite filenames/pages when possible. +- Avoid generic claims that are not supported by the snippets. +- If multiple sources disagree, mention the conflict briefly. + +User question: +{query} + +Optional draft from tool-using agent (may be incomplete): +{agent_draft} + +Context snippets (JSON): +{context_json} + +Return a concise answer with source mentions in plain text. +""" +) + + +@dataclass +class RetrievalSnippet: + content: str + score: float | None + metadata: dict[str, Any] + + def to_api_dict(self) -> dict[str, Any]: + metadata = self.metadata or {} + content_preview = self.content.strip().replace("\n", " ") + if len(content_preview) > 400: + content_preview = content_preview[:400] + "..." + return { + "content_snippet": content_preview, + "score": self.score, + "metadata": { + "filename": metadata.get("filename", "unknown"), + "file_path": metadata.get("file_path", "unknown"), + "page_label": metadata.get("page_label", metadata.get("page", "unknown")), + "chunk_number": metadata.get("chunk_number", "unknown"), + "total_chunks": metadata.get("total_chunks", "unknown"), + "file_type": metadata.get("file_type", "unknown"), + "processed_at": metadata.get("processed_at", "unknown"), + }, + } + + +def _normalize_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value) + + +def _extract_years(query: str) -> list[int]: + years = [] + for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query): + try: + years.append(int(match)) + except ValueError: + continue + return sorted(set(years)) + + +def _extract_keywords(query: str) -> list[str]: + words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower()) + stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"} + keywords = [w for w in words if w not in stop] + return list(dict.fromkeys(keywords))[:6] + + +def _node_text(node_with_score: NodeWithScore) -> str: + node = getattr(node_with_score, "node", None) + if node is None: + return _normalize_text(getattr(node_with_score, "text", "")) + return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", ""))) + + +def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]: + node = getattr(node_with_score, "node", None) + if node is None: + return dict(getattr(node_with_score, "metadata", {}) or {}) + return dict(getattr(node, "metadata", {}) or {}) + + +def _similarity_key(text: str) -> str: + text = re.sub(r"\s+", " ", text.strip().lower()) + return text[:250] + + +def _is_low_information_chunk(text: str) -> bool: + compact = " ".join(text.split()) + if len(compact) < 20: + return True + alpha_chars = sum(ch.isalpha() for ch in compact) + if alpha_chars < 8: + return True + # Repetitive headers/footers often contain too few unique tokens. + tokens = [t for t in re.split(r"\W+", compact.lower()) if t] + return len(tokens) >= 3 and len(set(tokens)) <= 2 + + +def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]: + """ + Post-process retrieved nodes: + - drop empty / near-empty chunks + - optionally drop low-information chunks + - deduplicate near-identical chunks + """ + filtered: list[NodeWithScore] = [] + seen = set() + + for nws in nodes: + text = _node_text(nws) + if not text or not text.strip(): + continue + if _is_low_information_chunk(text): + continue + + meta = _node_metadata(nws) + dedup_key = ( + meta.get("file_path", ""), + meta.get("page_label", meta.get("page", "")), + meta.get("chunk_number", ""), + _similarity_key(text), + ) + if dedup_key in seen: + continue + seen.add(dedup_key) + filtered.append(nws) + + return filtered + + +def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]: + """ + Retrieve source nodes with light metadata-aware query expansion (years/keywords). + """ + setup_global_models() + _, index = get_vector_store_and_index() + retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k)) + + queries = [query] + years = _extract_years(query) + keywords = _extract_keywords(query) + queries.extend(str(year) for year in years) + queries.extend(keywords[:3]) + + collected: list[NodeWithScore] = [] + for q in list(dict.fromkeys([q for q in queries if q and q.strip()])): + try: + logger.info(f"Retrieving nodes for query variant: {q}") + collected.extend(retriever.retrieve(q)) + except Exception as e: + logger.warning(f"Retrieval variant failed for '{q}': {e}") + + processed = post_process_nodes(collected) + processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0))) + return processed[:top_k] + + +def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]: + """Return structured snippets for tools/API responses.""" + snippets: list[dict[str, Any]] = [] + for nws in nodes: + snippet = RetrievalSnippet( + content=_node_text(nws), + score=getattr(nws, "score", None), + metadata=_node_metadata(nws), + ) + snippets.append(snippet.to_api_dict()) + return snippets + + +def retrieval_tool_search(query: str, top_k: int = 5) -> str: + """ + Tool wrapper for document retrieval returning structured JSON snippets. + """ + nodes = retrieve_source_nodes(query=query, top_k=top_k) + snippets = build_structured_snippets(nodes) + payload = { + "query": query, + "count": len(snippets), + "snippets": snippets, + } + return json.dumps(payload, ensure_ascii=False) + + +def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str: + """ + Answer synthesis from retrieved sources using an explicit grounded prompt. + """ + llm = get_llm_model() + context_json = json.dumps(sources, ensure_ascii=False, indent=2) + prompt = GROUNDED_SYNTHESIS_PROMPT.format( + query=query, + agent_draft=agent_draft or "(none)", + context_json=context_json, + ) + logger.info("Synthesizing grounded answer from retrieved sources") + response = llm.complete(prompt) + return _normalize_text(getattr(response, "text", response)) + + +def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]: + """ + Response formatting with answer + structured sources. + """ + return { + "query": query, + "answer": final_answer, + "sources": sources, + "mode": mode, + } + + +def _extract_agent_result_text(result: Any) -> str: + if result is None: + return "" + if hasattr(result, "response"): + return _normalize_text(getattr(result, "response")) + return _normalize_text(result) + + +async def _run_agent_workflow_async(query: str, top_k: int) -> str: + """ + Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text. + """ + setup_global_models() + llm = get_llm_model() + tool = FunctionTool.from_defaults( + fn=retrieval_tool_search, + name="document_search", + description=( + "Search documents and return structured snippets as JSON with fields: " + "filename, file_path, page_label/page, chunk_number, content_snippet, score. " + "Use this before answering factual questions about documents." + ), + ) + + system_prompt = ( + "You are a QA agent over a document store. Use the document_search tool when factual " + "information may come from documents. If tool output is insufficient, say so." + ) + agent = AgentWorkflow.from_tools_or_functions( + [tool], + llm=llm, + system_prompt=system_prompt, + verbose=False, + ) + handler = agent.run(user_msg=query, max_iterations=4) + result = await handler + return _extract_agent_result_text(result) + + +def run_agent_workflow(query: str, top_k: int = 5) -> str: + """ + Synchronous wrapper around the async LlamaIndex agent workflow. + """ + try: + return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k)) + except RuntimeError: + # Fallback if already in an event loop; skip agent workflow in that case. + logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis") + return "" + except Exception as e: + logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}") + return "" + + +def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]: + """ + Full chat orchestration entrypoint: + - optionally run agent workflow (tool-calling) + - retrieve + post-process sources + - synthesize grounded answer + - format response + """ + logger.info(f"Starting chat orchestration for query: {query[:80]}") + agent_draft = "" + mode = "retrieval+synthesis" + if use_agent: + agent_draft = run_agent_workflow(query=query, top_k=top_k) + mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis" + + nodes = retrieve_source_nodes(query=query, top_k=top_k) + sources = build_structured_snippets(nodes) + final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft) + return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode) diff --git a/services/rag/llamaindex/server.py b/services/rag/llamaindex/server.py index 51530a0..d054c77 100644 --- a/services/rag/llamaindex/server.py +++ b/services/rag/llamaindex/server.py @@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from loguru import logger from pydantic import BaseModel, Field +from chat_engine import chat_with_documents from retrieval import initialize_retriever load_dotenv() @@ -55,6 +56,10 @@ app.add_middleware( class TestQueryRequest(BaseModel): query: str = Field(..., min_length=1, description="User query text") top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks") + mode: str = Field( + "agent", + description="agent (Phase 12 default) or retrieval (fallback/debug)", + ) class SourceItem(BaseModel): @@ -67,6 +72,7 @@ class TestQueryResponse(BaseModel): query: str response: str sources: list[SourceItem] + mode: str | None = None error: bool success: bool @@ -85,9 +91,43 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse: if not query: raise HTTPException(status_code=400, detail="Field 'query' must not be empty") - logger.info(f"Received /api/test-query request (top_k={payload.top_k})") + logger.info( + f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})" + ) try: + if payload.mode.lower() in {"agent", "chat"}: + chat_result = chat_with_documents( + query=query, + top_k=payload.top_k, + use_agent=True, + ) + sources = [ + SourceItem( + content=str(src.get("content_snippet", "")), + score=src.get("score"), + metadata=src.get("metadata", {}) or {}, + ) + for src in chat_result.get("sources", []) + ] + logger.info( + f"/api/test-query completed via agent-like chat path (sources={len(sources)})" + ) + return TestQueryResponse( + query=query, + response=str(chat_result.get("answer", "")), + sources=sources, + mode=str(chat_result.get("mode", "agent+retrieval+synthesis")), + error=False, + success=True, + ) + + if payload.mode.lower() not in {"retrieval", "debug"}: + raise HTTPException( + status_code=400, + detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.", + ) + query_engine = initialize_retriever(similarity_top_k=payload.top_k) result = query_engine.query(query) @@ -103,11 +143,14 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse: ) response_text = str(result) - logger.info(f"/api/test-query completed successfully (sources={len(sources)})") + logger.info( + f"/api/test-query completed via retrieval fallback (sources={len(sources)})" + ) return TestQueryResponse( query=query, response=response_text, sources=sources, + mode="retrieval", error=False, success=True, )