Pringled commited on
Commit
a847bef
·
1 Parent(s): 777bab9
Files changed (1) hide show
  1. app.py +425 -252
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import model2vec
5
  from reach import Reach
6
  from difflib import ndiff
7
- import time
8
 
9
  # Load the model at startup
10
  model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -26,19 +25,7 @@ def batch_iterable(iterable, batch_size):
26
  for i in range(0, len(iterable), batch_size):
27
  yield iterable[i:i + batch_size]
28
 
29
- def log_time(message, start_time=None, logs=None):
30
- """Helper function to log the start and end times."""
31
- current_time = time.time()
32
- if start_time is not None:
33
- elapsed = current_time - start_time
34
- log_message = f"{message} - Took {elapsed:.2f} seconds"
35
- else:
36
- log_message = f"{message} - Started"
37
-
38
- if logs is not None:
39
- logs.append(log_message)
40
-
41
- def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
42
  embeddings = []
43
  total_batches = (len(texts) + batch_size - 1) // batch_size
44
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -51,38 +38,26 @@ def deduplicate(
51
  embedding_matrix: np.ndarray,
52
  threshold: float,
53
  batch_size: int = 1024,
54
- progress=None,
55
- logs=None
56
  ) -> tuple[np.ndarray, dict[int, int]]:
57
- # Building the index
58
- log_time("Building search index", logs=logs)
59
- reach = Reach(
60
- vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
61
- )
62
 
63
  deduplicated_indices = set(range(len(embedding_matrix)))
64
  duplicate_to_original_mapping = {}
65
 
66
- # Finding nearest neighbors
67
- log_time("Finding nearest neighbors", logs=logs)
68
  results = reach.nearest_neighbor_threshold(
69
  embedding_matrix,
70
  threshold=threshold,
71
  batch_size=batch_size,
72
- show_progressbar=False, # Disable internal progress bar
73
  )
74
 
75
- # Processing duplicates with a progress bar
76
  total_items = len(embedding_matrix)
77
- log_time("Processing duplicates", logs=logs)
78
- for i, similar_items in enumerate(
79
- progress.tqdm(results, desc="Processing duplicates", total=total_items)
80
- ):
81
  if i not in deduplicated_indices:
82
  continue
83
 
84
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
85
-
86
  for sim_idx in similar_indices:
87
  if sim_idx in deduplicated_indices:
88
  deduplicated_indices.remove(sim_idx)
@@ -94,11 +69,6 @@ def display_word_differences(x: str, y: str) -> str:
94
  diff = ndiff(x.split(), y.split())
95
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
96
 
