Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,8 @@ from tqdm import tqdm
|
|
12 |
from datasets import load_dataset
|
13 |
import pandas as pd
|
14 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
15 |
|
16 |
# --- Page Config (MUST BE FIRST Streamlit call) ---
|
17 |
st.set_page_config(layout="wide")
|
@@ -88,14 +90,12 @@ def load_dataset_from_hf():
|
|
88 |
df = pd.read_parquet(parquet_path)
|
89 |
logging.info(f"Dataset loaded into DataFrame with shape: {df.shape}")
|
90 |
|
91 |
-
# Verify required columns
|
92 |
required_cols = ['id', 'document', 'embedding', 'metadata']
|
93 |
if not all(col in df.columns for col in required_cols):
|
94 |
st.error(f"Dataset Parquet file is missing required columns. Found: {df.columns}. Required: {required_cols}")
|
95 |
logging.error(f"Dataset Parquet file missing required columns. Found: {df.columns}")
|
96 |
-
return None
|
97 |
|
98 |
-
# Ensure embeddings are lists of floats
|
99 |
logging.info("Ensuring embeddings are in list format...")
|
100 |
if not df.empty and df['embedding'].iloc[0] is not None and (not isinstance(df['embedding'].iloc[0], list) or not isinstance(df['embedding'].iloc[0][0], float)):
|
101 |
df['embedding'] = df['embedding'].apply(lambda x: list(map(float, x)) if isinstance(x, (np.ndarray, list)) else None)
|
@@ -111,7 +111,7 @@ def load_dataset_from_hf():
|
|
111 |
if df.empty:
|
112 |
st.error("No valid data loaded from the dataset after processing embeddings.")
|
113 |
logging.error("DataFrame empty after embedding processing.")
|
114 |
-
return None
|
115 |
|
116 |
return df
|
117 |
|
@@ -122,7 +122,7 @@ def load_dataset_from_hf():
|
|
122 |
st.error(f"Failed to load data from dataset: {e}")
|
123 |
logging.exception(f"An unexpected error occurred during data load: {e}")
|
124 |
|
125 |
-
return None
|
126 |
|
127 |
# --- Initialize Clients and Models ---
|
128 |
generation_client = initialize_hf_client()
|
@@ -130,90 +130,112 @@ embedding_model = load_local_embedding_model()
|
|
130 |
# ---
|
131 |
|
132 |
# --- Setup ChromaDB Collection (using Session State) ---
|
133 |
-
|
134 |
-
|
135 |
-
if
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
cleaned_dict = {}
|
171 |
-
if isinstance(item, dict):
|
172 |
-
current_meta = item
|
173 |
else:
|
174 |
-
try:
|
175 |
-
except:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
if error_count > 0: logging.warning(f"Encountered errors in {error_count} batches during add.")
|
200 |
-
|
201 |
-
final_count = collection_instance.count()
|
202 |
-
logging.info(f"Final document count in Chroma collection: {final_count}")
|
203 |
-
if final_count > 0:
|
204 |
-
st.session_state.chroma_collection = collection_instance
|
205 |
-
st.success("Vector database loaded successfully!")
|
206 |
-
else:
|
207 |
-
st.error("Failed to load documents into the vector database.")
|
208 |
-
|
209 |
-
except Exception as setup_e:
|
210 |
-
st.error(f"Failed to setup ChromaDB: {setup_e}")
|
211 |
-
logging.exception(f"Failed to setup ChromaDB: {setup_e}")
|
212 |
else:
|
213 |
-
|
|
|
214 |
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
# ---
|
218 |
|
219 |
# --- Helper Functions ---
|
@@ -235,7 +257,6 @@ def query_hf_inference(prompt, client_instance=None, model_name=HF_GENERATION_MO
|
|
235 |
|
236 |
def generate_query_variations(query, llm_func, model_name=HF_GENERATION_MODEL, num_variations=3):
|
237 |
"""Uses LLM (HF Inference API) to generate alternative phrasings."""
|
238 |
-
# ... (rest of function remains the same) ...
|
239 |
prompt = f"""Given the user query: "{query}"
|
240 |
Generate {num_variations} alternative phrasings or related queries someone might use to find the same information.
|
241 |
Focus on synonyms, different levels of specificity, and related concepts.
|
@@ -268,10 +289,8 @@ Output:"""
|
|
268 |
logging.error(f"Failed to generate query variations: {e}")
|
269 |
return []
|
270 |
|
271 |
-
|
272 |
def generate_prompt(query, context_chunks):
|
273 |
"""Generates a prompt for the LLM."""
|
274 |
-
# ... (function remains the same) ...
|
275 |
context_str = "\n\n".join(context_chunks)
|
276 |
liaison_directory_url = "https://libguides.gc.cuny.edu/directory/subject"
|
277 |
prompt = f"""Based on the following context from the library guides, answer the user's question.
|
|
|
12 |
from datasets import load_dataset
|
13 |
import pandas as pd
|
14 |
from sentence_transformers import SentenceTransformer
|
15 |
+
# Import config if needed for EphemeralClient settings, though default might be fine
|
16 |
+
import chromadb.config
|
17 |
|
18 |
# --- Page Config (MUST BE FIRST Streamlit call) ---
|
19 |
st.set_page_config(layout="wide")
|
|
|
90 |
df = pd.read_parquet(parquet_path)
|
91 |
logging.info(f"Dataset loaded into DataFrame with shape: {df.shape}")
|
92 |
|
|
|
93 |
required_cols = ['id', 'document', 'embedding', 'metadata']
|
94 |
if not all(col in df.columns for col in required_cols):
|
95 |
st.error(f"Dataset Parquet file is missing required columns. Found: {df.columns}. Required: {required_cols}")
|
96 |
logging.error(f"Dataset Parquet file missing required columns. Found: {df.columns}")
|
97 |
+
return None
|
98 |
|
|
|
99 |
logging.info("Ensuring embeddings are in list format...")
|
100 |
if not df.empty and df['embedding'].iloc[0] is not None and (not isinstance(df['embedding'].iloc[0], list) or not isinstance(df['embedding'].iloc[0][0], float)):
|
101 |
df['embedding'] = df['embedding'].apply(lambda x: list(map(float, x)) if isinstance(x, (np.ndarray, list)) else None)
|
|
|
111 |
if df.empty:
|
112 |
st.error("No valid data loaded from the dataset after processing embeddings.")
|
113 |
logging.error("DataFrame empty after embedding processing.")
|
114 |
+
return None
|
115 |
|
116 |
return df
|
117 |
|
|
|
122 |
st.error(f"Failed to load data from dataset: {e}")
|
123 |
logging.exception(f"An unexpected error occurred during data load: {e}")
|
124 |
|
125 |
+
return None
|
126 |
|
127 |
# --- Initialize Clients and Models ---
|
128 |
generation_client = initialize_hf_client()
|
|
|
130 |
# ---
|
131 |
|
132 |
# --- Setup ChromaDB Collection (using Session State) ---
|
133 |
+
# This function now attempts to load or create the collection and stores it in session state
|
134 |
+
def setup_chroma_collection():
|
135 |
+
if 'chroma_collection' in st.session_state and st.session_state.chroma_collection is not None:
|
136 |
+
logging.info("Using existing Chroma collection from session state.")
|
137 |
+
return st.session_state.chroma_collection
|
138 |
+
|
139 |
+
# Proceed with setup only if essential components are loaded
|
140 |
+
if not embedding_model or not generation_client:
|
141 |
+
st.error("Cannot setup ChromaDB: Required models/clients failed to initialize.")
|
142 |
+
return None
|
143 |
+
|
144 |
+
with st.spinner("Loading and preparing vector database..."):
|
145 |
+
df = load_dataset_from_hf()
|
146 |
+
if df is None or df.empty:
|
147 |
+
st.error("Failed to load embedding data. Cannot initialize vector database.")
|
148 |
+
return None
|
149 |
+
|
150 |
+
try:
|
151 |
+
logging.info("Initializing Ephemeral ChromaDB client...")
|
152 |
+
# Use EphemeralClient explicitly
|
153 |
+
chroma_client = chromadb.EphemeralClient(
|
154 |
+
settings=chromadb.config.Settings(
|
155 |
+
anonymized_telemetry=False, # Optional: Disable telemetry
|
156 |
+
allow_reset=True # Optional: Allows resetting
|
157 |
+
)
|
158 |
+
)
|
159 |
+
|
160 |
+
# Check if collection exists and delete if it does (robustness)
|
161 |
+
try:
|
162 |
+
existing_collections = [col.name for col in chroma_client.list_collections()]
|
163 |
+
if COLLECTION_NAME in existing_collections:
|
164 |
+
chroma_client.delete_collection(name=COLLECTION_NAME)
|
165 |
+
logging.info(f"Deleted existing collection: {COLLECTION_NAME}")
|
166 |
+
except Exception as delete_e:
|
167 |
+
logging.warning(f"Could not check/delete existing collection (might be okay): {delete_e}")
|
168 |
+
|
169 |
+
|
170 |
+
logging.info(f"Creating collection: {COLLECTION_NAME}")
|
171 |
+
collection_instance = chroma_client.create_collection(
|
172 |
+
name=COLLECTION_NAME,
|
173 |
+
metadata={"hnsw:space": "cosine"} # No embedding function needed here
|
174 |
+
)
|
175 |
+
|
176 |
+
logging.info(f"Adding {len(df)} documents to ChromaDB in batches of {ADD_BATCH_SIZE}...")
|
177 |
+
start_time = time.time()
|
178 |
+
error_count = 0
|
179 |
+
num_batches = (len(df) + ADD_BATCH_SIZE - 1) // ADD_BATCH_SIZE
|
180 |
+
|
181 |
+
for i in range(num_batches):
|
182 |
+
start_idx = i * ADD_BATCH_SIZE
|
183 |
+
end_idx = start_idx + ADD_BATCH_SIZE
|
184 |
+
batch_df = df.iloc[start_idx:end_idx]
|
185 |
|
186 |
+
try:
|
187 |
+
# Prepare and clean metadata for the batch
|
188 |
+
metadatas_list_raw = batch_df['metadata'].tolist()
|
189 |
+
cleaned_metadatas = []
|
190 |
+
for item in metadatas_list_raw:
|
191 |
+
cleaned_dict = {}
|
192 |
+
current_meta = item if isinstance(item, dict) else {}
|
193 |
+
if not isinstance(item, dict):
|
194 |
+
try: current_meta = json.loads(item) if isinstance(item, str) else {}
|
195 |
+
except: current_meta = {}
|
196 |
+
|
197 |
+
if isinstance(current_meta, dict):
|
198 |
+
for key, value in current_meta.items():
|
199 |
+
if value is None: cleaned_dict[key] = ""
|
200 |
+
elif isinstance(value, (str, int, float, bool)): cleaned_dict[key] = value
|
|
|
|
|
|
|
201 |
else:
|
202 |
+
try: cleaned_dict[key] = str(value)
|
203 |
+
except: pass
|
204 |
+
cleaned_metadatas.append(cleaned_dict)
|
205 |
+
|
206 |
+
# Add the batch
|
207 |
+
collection_instance.add(
|
208 |
+
ids=batch_df['id'].tolist(),
|
209 |
+
embeddings=batch_df['embedding'].tolist(),
|
210 |
+
documents=batch_df['document'].tolist(),
|
211 |
+
metadatas=cleaned_metadatas
|
212 |
+
)
|
213 |
+
except Exception as e:
|
214 |
+
logging.error(f"Error adding batch {i+1}/{num_batches} to Chroma: {e}")
|
215 |
+
error_count += 1
|
216 |
+
|
217 |
+
end_time = time.time()
|
218 |
+
logging.info(f"Finished loading data into ChromaDB. Took {end_time - start_time:.2f} seconds.")
|
219 |
+
if error_count > 0: logging.warning(f"Encountered errors in {error_count} batches during add.")
|
220 |
+
|
221 |
+
final_count = collection_instance.count()
|
222 |
+
logging.info(f"Final document count in Chroma collection: {final_count}")
|
223 |
+
if final_count > 0:
|
224 |
+
st.session_state.chroma_collection = collection_instance
|
225 |
+
st.success("Vector database loaded successfully!")
|
226 |
+
return collection_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
else:
|
228 |
+
st.error("Failed to load documents into the vector database.")
|
229 |
+
return None
|
230 |
|
231 |
+
except Exception as setup_e:
|
232 |
+
st.error(f"Failed to setup ChromaDB: {setup_e}")
|
233 |
+
logging.exception(f"Failed to setup ChromaDB: {setup_e}")
|
234 |
+
return None
|
235 |
+
|
236 |
+
# --- Initialize collection ---
|
237 |
+
# Call the setup function which populates session state if needed
|
238 |
+
collection = setup_chroma_collection()
|
239 |
# ---
|
240 |
|
241 |
# --- Helper Functions ---
|
|
|
257 |
|
258 |
def generate_query_variations(query, llm_func, model_name=HF_GENERATION_MODEL, num_variations=3):
|
259 |
"""Uses LLM (HF Inference API) to generate alternative phrasings."""
|
|
|
260 |
prompt = f"""Given the user query: "{query}"
|
261 |
Generate {num_variations} alternative phrasings or related queries someone might use to find the same information.
|
262 |
Focus on synonyms, different levels of specificity, and related concepts.
|
|
|
289 |
logging.error(f"Failed to generate query variations: {e}")
|
290 |
return []
|
291 |
|
|
|
292 |
def generate_prompt(query, context_chunks):
|
293 |
"""Generates a prompt for the LLM."""
|
|
|
294 |
context_str = "\n\n".join(context_chunks)
|
295 |
liaison_directory_url = "https://libguides.gc.cuny.edu/directory/subject"
|
296 |
prompt = f"""Based on the following context from the library guides, answer the user's question.
|