Rong6693 commited on
Commit
7e02805
verified
1 Parent(s): 7849534

Create rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +98 -0
rag_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_utils.py
2
+ import os, json
3
+ import numpy as np
4
+ from typing import List, Tuple, Dict
5
+
6
+ import faiss
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
10
+ CACHE_DIR = "/tmp" # HF Spaces 姣忔鍟熷嫊鍙閫欒!
11
+
12
+ TAROT_JSON = os.path.join(DATA_DIR, "tarot_data_full.json")
13
+ NUM_JSON = os.path.join(DATA_DIR, "numerology_data_full.json")
14
+
15
+ TAROT_IDX = os.path.join(CACHE_DIR, "faiss_tarot.index")
16
+ TAROT_META = os.path.join(CACHE_DIR, "faiss_tarot_meta.json")
17
+ NUM_IDX = os.path.join(CACHE_DIR, "faiss_num.index")
18
+ NUM_META = os.path.join(CACHE_DIR, "faiss_num_meta.json")
19
+
20
+ # 杓曢噺濂界敤
21
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
22
+
23
+ _model = None
24
+ def get_model():
25
+ global _model
26
+ if _model is None:
27
+ _model = SentenceTransformer(EMBED_MODEL_NAME)
28
+ return _model
29
+
30
+ def _build_tarot():
31
+ with open(TAROT_JSON) as f:
32
+ data = json.load(f)
33
+
34
+ texts = [
35
+ (i, c["card_name"],
36
+ (c.get("meaning_upright","") + " " + c.get("advice","")).strip())
37
+ for i, c in enumerate(data)
38
+ ]
39
+ model = get_model()
40
+ embs = model.encode([t[2] for t in texts], normalize_embeddings=True)
41
+
42
+ idx = faiss.IndexFlatIP(embs.shape[1])
43
+ idx.add(np.array(embs, dtype="float32"))
44
+
45
+ faiss.write_index(idx, TAROT_IDX)
46
+ with open(TAROT_META, "w") as f:
47
+ json.dump([{"i":i, "card_name":name, "text":txt} for (i,name,txt) in texts], f, indent=2)
48
+
49
+ def _build_num():
50
+ with open(NUM_JSON) as f:
51
+ data = json.load(f)
52
+
53
+ texts = [
54
+ (i, n["number"],
55
+ (str(n["number"]) + " " + n.get("life_path_meaning","") + " " + n.get("advice","")).strip())
56
+ for i, n in enumerate(data)
57
+ ]
58
+ model = get_model()
59
+ embs = model.encode([t[2] for t in texts], normalize_embeddings=True)
60
+
61
+ idx = faiss.IndexFlatIP(embs.shape[1])
62
+ idx.add(np.array(embs, dtype="float32"))
63
+
64
+ faiss.write_index(idx, NUM_IDX)
65
+ with open(NUM_META, "w") as f:
66
+ json.dump([{"i":i, "number":num, "text":txt} for (i,num,txt) in texts], f, indent=2)
67
+
68
+ def ensure_indexes():
69
+ os.makedirs(CACHE_DIR, exist_ok=True)
70
+ if not (os.path.exists(TAROT_IDX) and os.path.exists(TAROT_META)):
71
+ _build_tarot()
72
+ if not (os.path.exists(NUM_IDX) and os.path.exists(NUM_META)):
73
+ _build_num()
74
+
75
+ def _search(index_path: str, meta_path: str, query: str, k: int = 3):
76
+ model = get_model()
77
+ idx = faiss.read_index(index_path)
78
+ with open(meta_path) as f:
79
+ meta = json.load(f)
80
+
81
+ q = model.encode([query], normalize_embeddings=True).astype("float32")
82
+ D, I = idx.search(q, k)
83
+ results = []
84
+ for rank, (score, j) in enumerate(zip(D[0], I[0]), start=1):
85
+ m = meta[j]
86
+ m = dict(m) # copy
87
+ m["score"] = float(score)
88
+ m["rank"] = rank
89
+ results.append(m)
90
+ return results
91
+
92
+ def search_tarot(query: str, k: int = 3):
93
+ ensure_indexes()
94
+ return _search(TAROT_IDX, TAROT_META, query, k)
95
+
96
+ def search_numerology(query: str, k: int = 3):
97
+ ensure_indexes()
98
+ return _search(NUM_IDX, NUM_META, query, k)