97
- def encode_texts(texts, progress=None, logs=None):
98
- embedding_matrix = model.encode(texts, show_progressbar=False)
99
- log_time("Encoding texts completed", logs=logs)
100
- return embedding_matrix
101
-
102
  def perform_deduplication(
103
  deduplication_type,
104
  dataset1_name,
@@ -110,59 +80,24 @@ def perform_deduplication(
110
  threshold=default_threshold,
111
  progress=gr.Progress(track_tqdm=True),
112
  ):
113
- logs = [] # To store log messages
114
  try:
115
- # Convert threshold to float
116
  threshold = float(threshold)
117
 
118
- # Initialize status message
119
- log_time("Deduplication started", logs=logs)
120
-
121
  if deduplication_type == "Single dataset":
122
- # Load Dataset 1
123
- start_time = time.time()
124
- log_time("Loading Dataset 1", logs=logs)
125
- if (
126
- dataset1_name == default_dataset1_name
127
- and dataset1_split == default_dataset1_split
128
- ):
129
- ds = ds_default1
130
- else:
131
- ds = load_dataset(dataset1_name, split=dataset1_split)
132
- log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
133
-
134
- # Extract texts
135
- start_time = time.time()
136
- log_time("Extracting texts from Dataset 1", logs=logs)
137
  texts = [example[dataset1_text_column] for example in ds]
138
- log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
139
-
140
- # Compute embeddings
141
- start_time = time.time()
142
- log_time("Computing embeddings for Dataset 1", logs=logs)
143
- embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
144
- log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
145
-
146
- # Deduplicate
147
- start_time = time.time()
148
- log_time("Deduplicating embeddings", logs=logs)
149
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(
150
- embedding_matrix, threshold, progress=progress, logs=logs
151
- )
152
- log_time("Deduplication completed", start_time=start_time, logs=logs)
153
-
154
- # Prepare the results
155
  num_duplicates = len(duplicate_to_original_mapping)
156
  num_total = len(texts)
157
  num_deduplicated = len(deduplicated_indices)
158
 
159
  result_text = f"**Total documents:** {num_total}\n"
160
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
161
- result_text += (
162
- f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
163
- )
164
 
165
- # Show deduplicated examples
166
  if num_duplicates > 0:
167
  result_text += "**Examples of duplicates found:**\n\n"
168
  num_examples = min(5, num_duplicates)
@@ -177,16 +112,12 @@ def perform_deduplication(
177
  else:
178
  result_text += "No duplicates found."
179
 
180
- log_time("Deduplication process finished", logs=logs)
181
- full_log = "\n".join(logs) # Combine all logs into one output
182
- yield full_log, result_text
183
 
184
  except Exception as e:
185
- full_log = "\n".join(logs) # Combine all logs into one output in case of an error
186
- yield f"An error occurred: {e}", ""
187
- raise e
188
 
189
- # Adjust the height of the status_output component using custom CSS
190
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
191
  gr.Markdown("# Semantic Deduplication")
192
 
@@ -209,22 +140,14 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
209
  dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
210
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
211
 
212
- threshold = gr.Slider(
213
- minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
214
- )
215
 
216
  compute_button = gr.Button("Compute")
217
 
218
- # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
219
- status_output = gr.Markdown(elem_id="status_output")
220
  result_output = gr.Markdown()
221
 
222
- # Function to update the visibility of dataset2_inputs
223
  def update_visibility(deduplication_type_value):
224
- if deduplication_type_value == "Cross-dataset":
225
- return gr.update(visible=True)
226
- else:
227
- return gr.update(visible=False)
228
 
229
  deduplication_type.change(
230
  update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
@@ -242,21 +165,19 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
242
  dataset2_text_column,
243
  threshold,
244
  ],
245
- outputs=[status_output, result_output],
246
  )
247
 
248
  demo.launch()
249
 
250
 
251
-
252
  # import gradio as gr
253
  # from datasets import load_dataset
254
  # import numpy as np
255
- # #from model2vec import StaticModel
256
  # import model2vec
257
  # from reach import Reach
258
  # from difflib import ndiff
259
-
260
 
261
  # # Load the model at startup
262
  # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -273,13 +194,24 @@ demo.launch()
273
  # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
274
  # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
275
 
276
-
277
  # def batch_iterable(iterable, batch_size):
278
  # """Helper function to create batches from an iterable."""
279
  # for i in range(0, len(iterable), batch_size):
280
  # yield iterable[i:i + batch_size]
281
 
282
- # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
283
  # embeddings = []
284
  # total_batches = (len(texts) + batch_size - 1) // batch_size
285
  # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -292,10 +224,11 @@ demo.launch()
292
  # embedding_matrix: np.ndarray,
293
  # threshold: float,
294
  # batch_size: int = 1024,
295
- # progress=None
 
296
  # ) -> tuple[np.ndarray, dict[int, int]]:
297
  # # Building the index
298
- # progress(0, desc="Building search index...")
299
  # reach = Reach(
300
  # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
301
  # )
@@ -304,7 +237,7 @@ demo.launch()
304
  # duplicate_to_original_mapping = {}
305
 
306
  # # Finding nearest neighbors
307
- # progress(0, desc="Finding nearest neighbors...")
308
  # results = reach.nearest_neighbor_threshold(
309
  # embedding_matrix,
310
  # threshold=threshold,
@@ -314,6 +247,7 @@ demo.launch()
314
 
315
  # # Processing duplicates with a progress bar
316
  # total_items = len(embedding_matrix)
 
317
  # for i, similar_items in enumerate(
318
  # progress.tqdm(results, desc="Processing duplicates", total=total_items)
319
  # ):
@@ -333,9 +267,9 @@ demo.launch()
333
  # diff = ndiff(x.split(), y.split())
334
  # return " ".join([word for word in diff if word.startswith(("+", "-"))])
335
 
336
-
337
- # def encode_texts(texts, progress=None):
338
  # embedding_matrix = model.encode(texts, show_progressbar=False)
 
339
  # return embedding_matrix
340
 
341
  # def perform_deduplication(
@@ -349,17 +283,18 @@ demo.launch()
349
  # threshold=default_threshold,
350
  # progress=gr.Progress(track_tqdm=True),
351
  # ):
 
352
  # try:
353
  # # Convert threshold to float
354
  # threshold = float(threshold)
355
 
356
  # # Initialize status message
357
- # status = ""
358
 
359
  # if deduplication_type == "Single dataset":
360
  # # Load Dataset 1
361
- # status = "Loading Dataset 1..."
362
- # yield status, ""
363
  # if (
364
  # dataset1_name == default_dataset1_name
365
  # and dataset1_split == default_dataset1_split
@@ -367,29 +302,27 @@ demo.launch()
367
  # ds = ds_default1
368
  # else:
369
  # ds = load_dataset(dataset1_name, split=dataset1_split)
 
370
 
371
  # # Extract texts
372
- # status = "Extracting texts from Dataset 1..."
373
- # yield status, ""
374
  # texts = [example[dataset1_text_column] for example in ds]
 
 
375
  # # Compute embeddings
376
- # status = "Computing embeddings for Dataset 1..."
377
- # yield status, ""
378
- # embedding_matrix = encode_texts(texts, progress=progress)
379
- # #embedding_matrix = model.encode(texts, show_progressbar=True)
380
- # # embedding_matrix = compute_embeddings(
381
- # # texts,
382
- # # batch_size=64,
383
- # # progress=progress,
384
- # # desc="Computing embeddings for Dataset 1",
385
- # # )
386
 
387
  # # Deduplicate
388
- # status = "Deduplicating embeddings..."
389
- # yield status, ""
390
  # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
391
- # embedding_matrix, threshold, progress=progress
392
  # )
 
393
 
394
  # # Prepare the results
395
  # num_duplicates = len(duplicate_to_original_mapping)
@@ -417,141 +350,15 @@ demo.launch()
417
  # else:
418
  # result_text += "No duplicates found."
419
 
420
- # # Final status
421
- # status = "Deduplication completed."
422
- # yield status, result_text
423
-
424
- # elif deduplication_type == "Cross-dataset":
425
- # # Similar code for cross-dataset deduplication
426
- # # Load Dataset 1
427
- # status = "Loading Dataset 1..."
428
- # yield status, ""
429
- # if (
430
- # dataset1_name == default_dataset1_name
431
- # and dataset1_split == default_dataset1_split
432
- # ):
433
- # ds1 = ds_default1
434
- # else:
435
- # ds1 = load_dataset(dataset1_name, split=dataset1_split)
436
-
437
- # # Load Dataset 2
438
- # status = "Loading Dataset 2..."
439
- # yield status, ""
440
- # if (
441
- # dataset2_name == default_dataset2_name
442
- # and dataset2_split == default_dataset2_split
443
- # ):
444
- # ds2 = ds_default2
445
- # else:
446
- # ds2 = load_dataset(dataset2_name, split=dataset2_split)
447
-
448
- # # Extract texts from Dataset 1
449
- # status = "Extracting texts from Dataset 1..."
450
- # yield status, ""
451
- # texts1 = [example[dataset1_text_column] for example in ds1]
452
-
453
- # # Extract texts from Dataset 2
454
- # status = "Extracting texts from Dataset 2..."
455
- # yield status, ""
456
- # texts2 = [example[dataset2_text_column] for example in ds2]
457
-
458
- # # Compute embeddings for Dataset 1
459
- # status = "Computing embeddings for Dataset 1..."
460
- # yield status, ""
461
- # embedding_matrix1 = compute_embeddings(
462
- # texts1,
463
- # batch_size=64,
464
- # progress=progress,
465
- # desc="Computing embeddings for Dataset 1",
466
- # )
467
-
468
- # # Compute embeddings for Dataset 2
469
- # status = "Computing embeddings for Dataset 2..."
470
- # yield status, ""
471
- # embedding_matrix2 = compute_embeddings(
472
- # texts2,
473
- # batch_size=64,
474
- # progress=progress,
475
- # desc="Computing embeddings for Dataset 2",
476
- # )
477
-
478
- # # Deduplicate across datasets
479
- # status = "Deduplicating embeddings across datasets..."
480
- # yield status, ""
481
- # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
482
- # embedding_matrix1, embedding_matrix2, threshold, progress=progress
483
- # )
484
-
485
- # num_duplicates = len(duplicate_indices_in_ds2)
486
- # num_total_ds2 = len(texts2)
487
- # num_unique_ds2 = num_total_ds2 - num_duplicates
488
-
489
- # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
490
- # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
491
- # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
492
-
493
- # # Show deduplicated examples
494
- # if num_duplicates > 0:
495
- # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
496
- # num_examples = min(5, num_duplicates)
497
- # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
498
- # original_idx = duplicate_to_original_mapping[duplicate_idx]
499
- # original_text = texts1[original_idx]
500
- # duplicate_text = texts2[duplicate_idx]
501
- # differences = display_word_differences(original_text, duplicate_text)
502
- # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
503
- # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
504
- # result_text += f"**Differences:**\n{differences}\n"
505
- # result_text += "-" * 50 + "\n\n"
506
- # else:
507
- # result_text += "No duplicates found."
508
-
509
- # # Final status
510
- # status = "Deduplication completed."
511
- # yield status, result_text
512
 
513
  # except Exception as e:
 
514
  # yield f"An error occurred: {e}", ""
515
  # raise e
516
 
517
- # def deduplicate_across_datasets(
518
- # embedding_matrix_1: np.ndarray,
519
- # embedding_matrix_2: np.ndarray,
520
- # threshold: float,
521
- # batch_size: int = 1024,
522
- # progress=None
523
- # ) -> tuple[list[int], dict[int, int]]:
524
- # # Building the index from Dataset 1
525
- # progress(0, desc="Building search index from Dataset 1...")
526
- # reach = Reach(
527
- # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
528
- # )
529
-
530
- # duplicate_indices_in_test = []
531
- # duplicate_to_original_mapping = {}
532
-
533
- # # Finding nearest neighbors between datasets
534
- # progress(0, desc="Finding nearest neighbors between datasets...")
535
- # results = reach.nearest_neighbor_threshold(
536
- # embedding_matrix_2,
537
- # threshold=threshold,
538
- # batch_size=batch_size,
539
- # show_progressbar=False, # Disable internal progress bar
540
- # )
541
-
542
- # total_items = len(embedding_matrix_2)
543
- # # Processing duplicates with a progress bar
544
- # for i, similar_items in enumerate(
545
- # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
546
- # ):
547
- # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
548
-
549
- # if similar_indices:
550
- # duplicate_indices_in_test.append(i)
551
- # duplicate_to_original_mapping[i] = similar_indices[0]
552
-
553
- # return duplicate_indices_in_test, duplicate_to_original_mapping
554
-
555
  # # Adjust the height of the status_output component using custom CSS
556
  # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
557
  # gr.Markdown("# Semantic Deduplication")
@@ -612,3 +419,369 @@ demo.launch()
612
  # )
613
 
614
  # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import model2vec
5
  from reach import Reach
6
  from difflib import ndiff
 
7
 
8
  # Load the model at startup
9
  model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
25
  for i in range(0, len(iterable), batch_size):
26
  yield iterable[i:i + batch_size]
27
 
28
+ def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
29
  embeddings = []
30
  total_batches = (len(texts) + batch_size - 1) // batch_size
31
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
38
  embedding_matrix: np.ndarray,
39
  threshold: float,
40
  batch_size: int = 1024,
41
+ progress=None
 
42
  ) -> tuple[np.ndarray, dict[int, int]]:
43
+ reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
44
 
45
  deduplicated_indices = set(range(len(embedding_matrix)))
46
  duplicate_to_original_mapping = {}
47
 
 
 
48
  results = reach.nearest_neighbor_threshold(
49
  embedding_matrix,
50
  threshold=threshold,
51
  batch_size=batch_size,
52
+ show_progressbar=False,
53
  )
54
 
 
55
  total_items = len(embedding_matrix)
56
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
 
 
 
57
  if i not in deduplicated_indices:
58
  continue
59
 
60
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
 
61
  for sim_idx in similar_indices:
62
  if sim_idx in deduplicated_indices:
63
  deduplicated_indices.remove(sim_idx)
 
69
  diff = ndiff(x.split(), y.split())
70
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
71
 
 
 
 
 
 
72
  def perform_deduplication(
73
  deduplication_type,
74
  dataset1_name,
 
80
  threshold=default_threshold,
81
  progress=gr.Progress(track_tqdm=True),
82
  ):
 
83
  try:
 
84
  threshold = float(threshold)
85
 
 
 
 
86
  if deduplication_type == "Single dataset":
87
+ ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  texts = [example[dataset1_text_column] for example in ds]
89
+
90
+ embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
91
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
92
+
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  num_duplicates = len(duplicate_to_original_mapping)
94
  num_total = len(texts)
