more sophisticated chat like retrieval for llamaindex

This commit is contained in:
2026-02-26 19:02:05 +03:00
parent 468d5fb572
commit 6b3fa1cfaa
3 changed files with 390 additions and 2 deletions

View File

@@ -69,3 +69,27 @@ Chosen data folder: relatve ./../../../data - from the current folder
- [x] Create file `server.py`, with web framework fastapi, for example - [x] Create file `server.py`, with web framework fastapi, for example
- [x] Add POST endpoint "/api/test-query" which will use agent, and retrieve response for query, sent in JSON format, field "query" - [x] Add POST endpoint "/api/test-query" which will use agent, and retrieve response for query, sent in JSON format, field "query"
# Phase 12 (upgrade from simple retrieval to agent-like chat in LlamaIndex)
- [x] Revisit Phase 5 assumption ("simple retrieval only") and explicitly allow agent/chat orchestration in LlamaIndex for QA over documents.
- [x] Create new module for chat orchestration (for example `agent.py` or `chat_engine.py`) that separates:
1) retrieval of source nodes
2) answer synthesis with explicit prompt
3) response formatting with sources/metadata
- [x] Implement a LlamaIndex-based chat feature (agent-like behavior) using framework-native primitives (chat engine / agent workflow / tool-calling approach supported by installed version), so the model can iteratively query retrieval tools when needed.
- [x] Add a retrieval tool wrapper for document search that returns structured snippets (`filename`, `file_path`, `page_label/page`, `chunk_number`, content preview, score) instead of raw text only.
- [x] Add a grounded answer prompt/template for the LlamaIndex chat path with rules:
- answer only from retrieved context
- if information is missing, say so directly
- prefer exact dates/years and quote filenames/pages where possible
- avoid generic claims not supported by sources
- [x] Add response mode that returns both:
- final answer text
- list of retrieved sources (content snippet + metadata + score)
- [x] Add post-processing for retrieved nodes before synthesis:
- deduplicate near-identical chunks
- drop empty / near-empty chunks
- optionally filter low-information chunks (headers/footers)
- [x] Add optional metadata-aware retrieval improvements (years/events/keywords) parity with LangChain approach (folder near current folder), if feasible in the chosen LlamaIndex primitives.
- [x] Update `server.py` endpoint to use the new agent-like chat path (keep simple retrieval endpoint available as fallback or debug mode).

View File

