GenAIDevTOProd commited on
Commit
35c5459
·
verified ·
1 Parent(s): b22198e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +132 -0
  2. main.py +441 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import faiss
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ # --- minimal core (in-memory only) ---
7
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
8
+ _model = SentenceTransformer(MODEL_NAME)
9
+ _dim = int(_model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # 384
10
+
11
+ _index = faiss.IndexFlatIP(_dim) # cosine via L2-normalized IP
12
+ _ids, _texts, _metas = [], [], []
13
+
14
+ def _normalize(v: np.ndarray) -> np.ndarray:
15
+ n = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12
16
+ return (v / n).astype("float32")
17
+
18
+ def _chunk(text: str, size: int, overlap: int):
19
+ t = " ".join((text or "").split())
20
+ n = len(t); s = 0; out = []
21
+ if overlap >= size: overlap = max(size - 1, 0)
22
+ while s < n:
23
+ e = min(s + size, n)
24
+ out.append((t[s:e], s, e))
25
+ if e == n: break
26
+ s = max(e - overlap, 0)
27
+ return out
28
+
29
+ def reset():
30
+ global _index, _ids, _texts, _metas
31
+ _index = faiss.IndexFlatIP(_dim)
32
+ _ids, _texts, _metas = [], [], []
33
+ return gr.update(value="Index reset."), gr.update(value=0)
34
+
35
+ def load_sample():
36
+ docs = [
37
+ ("a", "PySpark scales ETL across clusters.", {"tag":"spark"}),
38
+ ("b", "FAISS powers fast vector similarity search used in retrieval.", {"tag":"faiss"})
39
+ ]
40
+ return "\n".join([d[1] for d in docs])
41
+
42
+ def ingest(docs_text, size, overlap):
43
+ if not docs_text.strip():
44
+ return "Provide at least one line of text.", len(_ids)
45
+ # one document per line
46
+ lines = [ln.strip() for ln in docs_text.splitlines() if ln.strip()]
47
+ rows = []
48
+ for i, ln in enumerate(lines):
49
+ pid = f"doc-{len(_ids)}-{i}"
50
+ for ctext, s, e in _chunk(ln, size, overlap):
51
+ rows.append((f"{pid}::offset:{s}-{e}", ctext, {"parent_id": pid, "start": s, "end": e}))
52
+ if not rows:
53
+ return "No chunks produced.", len(_ids)
54
+ vecs = _normalize(_model.encode([r[1] for r in rows], convert_to_numpy=True))
55
+ _index.add(vecs)
56
+ for rid, txt, meta in rows:
57
+ _ids.append(rid); _texts.append(txt); _metas.append(meta)
58
+ return f"Ingested docs={len(lines)} chunks={len(rows)}", len(_ids)
59
+
60
+ def answer(q, k, max_context_chars):
61
+ if _index.ntotal == 0:
62
+ return {"answer": "Index is empty. Ingest first.", "matches": []}
63
+ qv = _normalize(_model.encode([q], convert_to_numpy=True))
64
+ D, I = _index.search(qv, int(k))
65
+ matches = []
66
+ for i, s in zip(I[0].tolist(), D[0].tolist()):
67
+ if i < 0: continue
68
+ matches.append({
69
+ "id": _ids[i],
70
+ "score": float(s),
71
+ "text": _texts[i],
72
+ "meta": _metas[i]
73
+ })
74
+ # compose simple answer from contexts
75
+ blob, total = [], 0
76
+ for m in matches:
77
+ t = m["text"]; cut = min(len(t), max_context_chars - total)
78
+ if cut <= 0: break
79
+ blob.append(t[:cut]); total += cut
80
+ if total >= max_context_chars: break
81
+ if not blob:
82
+ out = "No relevant context."
83
+ else:
84
+ lines = [ln for ln in " ".join(blob).split(". ") if ln]
85
+ hits = [ln for ln in lines if any(tok in ln.lower() for tok in q.lower().split())] or lines[:2]
86
+ out = "Based on retrieved context:\n- " + "\n- ".join(hits[:4])
87
+ return {"answer": out, "matches": matches}
88
+
89
+ with gr.Blocks(title="RAG-as-a-Service") as demo:
90
+ gr.Markdown("### RAG-as-a-Service - Gradio\nIn-memory FAISS + MiniLM\n; one-line-per-doc ingest\n; quick answers.")
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ docs = gr.Textbox(label="Documents (one per line)", lines=6, placeholder="One document per line…")
95
+ with gr.Row():
96
+ chunk_size = gr.Slider(64, 1024, value=256, step=16, label="Chunk size")
97
+ overlap = gr.Slider(0, 256, value=32, step=8, label="Overlap")
98
+ with gr.Row():
99
+ ingest_btn = gr.Button("Ingest")
100
+ sample_btn = gr.Button("Load sample")
101
+ reset_btn = gr.Button("Reset")
102
+ ingest_status = gr.Textbox(label="Ingest status", interactive=False)
103
+ index_size = gr.Number(label="Index size", interactive=False, value=0)
104
+ with gr.Column():
105
+ q = gr.Textbox(label="Query", placeholder="Ask something...")
106
+ k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
107
+ max_chars = gr.Slider(200, 4000, value=1000, step=100, label="Max context chars")
108
+ run = gr.Button("Answer")
109
+ out = gr.JSON(label="Answer + matches")
110
+
111
+ ingest_btn.click(
112
+ ingest,
113
+ [docs, chunk_size, overlap],
114
+ [ingest_status, index_size],
115
+ api_name="ingest" # exposes POST /api/ingest
116
+ )
117
+ sample_btn.click(load_sample, None, docs)
118
+ reset_btn.click(
119
+ reset,
120
+ None,
121
+ [ingest_status, index_size],
122
+ api_name="reset" # exposes POST /api/reset (optional)
123
+ )
124
+ run.click(
125
+ answer,
126
+ [q, k, max_chars],
127
+ out,
128
+ api_name="answer" # exposes POST /api/answer
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch(share=True)
main.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Tuple
2
+ from fastapi import FastAPI, HTTPException, Request, Depends
3
+ from fastapi.responses import HTMLResponse
4
+ from pydantic import BaseModel, Field
5
+ from pathlib import Path
6
+ import numpy as np, json, os, time, uuid, pandas as pd
7
+ from sentence_transformers import SentenceTransformer
8
+ import faiss
9
+
10
+ # optional engines
11
+ try:
12
+ from pyspark.sql import SparkSession, functions as F
13
+ from pyspark.sql.types import StringType
14
+ SPARK_AVAILABLE = True
15
+ except Exception:
16
+ SPARK_AVAILABLE = False
17
+ try:
18
+ from sentence_transformers import CrossEncoder
19
+ RERANK_AVAILABLE = True
20
+ except Exception:
21
+ RERANK_AVAILABLE = False
22
+
23
+ APP_VERSION = "1.0.0"
24
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
25
+
26
+ DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True)
27
+ INDEX_FP = DATA_DIR / "index.faiss"
28
+ META_FP = DATA_DIR / "meta.jsonl"
29
+ PARQ_FP = DATA_DIR / "meta.parquet"
30
+ CFG_FP = DATA_DIR / "store.json"
31
+
32
+ # --------- Schemas ----------
33
+ class EchoRequest(BaseModel):
34
+ message: str
35
+ class HealthResponse(BaseModel):
36
+ status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False
37
+ persisted: bool = False; rerank: bool = False; index_type: str = "flat"
38
+ class EmbedRequest(BaseModel):
39
+ texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True
40
+ class EmbedResponse(BaseModel):
41
+ dim: int; count: int; preview: List[List[float]]
42
+ class Doc(BaseModel):
43
+ id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict)
44
+ class ChunkConfig(BaseModel):
45
+ size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0)
46
+ class IngestRequest(BaseModel):
47
+ docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None
48
+ class Match(BaseModel):
49
+ id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict)
50
+ class QueryRequest(BaseModel):
51
+ q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True
52
+ class QueryResponse(BaseModel):
53
+ matches: List[Match]
54
+ class ExplainMatch(Match):
55
+ start: int; end: int; token_overlap: float
56
+ class ExplainRequest(QueryRequest): pass
57
+ class ExplainResponse(BaseModel):
58
+ matches: List[ExplainMatch]
59
+ class AnswerRequest(BaseModel):
60
+ q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock")
61
+ max_context_chars: int = Field(default=1600, ge=200, le=20000)
62
+ return_contexts: bool = True; rerank: bool = False
63
+ rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2")
64
+ class AnswerResponse(BaseModel):
65
+ answer: str; contexts: List[Match] = []
66
+ class ReindexParams(BaseModel):
67
+ index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$")
68
+ nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128)
69
+
70
+ # --------- Embeddings ----------
71
+ class LazyEmbedder:
72
+ def __init__(self, model_name: str = EMBED_MODEL_NAME):
73
+ self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None
74
+ def _ensure(self):
75
+ if self._model is None:
76
+ self._model = SentenceTransformer(self.model_name)
77
+ self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore
78
+ @property
79
+ def dim(self) -> int:
80
+ self._ensure(); return int(self._dim) # type: ignore
81
+ def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray:
82
+ self._ensure()
83
+ vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore
84
+ if normalize:
85
+ norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
86
+ vecs = vecs / norms
87
+ return vecs.astype("float32")
88
+ _embedder = LazyEmbedder()
89
+
90
+ # --------- Reranker ----------
91
+ class LazyReranker:
92
+ def __init__(self): self._model=None; self._name=None
93
+ def ensure(self, name: str):
94
+ if not RERANK_AVAILABLE: return
95
+ if self._model is None or self._name != name:
96
+ self._model = CrossEncoder(name); self._name = name
97
+ def score(self, q: str, texts: List[str]) -> List[float]:
98
+ if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts)
99
+ return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore
100
+ _reranker = LazyReranker()
101
+
102
+ # --------- Chunking ----------
103
+ def chunk_text_py(text: str, size: int, overlap: int):
104
+ t = " ".join((text or "").split()); n=len(t); out=[]; s=0
105
+ if overlap >= size: overlap = max(size - 1, 0)
106
+ while s<n:
107
+ e=min(s+size,n); out.append((t[s:e],(s,e)))
108
+ if e==n: break
109
+ s=max(e-overlap,0)
110
+ return out
111
+ def spark_clean_and_chunk(docs: List[Doc], size: int, overlap: int):
112
+ if not SPARK_AVAILABLE: raise RuntimeError("Spark not available")
113
+ spark = SparkSession.builder.appName("RAG-ETL").getOrCreate()
114
+ import json as _j
115
+ rows=[{"id":d.id or f"doc-{i}","text":d.text,"meta_json":_j.dumps(d.meta)} for i,d in enumerate(docs)]
116
+ df=spark.createDataFrame(rows).withColumn("text",F.regexp_replace(F.col("text"),r"\s+"," ")).withColumn("text",F.trim(F.col("text"))).filter(F.length("text")>0)
117
+ sz,ov=int(size),int(overlap);
118
+ if ov>=sz: ov=max(sz-1,0)
119
+ @F.udf(returnType=StringType())
120
+ def chunk_udf(text: str, pid: str, meta_json: str) -> str:
121
+ t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[]
122
+ while s<n:
123
+ e=min(s+sz,n); cid=f"{pid}::offset:{s}-{e}"; m=dict(base); m.update({"parent_id":pid,"start":s,"end":e})
124
+ out.append({"id":cid,"text":t[s:e],"meta":m});
125
+ if e==n: break
126
+ s=max(e-ov,0)
127
+ return _j.dumps(out)
128
+ df=df.withColumn("chunks_json",chunk_udf(F.col("text"),F.col("id"),F.col("meta_json")))
129
+ exploded=df.select(F.explode(F.from_json("chunks_json","array<map<string,string>>")).alias("c"))
130
+ out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect()
131
+ import json as _j2
132
+ return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out]
133
+
134
+ # --------- Vector index ----------
135
+ class VectorIndex:
136
+ def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
137
+ self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M
138
+ if index_type=="flat":
139
+ self.index = faiss.IndexFlatIP(dim)
140
+ elif index_type=="ivf":
141
+ quant = faiss.IndexFlatIP(dim)
142
+ self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT)
143
+ elif index_type=="hnsw":
144
+ self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2"
145
+ else:
146
+ raise ValueError("bad index_type")
147
+ def train(self, vecs: np.ndarray):
148
+ if hasattr(self.index,"is_trained") and not self.index.is_trained:
149
+ self.index.train(vecs)
150
+ def add(self, vecs: np.ndarray):
151
+ self.train(vecs); self.index.add(vecs)
152
+ def search(self, qvec: np.ndarray, k: int):
153
+ D,I = self.index.search(qvec,k)
154
+ scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist()
155
+ return I[0].tolist(), scores
156
+ def save(self, fp: Path): faiss.write_index(self.index, str(fp))
157
+ @staticmethod
158
+ def load(fp: Path) -> "VectorIndex":
159
+ idx = faiss.read_index(str(fp))
160
+ vi = VectorIndex(idx.d, "flat"); vi.index = idx
161
+ vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2"
162
+ return vi
163
+
164
+ # --------- Store ----------
165
+ class MemoryIndex:
166
+ def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
167
+ self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[]
168
+ self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M)
169
+ def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]):
170
+ if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count")
171
+ self.vindex.add(vecs)
172
+ for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"])
173
+ def size(self)->int: return self.vindex.index.ntotal
174
+ def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k)
175
+ def save(self):
176
+ self.vindex.save(INDEX_FP)
177
+ with META_FP.open("w",encoding="utf-8") as f:
178
+ for i in range(len(self.ids)):
179
+ f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n")
180
+ try:
181
+ df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]})
182
+ df.to_parquet(PARQ_FP, index=False)
183
+ except Exception:
184
+ pass
185
+ CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8")
186
+ @staticmethod
187
+ def load_if_exists() -> Optional["MemoryIndex"]:
188
+ if not INDEX_FP.exists() or not META_FP.exists(): return None
189
+ cfg={"index_type":"flat","nlist":64,"M":32}
190
+ if CFG_FP.exists():
191
+ try: cfg.update(json.loads(CFG_FP.read_text()))
192
+ except Exception: pass
193
+ vi = VectorIndex.load(INDEX_FP)
194
+ store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32))
195
+ store.vindex = vi
196
+ ids,texts,metas=[],[],[]
197
+ with META_FP.open("r",encoding="utf-8") as f:
198
+ for line in f:
199
+ rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{}))
200
+ store.ids,store.texts,store.metas=ids,texts,metas
201
+ return store
202
+ @staticmethod
203
+ def reset_files():
204
+ for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]:
205
+ try:
206
+ if p.exists(): p.unlink()
207
+ except Exception:
208
+ pass
209
+
210
+ _mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists()
211
+ def require_store() -> MemoryIndex:
212
+ if _mem_store is None or _mem_store.size()==0:
213
+ raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.")
214
+ return _mem_store
215
+
216
+ # --------- Helpers ----------
217
+ def _token_overlap(q: str, txt: str) -> float:
218
+ qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t}
219
+ if not qt: return 0.0
220
+ return float(len(qt & tt))/float(len(qt))
221
+ def _topk(q: str, k: int) -> List[Match]:
222
+ store=require_store(); qvec=_embedder.encode([q], normalize=True)
223
+ idxs,scores=store.search(qvec,k); out=[]
224
+ for i,s in zip(idxs,scores):
225
+ if i==-1: continue
226
+ out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i]))
227
+ return out
228
+ def _compose_contexts(matches: List[Match], max_chars: int) -> str:
229
+ buf=[]; total=0
230
+ for m in matches:
231
+ t=m.text or ""; cut=min(len(t), max_chars-total)
232
+ if cut<=0: break
233
+ buf.append(t[:cut]); total+=cut
234
+ if total>=max_chars: break
235
+ return "\n\n".join(buf).strip()
236
+ def _answer_with_mock(q: str, contexts: str) -> str:
237
+ if not contexts: return "No indexed context available to answer the question."
238
+ lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()]
239
+ hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())]
240
+ if not hits: hits=lines[:2]
241
+ return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4])
242
+ def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]:
243
+ if not enabled: return matches
244
+ try:
245
+ _reranker.ensure(model_name)
246
+ scores=_reranker.score(q, [m.text or "" for m in matches])
247
+ order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True)
248
+ return [matches[i] for i in order]
249
+ except Exception:
250
+ return matches
251
+ def _write_parquet_if_missing():
252
+ if not PARQ_FP.exists() and META_FP.exists():
253
+ try:
254
+ rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")]
255
+ if rows:
256
+ pd.DataFrame({"id":[r["id"] for r in rows],
257
+ "text":[r["text"] for r in rows],
258
+ "meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False)
259
+ except Exception:
260
+ pass
261
+
262
+ # --------- Auth/limits/metrics ----------
263
+ API_KEY = os.getenv("API_KEY","")
264
+ _rate = {"capacity":60,"refill_per_sec":1.0}
265
+ _buckets: Dict[str, Dict[str, float]] = {}
266
+ _metrics = {"requests":0,"by_endpoint":{}, "started": time.time()}
267
+
268
+ def _allow(ip: str) -> bool:
269
+ now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now})
270
+ tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"])
271
+ if tokens<1.0:
272
+ _buckets[ip]={"tokens":tokens,"ts":now}; return False
273
+ _buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True
274
+ async def guard(request: Request):
275
+ if API_KEY and request.headers.get("x-api-key","")!=API_KEY:
276
+ raise HTTPException(status_code=401, detail="invalid api key")
277
+ ip=request.client.host if request.client else "local"
278
+ if not _allow(ip):
279
+ raise HTTPException(status_code=429, detail="rate limited")
280
+
281
+ app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13")
282
+
283
+ @app.middleware("http")
284
+ async def req_meta(request: Request, call_next):
285
+ rid=str(uuid.uuid4()); _metrics["requests"]+=1
286
+ ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1
287
+ resp=await call_next(request)
288
+ try: resp.headers["x-request-id"]=rid
289
+ except Exception: pass
290
+ return resp
291
+
292
+ # --------- API ----------
293
+ @app.get("/", response_class=HTMLResponse)
294
+ def root():
295
+ return """<!doctype html><html><head><meta charset="utf-8"><title>RAG-as-a-Service</title></head>
296
+ <body style="font-family:system-ui;margin:2rem;max-width:900px">
297
+ <h2>RAG-as-a-Service</h2>
298
+ <input id="q" style="width:70%" placeholder="Ask a question"><button onclick="ask()">Ask</button>
299
+ <pre id="out" style="background:#111;color:#eee;padding:1rem;border-radius:8px;white-space:pre-wrap"></pre>
300
+ <script>
301
+ async function ask(){
302
+ const q=document.getElementById('q').value;
303
+ const res=await fetch('/answer',{method:'POST',headers:{'content-type':'application/json'},body:JSON.stringify({q, k:5, return_contexts:true})});
304
+ document.getElementById('out').textContent=JSON.stringify(await res.json(),null,2);
305
+ }
306
+ </script></body></html>"""
307
+
308
+ @app.get("/health", response_model=HealthResponse)
309
+ def health() -> HealthResponse:
310
+ size=_mem_store.size() if _mem_store is not None else 0
311
+ persisted=INDEX_FP.exists() and META_FP.exists()
312
+ idx_type="flat"
313
+ if CFG_FP.exists():
314
+ try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat")
315
+ except Exception: pass
316
+ return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type)
317
+
318
+ @app.get("/metrics")
319
+ def metrics():
320
+ up=time.time()-_metrics["started"]
321
+ return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)}
322
+
323
+ @app.post("/echo", dependencies=[Depends(guard)])
324
+ def echo(payload: EchoRequest) -> Dict[str, str]:
325
+ return {"echo": payload.message, "length": str(len(payload.message))}
326
+
327
+ @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(guard)])
328
+ def embed(payload: EmbedRequest) -> EmbedResponse:
329
+ vecs=_embedder.encode(payload.texts, normalize=payload.normalize)
330
+ preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else []
331
+ return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview)
332
+
333
+ @app.post("/ingest", dependencies=[Depends(guard)])
334
+ def ingest(req: IngestRequest) -> Dict[str, Any]:
335
+ global _mem_store
336
+ if _mem_store is None:
337
+ cfg={"index_type":"flat","nlist":64,"M":32}
338
+ if CFG_FP.exists():
339
+ try: cfg.update(json.loads(CFG_FP.read_text()))
340
+ except Exception: pass
341
+ _mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"])
342
+ use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark)
343
+ rows=[]
344
+ if use_spark:
345
+ try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap)
346
+ except Exception: rows=[]
347
+ if not rows:
348
+ for d in req.docs:
349
+ pid=d.id or "doc"
350
+ for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap):
351
+ meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e})
352
+ rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta})
353
+ if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced")
354
+ vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize)
355
+ _mem_store.add(vecs, rows); _mem_store.save();
356
+ if not PARQ_FP.exists():
357
+ try:
358
+ pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False)
359
+ except Exception: pass
360
+ return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True}
361
+
362
+ @app.post("/query", response_model=QueryResponse, dependencies=[Depends(guard)])
363
+ def query(req: QueryRequest) -> QueryResponse:
364
+ matches=_topk(req.q, req.k)
365
+ if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches]
366
+ return QueryResponse(matches=matches)
367
+
368
+ @app.post("/explain", response_model=ExplainResponse, dependencies=[Depends(guard)])
369
+ def explain(req: ExplainRequest) -> ExplainResponse:
370
+ matches=_topk(req.q, req.k); out=[]
371
+ for m in matches:
372
+ meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0))
373
+ out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4))))
374
+ return ExplainResponse(matches=out)
375
+
376
+ @app.post("/answer", response_model=AnswerResponse, dependencies=[Depends(guard)])
377
+ def answer(req: AnswerRequest) -> AnswerResponse:
378
+ matches=_topk(req.q, req.k)
379
+ matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model)
380
+ ctx=_compose_contexts(matches, req.max_context_chars)
381
+ out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx)
382
+ return AnswerResponse(answer=out, contexts=matches if req.return_contexts else [])
383
+
384
+ @app.post("/reindex", dependencies=[Depends(guard)])
385
+ def reindex(params: ReindexParams) -> Dict[str, Any]:
386
+ global _mem_store
387
+ if not META_FP.exists():
388
+ raise HTTPException(status_code=400, detail="no metadata on disk")
389
+
390
+ rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")]
391
+ if not rows:
392
+ raise HTTPException(status_code=400, detail="empty metadata")
393
+
394
+ texts = [r["text"] for r in rows]
395
+ vecs = _embedder.encode(texts, normalize=True)
396
+
397
+ # Cap nlist to dataset size for IVF
398
+ idx_type = params.index_type
399
+ eff_nlist = params.nlist
400
+ if idx_type == "ivf":
401
+ eff_nlist = max(1, min(eff_nlist, len(rows)))
402
+
403
+ try:
404
+ _mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M)
405
+ _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
406
+ _mem_store.save()
407
+ return {
408
+ "reindexed": True,
409
+ "index_type": idx_type,
410
+ "index_size": _mem_store.size(),
411
+ "nlist": eff_nlist,
412
+ "M": params.M
413
+ }
414
+ except Exception as e:
415
+ # Fallback to flat if IVF/HNSW training/add fails for any reason
416
+ _mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat")
417
+ _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
418
+ _mem_store.save()
419
+ return {
420
+ "reindexed": True,
421
+ "index_type": "flat",
422
+ "index_size": _mem_store.size(),
423
+ "note": f"fallback due to: {str(e)[:120]}"
424
+ }
425
+ @app.post("/reset", dependencies=[Depends(guard)])
426
+ def reset() -> Dict[str, Any]:
427
+ global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True}
428
+
429
+ @app.post("/bulk_load_hf", dependencies=[Depends(guard)])
430
+ def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120):
431
+ try:
432
+ from datasets import load_dataset
433
+ ds = load_dataset(repo, split=split)
434
+ docs=[]
435
+ for rec in ds:
436
+ rid = str(rec[id_field]) if id_field and id_field in rec else None
437
+ meta = {k: rec[k] for k in (meta_fields or []) if k in rec}
438
+ docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta))
439
+ return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True))
440
+ except Exception as e:
441
+ raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}")