95
  num_deduplicated = len(deduplicated_indices)
96
 
97
  result_text = f"**Total documents:** {num_total}\n"
98
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
99
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
100
 
 
101
  if num_duplicates > 0:
102
  result_text += "**Examples of duplicates found:**\n\n"
103
  num_examples = min(5, num_duplicates)
 
112
  else:
113
  result_text += "No duplicates found."
114
 
115
+ yield result_text
 
 
116
 
117
  except Exception as e:
118
+ yield f"An error occurred: {e}"
 
 
119
 
120
+ # Gradio interface setup
121
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
122
  gr.Markdown("# Semantic Deduplication")
123
 
 
140
  dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
141
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
142
 
143
+ threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
 
 
144
 
145
  compute_button = gr.Button("Compute")
146
 
 
 
147
  result_output = gr.Markdown()
148
 
 
149
  def update_visibility(deduplication_type_value):
150
+ return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
 
 
 
151
 
152
  deduplication_type.change(
153
  update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
165
  dataset2_text_column,
166
  threshold,
167
  ],
168
+ outputs=[result_output],
169
  )
170
 
171
  demo.launch()
172
 
173
 
 
174
  # import gradio as gr
175
  # from datasets import load_dataset
176
  # import numpy as np
 
177
  # import model2vec
178
  # from reach import Reach
179
  # from difflib import ndiff
180
+ # import time
181
 
182
  # # Load the model at startup
183
  # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
194
  # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
195
  # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
196
 
 
197
  # def batch_iterable(iterable, batch_size):
198
  # """Helper function to create batches from an iterable."""
199
  # for i in range(0, len(iterable), batch_size):
200
  # yield iterable[i:i + batch_size]
201
 
202
+ # def log_time(message, start_time=None, logs=None):
203
+ # """Helper function to log the start and end times."""
204
+ # current_time = time.time()
205
+ # if start_time is not None:
206
+ # elapsed = current_time - start_time
207
+ # log_message = f"{message} - Took {elapsed:.2f} seconds"
208
+ # else:
209
+ # log_message = f"{message} - Started"
210
+
211
+ # if logs is not None:
212
+ # logs.append(log_message)
213
+
214
+ # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
215
  # embeddings = []
216
  # total_batches = (len(texts) + batch_size - 1) // batch_size
217
  # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
224
  # embedding_matrix: np.ndarray,
225
  # threshold: float,
226
  # batch_size: int = 1024,
227
+ # progress=None,
228
+ # logs=None
229
  # ) -> tuple[np.ndarray, dict[int, int]]:
230
  # # Building the index
231
+ # log_time("Building search index", logs=logs)
232
  # reach = Reach(
233
  # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
234
  # )
 
237
  # duplicate_to_original_mapping = {}
238
 
239
  # # Finding nearest neighbors
240
+ # log_time("Finding nearest neighbors", logs=logs)
241
  # results = reach.nearest_neighbor_threshold(
242
  # embedding_matrix,
243
  # threshold=threshold,
 
247
 
248
  # # Processing duplicates with a progress bar
249
  # total_items = len(embedding_matrix)
250
+ # log_time("Processing duplicates", logs=logs)
251
  # for i, similar_items in enumerate(
252
  # progress.tqdm(results, desc="Processing duplicates", total=total_items)
253
  # ):
 
267
  # diff = ndiff(x.split(), y.split())
268
  # return " ".join([word for word in diff if word.startswith(("+", "-"))])
269
 
270
+ # def encode_texts(texts, progress=None, logs=None):
 
271
  # embedding_matrix = model.encode(texts, show_progressbar=False)
272
+ # log_time("Encoding texts completed", logs=logs)
273
  # return embedding_matrix
274
 
275
  # def perform_deduplication(
 
283
  # threshold=default_threshold,
284
  # progress=gr.Progress(track_tqdm=True),
285
  # ):
286
+ # logs = [] # To store log messages
287
  # try:
288
  # # Convert threshold to float
289
  # threshold = float(threshold)
290
 
291
  # # Initialize status message
292
+ # log_time("Deduplication started", logs=logs)
293
 
294
  # if deduplication_type == "Single dataset":
295
  # # Load Dataset 1
296
+ # start_time = time.time()
297
+ # log_time("Loading Dataset 1", logs=logs)
298
  # if (
299
  # dataset1_name == default_dataset1_name
300
  # and dataset1_split == default_dataset1_split
 
302
  # ds = ds_default1
303
  # else:
304
  # ds = load_dataset(dataset1_name, split=dataset1_split)
305
+ # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
306
 
307
  # # Extract texts
308
+ # start_time = time.time()
309
+ # log_time("Extracting texts from Dataset 1", logs=logs)
310
  # texts = [example[dataset1_text_column] for example in ds]
311
+ # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
312
+
313
  # # Compute embeddings
314
+ # start_time = time.time()
315
+ # log_time("Computing embeddings for Dataset 1", logs=logs)
316
+ # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
317
+ # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
 
 
 
 
 
 
318
 
319
  # # Deduplicate
320
+ # start_time = time.time()
321
+ # log_time("Deduplicating embeddings", logs=logs)
322
  # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
323
+ # embedding_matrix, threshold, progress=progress, logs=logs
324
  # )
325
+ # log_time("Deduplication completed", start_time=start_time, logs=logs)
326
 
327
  # # Prepare the results
328
  # num_duplicates = len(duplicate_to_original_mapping)
 
350
  # else:
351
  # result_text += "No duplicates found."
352
 
353
+ # log_time("Deduplication process finished", logs=logs)
354
+ # full_log = "\n".join(logs) # Combine all logs into one output
355
+ # yield full_log, result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  # except Exception as e:
358
+ # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
359
  # yield f"An error occurred: {e}", ""