@@ -0,0 +1,321 @@
"""
Agent-like chat orchestration for grounded QA over documents using LlamaIndex.
This module separates:
1) retrieval of source nodes
2) answer synthesis with an explicit grounded prompt
3) response formatting with sources/metadata
"""
from __future__ import annotations
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Any, Iterable, List
from llama_index.core import PromptTemplate
from llama_index.core.agent import AgentWorkflow
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import FunctionTool
from loguru import logger
from config import get_llm_model, setup_global_models
from vector_storage import get_vector_store_and_index
GROUNDED_SYNTHESIS_PROMPT = PromptTemplate(
"""You are a grounded QA assistant for a document knowledge base.
Rules:
- Answer ONLY from the provided context snippets.
- If the context is insufficient, say directly that the information is not available in the retrieved sources.
- Prefer exact dates/years and cite filenames/pages when possible.
- Avoid generic claims that are not supported by the snippets.
- If multiple sources disagree, mention the conflict briefly.
User question:
{query}
Optional draft from tool-using agent (may be incomplete):
{agent_draft}
Context snippets (JSON):
{context_json}
Return a concise answer with source mentions in plain text.
"""
)
@dataclass
class RetrievalSnippet:
content: str
score: float | None
metadata: dict[str, Any]
def to_api_dict(self) -> dict[str, Any]:
metadata = self.metadata or {}
content_preview = self.content.strip().replace("\n", " ")
if len(content_preview) > 400:
content_preview = content_preview[:400] + "..."
return {
"content_snippet": content_preview,
"score": self.score,
"metadata": {
"filename": metadata.get("filename", "unknown"),
"file_path": metadata.get("file_path", "unknown"),
"page_label": metadata.get("page_label", metadata.get("page", "unknown")),
"chunk_number": metadata.get("chunk_number", "unknown"),
"total_chunks": metadata.get("total_chunks", "unknown"),
"file_type": metadata.get("file_type", "unknown"),
"processed_at": metadata.get("processed_at", "unknown"),
},
}
def _normalize_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _extract_years(query: str) -> list[int]:
years = []
for match in re.findall(r"\b(19\d{2}|20\d{2}|21\d{2})\b", query):
try:
years.append(int(match))
except ValueError:
continue
return sorted(set(years))
def _extract_keywords(query: str) -> list[str]:
words = re.findall(r"[A-Za-zА-Яа-я0-9_-]{4,}", query.lower())
stop = {"what", "when", "where", "which", "that", "this", "with", "from", "about"}
keywords = [w for w in words if w not in stop]
return list(dict.fromkeys(keywords))[:6]
def _node_text(node_with_score: NodeWithScore) -> str:
node = getattr(node_with_score, "node", None)
if node is None:
return _normalize_text(getattr(node_with_score, "text", ""))
return _normalize_text(getattr(node, "text", getattr(node_with_score, "text", "")))
def _node_metadata(node_with_score: NodeWithScore) -> dict[str, Any]:
node = getattr(node_with_score, "node", None)
if node is None:
return dict(getattr(node_with_score, "metadata", {}) or {})
return dict(getattr(node, "metadata", {}) or {})
def _similarity_key(text: str) -> str:
text = re.sub(r"\s+", " ", text.strip().lower())
return text[:250]
def _is_low_information_chunk(text: str) -> bool:
compact = " ".join(text.split())
if len(compact) < 20:
return True
alpha_chars = sum(ch.isalpha() for ch in compact)
if alpha_chars < 8:
return True
# Repetitive headers/footers often contain too few unique tokens.
tokens = [t for t in re.split(r"\W+", compact.lower()) if t]
return len(tokens) >= 3 and len(set(tokens)) <= 2
def post_process_nodes(nodes: list[NodeWithScore]) -> list[NodeWithScore]:
"""
Post-process retrieved nodes:
- drop empty / near-empty chunks
- optionally drop low-information chunks
- deduplicate near-identical chunks
"""
filtered: list[NodeWithScore] = []
seen = set()
for nws in nodes:
text = _node_text(nws)
if not text or not text.strip():
continue
if _is_low_information_chunk(text):
continue
meta = _node_metadata(nws)
dedup_key = (
meta.get("file_path", ""),
meta.get("page_label", meta.get("page", "")),
meta.get("chunk_number", ""),
_similarity_key(text),
)
if dedup_key in seen:
continue
seen.add(dedup_key)
filtered.append(nws)
return filtered
def retrieve_source_nodes(query: str, top_k: int = 5, search_multiplier: int = 3) -> list[NodeWithScore]:
"""
Retrieve source nodes with light metadata-aware query expansion (years/keywords).
"""
setup_global_models()
_, index = get_vector_store_and_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=max(top_k * search_multiplier, top_k))
queries = [query]
years = _extract_years(query)
keywords = _extract_keywords(query)
queries.extend(str(year) for year in years)
queries.extend(keywords[:3])
collected: list[NodeWithScore] = []
for q in list(dict.fromkeys([q for q in queries if q and q.strip()])):
try:
logger.info(f"Retrieving nodes for query variant: {q}")
collected.extend(retriever.retrieve(q))
except Exception as e:
logger.warning(f"Retrieval variant failed for '{q}': {e}")
processed = post_process_nodes(collected)
processed.sort(key=lambda n: (getattr(n, "score", None) is None, -(getattr(n, "score", 0.0) or 0.0)))
return processed[:top_k]
def build_structured_snippets(nodes: list[NodeWithScore]) -> list[dict[str, Any]]:
"""Return structured snippets for tools/API responses."""
snippets: list[dict[str, Any]] = []
for nws in nodes:
snippet = RetrievalSnippet(
content=_node_text(nws),
score=getattr(nws, "score", None),
metadata=_node_metadata(nws),
)
snippets.append(snippet.to_api_dict())
return snippets
def retrieval_tool_search(query: str, top_k: int = 5) -> str:
"""
Tool wrapper for document retrieval returning structured JSON snippets.
"""
nodes = retrieve_source_nodes(query=query, top_k=top_k)
snippets = build_structured_snippets(nodes)
payload = {
"query": query,
"count": len(snippets),
"snippets": snippets,
}
return json.dumps(payload, ensure_ascii=False)
def synthesize_answer(query: str, sources: list[dict[str, Any]], agent_draft: str = "") -> str:
"""
Answer synthesis from retrieved sources using an explicit grounded prompt.
"""
llm = get_llm_model()
context_json = json.dumps(sources, ensure_ascii=False, indent=2)
prompt = GROUNDED_SYNTHESIS_PROMPT.format(
query=query,
agent_draft=agent_draft or "(none)",
context_json=context_json,
)
logger.info("Synthesizing grounded answer from retrieved sources")
response = llm.complete(prompt)
return _normalize_text(getattr(response, "text", response))
def format_chat_response(query: str, final_answer: str, sources: list[dict[str, Any]], mode: str) -> dict[str, Any]:
"""
Response formatting with answer + structured sources.
"""
return {
"query": query,
"answer": final_answer,
"sources": sources,
"mode": mode,
}
def _extract_agent_result_text(result: Any) -> str:
if result is None:
return ""
if hasattr(result, "response"):
return _normalize_text(getattr(result, "response"))
return _normalize_text(result)
async def _run_agent_workflow_async(query: str, top_k: int) -> str:
"""
Run LlamaIndex AgentWorkflow with a retrieval tool. Returns agent draft answer text.
"""
setup_global_models()
llm = get_llm_model()
tool = FunctionTool.from_defaults(
fn=retrieval_tool_search,
name="document_search",
description=(
"Search documents and return structured snippets as JSON with fields: "
"filename, file_path, page_label/page, chunk_number, content_snippet, score. "
"Use this before answering factual questions about documents."
),
)
system_prompt = (
"You are a QA agent over a document store. Use the document_search tool when factual "
"information may come from documents. If tool output is insufficient, say so."
)
agent = AgentWorkflow.from_tools_or_functions(
[tool],
llm=llm,
system_prompt=system_prompt,
verbose=False,
)
handler = agent.run(user_msg=query, max_iterations=4)
result = await handler
return _extract_agent_result_text(result)
def run_agent_workflow(query: str, top_k: int = 5) -> str:
"""
Synchronous wrapper around the async LlamaIndex agent workflow.
"""
try:
return asyncio.run(_run_agent_workflow_async(query=query, top_k=top_k))
except RuntimeError:
# Fallback if already in an event loop; skip agent workflow in that case.
logger.warning("Async event loop already running; skipping agent workflow and using direct retrieval+synthesis")
return ""
except Exception as e:
logger.warning(f"Agent workflow failed, will fallback to direct retrieval+synthesis: {e}")
return ""
def chat_with_documents(query: str, top_k: int = 5, use_agent: bool = True) -> dict[str, Any]:
"""
Full chat orchestration entrypoint:
- optionally run agent workflow (tool-calling)
- retrieve + post-process sources
- synthesize grounded answer
- format response
"""
logger.info(f"Starting chat orchestration for query: {query[:80]}")
agent_draft = ""
mode = "retrieval+synthesis"
if use_agent:
agent_draft = run_agent_workflow(query=query, top_k=top_k)
mode = "agent+retrieval+synthesis" if agent_draft else "retrieval+synthesis"
nodes = retrieve_source_nodes(query=query, top_k=top_k)
sources = build_structured_snippets(nodes)
final_answer = synthesize_answer(query=query, sources=sources, agent_draft=agent_draft)
return format_chat_response(query=query, final_answer=final_answer, sources=sources, mode=mode)

