Zwounds commited on
Commit
d93b2e5
·
verified ·
1 Parent(s): 93c51f9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -88
app.py CHANGED
@@ -12,6 +12,8 @@ from tqdm import tqdm
12
  from datasets import load_dataset
13
  import pandas as pd
14
  from sentence_transformers import SentenceTransformer
 
 
15
 
16
  # --- Page Config (MUST BE FIRST Streamlit call) ---
17
  st.set_page_config(layout="wide")
@@ -88,14 +90,12 @@ def load_dataset_from_hf():
88
  df = pd.read_parquet(parquet_path)
89
  logging.info(f"Dataset loaded into DataFrame with shape: {df.shape}")
90
 
91
- # Verify required columns
92
  required_cols = ['id', 'document', 'embedding', 'metadata']
93
  if not all(col in df.columns for col in required_cols):
94
  st.error(f"Dataset Parquet file is missing required columns. Found: {df.columns}. Required: {required_cols}")
95
  logging.error(f"Dataset Parquet file missing required columns. Found: {df.columns}")
96
- return None # Return None on error
97
 
98
- # Ensure embeddings are lists of floats
99
  logging.info("Ensuring embeddings are in list format...")
100
  if not df.empty and df['embedding'].iloc[0] is not None and (not isinstance(df['embedding'].iloc[0], list) or not isinstance(df['embedding'].iloc[0][0], float)):
101
  df['embedding'] = df['embedding'].apply(lambda x: list(map(float, x)) if isinstance(x, (np.ndarray, list)) else None)
@@ -111,7 +111,7 @@ def load_dataset_from_hf():
111
  if df.empty:
112
  st.error("No valid data loaded from the dataset after processing embeddings.")
113
  logging.error("DataFrame empty after embedding processing.")
114
- return None # Return None on error
115
 
116
  return df
117
 
@@ -122,7 +122,7 @@ def load_dataset_from_hf():
122
  st.error(f"Failed to load data from dataset: {e}")
123
  logging.exception(f"An unexpected error occurred during data load: {e}")
124
 
125
- return None # Return None on any error
126
 
127
  # --- Initialize Clients and Models ---
128
  generation_client = initialize_hf_client()
@@ -130,90 +130,112 @@ embedding_model = load_local_embedding_model()
130
  # ---
131
 
132
  # --- Setup ChromaDB Collection (using Session State) ---
