Spaces:
Sleeping
Sleeping
# app/api.py | |
from __future__ import annotations | |
from typing import List, Optional | |
from collections import deque | |
from datetime import datetime | |
from time import perf_counter | |
import re | |
import os | |
import faiss | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from pydantic import BaseModel, Field | |
from .rag_system import SimpleRAG, UPLOAD_DIR, INDEX_DIR | |
# ------------------------------------------------------------------------------ | |
# App setup | |
# ------------------------------------------------------------------------------ | |
app = FastAPI(title="RAG API", version="1.3.0") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
rag = SimpleRAG() | |
# ------------------------------------------------------------------------------ | |
# Models | |
# ------------------------------------------------------------------------------ | |
class UploadResponse(BaseModel): | |
filename: str | |
chunks_added: int | |
class AskRequest(BaseModel): | |
question: str = Field(..., min_length=1) | |
top_k: int = Field(5, ge=1, le=20) | |
class AskResponse(BaseModel): | |
answer: str | |
contexts: List[str] | |
class HistoryItem(BaseModel): | |
question: str | |
timestamp: str | |
class HistoryResponse(BaseModel): | |
total_chunks: int | |
history: List[HistoryItem] = [] | |
# ------------------------------------------------------------------------------ | |
# Lightweight stats store (in-memory) | |
# ------------------------------------------------------------------------------ | |
class StatsStore: | |
def __init__(self): | |
self.documents_indexed = 0 | |
self.questions_answered = 0 | |
self.latencies_ms = deque(maxlen=500) | |
# Mon..Sun simple counter (index 0 = today for simplicity) | |
self.last7_questions = deque([0] * 7, maxlen=7) | |
self.history = deque(maxlen=50) # recent questions | |
def add_docs(self, n: int): | |
if n > 0: | |
self.documents_indexed += n | |
def add_question(self, latency_ms: Optional[int] = None, q: Optional[str] = None): | |
self.questions_answered += 1 | |
if latency_ms is not None: | |
self.latencies_ms.append(int(latency_ms)) | |
if len(self.last7_questions) < 7: | |
self.last7_questions.appendleft(1) | |
else: | |
# attribute to "today" bucket | |
self.last7_questions[0] += 1 | |
if q: | |
self.history.appendleft( | |
{"question": q, "timestamp": datetime.utcnow().isoformat()} | |
) | |
def avg_ms(self) -> int: | |
return int(sum(self.latencies_ms) / len(self.latencies_ms)) if self.latencies_ms else 0 | |
stats = StatsStore() | |
# ------------------------------------------------------------------------------ | |
# Helpers | |
# ------------------------------------------------------------------------------ | |
_GENERIC_PATTERNS = [ | |
r"\bbased on document context\b", | |
r"\bappears to be\b", | |
r"\bgeneral (?:summary|overview)\b", | |
] | |
_STOPWORDS = { | |
"the","a","an","of","for","and","or","in","on","to","from","with","by","is","are", | |
"was","were","be","been","being","at","as","that","this","these","those","it", | |
"its","into","than","then","so","such","about","over","per","via","vs","within" | |
} | |
def is_generic_answer(text: str) -> bool: | |
if not text: | |
return True | |
low = text.strip().lower() | |
if len(low) < 15: | |
return True | |
for pat in _GENERIC_PATTERNS: | |
if re.search(pat, low): | |
return True | |
return False | |
def tokenize(s: str) -> List[str]: | |
return [w for w in re.findall(r"[a-zA-Z0-9]+", s.lower()) if w and w not in _STOPWORDS and len(w) > 2] | |
def extractive_answer(question: str, contexts: List[str], max_chars: int = 500) -> str: | |
""" | |
Simple keyword-based extractive fallback: | |
pick sentences containing most question tokens. | |
""" | |
if not contexts: | |
return "I couldn't find relevant information in the indexed documents for this question." | |
q_tokens = set(tokenize(question)) | |
if not q_tokens: | |
# if question is e.g. numbers only | |
q_tokens = set(tokenize(" ".join(contexts[:1]))) | |
# split into sentences | |
sentences: List[str] = [] | |
for c in contexts: | |
c = c or "" | |
# rough sentence split | |
for s in re.split(r"(?<=[\.!\?])\s+|\n+", c.strip()): | |
s = s.strip() | |
if s: | |
sentences.append(s) | |
if not sentences: | |
# fallback to first context chunk | |
return (contexts[0] or "")[:max_chars] | |
# score sentences | |
scored: List[tuple[int, str]] = [] | |
for s in sentences: | |
toks = set(tokenize(s)) | |
score = len(q_tokens & toks) | |
scored.append((score, s)) | |
# pick top sentences with score > 0, otherwise first few sentences | |
scored.sort(key=lambda x: (x[0], len(x[1])), reverse=True) | |
picked: List[str] = [] | |
for score, sent in scored: | |
if score <= 0 and picked: | |
break | |
if len(" ".join(picked) + " " + sent) > max_chars: | |
break | |
picked.append(sent) | |
if not picked: | |
# no overlap, take first ~max_chars from contexts | |
return (contexts[0] or "")[:max_chars] | |
return " ".join(picked).strip() | |
# ------------------------------------------------------------------------------ | |
# Routes | |
# ------------------------------------------------------------------------------ | |
def root(): | |
return RedirectResponse(url="/docs") | |
def health(): | |
return {"status": "ok", "version": app.version, "summarizer": "extractive_en + translate + fallback"} | |
def debug_translate(): | |
try: | |
from transformers import pipeline | |
tr = pipeline("translation", model="Helsinki-NLP/opus-mt-az-en", cache_dir=str(rag.cache_dir), device=-1) | |
out = tr("Sənəd təmiri və quraşdırılması ilə bağlı işlər görülüb.", max_length=80)[0]["translation_text"] | |
return {"ok": True, "example_out": out} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"ok": False, "error": str(e)}) | |
async def upload_pdf(file: UploadFile = File(...)): | |
if not file.filename.lower().endswith(".pdf"): | |
raise HTTPException(status_code=400, detail="Only PDF files are allowed.") | |
dest = UPLOAD_DIR / file.filename | |
with open(dest, "wb") as f: | |
while True: | |
chunk = await file.read(1024 * 1024) | |
if not chunk: | |
break | |
f.write(chunk) | |
added = rag.add_pdf(dest) | |
if added == 0: | |
raise HTTPException(status_code=400, detail="No extractable text found (likely a scanned image PDF).") | |
stats.add_docs(added) | |
return UploadResponse(filename=file.filename, chunks_added=added) | |
def ask_question(payload: AskRequest): | |
q = (payload.question or "").strip() | |
if not q: | |
raise HTTPException(status_code=400, detail="Missing 'question'.") | |
k = max(1, int(payload.top_k)) | |
t0 = perf_counter() | |
# retrieval | |
try: | |
hits = rag.search(q, k=k) # expected: List[Tuple[str, float]] | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Search failed: {e}") | |
contexts = [c for c, _ in (hits or []) if c] or (rag.last_added[:k] if getattr(rag, "last_added", None) else []) | |
if not contexts: | |
stats.add_question(int((perf_counter() - t0) * 1000), q=q) | |
return AskResponse( | |
answer="I couldn't find relevant information in the indexed documents for this question.", | |
contexts=[] | |
) | |
# synthesis (LLM or rule-based inside rag) | |
try: | |
synthesized = rag.synthesize_answer(q, contexts) or "" | |
except Exception: | |
synthesized = "" | |
# guard against generic/unchanging answers | |
if is_generic_answer(synthesized): | |
synthesized = extractive_answer(q, contexts, max_chars=600) | |
latency_ms = int((perf_counter() - t0) * 1000) | |
stats.add_question(latency_ms, q=q) | |
return AskResponse(answer=synthesized.strip(), contexts=contexts) | |
def get_history(): | |
return HistoryResponse( | |
total_chunks=len(rag.chunks), | |
history=[HistoryItem(**h) for h in list(stats.history)] | |
) | |
def stats_endpoint(): | |
# keep backward compat fields + add dashboard-friendly metrics | |
return { | |
"documents_indexed": stats.documents_indexed, | |
"questions_answered": stats.questions_answered, | |
"avg_ms": stats.avg_ms, | |
"last7_questions": list(stats.last7_questions), | |
"total_chunks": len(rag.chunks), | |
"faiss_ntotal": int(getattr(rag.index, "ntotal", 0)), | |
"model_dim": int(getattr(rag.index, "d", rag.embed_dim)), | |
"last_added_chunks": len(getattr(rag, "last_added", [])), | |
"version": app.version, | |
} | |
def reset_index(): | |
try: | |
rag.index = faiss.IndexFlatIP(rag.embed_dim) | |
rag.chunks = [] | |
rag.last_added = [] | |
for p in [INDEX_DIR / "faiss.index", INDEX_DIR / "meta.npy"]: | |
try: | |
os.remove(p) | |
except FileNotFoundError: | |
pass | |
# also reset stats counters to avoid stale analytics | |
stats.documents_indexed = 0 | |
stats.questions_answered = 0 | |
stats.latencies_ms.clear() | |
stats.last7_questions = deque([0] * 7, maxlen=7) | |
stats.history.clear() | |
return {"ok": True} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |