|
|
import os |
|
|
import uuid |
|
|
import json |
|
|
from typing import List, Tuple, Dict, Any, Optional |
|
|
|
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
from pypdf import PdfReader |
|
|
|
|
|
|
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chroma_client = chromadb.PersistentClient( |
|
|
path="chroma_db", |
|
|
settings=Settings(anonymized_telemetry=False), |
|
|
) |
|
|
|
|
|
collection = chroma_client.get_or_create_collection( |
|
|
name="rag_docs", |
|
|
metadata={"hnsw:space": "cosine"}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_CROSS_ENCODER: Optional[CrossEncoder] = None |
|
|
CROSS_ENCODER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
|
|
|
|
|
|
|
def get_cross_encoder() -> CrossEncoder: |
|
|
global _CROSS_ENCODER |
|
|
if _CROSS_ENCODER is None: |
|
|
_CROSS_ENCODER = CrossEncoder(CROSS_ENCODER_MODEL_NAME) |
|
|
return _CROSS_ENCODER |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_openai_client(api_key: str) -> OpenAI: |
|
|
if not api_key or not api_key.strip(): |
|
|
raise ValueError("OpenAI API key is missing.") |
|
|
return OpenAI(api_key=api_key.strip()) |
|
|
|
|
|
|
|
|
def extract_text_from_file(file_path: str) -> str: |
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if ext in [".txt", ".md"]: |
|
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
|
|
return f.read() |
|
|
|
|
|
if ext == ".pdf": |
|
|
text = [] |
|
|
reader = PdfReader(file_path) |
|
|
for page in reader.pages: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text.append(page_text) |
|
|
return "\n".join(text) |
|
|
|
|
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
|
|
return f.read() |
|
|
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 800, overlap: int = 200) -> List[str]: |
|
|
text = text.replace("\r\n", "\n").replace("\r", "\n") |
|
|
chunks = [] |
|
|
start = 0 |
|
|
while start < len(text): |
|
|
end = start + chunk_size |
|
|
chunks.append(text[start:end]) |
|
|
start += chunk_size - overlap |
|
|
return chunks |
|
|
|
|
|
|
|
|
def embed_texts(texts: List[str], api_key: str) -> List[List[float]]: |
|
|
if not texts: |
|
|
return [] |
|
|
client = get_openai_client(api_key) |
|
|
resp = client.embeddings.create( |
|
|
model="text-embedding-3-small", |
|
|
input=texts, |
|
|
) |
|
|
return [d.embedding for d in resp.data] |
|
|
|
|
|
|
|
|
def add_documents_to_chroma(file_paths: List[str], api_key: str) -> str: |
|
|
if not file_paths: |
|
|
return "⚠️ No files were provided." |
|
|
|
|
|
total_chunks = 0 |
|
|
for file_path in file_paths: |
|
|
file_name = os.path.basename(file_path) |
|
|
raw_text = extract_text_from_file(file_path) |
|
|
|
|
|
if not raw_text.strip(): |
|
|
continue |
|
|
|
|
|
chunks = chunk_text(raw_text) |
|
|
embeddings = embed_texts(chunks, api_key) |
|
|
|
|
|
ids = [f"{file_name}-{uuid.uuid4()}" for _ in chunks] |
|
|
metadatas = [{"source": file_name} for _ in chunks] |
|
|
|
|
|
collection.add( |
|
|
ids=ids, |
|
|
documents=chunks, |
|
|
metadatas=metadatas, |
|
|
embeddings=embeddings, |
|
|
) |
|
|
|
|
|
total_chunks += len(chunks) |
|
|
|
|
|
count = collection.count() |
|
|
return ( |
|
|
f"✅ Indexed {len(file_paths)} file(s) into Chroma with {total_chunks} chunks. " |
|
|
f"Collection now has {count} vectors." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_expansion(user_query: str, api_key: str) -> List[str]: |
|
|
user_query = (user_query or "").strip() |
|
|
if not user_query: |
|
|
return [] |
|
|
|
|
|
client = get_openai_client(api_key) |
|
|
|
|
|
system_prompt = ( |
|
|
"You are an expert in information retrieval systems, particularly skilled in enhancing queries " |
|
|
"for document search efficiency." |
|
|
) |
|
|
|
|
|
user_prompt = f""" |
|
|
Perform query expansion on the received question by considering alternative phrasings or synonyms commonly used in document retrieval contexts. |
|
|
If there are multiple ways to phrase the user's question or common synonyms for key terms, provide several reworded versions. |
|
|
If there are acronyms or words you are not familiar with, do not try to rephrase them. |
|
|
Return at least 3 versions of the question. |
|
|
Return ONLY valid JSON with this exact shape: |
|
|
{{ |
|
|
"expanded": ["q1", "q2", "q3"] |
|
|
}} |
|
|
Question: |
|
|
{user_query} |
|
|
""".strip() |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-4.1-mini", |
|
|
temperature=0.2, |
|
|
response_format={"type": "json_object"}, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
], |
|
|
) |
|
|
|
|
|
raw = completion.choices[0].message.content |
|
|
try: |
|
|
data = json.loads(raw) |
|
|
expanded = data.get("expanded", []) |
|
|
except json.JSONDecodeError: |
|
|
expanded = [] |
|
|
|
|
|
expanded = [q.strip() for q in expanded if isinstance(q, str) and q.strip()] |
|
|
while len(expanded) < 3: |
|
|
expanded.append(user_query) |
|
|
|
|
|
|
|
|
if expanded and expanded[0] != user_query: |
|
|
expanded = [user_query] + expanded |
|
|
|
|
|
|
|
|
seen = set() |
|
|
out = [] |
|
|
for q in expanded: |
|
|
if q not in seen: |
|
|
seen.add(q) |
|
|
out.append(q) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def format_expansions_md(expanded: List[str]) -> str: |
|
|
if not expanded: |
|
|
return "*(No expansions yet — type a question and press Enter.)*" |
|
|
lines = [f"{i+1}. {q}" for i, q in enumerate(expanded)] |
|
|
return "### 🧠 Expanded Queries\n\n" + "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_answer(question: str, context: str, answer: str, api_key: str) -> dict: |
|
|
client = get_openai_client(api_key) |
|
|
|
|
|
system_prompt = ( |
|
|
"You are an impartial evaluator for a Retrieval-Augmented Generation (RAG) system. " |
|
|
"You will receive: (1) the user query, (2) the retrieved context, and (3) the model's answer. " |
|
|
"You must evaluate the answer on five metrics, each scored from 1 (very poor) to 5 (excellent):\n" |
|
|
"- Groundedness: Is the answer supported by the retrieved CONTEXT (not outside knowledge)?\n" |
|
|
"- Relevance: Does the answer address the USER QUERY directly and appropriately?\n" |
|
|
"- Faithfulness: Are the statements logically valid and consistent with the context (no contradictions)?\n" |
|
|
"- Context Precision: Does the answer avoid including irrelevant details from the context?\n" |
|
|
"- Context Recall: Does the answer capture all IMPORTANT information from the context needed to answer well?\n\n" |
|
|
"Return ONLY a single JSON object with this exact structure:\n" |
|
|
"{\n" |
|
|
' "query": string,\n' |
|
|
' "response": string,\n' |
|
|
' "groundedness_evaluation": {"score": int, "justification": string},\n' |
|
|
' "relevance_evaluation": {"score": int, "justification": string},\n' |
|
|
' "faithfulness_evaluation": {"score": int, "justification": string},\n' |
|
|
' "context_precision_evaluation": {"score": int, "justification": string},\n' |
|
|
' "context_recall_evaluation": {"score": int, "justification": string}\n' |
|
|
"}" |
|
|
) |
|
|
|
|
|
user_prompt = ( |
|
|
f"USER QUERY:\n{question}\n\n" |
|
|
f"RETRIEVED CONTEXT:\n{context}\n\n" |
|
|
f"MODEL ANSWER:\n{answer}" |
|
|
) |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-4.1-mini", |
|
|
temperature=0.0, |
|
|
response_format={"type": "json_object"}, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
], |
|
|
) |
|
|
|
|
|
raw = completion.choices[0].message.content |
|
|
try: |
|
|
return json.loads(raw) |
|
|
except json.JSONDecodeError: |
|
|
return { |
|
|
"query": question, |
|
|
"response": answer, |
|
|
"groundedness_evaluation": {"score": None, "justification": "Failed to parse JSON evaluation."}, |
|
|
"relevance_evaluation": {"score": None, "justification": raw}, |
|
|
"faithfulness_evaluation": {"score": None, "justification": ""}, |
|
|
"context_precision_evaluation": {"score": None, "justification": ""}, |
|
|
"context_recall_evaluation": {"score": None, "justification": ""}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_from_chroma(query: str, top_k: int, api_key: str) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieve top_k passages from Chroma using embeddings. |
|
|
Preserves ids + metadatas + distances + documents. |
|
|
|
|
|
Returns list[dict] with keys: |
|
|
- id: str |
|
|
- text: str |
|
|
- metadata: dict |
|
|
- distance: float|None |
|
|
""" |
|
|
query = (query or "").strip() |
|
|
if not query: |
|
|
return [] |
|
|
|
|
|
if collection.count() == 0: |
|
|
return [] |
|
|
|
|
|
q_emb = embed_texts([query], api_key)[0] |
|
|
results = collection.query( |
|
|
query_embeddings=[q_emb], |
|
|
n_results=top_k, |
|
|
) |
|
|
|
|
|
ids = results.get("ids", [[]])[0] or [] |
|
|
docs = results.get("documents", [[]])[0] or [] |
|
|
metas = results.get("metadatas", [[]])[0] or [] |
|
|
dists = results.get("distances", [[]])[0] if "distances" in results else [None] * len(docs) |
|
|
|
|
|
out = [] |
|
|
for i in range(min(len(docs), len(ids), len(metas))): |
|
|
out.append({ |
|
|
"id": ids[i], |
|
|
"text": docs[i], |
|
|
"metadata": metas[i] or {}, |
|
|
"distance": dists[i] if i < len(dists) else None, |
|
|
}) |
|
|
return out |
|
|
|
|
|
|
|
|
def cross_encoder_rerank(query: str, docs: List[Dict[str, Any]], top_n: int) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Rerank retrieved passages with a HF cross-encoder: |
|
|
model = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
|
|
|
|
Inputs: |
|
|
- query: str |
|
|
- docs: list of dicts from retrieve_from_chroma or merged retrieval |
|
|
- top_n: int |
|
|
|
|
|
Returns: list of docs with added field: |
|
|
- score: float (higher is better) |
|
|
""" |
|
|
query = (query or "").strip() |
|
|
if not query or not docs: |
|
|
return [] |
|
|
|
|
|
model = get_cross_encoder() |
|
|
|
|
|
pairs = [(query, d.get("text", "")) for d in docs] |
|
|
scores = model.predict(pairs) |
|
|
|
|
|
reranked = [] |
|
|
for d, s in zip(docs, scores): |
|
|
dd = dict(d) |
|
|
dd["score"] = float(s) |
|
|
reranked.append(dd) |
|
|
|
|
|
reranked.sort(key=lambda x: x.get("score", float("-inf")), reverse=True) |
|
|
return reranked[:top_n] |
|
|
|
|
|
|
|
|
def build_prompt(query: str, reranked_docs: List[Dict[str, Any]]) -> Tuple[str, str]: |
|
|
""" |
|
|
Build the final context string and the LLM prompt. |
|
|
|
|
|
Returns: |
|
|
- context: str (the final context string) |
|
|
- prompt: str (full prompt for the LLM) |
|
|
""" |
|
|
parts = [] |
|
|
for d in reranked_docs: |
|
|
md = d.get("metadata", {}) or {} |
|
|
source = md.get("source", "unknown") |
|
|
page = md.get("page", md.get("page_number", md.get("pageno", ""))) |
|
|
|
|
|
header = f"Source: {source}" |
|
|
if page != "" and page is not None: |
|
|
header += f" | Page: {page}" |
|
|
|
|
|
parts.append(f"{header}\n{d.get('text','')}".strip()) |
|
|
|
|
|
context = "\n\n---\n\n".join(parts).strip() |
|
|
|
|
|
prompt = ( |
|
|
"You are a helpful assistant that answers questions ONLY using the provided document context. " |
|
|
"If the context does not contain the answer, say you do not know.\n\n" |
|
|
f"Context from documents:\n\n{context}\n\n" |
|
|
f"Question: {query}\n\n" |
|
|
"Answer based only on the context above." |
|
|
) |
|
|
|
|
|
return context, prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_docs_by_id(doc_lists: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Merge/dedupe docs (dicts) by Chroma chunk id. Keeps the best (lowest) distance if present. |
|
|
""" |
|
|
merged: Dict[str, Dict[str, Any]] = {} |
|
|
for docs in doc_lists: |
|
|
for d in docs: |
|
|
cid = d.get("id") |
|
|
if not cid: |
|
|
continue |
|
|
if cid not in merged: |
|
|
merged[cid] = d |
|
|
else: |
|
|
|
|
|
old_dist = merged[cid].get("distance") |
|
|
new_dist = d.get("distance") |
|
|
if old_dist is not None and new_dist is not None and new_dist < old_dist: |
|
|
merged[cid] = d |
|
|
return list(merged.values()) |
|
|
|
|
|
|
|
|
def query_rag_multi(selected_queries: List[str], api_key: str) -> str: |
|
|
selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()] |
|
|
if not selected_queries: |
|
|
return "⚠️ Please select at least one expanded query." |
|
|
|
|
|
if collection.count() == 0: |
|
|
return "⚠️ No documents in the database yet. Upload and index some documents first." |
|
|
|
|
|
|
|
|
|
|
|
q_embs = embed_texts(selected_queries, api_key) |
|
|
results = collection.query( |
|
|
query_embeddings=q_embs, |
|
|
n_results=5, |
|
|
) |
|
|
|
|
|
|
|
|
all_ids = results.get("ids", []) |
|
|
all_docs = results.get("documents", []) |
|
|
all_metas = results.get("metadatas", []) |
|
|
all_dist = results.get("distances", None) |
|
|
|
|
|
doc_lists: List[List[Dict[str, Any]]] = [] |
|
|
for qi in range(len(all_docs)): |
|
|
ids_i = all_ids[qi] if qi < len(all_ids) else [] |
|
|
docs_i = all_docs[qi] if qi < len(all_docs) else [] |
|
|
metas_i = all_metas[qi] if qi < len(all_metas) else [] |
|
|
dist_i = all_dist[qi] if isinstance(all_dist, list) and qi < len(all_dist) else [None] * len(docs_i) |
|
|
|
|
|
out_i = [] |
|
|
for cid, doc, meta, dist in zip(ids_i, docs_i, metas_i, dist_i): |
|
|
out_i.append({"id": cid, "text": doc, "metadata": meta or {}, "distance": dist}) |
|
|
doc_lists.append(out_i) |
|
|
|
|
|
merged = _merge_docs_by_id(doc_lists) |
|
|
if not merged: |
|
|
return "I couldn't find any relevant context in the indexed documents." |
|
|
|
|
|
|
|
|
merged.sort(key=lambda d: (d.get("distance") is None, d.get("distance", 1e9))) |
|
|
top = merged[:5] |
|
|
|
|
|
context_parts = [] |
|
|
for d in top: |
|
|
md = d.get("metadata", {}) or {} |
|
|
context_parts.append(f"Source: {md.get('source','unknown')}\n{d.get('text','')}") |
|
|
context = "\n\n---\n\n".join(context_parts) |
|
|
|
|
|
client = get_openai_client(api_key) |
|
|
system_prompt = ( |
|
|
"You are a helpful assistant that answers questions ONLY using the provided document context. " |
|
|
"If the context does not contain the answer, say you do not know." |
|
|
) |
|
|
user_prompt = ( |
|
|
f"Context from documents:\n\n{context}\n\n" |
|
|
f"Selected expanded queries:\n- " + "\n- ".join(selected_queries) + "\n\n" |
|
|
"Answer based only on the context above." |
|
|
) |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-4.1-mini", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
], |
|
|
temperature=0.1, |
|
|
) |
|
|
|
|
|
response_text = completion.choices[0].message.content.strip() |
|
|
|
|
|
try: |
|
|
eval_dict = evaluate_answer( |
|
|
question=" | ".join(selected_queries), |
|
|
context=context, |
|
|
answer=response_text, |
|
|
api_key=api_key, |
|
|
) |
|
|
|
|
|
log_record = { |
|
|
"query": eval_dict.get("query"), |
|
|
"response": eval_dict.get("response"), |
|
|
"groundedness_evaluation": eval_dict.get("groundedness_evaluation"), |
|
|
"relevance_evaluation": eval_dict.get("relevance_evaluation"), |
|
|
"faithfulness_evaluation": eval_dict.get("faithfulness_evaluation"), |
|
|
"context_precision_evaluation": eval_dict.get("context_precision_evaluation"), |
|
|
"context_recall_evaluation": eval_dict.get("context_recall_evaluation"), |
|
|
} |
|
|
|
|
|
return ( |
|
|
f"### 💬 Answer\n\n{response_text}\n\n" |
|
|
f"---\n\n" |
|
|
f"### 🔍 Self-evaluation (1–5)\n\n" |
|
|
f"```json\n{json.dumps(log_record, indent=2)}\n```" |
|
|
) |
|
|
except Exception as e: |
|
|
return ( |
|
|
f"### 💬 Answer\n\n{response_text}\n\n" |
|
|
f"---\n\n" |
|
|
f"⚠️ Self-evaluation failed: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_rerank_results_md(query: str, reranked: List[Dict[str, Any]], top_n: int) -> str: |
|
|
if not reranked: |
|
|
return "*(No reranked results to display.)*" |
|
|
|
|
|
lines = [] |
|
|
lines.append(f"### 🎯 Cross-Encoder Rerank Results (top {top_n})") |
|
|
lines.append("") |
|
|
lines.append("| Rank | Score | Source | Page | Snippet |") |
|
|
lines.append("|---:|---:|---|---:|---|") |
|
|
|
|
|
for i, d in enumerate(reranked, start=1): |
|
|
md = d.get("metadata", {}) or {} |
|
|
source = str(md.get("source", "unknown")) |
|
|
page = md.get("page", md.get("page_number", md.get("pageno", ""))) |
|
|
score = d.get("score", None) |
|
|
snippet = (d.get("text", "") or "").replace("\n", " ").strip() |
|
|
if len(snippet) > 160: |
|
|
snippet = snippet[:160] + "…" |
|
|
lines.append(f"| {i} | {score:.4f} | {source} | {page if page is not None else ''} | {snippet} |") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_ingest(files, api_key): |
|
|
if not api_key or not api_key.strip(): |
|
|
return "❌ Please enter your OpenAI API key before indexing." |
|
|
|
|
|
if not files: |
|
|
return "⚠️ Please drop at least one document." |
|
|
|
|
|
file_paths = files if isinstance(files, list) else [files] |
|
|
|
|
|
try: |
|
|
status = add_documents_to_chroma(file_paths, api_key) |
|
|
except Exception as e: |
|
|
return f"❌ Error during indexing: {e}" |
|
|
return status |
|
|
|
|
|
|
|
|
def gradio_expand(question: str, api_key: str): |
|
|
if not api_key or not api_key.strip(): |
|
|
return gr.update(choices=[], value=[]), "❌ Please enter your OpenAI API key first." |
|
|
|
|
|
expanded = query_expansion(question, api_key) |
|
|
md = format_expansions_md(expanded) |
|
|
default_value = expanded[:1] if expanded else [] |
|
|
return gr.update(choices=expanded, value=default_value), md |
|
|
|
|
|
|
|
|
def gradio_run_selected(selected_queries: List[str], api_key: str) -> str: |
|
|
if not api_key or not api_key.strip(): |
|
|
return "❌ Please enter your OpenAI API key before searching." |
|
|
if not selected_queries: |
|
|
return "⚠️ Please expand a question and select one or more to run." |
|
|
|
|
|
try: |
|
|
return query_rag_multi(selected_queries, api_key) |
|
|
except Exception as e: |
|
|
return f"❌ Error during question answering: {e}" |
|
|
|
|
|
|
|
|
def gradio_cross_encode(original_question: str, selected_queries: List[str], api_key: str) -> Tuple[str, str]: |
|
|
""" |
|
|
Cross-encode button: |
|
|
- Initial retrieval via Chroma: top_k=20 (per requirement) |
|
|
- Rerank via cross-encoder: top_n=5 (per requirement) |
|
|
- Show: |
|
|
(a) top_n reranked passages, |
|
|
(b) their scores, |
|
|
(c) final context string |
|
|
""" |
|
|
if not api_key or not api_key.strip(): |
|
|
return "❌ Please enter your OpenAI API key first.", "" |
|
|
|
|
|
if collection.count() == 0: |
|
|
return "⚠️ No documents in the database yet. Upload and index some documents first.", "" |
|
|
|
|
|
original_question = (original_question or "").strip() |
|
|
selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()] |
|
|
|
|
|
if not original_question and not selected_queries: |
|
|
return "⚠️ Please type a question and/or select expansions first.", "" |
|
|
|
|
|
|
|
|
retrieval_queries = selected_queries if selected_queries else [original_question] |
|
|
|
|
|
|
|
|
retrieved_lists = [retrieve_from_chroma(q, top_k=20, api_key=api_key) for q in retrieval_queries] |
|
|
merged_docs = _merge_docs_by_id(retrieved_lists) |
|
|
|
|
|
if not merged_docs: |
|
|
return "I couldn't find any relevant context in the indexed documents.", "" |
|
|
|
|
|
|
|
|
scoring_query = original_question if original_question else retrieval_queries[0] |
|
|
|
|
|
|
|
|
reranked = cross_encoder_rerank(scoring_query, merged_docs, top_n=5) |
|
|
|
|
|
|
|
|
context, _prompt = build_prompt(scoring_query, reranked) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
md = format_rerank_results_md(scoring_query, reranked, top_n=5) |
|
|
return md, f"### 🧩 Final Context (for LLM)\n\n{context}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="RAG with ChromaDB") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 📚 RAG Q&A with ChromaDB + Gradio (Multi-Select Query Expansion + Cross-Encoder Rerank) |
|
|
1. Paste your **OpenAI API key** below. |
|
|
2. **Drag & drop** one or more documents into the upload box. |
|
|
3. Click **"Index documents"** to store them in a Chroma vector database. |
|
|
4. Type a question and press **Enter** (or click **Expand**) to generate expanded queries. |
|
|
5. Select **one or more** expanded queries. |
|
|
6. Click **Run Search** for the normal pipeline, or **Cross Encode** to view reranked passages + scores + final context. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
api_key_box = gr.Textbox( |
|
|
label="OpenAI API Key", |
|
|
placeholder="sk-... (this is kept in memory only for this session)", |
|
|
type="password", |
|
|
) |
|
|
|
|
|
file_input = gr.File( |
|
|
label="Drop your document(s) here", |
|
|
file_count="multiple", |
|
|
type="filepath", |
|
|
) |
|
|
ingest_button = gr.Button("Index documents") |
|
|
ingest_status = gr.Markdown("⚙️ Waiting for documents...") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
question_box = gr.Textbox( |
|
|
label="Type a question, then press Enter to expand", |
|
|
placeholder="e.g., What are the main findings in the report?", |
|
|
lines=3, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
expand_button = gr.Button("Expand") |
|
|
run_button = gr.Button("Run Search") |
|
|
cross_button = gr.Button("Cross Encode") |
|
|
|
|
|
expanded_checks = gr.CheckboxGroup( |
|
|
label="Choose one or more expanded queries to run", |
|
|
choices=[], |
|
|
value=[], |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
expansions_preview = gr.Markdown("*(No expansions yet — type a question and press Enter.)*") |
|
|
answer_box = gr.Markdown("💬 Answer will appear here (with self-evaluation).") |
|
|
|
|
|
gr.Markdown("---") |
|
|
rerank_results_box = gr.Markdown("*(Cross-encoder rerank results will appear here.)*") |
|
|
rerank_context_box = gr.Markdown("*(Final context for the LLM will appear here.)*") |
|
|
|
|
|
ingest_button.click( |
|
|
fn=gradio_ingest, |
|
|
inputs=[file_input, api_key_box], |
|
|
outputs=[ingest_status], |
|
|
) |
|
|
|
|
|
|
|
|
question_box.submit( |
|
|
fn=gradio_expand, |
|
|
inputs=[question_box, api_key_box], |
|
|
outputs=[expanded_checks, expansions_preview], |
|
|
) |
|
|
|
|
|
|
|
|
expand_button.click( |
|
|
fn=gradio_expand, |
|
|
inputs=[question_box, api_key_box], |
|
|
outputs=[expanded_checks, expansions_preview], |
|
|
) |
|
|
|
|
|
|
|
|
run_button.click( |
|
|
fn=gradio_run_selected, |
|
|
inputs=[expanded_checks, api_key_box], |
|
|
outputs=[answer_box], |
|
|
) |
|
|
|
|
|
|
|
|
cross_button.click( |
|
|
fn=gradio_cross_encode, |
|
|
inputs=[question_box, expanded_checks, api_key_box], |
|
|
outputs=[rerank_results_box, rerank_context_box], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|