HamidOmarov commited on
Commit
7715973
·
verified ·
1 Parent(s): 26ad320

Update app/api.py

Browse files
Files changed (1) hide show
  1. app/api.py +200 -16
app/api.py CHANGED
@@ -1,14 +1,24 @@
1
  # app/api.py
2
- from typing import List
3
 
4
- import faiss, os
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.responses import JSONResponse, RedirectResponse
8
- from pydantic import BaseModel
9
 
10
  from .rag_system import SimpleRAG, UPLOAD_DIR, INDEX_DIR
11
 
 
 
 
12
  app = FastAPI(title="RAG API", version="1.3.0")
13
 
14
  app.add_middleware(
@@ -21,23 +31,148 @@ app.add_middleware(
21
 
22
  rag = SimpleRAG()
23
 
24
- # ---------- Schemas ----------
 
 
25
  class UploadResponse(BaseModel):
26
  filename: str
27
  chunks_added: int
28
 
29
  class AskRequest(BaseModel):
30
- question: str
31
- top_k: int = 5
32
 
33
  class AskResponse(BaseModel):
34
  answer: str
35
  contexts: List[str]
36
 
 
 
 
 
37
  class HistoryResponse(BaseModel):
38
  total_chunks: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # ---------- Utility ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @app.get("/")
42
  def root():
43
  return RedirectResponse(url="/docs")
@@ -56,9 +191,11 @@ def debug_translate():
56
  except Exception as e:
57
  return JSONResponse(status_code=500, content={"ok": False, "error": str(e)})
58
 
59
- # ---------- Core ----------
60
  @app.post("/upload_pdf", response_model=UploadResponse)
61
  async def upload_pdf(file: UploadFile = File(...)):
 
 
 
62
  dest = UPLOAD_DIR / file.filename
63
  with open(dest, "wb") as f:
64
  while True:
@@ -66,30 +203,71 @@ async def upload_pdf(file: UploadFile = File(...)):
66
  if not chunk:
67
  break
68
  f.write(chunk)
 
69
  added = rag.add_pdf(dest)
70
  if added == 0:
71
- # Clear message for scanned/empty PDFs
72
  raise HTTPException(status_code=400, detail="No extractable text found (likely a scanned image PDF).")
 
 
73
  return UploadResponse(filename=file.filename, chunks_added=added)
74
 
75
  @app.post("/ask_question", response_model=AskResponse)
76
  def ask_question(payload: AskRequest):
77
- hits = rag.search(payload.question, k=max(1, payload.top_k))
78
- contexts = [c for c, _ in hits]
79
- answer = rag.synthesize_answer(payload.question, contexts)
80
- return AskResponse(answer=answer, contexts=contexts or rag.last_added[:5])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  @app.get("/get_history", response_model=HistoryResponse)
83
  def get_history():
84
- return HistoryResponse(total_chunks=len(rag.chunks))
 
 
 
85
 
86
  @app.get("/stats")
87
- def stats():
 
88
  return {
 
 
 
 
89
  "total_chunks": len(rag.chunks),
90
  "faiss_ntotal": int(getattr(rag.index, "ntotal", 0)),
91
  "model_dim": int(getattr(rag.index, "d", rag.embed_dim)),
92
- "last_added_chunks": len(rag.last_added),
93
  "version": app.version,
94
  }
95
 
@@ -104,6 +282,12 @@ def reset_index():
104
  os.remove(p)
105
  except FileNotFoundError:
106
  pass
 
 
 
 
 
 
107
  return {"ok": True}
108
  except Exception as e:
109
  raise HTTPException(status_code=500, detail=str(e))
 
1
  # app/api.py
2
+ from __future__ import annotations
3
 
4
+ from typing import List, Optional
5
+ from collections import deque
6
+ from datetime import datetime
7
+ from time import perf_counter
8
+ import re
9
+ import os
10
+
11
+ import faiss
12
  from fastapi import FastAPI, UploadFile, File, HTTPException
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse, RedirectResponse
15
+ from pydantic import BaseModel, Field
16
 
17
  from .rag_system import SimpleRAG, UPLOAD_DIR, INDEX_DIR
18
 
19
+ # ------------------------------------------------------------------------------
20
+ # App setup
21
+ # ------------------------------------------------------------------------------
22
  app = FastAPI(title="RAG API", version="1.3.0")
23
 
24
  app.add_middleware(
 
31
 
32
  rag = SimpleRAG()
33
 
34
+ # ------------------------------------------------------------------------------
35
+ # Models
36
+ # ------------------------------------------------------------------------------
37
  class UploadResponse(BaseModel):
38
  filename: str
39
  chunks_added: int
40
 
41
  class AskRequest(BaseModel):
42
+ question: str = Field(..., min_length=1)
43
+ top_k: int = Field(5, ge=1, le=20)
44
 
45
  class AskResponse(BaseModel):
46
  answer: str
47
  contexts: List[str]
48
 
49
+ class HistoryItem(BaseModel):
50
+ question: str
51
+ timestamp: str
52
+
53
  class HistoryResponse(BaseModel):
54
  total_chunks: int
55
+ history: List[HistoryItem] = []
56
+
57
+ # ------------------------------------------------------------------------------
58
+ # Lightweight stats store (in-memory)
59
+ # ------------------------------------------------------------------------------
60
+ class StatsStore:
61
+ def __init__(self):
62
+ self.documents_indexed = 0
63
+ self.questions_answered = 0
64
+ self.latencies_ms = deque(maxlen=500)
65
+ # Mon..Sun simple counter (index 0 = today for simplicity)
66
+ self.last7_questions = deque([0] * 7, maxlen=7)
67
+ self.history = deque(maxlen=50) # recent questions
68
+
69
+ def add_docs(self, n: int):
70
+ if n > 0:
71
+ self.documents_indexed += n
72
+
73
+ def add_question(self, latency_ms: Optional[int] = None, q: Optional[str] = None):
74
+ self.questions_answered += 1
75
+ if latency_ms is not None:
76
+ self.latencies_ms.append(int(latency_ms))
77
+ if len(self.last7_questions) < 7:
78
+ self.last7_questions.appendleft(1)
79
+ else:
80
+ # attribute to "today" bucket
81
+ self.last7_questions[0] += 1
82
+ if q:
83
+ self.history.appendleft(
84
+ {"question": q, "timestamp": datetime.utcnow().isoformat()}
85
+ )
86
+
87
+ @property
88
+ def avg_ms(self) -> int:
89
+ return int(sum(self.latencies_ms) / len(self.latencies_ms)) if self.latencies_ms else 0
90
+
91
+ stats = StatsStore()
92
+
93
+ # ------------------------------------------------------------------------------
94
+ # Helpers
95
+ # ------------------------------------------------------------------------------
96
+ _GENERIC_PATTERNS = [
97
+ r"\bbased on document context\b",
98
+ r"\bappears to be\b",
99
+ r"\bgeneral (?:summary|overview)\b",
100
+ ]
101
+
102
+ _STOPWORDS = {
103
+ "the","a","an","of","for","and","or","in","on","to","from","with","by","is","are",
104
+ "was","were","be","been","being","at","as","that","this","these","those","it",
105
+ "its","into","than","then","so","such","about","over","per","via","vs","within"
106
+ }
107
+
108
+ def is_generic_answer(text: str) -> bool:
109
+ if not text:
110
+ return True
111
+ low = text.strip().lower()
112
+ if len(low) < 15:
113
+ return True
114
+ for pat in _GENERIC_PATTERNS:
115
+ if re.search(pat, low):
116
+ return True
117
+ return False
118
+
119
+ def tokenize(s: str) -> List[str]:
120
+ return [w for w in re.findall(r"[a-zA-Z0-9]+", s.lower()) if w and w not in _STOPWORDS and len(w) > 2]
121
+
122
+ def extractive_answer(question: str, contexts: List[str], max_chars: int = 500) -> str:
123
+ """
124
+ Simple keyword-based extractive fallback:
125
+ pick sentences containing most question tokens.
126
+ """
127
+ if not contexts:
128
+ return "I couldn't find relevant information in the indexed documents for this question."
129
+
130
+ q_tokens = set(tokenize(question))
131
+ if not q_tokens:
132
+ # if question is e.g. numbers only
133
+ q_tokens = set(tokenize(" ".join(contexts[:1])))
134
+
135
+ # split into sentences
136
+ sentences: List[str] = []
137
+ for c in contexts:
138
+ c = c or ""
139
+ # rough sentence split
140
+ for s in re.split(r"(?<=[\.!\?])\s+|\n+", c.strip()):
141
+ s = s.strip()
142
+ if s:
143
+ sentences.append(s)
144
+
145
+ if not sentences:
146
+ # fallback to first context chunk
147
+ return (contexts[0] or "")[:max_chars]
148
+
149
+ # score sentences
150
+ scored: List[tuple[int, str]] = []
151
+ for s in sentences:
152
+ toks = set(tokenize(s))
153
+ score = len(q_tokens & toks)
154
+ scored.append((score, s))
155
 
156
+ # pick top sentences with score > 0, otherwise first few sentences
157
+ scored.sort(key=lambda x: (x[0], len(x[1])), reverse=True)
158
+ picked: List[str] = []
159
+
160
+ for score, sent in scored:
161
+ if score <= 0 and picked:
162
+ break
163
+ if len(" ".join(picked) + " " + sent) > max_chars:
164
+ break
165
+ picked.append(sent)
166
+
167
+ if not picked:
168
+ # no overlap, take first ~max_chars from contexts
169
+ return (contexts[0] or "")[:max_chars]
170
+
171
+ return " ".join(picked).strip()
172
+
173
+ # ------------------------------------------------------------------------------
174
+ # Routes
175
+ # ------------------------------------------------------------------------------
176
  @app.get("/")
177
  def root():
178
  return RedirectResponse(url="/docs")
 
191
  except Exception as e:
192
  return JSONResponse(status_code=500, content={"ok": False, "error": str(e)})
193
 
 
194
  @app.post("/upload_pdf", response_model=UploadResponse)
195
  async def upload_pdf(file: UploadFile = File(...)):
196
+ if not file.filename.lower().endswith(".pdf"):
197
+ raise HTTPException(status_code=400, detail="Only PDF files are allowed.")
198
+
199
  dest = UPLOAD_DIR / file.filename
200
  with open(dest, "wb") as f:
201
  while True:
 
203
  if not chunk:
204
  break
205
  f.write(chunk)
206
+
207
  added = rag.add_pdf(dest)
208
  if added == 0:
 
209
  raise HTTPException(status_code=400, detail="No extractable text found (likely a scanned image PDF).")
210
+
211
+ stats.add_docs(added)
212
  return UploadResponse(filename=file.filename, chunks_added=added)
213
 
214
  @app.post("/ask_question", response_model=AskResponse)
215
  def ask_question(payload: AskRequest):
216
+ q = (payload.question or "").strip()
217
+ if not q:
218
+ raise HTTPException(status_code=400, detail="Missing 'question'.")
219
+
220
+ k = max(1, int(payload.top_k))
221
+ t0 = perf_counter()
222
+
223
+ # retrieval
224
+ try:
225
+ hits = rag.search(q, k=k) # expected: List[Tuple[str, float]]
226
+ except Exception as e:
227
+ raise HTTPException(status_code=500, detail=f"Search failed: {e}")
228
+
229
+ contexts = [c for c, _ in (hits or []) if c] or (rag.last_added[:k] if getattr(rag, "last_added", None) else [])
230
+
231
+ if not contexts:
232
+ stats.add_question(int((perf_counter() - t0) * 1000), q=q)
233
+ return AskResponse(
234
+ answer="I couldn't find relevant information in the indexed documents for this question.",
235
+ contexts=[]
236
+ )
237
+
238
+ # synthesis (LLM or rule-based inside rag)
239
+ try:
240
+ synthesized = rag.synthesize_answer(q, contexts) or ""
241
+ except Exception:
242
+ synthesized = ""
243
+
244
+ # guard against generic/unchanging answers
245
+ if is_generic_answer(synthesized):
246
+ synthesized = extractive_answer(q, contexts, max_chars=600)
247
+
248
+ latency_ms = int((perf_counter() - t0) * 1000)
249
+ stats.add_question(latency_ms, q=q)
250
+ return AskResponse(answer=synthesized.strip(), contexts=contexts)
251
 
252
  @app.get("/get_history", response_model=HistoryResponse)
253
  def get_history():
254
+ return HistoryResponse(
255
+ total_chunks=len(rag.chunks),
256
+ history=[HistoryItem(**h) for h in list(stats.history)]
257
+ )
258
 
259
  @app.get("/stats")
260
+ def stats_endpoint():
261
+ # keep backward compat fields + add dashboard-friendly metrics
262
  return {
263
+ "documents_indexed": stats.documents_indexed,
264
+ "questions_answered": stats.questions_answered,
265
+ "avg_ms": stats.avg_ms,
266
+ "last7_questions": list(stats.last7_questions),
267
  "total_chunks": len(rag.chunks),
268
  "faiss_ntotal": int(getattr(rag.index, "ntotal", 0)),
269
  "model_dim": int(getattr(rag.index, "d", rag.embed_dim)),
270
+ "last_added_chunks": len(getattr(rag, "last_added", [])),
271
  "version": app.version,
272
  }
273
 
 
282
  os.remove(p)
283
  except FileNotFoundError:
284
  pass
285
+ # also reset stats counters to avoid stale analytics
286
+ stats.documents_indexed = 0
287
+ stats.questions_answered = 0
288
+ stats.latencies_ms.clear()
289
+ stats.last7_questions = deque([0] * 7, maxlen=7)
290
+ stats.history.clear()
291
  return {"ok": True}
292
  except Exception as e:
293
  raise HTTPException(status_code=500, detail=str(e))