133
- if 'chroma_collection' not in st.session_state:
134
- st.session_state.chroma_collection = None
135
- if embedding_model and generation_client: # Only proceed if models/clients loaded
136
- with st.spinner("Loading and preparing vector database..."):
137
- df = load_dataset_from_hf()
138
- if df is not None and not df.empty:
139
- try:
140
- logging.info("Initializing Ephemeral ChromaDB client...")
141
- chroma_client = chromadb.EphemeralClient() # Use Ephemeral Client
142
-
143
- # Delete collection if it somehow exists (unlikely for ephemeral)
144
- try:
145
- chroma_client.delete_collection(name=COLLECTION_NAME)
146
- logging.info(f"Deleted existing collection (if any): {COLLECTION_NAME}")
147
- except: pass
148
-
149
- logging.info(f"Creating collection: {COLLECTION_NAME}")
150
- collection_instance = chroma_client.create_collection(
151
- name=COLLECTION_NAME,
152
- metadata={"hnsw:space": "cosine"}
153
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- logging.info(f"Adding {len(df)} documents to ChromaDB in batches of {ADD_BATCH_SIZE}...")
156
- start_time = time.time()
157
- error_count = 0
158
- num_batches = (len(df) + ADD_BATCH_SIZE - 1) // ADD_BATCH_SIZE
159
-
160
- for i in range(num_batches):
161
- start_idx = i * ADD_BATCH_SIZE
162
- end_idx = start_idx + ADD_BATCH_SIZE
163
- batch_df = df.iloc[start_idx:end_idx]
164
-
165
- try:
166
- # Prepare and clean metadata for the batch
167
- metadatas_list_raw = batch_df['metadata'].tolist()
168
- cleaned_metadatas = []
169
- for item in metadatas_list_raw:
170
- cleaned_dict = {}
171
- if isinstance(item, dict):
172
- current_meta = item
173
  else:
174
- try: current_meta = json.loads(item) if isinstance(item, str) else {}
175
- except: current_meta = {}
176
-
177
- if isinstance(current_meta, dict):
178
- for key, value in current_meta.items():
179
- if value is None: cleaned_dict[key] = ""
180
- elif isinstance(value, (str, int, float, bool)): cleaned_dict[key] = value
181
- else:
182
- try: cleaned_dict[key] = str(value)
183
- except: pass # Skip unconvertible types
184
- cleaned_metadatas.append(cleaned_dict)
185
-
186
- # Add the batch
187
- collection_instance.add(
188
- ids=batch_df['id'].tolist(),
189
- embeddings=batch_df['embedding'].tolist(),
190
- documents=batch_df['document'].tolist(),
191
- metadatas=cleaned_metadatas
192
- )
193
- except Exception as e:
194
- logging.error(f"Error adding batch {i+1}/{num_batches} to Chroma: {e}")
195
- error_count += 1
196
-
197
- end_time = time.time()
198
- logging.info(f"Finished loading data into ChromaDB. Took {end_time - start_time:.2f} seconds.")
199
- if error_count > 0: logging.warning(f"Encountered errors in {error_count} batches during add.")
200
-
201
- final_count = collection_instance.count()
202
- logging.info(f"Final document count in Chroma collection: {final_count}")
203
- if final_count > 0:
204
- st.session_state.chroma_collection = collection_instance
205
- st.success("Vector database loaded successfully!")
206
- else:
207
- st.error("Failed to load documents into the vector database.")
208
-
209
- except Exception as setup_e:
210
- st.error(f"Failed to setup ChromaDB: {setup_e}")
211
- logging.exception(f"Failed to setup ChromaDB: {setup_e}")
212
  else:
213
- st.error("Failed to load data from the dataset. Cannot initialize database.")
 
214
 
215
- # Assign collection from session state for use in the app
216
- collection = st.session_state.get('chroma_collection', None)
 
 
 
 
 
 
217
  # ---
218
 
219
  # --- Helper Functions ---
@@ -235,7 +257,6 @@ def query_hf_inference(prompt, client_instance=None, model_name=HF_GENERATION_MO
235
 
236
  def generate_query_variations(query, llm_func, model_name=HF_GENERATION_MODEL, num_variations=3):
237
  """Uses LLM (HF Inference API) to generate alternative phrasings."""
238
- # ... (rest of function remains the same) ...
239
  prompt = f"""Given the user query: "{query}"
240
  Generate {num_variations} alternative phrasings or related queries someone might use to find the same information.
241
  Focus on synonyms, different levels of specificity, and related concepts.
@@ -268,10 +289,8 @@ Output:"""
268
  logging.error(f"Failed to generate query variations: {e}")
269
  return []
270
 
271
-
272
  def generate_prompt(query, context_chunks):
273
  """Generates a prompt for the LLM."""
274
- # ... (function remains the same) ...
275
  context_str = "\n\n".join(context_chunks)
276
  liaison_directory_url = "https://libguides.gc.cuny.edu/directory/subject"
277
  prompt = f"""Based on the following context from the library guides, answer the user's question.
 
12
  from datasets import load_dataset
13
  import pandas as pd
14
  from sentence_transformers import SentenceTransformer
15
+ # Import config if needed for EphemeralClient settings, though default might be fine
16
+ import chromadb.config
17
 
18
  # --- Page Config (MUST BE FIRST Streamlit call) ---
19
  st.set_page_config(layout="wide")
 
90
  df = pd.read_parquet(parquet_path)
91
  logging.info(f"Dataset loaded into DataFrame with shape: {df.shape}")
92
 
 
93
  required_cols = ['id', 'document', 'embedding', 'metadata']
94
  if not all(col in df.columns for col in required_cols):
95
  st.error(f"Dataset Parquet file is missing required columns. Found: {df.columns}. Required: {required_cols}")
96
  logging.error(f"Dataset Parquet file missing required columns. Found: {df.columns}")
97
+ return None
98
 
 
99
  logging.info("Ensuring embeddings are in list format...")
100
  if not df.empty and df['embedding'].iloc[0] is not None and (not isinstance(df['embedding'].iloc[0], list) or not isinstance(df['embedding'].iloc[0][0], float)):
101
  df['embedding'] = df['embedding'].apply(lambda x: list(map(float, x)) if isinstance(x, (np.ndarray, list)) else None)
 
111
  if df.empty:
112
  st.error("No valid data loaded from the dataset after processing embeddings.")
113
  logging.error("DataFrame empty after embedding processing.")
114
+ return None
115
 
116
  return df
117
 
 
122
  st.error(f"Failed to load data from dataset: {e}")
123
  logging.exception(f"An unexpected error occurred during data load: {e}")
124
 
125
+ return None
126
 
127
  # --- Initialize Clients and Models ---
128
  generation_client = initialize_hf_client()
 
130
  # ---
131
 
132
  # --- Setup ChromaDB Collection (using Session State) ---
133
+ # This function now attempts to load or create the collection and stores it in session state
134
+ def setup_chroma_collection():
135
+ if 'chroma_collection' in st.session_state and st.session_state.chroma_collection is not None:
136
+ logging.info("Using existing Chroma collection from session state.")
137
+ return st.session_state.chroma_collection
138
+
139
+ # Proceed with setup only if essential components are loaded
140
+ if not embedding_model or not generation_client:
141
+ st.error("Cannot setup ChromaDB: Required models/clients failed to initialize.")
142
+ return None
143
+
144
+ with st.spinner("Loading and preparing vector database..."):
145
+ df = load_dataset_from_hf()
146
+ if df is None or df.empty:
147
+ st.error("Failed to load embedding data. Cannot initialize vector database.")
148
+ return None
149
+
150
+ try:
151
+ logging.info("Initializing Ephemeral ChromaDB client...")
152
+ # Use EphemeralClient explicitly
153
+ chroma_client = chromadb.EphemeralClient(
154
+ settings=chromadb.config.Settings(
155
+ anonymized_telemetry=False, # Optional: Disable telemetry
156
+ allow_reset=True # Optional: Allows resetting
157
+ )
158
+ )
159
+
160
+ # Check if collection exists and delete if it does (robustness)
161
+ try:
162
+ existing_collections = [col.name for col in chroma_client.list_collections()]
163
+ if COLLECTION_NAME in existing_collections:
164
+ chroma_client.delete_collection(name=COLLECTION_NAME)
165
+ logging.info(f"Deleted existing collection: {COLLECTION_NAME}")
166
+ except Exception as delete_e:
167
+ logging.warning(f"Could not check/delete existing collection (might be okay): {delete_e}")
168
+
169
+
170
+ logging.info(f"Creating collection: {COLLECTION_NAME}")
171
+ collection_instance = chroma_client.create_collection(
172
+ name=COLLECTION_NAME,
173
+ metadata={"hnsw:space": "cosine"} # No embedding function needed here
174
+ )
175
+
176
+ logging.info(f"Adding {len(df)} documents to ChromaDB in batches of {ADD_BATCH_SIZE}...")
177
+ start_time = time.time()
178
+ error_count = 0
179
+ num_batches = (len(df) + ADD_BATCH_SIZE - 1) // ADD_BATCH_SIZE
180
+
181
+ for i in range(num_batches):
182
+ start_idx = i * ADD_BATCH_SIZE
183
+ end_idx = start_idx + ADD_BATCH_SIZE
184
+ batch_df = df.iloc[start_idx:end_idx]
185
 
186
+ try:
187
+ # Prepare and clean metadata for the batch
188
+ metadatas_list_raw = batch_df['metadata'].tolist()
189
+ cleaned_metadatas = []
190
+ for item in metadatas_list_raw:
191
+ cleaned_dict = {}
192
+ current_meta = item if isinstance(item, dict) else {}
193
+ if not isinstance(item, dict):
194
+ try: current_meta = json.loads(item) if isinstance(item, str) else {}
195
+ except: current_meta = {}
196
+
197
+ if isinstance(current_meta, dict):
198
+ for key, value in current_meta.items():
199
+ if value is None: cleaned_dict[key] = ""
200
+ elif isinstance(value, (str, int, float, bool)): cleaned_dict[key] = value
 
 
 
201
  else:
202
+ try: cleaned_dict[key] = str(value)
203
+ except: pass
204
+ cleaned_metadatas.append(cleaned_dict)
205
+
206
+ # Add the batch
207
+ collection_instance.add(
208
+ ids=batch_df['id'].tolist(),
209
+ embeddings=batch_df['embedding'].tolist(),
210
+ documents=batch_df['document'].tolist(),
211
+ metadatas=cleaned_metadatas
212
+ )
213
+ except Exception as e:
214
+ logging.error(f"Error adding batch {i+1}/{num_batches} to Chroma: {e}")
215
+ error_count += 1
216
+
217
+ end_time = time.time()
218
+ logging.info(f"Finished loading data into ChromaDB. Took {end_time - start_time:.2f} seconds.")
219
+ if error_count > 0: logging.warning(f"Encountered errors in {error_count} batches during add.")
220
+
221
+ final_count = collection_instance.count()
222
+ logging.info(f"Final document count in Chroma collection: {final_count}")
223
+ if final_count > 0:
224
+ st.session_state.chroma_collection = collection_instance
225
+ st.success("Vector database loaded successfully!")
226
+ return collection_instance
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  else:
228
+ st.error("Failed to load documents into the vector database.")
229
+ return None
230
 
231
+ except Exception as setup_e:
232
+ st.error(f"Failed to setup ChromaDB: {setup_e}")
233
+ logging.exception(f"Failed to setup ChromaDB: {setup_e}")
234
+ return None
235
+
236
+ # --- Initialize collection ---
237
+ # Call the setup function which populates session state if needed
238
+ collection = setup_chroma_collection()
239
  # ---
240
 
241
  # --- Helper Functions ---
 
257
 
258
  def generate_query_variations(query, llm_func, model_name=HF_GENERATION_MODEL, num_variations=3):
259
  """Uses LLM (HF Inference API) to generate alternative phrasings."""
 
260
  prompt = f"""Given the user query: "{query}"
261
  Generate {num_variations} alternative phrasings or related queries someone might use to find the same information.
262
  Focus on synonyms, different levels of specificity, and related concepts.
 
289
  logging.error(f"Failed to generate query variations: {e}")
290
  return []
291
 
 
292
  def generate_prompt(query, context_chunks):
293
  """Generates a prompt for the LLM."""
 
294
  context_str = "\n\n".join(context_chunks)
295
  liaison_directory_url = "https://libguides.gc.cuny.edu/directory/subject"
296
  prompt = f"""Based on the following context from the library guides, answer the user's question.