broadfield-dev commited on
Commit
9376ac0
·
verified ·
1 Parent(s): 04eac3c

Update build_rag.py

Browse files
Files changed (1) hide show
  1. build_rag.py +24 -13
build_rag.py CHANGED
@@ -1,7 +1,10 @@
 
 
1
  import json
2
  import os
3
  import pandas as pd
4
  import torch
 
5
  from transformers import AutoTokenizer, AutoModel
6
  import chromadb
7
  import sys
@@ -18,6 +21,7 @@ STATUS_FILE = "build_status.log"
18
  JSON_DIRECTORY = 'bible_json'
19
  CHUNK_SIZE = 3
20
  EMBEDDING_BATCH_SIZE = 16
 
21
  BOOK_ID_TO_NAME = {
22
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
23
  6: "Joshua", 7: "Judges", 8: "Ruth", 9: "1 Samuel", 10: "2 Samuel",
@@ -36,13 +40,12 @@ BOOK_ID_TO_NAME = {
36
  }
37
 
38
  def update_status(message):
39
- """Writes a new status to the log file."""
40
- print(message) # Also print to Space logs
41
  with open(STATUS_FILE, "w") as f:
42
  f.write(message)
43
 
44
  def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
45
- # (This function's internal logic is unchanged)
46
  all_verses = []
47
  if not os.path.exists(directory_path) or not os.listdir(directory_path):
48
  raise FileNotFoundError(f"Directory '{directory_path}' is empty or does not exist.")
@@ -72,7 +75,6 @@ def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFra
72
  return pd.DataFrame(all_chunks)
73
 
74
  def main():
75
- """Main build process."""
76
  update_status("IN_PROGRESS: Step 1/5 - Processing JSON files...")
77
  bible_chunks_df = process_bible_json_files(JSON_DIRECTORY, chunk_size=CHUNK_SIZE)
78
 
@@ -81,29 +83,38 @@ def main():
81
  import shutil
82
  shutil.rmtree(CHROMA_PATH)
83
  client = chromadb.PersistentClient(path=CHROMA_PATH)
84
- collection = client.create_collection(name=COLLECTION_NAME)
 
 
 
 
 
85
 
86
  update_status(f"IN_PROGRESS: Step 3/5 - Loading embedding model '{MODEL_NAME}'...")
87
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
88
  model = AutoModel.from_pretrained(MODEL_NAME, device_map="auto")
89
 
90
- update_status("IN_PROGRESS: Step 4/5 - Generating embeddings and populating database...")
91
- total_chunks = len(bible_chunks_df)
92
- for i in tqdm(range(0, total_chunks, EMBEDDING_BATCH_SIZE), desc="Embedding Chunks"):
93
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
94
  texts = batch_df['text'].tolist()
 
95
  inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
96
  with torch.no_grad():
97
  outputs = model(**inputs)
98
- embeddings = outputs.last_hidden_state.mean(dim=1).cpu().tolist()
 
 
 
99
  collection.add(
100
  ids=[str(j) for j in range(i, i + len(batch_df))],
101
- embeddings=embeddings,
102
  documents=texts,
103
  metadatas=batch_df[['reference', 'version']].to_dict('records')
104
  )
105
 
106
  update_status(f"IN_PROGRESS: Step 5/5 - Pushing database to Hugging Face Hub '{DATASET_REPO}'...")
 
107
  create_repo(repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True)
108
  api = HfApi()
109
  api.upload_folder(
@@ -118,10 +129,10 @@ if __name__ == "__main__":
118
  try:
119
  main()
120
  except Exception as e:
 
121
  error_message = traceback.format_exc()
122
- # Be specific about token errors
123
  if "401" in str(e) or "Unauthorized" in str(e):
124
- update_status("FAILED: Hugging Face authentication error. Please ensure your HF_TOKEN secret is set correctly and has WRITE permissions.")
125
  else:
126
- update_status(f"FAILED: An unexpected error occurred. Check Space logs for details. Error: {e}")
127
  print(error_message, file=sys.stderr)
 
1
+ # build_rag.py (Updated with Normalization and Cosine Distance)
2
+
3
  import json
4
  import os
5
  import pandas as pd
6
  import torch
7
+ import torch.nn.functional as F # Import the functional module
8
  from transformers import AutoTokenizer, AutoModel
9
  import chromadb
10
  import sys
 
21
  JSON_DIRECTORY = 'bible_json'
22
  CHUNK_SIZE = 3
23
  EMBEDDING_BATCH_SIZE = 16
24
+ # (BOOK_ID_TO_NAME dictionary remains the same)
25
  BOOK_ID_TO_NAME = {
26
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
27
  6: "Joshua", 7: "Judges", 8: "Ruth", 9: "1 Samuel", 10: "2 Samuel",
 
40
  }
41
 
42
  def update_status(message):
43
+ print(message)
 
44
  with open(STATUS_FILE, "w") as f:
45
  f.write(message)
46
 
47
  def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
48
+ # (This function is unchanged)
49
  all_verses = []
50
  if not os.path.exists(directory_path) or not os.listdir(directory_path):
51
  raise FileNotFoundError(f"Directory '{directory_path}' is empty or does not exist.")
 
75
  return pd.DataFrame(all_chunks)
76
 
77
  def main():
 
78
  update_status("IN_PROGRESS: Step 1/5 - Processing JSON files...")
79
  bible_chunks_df = process_bible_json_files(JSON_DIRECTORY, chunk_size=CHUNK_SIZE)
80
 
 
83
  import shutil
84
  shutil.rmtree(CHROMA_PATH)
85
  client = chromadb.PersistentClient(path=CHROMA_PATH)
86
+
87
+ # *** FIX 1: SET THE DISTANCE FUNCTION FOR THE COLLECTION ***
88
+ collection = client.create_collection(
89
+ name=COLLECTION_NAME,
90
+ metadata={"hnsw:space": "cosine"} # Use cosine distance
91
+ )
92
 
93
  update_status(f"IN_PROGRESS: Step 3/5 - Loading embedding model '{MODEL_NAME}'...")
94
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
95
  model = AutoModel.from_pretrained(MODEL_NAME, device_map="auto")
96
 
97
+ update_status("IN_PROGRESS: Step 4/5 - Generating and NORMALIZING embeddings...")
98
+ for i in tqdm(range(0, len(bible_chunks_df), EMBEDDING_BATCH_SIZE), desc="Embedding Chunks"):
 
99
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
100
  texts = batch_df['text'].tolist()
101
+
102
  inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
103
  with torch.no_grad():
104
  outputs = model(**inputs)
105
+
106
+ # *** FIX 2: NORMALIZE THE EMBEDDINGS ***
107
+ embeddings = F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=1)
108
+
109
  collection.add(
110
  ids=[str(j) for j in range(i, i + len(batch_df))],
111
+ embeddings=embeddings.cpu().tolist(), # Convert to list after normalization
112
  documents=texts,
113
  metadatas=batch_df[['reference', 'version']].to_dict('records')
114
  )
115
 
116
  update_status(f"IN_PROGRESS: Step 5/5 - Pushing database to Hugging Face Hub '{DATASET_REPO}'...")
117
+ # (This part is unchanged)
118
  create_repo(repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True)
119
  api = HfApi()
120
  api.upload_folder(
 
129
  try:
130
  main()
131
  except Exception as e:
132
+ # (Error handling is unchanged)
133
  error_message = traceback.format_exc()
 
134
  if "401" in str(e) or "Unauthorized" in str(e):
135
+ update_status("FAILED: Hugging Face authentication error. Ensure your HF_TOKEN secret has WRITE permissions.")
136
  else:
137
+ update_status(f"FAILED: An unexpected error occurred. Check Space logs. Error: {e}")
138
  print(error_message, file=sys.stderr)