broadfield-dev commited on
Commit
fc56677
·
verified ·
1 Parent(s): 9376ac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -43
app.py CHANGED
@@ -1,104 +1,78 @@
 
 
1
  import sys
2
  import subprocess
3
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
4
  import torch
 
5
  from transformers import AutoTokenizer, AutoModel
6
  import os
7
  import chromadb
8
  from huggingface_hub import snapshot_download
9
 
10
- # --- 1. Initialize Flask App ---
11
  app = Flask(__name__)
12
  app.secret_key = os.urandom(24)
13
 
14
- # --- 2. Configuration & Resource Loading ---
15
- print("Starting application...")
16
-
17
- # --- Configuration (Must match build_rag.py) ---
18
  CHROMA_PATH = "chroma_db"
19
  COLLECTION_NAME = "bible_verses"
20
  MODEL_NAME = "google/embeddinggemma-300m"
21
  DATASET_REPO = "broadfield-dev/bible-chromadb-gemma"
22
- STATUS_FILE = "build_status.log" # File to track build status
23
 
24
- # --- Global variables for resources ---
25
  chroma_collection = None
26
  tokenizer = None
27
  embedding_model = None
28
 
29
  def load_resources():
30
- """Downloads the DB from the Hub if not present, then loads it and the model."""
31
  global chroma_collection, tokenizer, embedding_model
32
- if chroma_collection and embedding_model:
33
- return True
34
-
35
  print("Attempting to load resources...")
36
  try:
37
  if not os.path.exists(CHROMA_PATH) or not os.listdir(CHROMA_PATH):
38
  print(f"Local DB not found. Downloading from '{DATASET_REPO}'...")
39
- snapshot_download(
40
- repo_id=DATASET_REPO,
41
- repo_type="dataset",
42
- local_dir=CHROMA_PATH,
43
- local_dir_use_symlinks=False
44
- )
45
  print("Database files downloaded.")
46
  else:
47
  print("Local database files found.")
48
-
49
  client = chromadb.PersistentClient(path=CHROMA_PATH)
50
  collection = client.get_collection(name=COLLECTION_NAME)
51
  if collection.count() == 0:
52
  print("Warning: Database collection is empty.")
53
  return False
54
-
55
  chroma_collection = collection
56
  print(f"Successfully connected to DB with {collection.count()} items.")
57
-
58
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
59
  embedding_model = AutoModel.from_pretrained(MODEL_NAME)
60
  print(f"Embedding model '{MODEL_NAME}' loaded successfully.")
61
-
62
  return True
63
  except Exception as e:
64
- print(f"Could not load resources. The database may not be built yet or the repo is empty.")
65
  print(f"Error: {e}")
66
  return False
67
 
68
  resources_loaded = load_resources()
69
 
70
- # --- 3. Define App Routes ---
71
-
72
  @app.route('/')
73
  def home():
74
  return render_template('index.html')
75
 
76
  @app.route('/build-rag', methods=['POST'])
77
  def build_rag_route():
78
- """Triggers the build script and immediately responds."""
79
  try:
80
- # Clear old status and set to "In Progress"
81
- with open(STATUS_FILE, "w") as f:
82
- f.write("IN_PROGRESS: Starting build process...")
83
-
84
- # Start the build process in the background
85
  subprocess.Popen([sys.executable, "build_rag.py"])
86
-
87
  return jsonify({"status": "started"})
88
  except Exception as e:
89
- with open(STATUS_FILE, "w") as f:
90
- f.write(f"FAILED: Could not start process - {e}")
91
  return jsonify({"status": "error", "message": str(e)}), 500
92
 
93
  @app.route('/status')
94
  def status():
95
- """Endpoint for the frontend to poll for build status."""
96
- if not os.path.exists(STATUS_FILE):
97
- return jsonify({"status": "NOT_STARTED"})
98
-
99
- with open(STATUS_FILE, "r") as f:
100
- status_line = f.read().strip()
101
-
102
  status_code, _, message = status_line.partition(': ')
103
  return jsonify({"status": status_code, "message": message})
104
 
@@ -118,10 +92,12 @@ def search():
118
  inputs = tokenizer(user_query, return_tensors="pt")
119
  with torch.no_grad():
120
  outputs = embedding_model(**inputs)
121
- query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
 
 
122
 
123
  search_results = chroma_collection.query(
124
- query_embeddings=query_embedding.tolist(),
125
  n_results=10
126
  )