View File

@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from chat_engine import chat_with_documents
from retrieval import initialize_retriever from retrieval import initialize_retriever
load_dotenv() load_dotenv()
@@ -55,6 +56,10 @@ app.add_middleware(
class TestQueryRequest(BaseModel): class TestQueryRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text") query: str = Field(..., min_length=1, description="User query text")
top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks") top_k: int = Field(5, ge=1, le=20, description="Number of retrieved chunks")
mode: str = Field(
"agent",
description="agent (Phase 12 default) or retrieval (fallback/debug)",
)
class SourceItem(BaseModel): class SourceItem(BaseModel):
@@ -67,6 +72,7 @@ class TestQueryResponse(BaseModel):
query: str query: str
response: str response: str
sources: list[SourceItem] sources: list[SourceItem]
mode: str | None = None
error: bool error: bool
success: bool success: bool
@@ -85,9 +91,43 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
if not query: if not query:
raise HTTPException(status_code=400, detail="Field 'query' must not be empty") raise HTTPException(status_code=400, detail="Field 'query' must not be empty")
logger.info(f"Received /api/test-query request (top_k={payload.top_k})") logger.info(
f"Received /api/test-query request (top_k={payload.top_k}, mode={payload.mode})"
)
try: try:
if payload.mode.lower() in {"agent", "chat"}:
chat_result = chat_with_documents(
query=query,
top_k=payload.top_k,
use_agent=True,
)
sources = [
SourceItem(
content=str(src.get("content_snippet", "")),
score=src.get("score"),
metadata=src.get("metadata", {}) or {},
)
for src in chat_result.get("sources", [])
]
logger.info(
f"/api/test-query completed via agent-like chat path (sources={len(sources)})"
)
return TestQueryResponse(
query=query,
response=str(chat_result.get("answer", "")),
sources=sources,
mode=str(chat_result.get("mode", "agent+retrieval+synthesis")),
error=False,
success=True,
)
if payload.mode.lower() not in {"retrieval", "debug"}:
raise HTTPException(
status_code=400,
detail="Unsupported mode. Use 'agent' (default) or 'retrieval'.",
)
query_engine = initialize_retriever(similarity_top_k=payload.top_k) query_engine = initialize_retriever(similarity_top_k=payload.top_k)
result = query_engine.query(query) result = query_engine.query(query)
@@ -103,11 +143,14 @@ def test_query(payload: TestQueryRequest) -> TestQueryResponse:
) )
response_text = str(result) response_text = str(result)
logger.info(f"/api/test-query completed successfully (sources={len(sources)})") logger.info(
f"/api/test-query completed via retrieval fallback (sources={len(sources)})"
)
return TestQueryResponse( return TestQueryResponse(
query=query, query=query,
response=response_text, response=response_text,
sources=sources, sources=sources,
mode="retrieval",
error=False, error=False,
success=True, success=True,
) )