360
  # raise e
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  # # Adjust the height of the status_output component using custom CSS
363
  # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
364
  # gr.Markdown("# Semantic Deduplication")
 
419
  # )
420
 
421
  # demo.launch()
422
+
423
+
424
+
425
+ # # import gradio as gr
426
+ # # from datasets import load_dataset
427
+ # # import numpy as np
428
+ # # #from model2vec import StaticModel
429
+ # # import model2vec
430
+ # # from reach import Reach
431
+ # # from difflib import ndiff
432
+
433
+
434
+ # # # Load the model at startup
435
+ # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
436
+
437
+ # # # Default dataset parameters
438
+ # # default_dataset1_name = "sst2"
439
+ # # default_dataset1_split = "train"
440
+ # # default_dataset2_name = "sst2"
441
+ # # default_dataset2_split = "validation"
442
+ # # default_text_column = "sentence"
443
+ # # default_threshold = 0.9
444
+
445
+ # # # Load the default datasets at startup
446
+ # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
447
+ # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
448
+
449
+
450
+ # # def batch_iterable(iterable, batch_size):
451
+ # # """Helper function to create batches from an iterable."""
452
+ # # for i in range(0, len(iterable), batch_size):
453
+ # # yield iterable[i:i + batch_size]
454
+
455
+ # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
456
+ # # embeddings = []
457
+ # # total_batches = (len(texts) + batch_size - 1) // batch_size
458
+ # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
459
+ # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
460
+ # # embeddings.append(batch_embeddings)
461
+ # # progress((i + 1) / total_batches, desc=desc)
462
+ # # return np.concatenate(embeddings, axis=0)
463
+
464
+ # # def deduplicate(
465
+ # # embedding_matrix: np.ndarray,
466
+ # # threshold: float,
467
+ # # batch_size: int = 1024,
468
+ # # progress=None
469
+ # # ) -> tuple[np.ndarray, dict[int, int]]:
470
+ # # # Building the index
471
+ # # progress(0, desc="Building search index...")
472
+ # # reach = Reach(
473
+ # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
474
+ # # )
475
+
476
+ # # deduplicated_indices = set(range(len(embedding_matrix)))
477
+ # # duplicate_to_original_mapping = {}
478
+
479
+ # # # Finding nearest neighbors
480
+ # # progress(0, desc="Finding nearest neighbors...")
481
+ # # results = reach.nearest_neighbor_threshold(
482
+ # # embedding_matrix,
483
+ # # threshold=threshold,
484
+ # # batch_size=batch_size,
485
+ # # show_progressbar=False, # Disable internal progress bar
486
+ # # )
487
+
488
+ # # # Processing duplicates with a progress bar
489
+ # # total_items = len(embedding_matrix)
490
+ # # for i, similar_items in enumerate(
491
+ # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
492
+ # # ):
493
+ # # if i not in deduplicated_indices:
494
+ # # continue
495
+
496
+ # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
497
+
498
+ # # for sim_idx in similar_indices:
499
+ # # if sim_idx in deduplicated_indices:
500
+ # # deduplicated_indices.remove(sim_idx)
501
+ # # duplicate_to_original_mapping[sim_idx] = i
502
+
503
+ # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
504
+
505
+ # # def display_word_differences(x: str, y: str) -> str:
506
+ # # diff = ndiff(x.split(), y.split())
507
+ # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
508
+
509
+
510
+ # # def encode_texts(texts, progress=None):
511
+ # # embedding_matrix = model.encode(texts, show_progressbar=False)
512
+ # # return embedding_matrix
513
+
514
+ # # def perform_deduplication(
515
+ # # deduplication_type,
516
+ # # dataset1_name,
517
+ # # dataset1_split,
518
+ # # dataset1_text_column,
519
+ # # dataset2_name="",
520
+ # # dataset2_split="",
521
+ # # dataset2_text_column="",
522
+ # # threshold=default_threshold,
523
+ # # progress=gr.Progress(track_tqdm=True),
524
+ # # ):
525
+ # # try:
526
+ # # # Convert threshold to float
527
+ # # threshold = float(threshold)
528
+
529
+ # # # Initialize status message
530
+ # # status = ""
531
+
532
+ # # if deduplication_type == "Single dataset":
533
+ # # # Load Dataset 1
534
+ # # status = "Loading Dataset 1..."
535
+ # # yield status, ""
536
+ # # if (
537
+ # # dataset1_name == default_dataset1_name
538
+ # # and dataset1_split == default_dataset1_split
539
+ # # ):
540
+ # # ds = ds_default1
541
+ # # else:
542
+ # # ds = load_dataset(dataset1_name, split=dataset1_split)
543
+
544
+ # # # Extract texts
545
+ # # status = "Extracting texts from Dataset 1..."
546
+ # # yield status, ""
547
+ # # texts = [example[dataset1_text_column] for example in ds]
548
+ # # # Compute embeddings
549
+ # # status = "Computing embeddings for Dataset 1..."
550
+ # # yield status, ""
551
+ # # embedding_matrix = encode_texts(texts, progress=progress)
552
+ # # #embedding_matrix = model.encode(texts, show_progressbar=True)
553
+ # # # embedding_matrix = compute_embeddings(
554
+ # # # texts,
555
+ # # # batch_size=64,
556
+ # # # progress=progress,
557
+ # # # desc="Computing embeddings for Dataset 1",
558
+ # # # )
559
+
560
+ # # # Deduplicate
561
+ # # status = "Deduplicating embeddings..."
562
+ # # yield status, ""
563
+ # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
564
+ # # embedding_matrix, threshold, progress=progress
565
+ # # )
566
+
567
+ # # # Prepare the results
568
+ # # num_duplicates = len(duplicate_to_original_mapping)
569
+ # # num_total = len(texts)
570
+ # # num_deduplicated = len(deduplicated_indices)
571
+
572
+ # # result_text = f"**Total documents:** {num_total}\n"
573
+ # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
574
+ # # result_text += (
575
+ # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
576
+ # # )
577
+
578
+ # # # Show deduplicated examples
579
+ # # if num_duplicates > 0:
580
+ # # result_text += "**Examples of duplicates found:**\n\n"
581
+ # # num_examples = min(5, num_duplicates)
582
+ # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
583
+ # # original_text = texts[original_idx]
584
+ # # duplicate_text = texts[duplicate_idx]
585
+ # # differences = display_word_differences(original_text, duplicate_text)
586
+ # # result_text += f"**Original text:**\n{original_text}\n\n"
587
+ # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
588
+ # # result_text += f"**Differences:**\n{differences}\n"
589
+ # # result_text += "-" * 50 + "\n\n"
590
+ # # else:
591
+ # # result_text += "No duplicates found."
592
+
593
+ # # # Final status
594
+ # # status = "Deduplication completed."
595
+ # # yield status, result_text
596
+
597
+ # # elif deduplication_type == "Cross-dataset":
598
+ # # # Similar code for cross-dataset deduplication
599
+ # # # Load Dataset 1
600
+ # # status = "Loading Dataset 1..."
601
+ # # yield status, ""
602
+ # # if (
603
+ # # dataset1_name == default_dataset1_name
604
+ # # and dataset1_split == default_dataset1_split
605
+ # # ):
606
+ # # ds1 = ds_default1
607
+ # # else:
608
+ # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
609
+
610
+ # # # Load Dataset 2
611
+ # # status = "Loading Dataset 2..."
612
+ # # yield status, ""
613
+ # # if (
614
+ # # dataset2_name == default_dataset2_name
615
+ # # and dataset2_split == default_dataset2_split
616
+ # # ):
617
+ # # ds2 = ds_default2
618
+ # # else:
619
+ # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
620
+
621
+ # # # Extract texts from Dataset 1
622
+ # # status = "Extracting texts from Dataset 1..."
623
+ # # yield status, ""
624
+ # # texts1 = [example[dataset1_text_column] for example in ds1]
625
+
626
+ # # # Extract texts from Dataset 2
627
+ # # status = "Extracting texts from Dataset 2..."
628
+ # # yield status, ""
629
+ # # texts2 = [example[dataset2_text_column] for example in ds2]
630
+
631
+ # # # Compute embeddings for Dataset 1
632
+ # # status = "Computing embeddings for Dataset 1..."
633
+ # # yield status, ""
634
+ # # embedding_matrix1 = compute_embeddings(
635
+ # # texts1,
636
+ # # batch_size=64,
637
+ # # progress=progress,
638
+ # # desc="Computing embeddings for Dataset 1",
639
+ # # )
640
+
641
+ # # # Compute embeddings for Dataset 2
642
+ # # status = "Computing embeddings for Dataset 2..."
643
+ # # yield status, ""
644
+ # # embedding_matrix2 = compute_embeddings(
645
+ # # texts2,
646
+ # # batch_size=64,
647
+ # # progress=progress,
648
+ # # desc="Computing embeddings for Dataset 2",
649
+ # # )
650
+
651
+ # # # Deduplicate across datasets
652
+ # # status = "Deduplicating embeddings across datasets..."
653
+ # # yield status, ""
654
+ # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
655
+ # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
656
+ # # )
657
+
658
+ # # num_duplicates = len(duplicate_indices_in_ds2)
659
+ # # num_total_ds2 = len(texts2)
660
+ # # num_unique_ds2 = num_total_ds2 - num_duplicates
661
+
662
+ # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
663
+ # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
664
+ # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
665
+
666
+ # # # Show deduplicated examples
667
+ # # if num_duplicates > 0:
668
+ # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
669
+ # # num_examples = min(5, num_duplicates)
670
+ # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
671
+ # # original_idx = duplicate_to_original_mapping[duplicate_idx]
672
+ # # original_text = texts1[original_idx]
673
+ # # duplicate_text = texts2[duplicate_idx]
674
+ # # differences = display_word_differences(original_text, duplicate_text)
675
+ # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
676
+ # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
677
+ # # result_text += f"**Differences:**\n{differences}\n"
678
+ # # result_text += "-" * 50 + "\n\n"
679
+ # # else:
680
+ # # result_text += "No duplicates found."
681
+
682
+ # # # Final status
683
+ # # status = "Deduplication completed."
684
+ # # yield status, result_text
685
+
686
+ # # except Exception as e:
687
+ # # yield f"An error occurred: {e}", ""
688
+ # # raise e
689
+
690
+ # # def deduplicate_across_datasets(
691
+ # # embedding_matrix_1: np.ndarray,
692
+ # # embedding_matrix_2: np.ndarray,
693
+ # # threshold: float,
694
+ # # batch_size: int = 1024,
695
+ # # progress=None
696
+ # # ) -> tuple[list[int], dict[int, int]]:
697
+ # # # Building the index from Dataset 1
698
+ # # progress(0, desc="Building search index from Dataset 1...")
699
+ # # reach = Reach(
700
+ # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
701
+ # # )
702
+
703
+ # # duplicate_indices_in_test = []
704
+ # # duplicate_to_original_mapping = {}
705
+
706
+ # # # Finding nearest neighbors between datasets
707
+ # # progress(0, desc="Finding nearest neighbors between datasets...")
708
+ # # results = reach.nearest_neighbor_threshold(
709
+ # # embedding_matrix_2,
710
+ # # threshold=threshold,
711
+ # # batch_size=batch_size,
712
+ # # show_progressbar=False, # Disable internal progress bar
713
+ # # )
714
+
715
+ # # total_items = len(embedding_matrix_2)
716
+ # # # Processing duplicates with a progress bar
717
+ # # for i, similar_items in enumerate(
718
+ # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
719
+ # # ):
720
+ # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
721
+
722
+ # # if similar_indices:
723
+ # # duplicate_indices_in_test.append(i)
724
+ # # duplicate_to_original_mapping[i] = similar_indices[0]
725
+
726
+ # # return duplicate_indices_in_test, duplicate_to_original_mapping
727
+
728
+ # # # Adjust the height of the status_output component using custom CSS
729
+ # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
730
+ # # gr.Markdown("# Semantic Deduplication")
731
+
732
+ # # deduplication_type = gr.Radio(
733
+ # # choices=["Single dataset", "Cross-dataset"],
734
+ # # label="Deduplication Type",
735
+ # # value="Single dataset",
736
+ # # )
737
+
738
+ # # with gr.Row():
739
+ # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
740
+ # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
741
+ # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
742
+
743
+ # # dataset2_inputs = gr.Column(visible=False)
744
+ # # with dataset2_inputs:
745
+ # # gr.Markdown("### Dataset 2")
746
+ # # with gr.Row():
747
+ # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
748
+ # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
749
+ # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
750
+
751
+ # # threshold = gr.Slider(
752
+ # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
753
+ # # )
754
+
755
+ # # compute_button = gr.Button("Compute")
756
+
757
+ # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
758
+ # # status_output = gr.Markdown(elem_id="status_output")
759
+ # # result_output = gr.Markdown()
760
+
761
+ # # # Function to update the visibility of dataset2_inputs
762
+ # # def update_visibility(deduplication_type_value):
763
+ # # if deduplication_type_value == "Cross-dataset":
764
+ # # return gr.update(visible=True)
765
+ # # else:
766
+ # # return gr.update(visible=False)
767
+
768
+ # # deduplication_type.change(
769
+ # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
770
+ # # )
771
+
772
+ # # compute_button.click(
773
+ # # fn=perform_deduplication,
774
+ # # inputs=[
775
+ # # deduplication_type,
776
+ # # dataset1_name,
777
+ # # dataset1_split,
778
+ # # dataset1_text_column,
779
+ # # dataset2_name,
780
+ # # dataset2_split,
781
+ # # dataset2_text_column,
782
+ # # threshold,
783
+ # # ],
784
+ # # outputs=[status_output, result_output],
785
+ # # )
786
+
787
+ # # demo.launch()