Spaces:
Running
Running
Upload 2 files
Browse files
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import faiss
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
|
6 |
+
# --- minimal core (in-memory only) ---
|
7 |
+
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
8 |
+
_model = SentenceTransformer(MODEL_NAME)
|
9 |
+
_dim = int(_model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # 384
|
10 |
+
|
11 |
+
_index = faiss.IndexFlatIP(_dim) # cosine via L2-normalized IP
|
12 |
+
_ids, _texts, _metas = [], [], []
|
13 |
+
|
14 |
+
def _normalize(v: np.ndarray) -> np.ndarray:
|
15 |
+
n = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12
|
16 |
+
return (v / n).astype("float32")
|
17 |
+
|
18 |
+
def _chunk(text: str, size: int, overlap: int):
|
19 |
+
t = " ".join((text or "").split())
|
20 |
+
n = len(t); s = 0; out = []
|
21 |
+
if overlap >= size: overlap = max(size - 1, 0)
|
22 |
+
while s < n:
|
23 |
+
e = min(s + size, n)
|
24 |
+
out.append((t[s:e], s, e))
|
25 |
+
if e == n: break
|
26 |
+
s = max(e - overlap, 0)
|
27 |
+
return out
|
28 |
+
|
29 |
+
def reset():
|
30 |
+
global _index, _ids, _texts, _metas
|
31 |
+
_index = faiss.IndexFlatIP(_dim)
|
32 |
+
_ids, _texts, _metas = [], [], []
|
33 |
+
return gr.update(value="Index reset."), gr.update(value=0)
|
34 |
+
|
35 |
+
def load_sample():
|
36 |
+
docs = [
|
37 |
+
("a", "PySpark scales ETL across clusters.", {"tag":"spark"}),
|
38 |
+
("b", "FAISS powers fast vector similarity search used in retrieval.", {"tag":"faiss"})
|
39 |
+
]
|
40 |
+
return "\n".join([d[1] for d in docs])
|
41 |
+
|
42 |
+
def ingest(docs_text, size, overlap):
|
43 |
+
if not docs_text.strip():
|
44 |
+
return "Provide at least one line of text.", len(_ids)
|
45 |
+
# one document per line
|
46 |
+
lines = [ln.strip() for ln in docs_text.splitlines() if ln.strip()]
|
47 |
+
rows = []
|
48 |
+
for i, ln in enumerate(lines):
|
49 |
+
pid = f"doc-{len(_ids)}-{i}"
|
50 |
+
for ctext, s, e in _chunk(ln, size, overlap):
|
51 |
+
rows.append((f"{pid}::offset:{s}-{e}", ctext, {"parent_id": pid, "start": s, "end": e}))
|
52 |
+
if not rows:
|
53 |
+
return "No chunks produced.", len(_ids)
|
54 |
+
vecs = _normalize(_model.encode([r[1] for r in rows], convert_to_numpy=True))
|
55 |
+
_index.add(vecs)
|
56 |
+
for rid, txt, meta in rows:
|
57 |
+
_ids.append(rid); _texts.append(txt); _metas.append(meta)
|
58 |
+
return f"Ingested docs={len(lines)} chunks={len(rows)}", len(_ids)
|
59 |
+
|
60 |
+
def answer(q, k, max_context_chars):
|
61 |
+
if _index.ntotal == 0:
|
62 |
+
return {"answer": "Index is empty. Ingest first.", "matches": []}
|
63 |
+
qv = _normalize(_model.encode([q], convert_to_numpy=True))
|
64 |
+
D, I = _index.search(qv, int(k))
|
65 |
+
matches = []
|
66 |
+
for i, s in zip(I[0].tolist(), D[0].tolist()):
|
67 |
+
if i < 0: continue
|
68 |
+
matches.append({
|
69 |
+
"id": _ids[i],
|
70 |
+
"score": float(s),
|
71 |
+
"text": _texts[i],
|
72 |
+
"meta": _metas[i]
|
73 |
+
})
|
74 |
+
# compose simple answer from contexts
|
75 |
+
blob, total = [], 0
|
76 |
+
for m in matches:
|
77 |
+
t = m["text"]; cut = min(len(t), max_context_chars - total)
|
78 |
+
if cut <= 0: break
|
79 |
+
blob.append(t[:cut]); total += cut
|
80 |
+
if total >= max_context_chars: break
|
81 |
+
if not blob:
|
82 |
+
out = "No relevant context."
|
83 |
+
else:
|
84 |
+
lines = [ln for ln in " ".join(blob).split(". ") if ln]
|
85 |
+
hits = [ln for ln in lines if any(tok in ln.lower() for tok in q.lower().split())] or lines[:2]
|
86 |
+
out = "Based on retrieved context:\n- " + "\n- ".join(hits[:4])
|
87 |
+
return {"answer": out, "matches": matches}
|
88 |
+
|
89 |
+
with gr.Blocks(title="RAG-as-a-Service") as demo:
|
90 |
+
gr.Markdown("### RAG-as-a-Service - Gradio\nIn-memory FAISS + MiniLM\n; one-line-per-doc ingest\n; quick answers.")
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
with gr.Column():
|
94 |
+
docs = gr.Textbox(label="Documents (one per line)", lines=6, placeholder="One document per line…")
|
95 |
+
with gr.Row():
|
96 |
+
chunk_size = gr.Slider(64, 1024, value=256, step=16, label="Chunk size")
|
97 |
+
overlap = gr.Slider(0, 256, value=32, step=8, label="Overlap")
|
98 |
+
with gr.Row():
|
99 |
+
ingest_btn = gr.Button("Ingest")
|
100 |
+
sample_btn = gr.Button("Load sample")
|
101 |
+
reset_btn = gr.Button("Reset")
|
102 |
+
ingest_status = gr.Textbox(label="Ingest status", interactive=False)
|
103 |
+
index_size = gr.Number(label="Index size", interactive=False, value=0)
|
104 |
+
with gr.Column():
|
105 |
+
q = gr.Textbox(label="Query", placeholder="Ask something...")
|
106 |
+
k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
|
107 |
+
max_chars = gr.Slider(200, 4000, value=1000, step=100, label="Max context chars")
|
108 |
+
run = gr.Button("Answer")
|
109 |
+
out = gr.JSON(label="Answer + matches")
|
110 |
+
|
111 |
+
ingest_btn.click(
|
112 |
+
ingest,
|
113 |
+
[docs, chunk_size, overlap],
|
114 |
+
[ingest_status, index_size],
|
115 |
+
api_name="ingest" # exposes POST /api/ingest
|
116 |
+
)
|
117 |
+
sample_btn.click(load_sample, None, docs)
|
118 |
+
reset_btn.click(
|
119 |
+
reset,
|
120 |
+
None,
|
121 |
+
[ingest_status, index_size],
|
122 |
+
api_name="reset" # exposes POST /api/reset (optional)
|
123 |
+
)
|
124 |
+
run.click(
|
125 |
+
answer,
|
126 |
+
[q, k, max_chars],
|
127 |
+
out,
|
128 |
+
api_name="answer" # exposes POST /api/answer
|
129 |
+
)
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
demo.launch(share=True)
|
main.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional, Tuple
|
2 |
+
from fastapi import FastAPI, HTTPException, Request, Depends
|
3 |
+
from fastapi.responses import HTMLResponse
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from pathlib import Path
|
6 |
+
import numpy as np, json, os, time, uuid, pandas as pd
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
import faiss
|
9 |
+
|
10 |
+
# optional engines
|
11 |
+
try:
|
12 |
+
from pyspark.sql import SparkSession, functions as F
|
13 |
+
from pyspark.sql.types import StringType
|
14 |
+
SPARK_AVAILABLE = True
|
15 |
+
except Exception:
|
16 |
+
SPARK_AVAILABLE = False
|
17 |
+
try:
|
18 |
+
from sentence_transformers import CrossEncoder
|
19 |
+
RERANK_AVAILABLE = True
|
20 |
+
except Exception:
|
21 |
+
RERANK_AVAILABLE = False
|
22 |
+
|
23 |
+
APP_VERSION = "1.0.0"
|
24 |
+
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
25 |
+
|
26 |
+
DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True)
|
27 |
+
INDEX_FP = DATA_DIR / "index.faiss"
|
28 |
+
META_FP = DATA_DIR / "meta.jsonl"
|
29 |
+
PARQ_FP = DATA_DIR / "meta.parquet"
|
30 |
+
CFG_FP = DATA_DIR / "store.json"
|
31 |
+
|
32 |
+
# --------- Schemas ----------
|
33 |
+
class EchoRequest(BaseModel):
|
34 |
+
message: str
|
35 |
+
class HealthResponse(BaseModel):
|
36 |
+
status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False
|
37 |
+
persisted: bool = False; rerank: bool = False; index_type: str = "flat"
|
38 |
+
class EmbedRequest(BaseModel):
|
39 |
+
texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True
|
40 |
+
class EmbedResponse(BaseModel):
|
41 |
+
dim: int; count: int; preview: List[List[float]]
|
42 |
+
class Doc(BaseModel):
|
43 |
+
id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict)
|
44 |
+
class ChunkConfig(BaseModel):
|
45 |
+
size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0)
|
46 |
+
class IngestRequest(BaseModel):
|
47 |
+
docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None
|
48 |
+
class Match(BaseModel):
|
49 |
+
id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict)
|
50 |
+
class QueryRequest(BaseModel):
|
51 |
+
q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True
|
52 |
+
class QueryResponse(BaseModel):
|
53 |
+
matches: List[Match]
|
54 |
+
class ExplainMatch(Match):
|
55 |
+
start: int; end: int; token_overlap: float
|
56 |
+
class ExplainRequest(QueryRequest): pass
|
57 |
+
class ExplainResponse(BaseModel):
|
58 |
+
matches: List[ExplainMatch]
|
59 |
+
class AnswerRequest(BaseModel):
|
60 |
+
q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock")
|
61 |
+
max_context_chars: int = Field(default=1600, ge=200, le=20000)
|
62 |
+
return_contexts: bool = True; rerank: bool = False
|
63 |
+
rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
64 |
+
class AnswerResponse(BaseModel):
|
65 |
+
answer: str; contexts: List[Match] = []
|
66 |
+
class ReindexParams(BaseModel):
|
67 |
+
index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$")
|
68 |
+
nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128)
|
69 |
+
|
70 |
+
# --------- Embeddings ----------
|
71 |
+
class LazyEmbedder:
|
72 |
+
def __init__(self, model_name: str = EMBED_MODEL_NAME):
|
73 |
+
self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None
|
74 |
+
def _ensure(self):
|
75 |
+
if self._model is None:
|
76 |
+
self._model = SentenceTransformer(self.model_name)
|
77 |
+
self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore
|
78 |
+
@property
|
79 |
+
def dim(self) -> int:
|
80 |
+
self._ensure(); return int(self._dim) # type: ignore
|
81 |
+
def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray:
|
82 |
+
self._ensure()
|
83 |
+
vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore
|
84 |
+
if normalize:
|
85 |
+
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
|
86 |
+
vecs = vecs / norms
|
87 |
+
return vecs.astype("float32")
|
88 |
+
_embedder = LazyEmbedder()
|
89 |
+
|
90 |
+
# --------- Reranker ----------
|
91 |
+
class LazyReranker:
|
92 |
+
def __init__(self): self._model=None; self._name=None
|
93 |
+
def ensure(self, name: str):
|
94 |
+
if not RERANK_AVAILABLE: return
|
95 |
+
if self._model is None or self._name != name:
|
96 |
+
self._model = CrossEncoder(name); self._name = name
|
97 |
+
def score(self, q: str, texts: List[str]) -> List[float]:
|
98 |
+
if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts)
|
99 |
+
return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore
|
100 |
+
_reranker = LazyReranker()
|
101 |
+
|
102 |
+
# --------- Chunking ----------
|
103 |
+
def chunk_text_py(text: str, size: int, overlap: int):
|
104 |
+
t = " ".join((text or "").split()); n=len(t); out=[]; s=0
|
105 |
+
if overlap >= size: overlap = max(size - 1, 0)
|
106 |
+
while s<n:
|
107 |
+
e=min(s+size,n); out.append((t[s:e],(s,e)))
|
108 |
+
if e==n: break
|
109 |
+
s=max(e-overlap,0)
|
110 |
+
return out
|
111 |
+
def spark_clean_and_chunk(docs: List[Doc], size: int, overlap: int):
|
112 |
+
if not SPARK_AVAILABLE: raise RuntimeError("Spark not available")
|
113 |
+
spark = SparkSession.builder.appName("RAG-ETL").getOrCreate()
|
114 |
+
import json as _j
|
115 |
+
rows=[{"id":d.id or f"doc-{i}","text":d.text,"meta_json":_j.dumps(d.meta)} for i,d in enumerate(docs)]
|
116 |
+
df=spark.createDataFrame(rows).withColumn("text",F.regexp_replace(F.col("text"),r"\s+"," ")).withColumn("text",F.trim(F.col("text"))).filter(F.length("text")>0)
|
117 |
+
sz,ov=int(size),int(overlap);
|
118 |
+
if ov>=sz: ov=max(sz-1,0)
|
119 |
+
@F.udf(returnType=StringType())
|
120 |
+
def chunk_udf(text: str, pid: str, meta_json: str) -> str:
|
121 |
+
t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[]
|
122 |
+
while s<n:
|
123 |
+
e=min(s+sz,n); cid=f"{pid}::offset:{s}-{e}"; m=dict(base); m.update({"parent_id":pid,"start":s,"end":e})
|
124 |
+
out.append({"id":cid,"text":t[s:e],"meta":m});
|
125 |
+
if e==n: break
|
126 |
+
s=max(e-ov,0)
|
127 |
+
return _j.dumps(out)
|
128 |
+
df=df.withColumn("chunks_json",chunk_udf(F.col("text"),F.col("id"),F.col("meta_json")))
|
129 |
+
exploded=df.select(F.explode(F.from_json("chunks_json","array<map<string,string>>")).alias("c"))
|
130 |
+
out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect()
|
131 |
+
import json as _j2
|
132 |
+
return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out]
|
133 |
+
|
134 |
+
# --------- Vector index ----------
|
135 |
+
class VectorIndex:
|
136 |
+
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
|
137 |
+
self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M
|
138 |
+
if index_type=="flat":
|
139 |
+
self.index = faiss.IndexFlatIP(dim)
|
140 |
+
elif index_type=="ivf":
|
141 |
+
quant = faiss.IndexFlatIP(dim)
|
142 |
+
self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT)
|
143 |
+
elif index_type=="hnsw":
|
144 |
+
self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2"
|
145 |
+
else:
|
146 |
+
raise ValueError("bad index_type")
|
147 |
+
def train(self, vecs: np.ndarray):
|
148 |
+
if hasattr(self.index,"is_trained") and not self.index.is_trained:
|
149 |
+
self.index.train(vecs)
|
150 |
+
def add(self, vecs: np.ndarray):
|
151 |
+
self.train(vecs); self.index.add(vecs)
|
152 |
+
def search(self, qvec: np.ndarray, k: int):
|
153 |
+
D,I = self.index.search(qvec,k)
|
154 |
+
scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist()
|
155 |
+
return I[0].tolist(), scores
|
156 |
+
def save(self, fp: Path): faiss.write_index(self.index, str(fp))
|
157 |
+
@staticmethod
|
158 |
+
def load(fp: Path) -> "VectorIndex":
|
159 |
+
idx = faiss.read_index(str(fp))
|
160 |
+
vi = VectorIndex(idx.d, "flat"); vi.index = idx
|
161 |
+
vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2"
|
162 |
+
return vi
|
163 |
+
|
164 |
+
# --------- Store ----------
|
165 |
+
class MemoryIndex:
|
166 |
+
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
|
167 |
+
self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[]
|
168 |
+
self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M)
|
169 |
+
def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]):
|
170 |
+
if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count")
|
171 |
+
self.vindex.add(vecs)
|
172 |
+
for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"])
|
173 |
+
def size(self)->int: return self.vindex.index.ntotal
|
174 |
+
def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k)
|
175 |
+
def save(self):
|
176 |
+
self.vindex.save(INDEX_FP)
|
177 |
+
with META_FP.open("w",encoding="utf-8") as f:
|
178 |
+
for i in range(len(self.ids)):
|
179 |
+
f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n")
|
180 |
+
try:
|
181 |
+
df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]})
|
182 |
+
df.to_parquet(PARQ_FP, index=False)
|
183 |
+
except Exception:
|
184 |
+
pass
|
185 |
+
CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8")
|
186 |
+
@staticmethod
|
187 |
+
def load_if_exists() -> Optional["MemoryIndex"]:
|
188 |
+
if not INDEX_FP.exists() or not META_FP.exists(): return None
|
189 |
+
cfg={"index_type":"flat","nlist":64,"M":32}
|
190 |
+
if CFG_FP.exists():
|
191 |
+
try: cfg.update(json.loads(CFG_FP.read_text()))
|
192 |
+
except Exception: pass
|
193 |
+
vi = VectorIndex.load(INDEX_FP)
|
194 |
+
store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32))
|
195 |
+
store.vindex = vi
|
196 |
+
ids,texts,metas=[],[],[]
|
197 |
+
with META_FP.open("r",encoding="utf-8") as f:
|
198 |
+
for line in f:
|
199 |
+
rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{}))
|
200 |
+
store.ids,store.texts,store.metas=ids,texts,metas
|
201 |
+
return store
|
202 |
+
@staticmethod
|
203 |
+
def reset_files():
|
204 |
+
for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]:
|
205 |
+
try:
|
206 |
+
if p.exists(): p.unlink()
|
207 |
+
except Exception:
|
208 |
+
pass
|
209 |
+
|
210 |
+
_mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists()
|
211 |
+
def require_store() -> MemoryIndex:
|
212 |
+
if _mem_store is None or _mem_store.size()==0:
|
213 |
+
raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.")
|
214 |
+
return _mem_store
|
215 |
+
|
216 |
+
# --------- Helpers ----------
|
217 |
+
def _token_overlap(q: str, txt: str) -> float:
|
218 |
+
qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t}
|
219 |
+
if not qt: return 0.0
|
220 |
+
return float(len(qt & tt))/float(len(qt))
|
221 |
+
def _topk(q: str, k: int) -> List[Match]:
|
222 |
+
store=require_store(); qvec=_embedder.encode([q], normalize=True)
|
223 |
+
idxs,scores=store.search(qvec,k); out=[]
|
224 |
+
for i,s in zip(idxs,scores):
|
225 |
+
if i==-1: continue
|
226 |
+
out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i]))
|
227 |
+
return out
|
228 |
+
def _compose_contexts(matches: List[Match], max_chars: int) -> str:
|
229 |
+
buf=[]; total=0
|
230 |
+
for m in matches:
|
231 |
+
t=m.text or ""; cut=min(len(t), max_chars-total)
|
232 |
+
if cut<=0: break
|
233 |
+
buf.append(t[:cut]); total+=cut
|
234 |
+
if total>=max_chars: break
|
235 |
+
return "\n\n".join(buf).strip()
|
236 |
+
def _answer_with_mock(q: str, contexts: str) -> str:
|
237 |
+
if not contexts: return "No indexed context available to answer the question."
|
238 |
+
lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()]
|
239 |
+
hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())]
|
240 |
+
if not hits: hits=lines[:2]
|
241 |
+
return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4])
|
242 |
+
def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]:
|
243 |
+
if not enabled: return matches
|
244 |
+
try:
|
245 |
+
_reranker.ensure(model_name)
|
246 |
+
scores=_reranker.score(q, [m.text or "" for m in matches])
|
247 |
+
order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True)
|
248 |
+
return [matches[i] for i in order]
|
249 |
+
except Exception:
|
250 |
+
return matches
|
251 |
+
def _write_parquet_if_missing():
|
252 |
+
if not PARQ_FP.exists() and META_FP.exists():
|
253 |
+
try:
|
254 |
+
rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")]
|
255 |
+
if rows:
|
256 |
+
pd.DataFrame({"id":[r["id"] for r in rows],
|
257 |
+
"text":[r["text"] for r in rows],
|
258 |
+
"meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False)
|
259 |
+
except Exception:
|
260 |
+
pass
|
261 |
+
|
262 |
+
# --------- Auth/limits/metrics ----------
|
263 |
+
API_KEY = os.getenv("API_KEY","")
|
264 |
+
_rate = {"capacity":60,"refill_per_sec":1.0}
|
265 |
+
_buckets: Dict[str, Dict[str, float]] = {}
|
266 |
+
_metrics = {"requests":0,"by_endpoint":{}, "started": time.time()}
|
267 |
+
|
268 |
+
def _allow(ip: str) -> bool:
|
269 |
+
now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now})
|
270 |
+
tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"])
|
271 |
+
if tokens<1.0:
|
272 |
+
_buckets[ip]={"tokens":tokens,"ts":now}; return False
|
273 |
+
_buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True
|
274 |
+
async def guard(request: Request):
|
275 |
+
if API_KEY and request.headers.get("x-api-key","")!=API_KEY:
|
276 |
+
raise HTTPException(status_code=401, detail="invalid api key")
|
277 |
+
ip=request.client.host if request.client else "local"
|
278 |
+
if not _allow(ip):
|
279 |
+
raise HTTPException(status_code=429, detail="rate limited")
|
280 |
+
|
281 |
+
app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13")
|
282 |
+
|
283 |
+
@app.middleware("http")
|
284 |
+
async def req_meta(request: Request, call_next):
|
285 |
+
rid=str(uuid.uuid4()); _metrics["requests"]+=1
|
286 |
+
ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1
|
287 |
+
resp=await call_next(request)
|
288 |
+
try: resp.headers["x-request-id"]=rid
|
289 |
+
except Exception: pass
|
290 |
+
return resp
|
291 |
+
|
292 |
+
# --------- API ----------
|
293 |
+
@app.get("/", response_class=HTMLResponse)
|
294 |
+
def root():
|
295 |
+
return """<!doctype html><html><head><meta charset="utf-8"><title>RAG-as-a-Service</title></head>
|
296 |
+
<body style="font-family:system-ui;margin:2rem;max-width:900px">
|
297 |
+
<h2>RAG-as-a-Service</h2>
|
298 |
+
<input id="q" style="width:70%" placeholder="Ask a question"><button onclick="ask()">Ask</button>
|
299 |
+
<pre id="out" style="background:#111;color:#eee;padding:1rem;border-radius:8px;white-space:pre-wrap"></pre>
|
300 |
+
<script>
|
301 |
+
async function ask(){
|
302 |
+
const q=document.getElementById('q').value;
|
303 |
+
const res=await fetch('/answer',{method:'POST',headers:{'content-type':'application/json'},body:JSON.stringify({q, k:5, return_contexts:true})});
|
304 |
+
document.getElementById('out').textContent=JSON.stringify(await res.json(),null,2);
|
305 |
+
}
|
306 |
+
</script></body></html>"""
|
307 |
+
|
308 |
+
@app.get("/health", response_model=HealthResponse)
|
309 |
+
def health() -> HealthResponse:
|
310 |
+
size=_mem_store.size() if _mem_store is not None else 0
|
311 |
+
persisted=INDEX_FP.exists() and META_FP.exists()
|
312 |
+
idx_type="flat"
|
313 |
+
if CFG_FP.exists():
|
314 |
+
try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat")
|
315 |
+
except Exception: pass
|
316 |
+
return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type)
|
317 |
+
|
318 |
+
@app.get("/metrics")
|
319 |
+
def metrics():
|
320 |
+
up=time.time()-_metrics["started"]
|
321 |
+
return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)}
|
322 |
+
|
323 |
+
@app.post("/echo", dependencies=[Depends(guard)])
|
324 |
+
def echo(payload: EchoRequest) -> Dict[str, str]:
|
325 |
+
return {"echo": payload.message, "length": str(len(payload.message))}
|
326 |
+
|
327 |
+
@app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(guard)])
|
328 |
+
def embed(payload: EmbedRequest) -> EmbedResponse:
|
329 |
+
vecs=_embedder.encode(payload.texts, normalize=payload.normalize)
|
330 |
+
preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else []
|
331 |
+
return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview)
|
332 |
+
|
333 |
+
@app.post("/ingest", dependencies=[Depends(guard)])
|
334 |
+
def ingest(req: IngestRequest) -> Dict[str, Any]:
|
335 |
+
global _mem_store
|
336 |
+
if _mem_store is None:
|
337 |
+
cfg={"index_type":"flat","nlist":64,"M":32}
|
338 |
+
if CFG_FP.exists():
|
339 |
+
try: cfg.update(json.loads(CFG_FP.read_text()))
|
340 |
+
except Exception: pass
|
341 |
+
_mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"])
|
342 |
+
use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark)
|
343 |
+
rows=[]
|
344 |
+
if use_spark:
|
345 |
+
try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap)
|
346 |
+
except Exception: rows=[]
|
347 |
+
if not rows:
|
348 |
+
for d in req.docs:
|
349 |
+
pid=d.id or "doc"
|
350 |
+
for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap):
|
351 |
+
meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e})
|
352 |
+
rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta})
|
353 |
+
if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced")
|
354 |
+
vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize)
|
355 |
+
_mem_store.add(vecs, rows); _mem_store.save();
|
356 |
+
if not PARQ_FP.exists():
|
357 |
+
try:
|
358 |
+
pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False)
|
359 |
+
except Exception: pass
|
360 |
+
return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True}
|
361 |
+
|
362 |
+
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(guard)])
|
363 |
+
def query(req: QueryRequest) -> QueryResponse:
|
364 |
+
matches=_topk(req.q, req.k)
|
365 |
+
if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches]
|
366 |
+
return QueryResponse(matches=matches)
|
367 |
+
|
368 |
+
@app.post("/explain", response_model=ExplainResponse, dependencies=[Depends(guard)])
|
369 |
+
def explain(req: ExplainRequest) -> ExplainResponse:
|
370 |
+
matches=_topk(req.q, req.k); out=[]
|
371 |
+
for m in matches:
|
372 |
+
meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0))
|
373 |
+
out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4))))
|
374 |
+
return ExplainResponse(matches=out)
|
375 |
+
|
376 |
+
@app.post("/answer", response_model=AnswerResponse, dependencies=[Depends(guard)])
|
377 |
+
def answer(req: AnswerRequest) -> AnswerResponse:
|
378 |
+
matches=_topk(req.q, req.k)
|
379 |
+
matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model)
|
380 |
+
ctx=_compose_contexts(matches, req.max_context_chars)
|
381 |
+
out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx)
|
382 |
+
return AnswerResponse(answer=out, contexts=matches if req.return_contexts else [])
|
383 |
+
|
384 |
+
@app.post("/reindex", dependencies=[Depends(guard)])
|
385 |
+
def reindex(params: ReindexParams) -> Dict[str, Any]:
|
386 |
+
global _mem_store
|
387 |
+
if not META_FP.exists():
|
388 |
+
raise HTTPException(status_code=400, detail="no metadata on disk")
|
389 |
+
|
390 |
+
rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")]
|
391 |
+
if not rows:
|
392 |
+
raise HTTPException(status_code=400, detail="empty metadata")
|
393 |
+
|
394 |
+
texts = [r["text"] for r in rows]
|
395 |
+
vecs = _embedder.encode(texts, normalize=True)
|
396 |
+
|
397 |
+
# Cap nlist to dataset size for IVF
|
398 |
+
idx_type = params.index_type
|
399 |
+
eff_nlist = params.nlist
|
400 |
+
if idx_type == "ivf":
|
401 |
+
eff_nlist = max(1, min(eff_nlist, len(rows)))
|
402 |
+
|
403 |
+
try:
|
404 |
+
_mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M)
|
405 |
+
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
|
406 |
+
_mem_store.save()
|
407 |
+
return {
|
408 |
+
"reindexed": True,
|
409 |
+
"index_type": idx_type,
|
410 |
+
"index_size": _mem_store.size(),
|
411 |
+
"nlist": eff_nlist,
|
412 |
+
"M": params.M
|
413 |
+
}
|
414 |
+
except Exception as e:
|
415 |
+
# Fallback to flat if IVF/HNSW training/add fails for any reason
|
416 |
+
_mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat")
|
417 |
+
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
|
418 |
+
_mem_store.save()
|
419 |
+
return {
|
420 |
+
"reindexed": True,
|
421 |
+
"index_type": "flat",
|
422 |
+
"index_size": _mem_store.size(),
|
423 |
+
"note": f"fallback due to: {str(e)[:120]}"
|
424 |
+
}
|
425 |
+
@app.post("/reset", dependencies=[Depends(guard)])
|
426 |
+
def reset() -> Dict[str, Any]:
|
427 |
+
global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True}
|
428 |
+
|
429 |
+
@app.post("/bulk_load_hf", dependencies=[Depends(guard)])
|
430 |
+
def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120):
|
431 |
+
try:
|
432 |
+
from datasets import load_dataset
|
433 |
+
ds = load_dataset(repo, split=split)
|
434 |
+
docs=[]
|
435 |
+
for rec in ds:
|
436 |
+
rid = str(rec[id_field]) if id_field and id_field in rec else None
|
437 |
+
meta = {k: rec[k] for k in (meta_fields or []) if k in rec}
|
438 |
+
docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta))
|
439 |
+
return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True))
|
440 |
+
except Exception as e:
|
441 |
+
raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}")
|