Pringled commited on
Commit
e49e0e9
·
1 Parent(s): 50c3ede
Files changed (1) hide show
  1. app.py +128 -10
app.py CHANGED
@@ -10,7 +10,7 @@ from contextlib import contextmanager
10
  # Load the model at startup
11
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
- # Update default dataset to 'sst2' and set default threshold to 0.9
14
  default_dataset1_name = "sst2"
15
  default_dataset1_split = "train"
16
  default_dataset2_name = "sst2"
@@ -47,7 +47,6 @@ def deduplicate(
47
  batch_size: int = 1024,
48
  progress=None
49
  ) -> tuple[np.ndarray, dict[int, int]]:
50
- # Existing deduplication code remains unchanged
51
  # Building the index
52
  progress(0, desc="Building search index...")
53
  reach = Reach(
@@ -171,18 +170,137 @@ def perform_deduplication(
171
 
172
  elif deduplication_type == "Cross-dataset":
173
  # Similar code for cross-dataset deduplication
174
- # Implement similar logic as above for cross-dataset
175
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  except Exception as e:
178
  yield f"An error occurred: {e}", ""
179
  raise e
180
 
181
- with gr.Blocks() as demo:
182
- # Replace 'gr.Markdown' with 'gr.Textbox' for 'status_output'
183
- status_output = gr.Textbox().style(height=150)
184
- result_output = gr.Markdown()
 
 
 
 
 
 
 
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  deduplication_type = gr.Radio(
188
  choices=["Single dataset", "Cross-dataset"],
@@ -209,8 +327,8 @@ with gr.Blocks() as demo:
209
 
210
  compute_button = gr.Button("Compute")
211
 
212
- # Adjust the height of the status_output component
213
- status_output = gr.Markdown().style(height=150)
214
  result_output = gr.Markdown()
215
 
216
  # Function to update the visibility of dataset2_inputs
 
10
  # Load the model at startup
11
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
+ # Default dataset parameters
14
  default_dataset1_name = "sst2"
15
  default_dataset1_split = "train"
16
  default_dataset2_name = "sst2"
 
47
  batch_size: int = 1024,
48
  progress=None
49
  ) -> tuple[np.ndarray, dict[int, int]]:
 
50
  # Building the index
51
  progress(0, desc="Building search index...")
52
  reach = Reach(
 
170
 
171
  elif deduplication_type == "Cross-dataset":
172
  # Similar code for cross-dataset deduplication
173
+ # Load Dataset 1
174
+ status = "Loading Dataset 1..."
175
+ yield status, ""
176
+ if (
177
+ dataset1_name == default_dataset1_name
178
+ and dataset1_split == default_dataset1_split
179
+ ):
180
+ ds1 = ds_default1
181
+ else:
182
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
183
+
184
+ # Load Dataset 2
185
+ status = "Loading Dataset 2..."
186
+ yield status, ""
187
+ if (
188
+ dataset2_name == default_dataset2_name
189
+ and dataset2_split == default_dataset2_split
190
+ ):
191
+ ds2 = ds_default2
192
+ else:
193
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
194
+
195
+ # Extract texts from Dataset 1
196
+ status = "Extracting texts from Dataset 1..."
197
+ yield status, ""
198
+ texts1 = [example[dataset1_text_column] for example in ds1]
199
+
200
+ # Extract texts from Dataset 2
201
+ status = "Extracting texts from Dataset 2..."
202
+ yield status, ""
203
+ texts2 = [example[dataset2_text_column] for example in ds2]
204
+
205
+ # Compute embeddings for Dataset 1
206
+ status = "Computing embeddings for Dataset 1..."
207
+ yield status, ""
208
+ embedding_matrix1 = compute_embeddings(
209
+ texts1,
210
+ batch_size=64,
211
+ progress=progress,
212
+ desc="Computing embeddings for Dataset 1",
213
+ )
214
+
215
+ # Compute embeddings for Dataset 2
216
+ status = "Computing embeddings for Dataset 2..."
217
+ yield status, ""
218
+ embedding_matrix2 = compute_embeddings(
219
+ texts2,
220
+ batch_size=64,
221
+ progress=progress,
222
+ desc="Computing embeddings for Dataset 2",
223
+ )
224
+
225
+ # Deduplicate across datasets
226
+ status = "Deduplicating embeddings across datasets..."
227
+ yield status, ""
228
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
229
+ embedding_matrix1, embedding_matrix2, threshold, progress=progress
230
+ )
231
+
232
+ num_duplicates = len(duplicate_indices_in_ds2)
233
+ num_total_ds2 = len(texts2)
234
+ num_unique_ds2 = num_total_ds2 - num_duplicates
235
+
236
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
237
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
238
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
239
+
240
+ # Show deduplicated examples
241
+ if num_duplicates > 0:
242
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
243
+ num_examples = min(5, num_duplicates)
244
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
245
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
246
+ original_text = texts1[original_idx]
247
+ duplicate_text = texts2[duplicate_idx]
248
+ differences = display_word_differences(original_text, duplicate_text)
249
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
250
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
251
+ result_text += f"**Differences:**\n{differences}\n"
252
+ result_text += "-" * 50 + "\n\n"
253
+ else:
254
+ result_text += "No duplicates found."
255
+
256
+ # Final status
257
+ status = "Deduplication completed."
258
+ yield status, result_text
259
 
260
  except Exception as e:
261
  yield f"An error occurred: {e}", ""
262
  raise e
263
 
264
+ def deduplicate_across_datasets(
265
+ embedding_matrix_1: np.ndarray,
266
+ embedding_matrix_2: np.ndarray,
267
+ threshold: float,
268
+ batch_size: int = 1024,
269
+ progress=None
270
+ ) -> tuple[list[int], dict[int, int]]:
271
+ # Building the index from Dataset 1
272
+ progress(0, desc="Building search index from Dataset 1...")
273
+ reach = Reach(
274
+ vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
275
+ )
276
 
277
+ duplicate_indices_in_test = []
278
+ duplicate_to_original_mapping = {}
279
+
280
+ # Finding nearest neighbors between datasets
281
+ progress(0, desc="Finding nearest neighbors between datasets...")
282
+ results = reach.nearest_neighbor_threshold(
283
+ embedding_matrix_2,
284
+ threshold=threshold,
285
+ batch_size=batch_size,
286
+ show_progressbar=False, # Disable internal progress bar
287
+ )
288
+
289
+ total_items = len(embedding_matrix_2)
290
+ # Processing duplicates with a progress bar
291
+ for i, similar_items in enumerate(
292
+ progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
293
+ ):
294
+ similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
295
+
296
+ if similar_indices:
297
+ duplicate_indices_in_test.append(i)
298
+ duplicate_to_original_mapping[i] = similar_indices[0]
299
+
300
+ return duplicate_indices_in_test, duplicate_to_original_mapping
301
+
302
+ with gr.Blocks() as demo:
303
+ gr.Markdown("# Semantic Deduplication")
304
 
305
  deduplication_type = gr.Radio(
306
  choices=["Single dataset", "Cross-dataset"],
 
327
 
328
  compute_button = gr.Button("Compute")
329
 
330
+ # Use 'lines' parameter to set the height
331
+ status_output = gr.Textbox(lines=10, label="Status")
332
  result_output = gr.Markdown()
333
 
334
  # Function to update the visibility of dataset2_inputs