HemanM commited on
Commit
9ee54df
·
verified ·
1 Parent(s): 1849681

Create rag_search.py

Browse files
Files changed (1) hide show
  1. rag_search.py +106 -0
rag_search.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 4: Retrieval helper (loads FAISS + metadata and searches top-k chunks).
3
+
4
+ What this module provides:
5
+ - RAGSearcher: class that loads the FAISS index and metadata created by indexer.py
6
+ - search(query, k): returns a list of hit dicts [{score, text, meta}]
7
+ - summarize_hits(hits): tiny, extractive-style summary (placeholder for Step 5 Evo)
8
+ - format_sources(hits): collapses to a neat "Sources:" list
9
+ """
10
+
11
+ from pathlib import Path
12
+ import json
13
+ from typing import List, Dict
14
+
15
+ import faiss
16
+ import numpy as np
17
+ from sentence_transformers import SentenceTransformer
18
+
19
+ # Paths must match indexer.py
20
+ DATA_DIR = Path("data")
21
+ INDEX_PATH = DATA_DIR / "index.faiss"
22
+ META_PATH = DATA_DIR / "meta.json"
23
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
24
+
25
+
26
+ class RAGSearcher:
27
+ """
28
+ Loads the FAISS index + metadata and performs semantic search.
29
+ If files are missing, it raises a RuntimeError (the UI will catch this and show a friendly message).
30
+ """
31
+
32
+ def __init__(self):
33
+ if not INDEX_PATH.exists() or not META_PATH.exists():
34
+ raise RuntimeError(
35
+ "Index not found. Build it first with the 'Build/Refresh Index' button."
36
+ )
37
+ # Load FAISS index and metadata
38
+ self.index = faiss.read_index(str(INDEX_PATH))
39
+ self.metas: List[Dict] = json.loads(META_PATH.read_text(encoding="utf-8"))
40
+ # Load the embedding model (small + fast)
41
+ self.model = SentenceTransformer(EMBED_MODEL)
42
+
43
+ def search(self, query: str, k: int = 6) -> List[Dict]:
44
+ """
45
+ Returns top-k hits with score, text, and meta fields.
46
+ - score ~ cosine similarity (because we normalized at indexing time)
47
+ """
48
+ if not query or len(query.strip()) < 3:
49
+ return []
50
+
51
+ # Encode the query to the same space used by the index
52
+ qvec = self.model.encode(
53
+ [query], convert_to_numpy=True, normalize_embeddings=True
54
+ )
55
+ scores, idxs = self.index.search(qvec, k)
56
+
57
+ hits: List[Dict] = []
58
+ for score, idx in zip(scores[0], idxs[0]):
59
+ if idx < 0:
60
+ continue
61
+ meta = self.metas[int(idx)]
62
+ text = Path(meta["chunk_file"]).read_text(encoding="utf-8")
63
+ hits.append(
64
+ {
65
+ "score": float(score),
66
+ "text": text,
67
+ "meta": meta, # contains: file, chunk_file, chunk_id
68
+ }
69
+ )
70
+ return hits
71
+
72
+
73
+ def summarize_hits(hits: List[Dict], max_points: int = 4) -> str:
74
+ """
75
+ Very small, safe extractive "summary":
76
+ - Take the first few hits and slice the first ~350 chars of each as bullet points.
77
+ - This is a placeholder. In Step 5, we'll replace with Evo synthesis.
78
+ """
79
+ if not hits:
80
+ return "I couldn't find relevant information. Try rephrasing your question."
81
+ bullets = []
82
+ for h in hits[:max_points]:
83
+ snippet = " ".join(h["text"].strip().split())
84
+ if len(snippet) > 350:
85
+ snippet = snippet[:350] + "..."
86
+ bullets.append(f"- {snippet}")
87
+ return "\n".join(bullets)
88
+
89
+
90
+ def format_sources(hits: List[Dict], max_files: int = 5) -> str:
91
+ """
92
+ Collapses the hit list to unique source files, and returns a short bulleted list.
93
+ """
94
+ if not hits:
95
+ return "Sources: (none)"
96
+ seen = []
97
+ order = []
98
+ for h in hits:
99
+ f = h["meta"]["file"]
100
+ if f not in seen:
101
+ seen.append(f)
102
+ order.append(f)
103
+ if len(order) >= max_files:
104
+ break
105
+ bullets = [f"- `{Path(f).name}`" for f in order]
106
+ return "Sources:\n" + "\n".join(bullets)