127
 
@@ -138,6 +114,5 @@ def search():
138
 
139
  return render_template('index.html', results=results_list, query=user_query)
140
 
141
- # --- 4. Run the App ---
142
  if __name__ == '__main__':
143
  app.run(host='0.0.0.0', port=7860)
 
1
+ # app.py (Updated with Normalization for the query)
2
+
3
  import sys
4
  import subprocess
5
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
6
  import torch
7
+ import torch.nn.functional as F # Import the functional module
8
  from transformers import AutoTokenizer, AutoModel
9
  import os
10
  import chromadb
11
  from huggingface_hub import snapshot_download
12
 
13
+ # (App setup and load_resources function are unchanged)
14
  app = Flask(__name__)
15
  app.secret_key = os.urandom(24)
16
 
 
 
 
 
17
  CHROMA_PATH = "chroma_db"
18
  COLLECTION_NAME = "bible_verses"
19
  MODEL_NAME = "google/embeddinggemma-300m"
20
  DATASET_REPO = "broadfield-dev/bible-chromadb-gemma"
21
+ STATUS_FILE = "build_status.log"
22
 
 
23
  chroma_collection = None
24
  tokenizer = None
25
  embedding_model = None
26
 
27
  def load_resources():
28
+ # (This function is unchanged)
29
  global chroma_collection, tokenizer, embedding_model
30
+ if chroma_collection and embedding_model: return True
 
 
31
  print("Attempting to load resources...")
32
  try:
33
  if not os.path.exists(CHROMA_PATH) or not os.listdir(CHROMA_PATH):
34
  print(f"Local DB not found. Downloading from '{DATASET_REPO}'...")
35
+ snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=CHROMA_PATH, local_dir_use_symlinks=False)
 
 
 
 
 
36
  print("Database files downloaded.")
37
  else:
38
  print("Local database files found.")
 
39
  client = chromadb.PersistentClient(path=CHROMA_PATH)
40
  collection = client.get_collection(name=COLLECTION_NAME)
41
  if collection.count() == 0:
42
  print("Warning: Database collection is empty.")
43
  return False
 
44
  chroma_collection = collection
45
  print(f"Successfully connected to DB with {collection.count()} items.")
 
46
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
47
  embedding_model = AutoModel.from_pretrained(MODEL_NAME)
48
  print(f"Embedding model '{MODEL_NAME}' loaded successfully.")
 
49
  return True
50
  except Exception as e:
51
+ print(f"Could not load resources. DB may not be built or repo is empty.")
52
  print(f"Error: {e}")
53
  return False
54
 
55
  resources_loaded = load_resources()
56
 
57
+ # (home, build_rag_route, and status routes are unchanged)
 
58
  @app.route('/')
59
  def home():
60
  return render_template('index.html')
61
 
62
  @app.route('/build-rag', methods=['POST'])
63
  def build_rag_route():
 
64
  try:
65
+ with open(STATUS_FILE, "w") as f: f.write("IN_PROGRESS: Starting build process...")
 
 
 
 
66
  subprocess.Popen([sys.executable, "build_rag.py"])
 
67
  return jsonify({"status": "started"})
68
  except Exception as e:
69
+ with open(STATUS_FILE, "w") as f: f.write(f"FAILED: Could not start process - {e}")
 
70
  return jsonify({"status": "error", "message": str(e)}), 500
71
 
72
  @app.route('/status')
73
  def status():
74
+ if not os.path.exists(STATUS_FILE): return jsonify({"status": "NOT_STARTED"})
75
+ with open(STATUS_FILE, "r") as f: status_line = f.read().strip()
 
 
 
 
 
76
  status_code, _, message = status_line.partition(': ')
77
  return jsonify({"status": status_code, "message": message})
78
 
 
92
  inputs = tokenizer(user_query, return_tensors="pt")
93
  with torch.no_grad():
94
  outputs = embedding_model(**inputs)
95
+
96
+ # *** FIX: NORMALIZE THE QUERY EMBEDDING ***
97
+ query_embedding = F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=1)
98
 
99
  search_results = chroma_collection.query(
100
+ query_embeddings=query_embedding.cpu().tolist(),
101
  n_results=10
102
  )
103
 
 
114
 
115
  return render_template('index.html', results=results_list, query=user_query)
116
 
 
117
  if __name__ == '__main__':
118
  app.run(host='0.0.0.0', port=7860)