Pringled commited on
Commit
c58907b
·
1 Parent(s): 1a5f99b
Files changed (1) hide show
  1. app.py +897 -541
app.py CHANGED
@@ -4,12 +4,11 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
- import tqdm
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
11
 
12
- # Update default dataset to 'sst2' and set default threshold to 0.9
13
  default_dataset1_name = "sst2"
14
  default_dataset1_split = "train"
15
  default_dataset2_name = "sst2"
@@ -28,29 +27,42 @@ def batch_iterable(iterable, batch_size):
28
 
29
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
30
  embeddings = []
31
- for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
32
- batch_embeddings = model.encode(batch, show_progressbar=False)
 
33
  embeddings.append(batch_embeddings)
 
34
  return np.concatenate(embeddings, axis=0)
35
 
36
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
37
- """
38
- Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
39
- """
40
- reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
 
 
41
 
42
  deduplicated_indices = set(range(len(embedding_matrix)))
43
  duplicate_to_original_mapping = {}
44
 
 
 
45
  results = reach.nearest_neighbor_threshold(
46
  embedding_matrix,
47
  threshold=threshold,
48
  batch_size=batch_size,
49
- show_progressbar=False
50
  )
51
 
 
52
  total_items = len(embedding_matrix)
53
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
 
 
54
  if i not in deduplicated_indices:
55
  continue
56
 
@@ -63,35 +75,9 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
63
 
64
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
65
 
66
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
67
- """
68
- Deduplicate embeddings across two datasets and return the indices of duplicates between them.
69
- """
70
- reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
71
-
72
- duplicate_indices_in_test = []
73
- duplicate_to_original_mapping = {}
74
-
75
- results = reach.nearest_neighbor_threshold(
76
- embedding_matrix_2,
77
- threshold=threshold,
78
- batch_size=batch_size,
79
- show_progressbar=False
80
- )
81
-
82
- total_items = len(embedding_matrix_2)
83
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
84
- similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
85
-
86
- if similar_indices:
87
- duplicate_indices_in_test.append(i)
88
- duplicate_to_original_mapping[i] = similar_indices[0]
89
-
90
- return duplicate_indices_in_test, duplicate_to_original_mapping
91
-
92
  def display_word_differences(x: str, y: str) -> str:
93
  diff = ndiff(x.split(), y.split())
94
- return " ".join([word for word in diff if word.startswith(('+', '-'))])
95
 
96
  def perform_deduplication(
97
  deduplication_type,
@@ -102,26 +88,61 @@ def perform_deduplication(
102
  dataset2_split="",
103
  dataset2_text_column="",
104
  threshold=default_threshold,
105
- progress=gr.Progress(track_tqdm=True)
106
  ):
107
  try:
 
108
  threshold = float(threshold)
109
 
 
 
 
110
  if deduplication_type == "Single dataset":
111
- ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
112
- texts = [example[dataset1_text_column] for example in ds]
 
 
 
 
 
 
 
 
113
 
114
- embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
115
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  num_duplicates = len(duplicate_to_original_mapping)
118
  num_total = len(texts)
119
  num_deduplicated = len(deduplicated_indices)
120
 
121
  result_text = f"**Total documents:** {num_total}\n"
122
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
123
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
124
 
 
125
  if num_duplicates > 0:
126
  result_text += "**Examples of duplicates found:**\n\n"
127
  num_examples = min(5, num_duplicates)
@@ -136,19 +157,70 @@ def perform_deduplication(
136
  else:
137
  result_text += "No duplicates found."
138
 
139
- yield result_text
 
 
140
 
141
  elif deduplication_type == "Cross-dataset":
142
- ds1 = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
143
- ds2 = ds_default2 if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split else load_dataset(dataset2_name, split=dataset2_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
145
  texts1 = [example[dataset1_text_column] for example in ds1]
146
- texts2 = [example[dataset2_text_column] for example in ds2]
147
 
148
- embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
149
- embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
 
 
150
 
151
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  num_duplicates = len(duplicate_indices_in_ds2)
154
  num_total_ds2 = len(texts2)
@@ -158,6 +230,7 @@ def perform_deduplication(
158
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
159
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
160
 
 
161
  if num_duplicates > 0:
162
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
163
  num_examples = min(5, num_duplicates)
@@ -173,19 +246,60 @@ def perform_deduplication(
173
  else:
174
  result_text += "No duplicates found."
175
 
176
- yield result_text
 
 
177
 
178
  except Exception as e:
179
  yield f"An error occurred: {e}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Adjust the height of the status_output and result_output components
182
- with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_output { height: 300px; overflow: auto; }") as demo:
 
 
 
 
 
 
183
  gr.Markdown("# Semantic Deduplication")
184
 
185
  deduplication_type = gr.Radio(
186
  choices=["Single dataset", "Cross-dataset"],
187
  label="Deduplication Type",
188
- value="Single dataset"
189
  )
190
 
191
  with gr.Row():
@@ -202,17 +316,16 @@ with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_ou
202
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
203
 
204
  threshold = gr.Slider(
205
- minimum=0.0,
206
- maximum=1.0,
207
- value=default_threshold,
208
- label="Similarity Threshold"
209
  )
210
 
211
  compute_button = gr.Button("Compute")
212
 
 
213
  status_output = gr.Markdown(elem_id="status_output")
214
- result_output = gr.Markdown(elem_id="result_output")
215
 
 
216
  def update_visibility(deduplication_type_value):
217
  if deduplication_type_value == "Cross-dataset":
218
  return gr.update(visible=True)
@@ -220,9 +333,7 @@ with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_ou
220
  return gr.update(visible=False)
221
 
222
  deduplication_type.change(
223
- update_visibility,
224
- inputs=deduplication_type,
225
- outputs=dataset2_inputs
226
  )
227
 
228
  compute_button.click(
@@ -235,13 +346,14 @@ with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_ou
235
  dataset2_name,
236
  dataset2_split,
237
  dataset2_text_column,
238
- threshold
239
  ],
240
- outputs=[status_output, result_output]
241
  )
242
 
243
  demo.launch()
244
 
 
245
  # import gradio as gr
246
  # from datasets import load_dataset
247
  # import numpy as np
@@ -281,23 +393,18 @@ demo.launch()
281
  # """
282
  # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
283
  # """
284
- # # Building the index
285
- # progress(0, desc="Building search index...")
286
  # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
287
 
288
  # deduplicated_indices = set(range(len(embedding_matrix)))
289
  # duplicate_to_original_mapping = {}
290
 
291
- # # Finding nearest neighbors
292
- # progress(0, desc="Finding nearest neighbors...")
293
  # results = reach.nearest_neighbor_threshold(
294
  # embedding_matrix,
295
  # threshold=threshold,
296
  # batch_size=batch_size,
297
- # show_progressbar=False # Disable internal progress bar
298
  # )
299
 
300
- # # Processing duplicates with a progress bar
301
  # total_items = len(embedding_matrix)
302
  # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
303
  # if i not in deduplicated_indices:
@@ -316,24 +423,19 @@ demo.launch()
316
  # """
317
  # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
318
  # """
319
- # # Building the index from Dataset 1
320
- # progress(0, desc="Building search index from Dataset 1...")
321
  # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
322
 
323
  # duplicate_indices_in_test = []
324
  # duplicate_to_original_mapping = {}
325
 
326
- # # Finding nearest neighbors between datasets
327
- # progress(0, desc="Finding nearest neighbors between datasets...")
328
  # results = reach.nearest_neighbor_threshold(
329
  # embedding_matrix_2,
330
  # threshold=threshold,
331
  # batch_size=batch_size,
332
- # show_progressbar=False # Disable internal progress bar
333
  # )
334
 
335
  # total_items = len(embedding_matrix_2)
336
- # # Processing duplicates with a progress bar
337
  # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
338
  # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
339
 
@@ -359,39 +461,15 @@ demo.launch()
359
  # progress=gr.Progress(track_tqdm=True)
360
  # ):
361
  # try:
362
- # # Convert threshold to float
363
  # threshold = float(threshold)
364
 
365
- # # Initialize status message
366
- # status = ""
367
-
368
  # if deduplication_type == "Single dataset":
369
- # # Load Dataset 1
370
- # status = "Loading Dataset 1..."
371
- # yield status, ""
372
- # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
373
- # ds = ds_default1
374
- # else:
375
- # ds = load_dataset(dataset1_name, split=dataset1_split)
376
-
377
- # # Extract texts
378
- # status = "Extracting texts from Dataset 1..."
379
- # yield status, ""
380
  # texts = [example[dataset1_text_column] for example in ds]
381
 
382
- # # Compute embeddings
383
- # status = "Computing embeddings for Dataset 1..."
384
- # yield status, ""
385
  # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
386
 
387
- # # Deduplicate
388
- # status = "Deduplicating embeddings..."
389
- # yield status, ""
390
- # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
391
- # embedding_matrix, threshold, progress=progress
392
- # )
393
-
394
- # # Prepare the results
395
  # num_duplicates = len(duplicate_to_original_mapping)
396
  # num_total = len(texts)
397
  # num_deduplicated = len(deduplicated_indices)
@@ -400,7 +478,6 @@ demo.launch()
400
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
401
  # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
402
 
403
- # # Show deduplicated examples
404
  # if num_duplicates > 0:
405
  # result_text += "**Examples of duplicates found:**\n\n"
406
  # num_examples = min(5, num_duplicates)
@@ -415,53 +492,19 @@ demo.launch()
415
  # else:
416
  # result_text += "No duplicates found."
417
 
418
- # # Final status
419
- # status = "Deduplication completed."
420
- # yield status, result_text
421
 
422
  # elif deduplication_type == "Cross-dataset":
423
- # # Load Dataset 1
424
- # status = "Loading Dataset 1..."
425
- # yield status, ""
426
- # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
427
- # ds1 = ds_default1
428
- # else:
429
- # ds1 = load_dataset(dataset1_name, split=dataset1_split)
430
-
431
- # # Load Dataset 2
432
- # status = "Loading Dataset 2..."
433
- # yield status, ""
434
- # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
435
- # ds2 = ds_default2
436
- # else:
437
- # ds2 = load_dataset(dataset2_name, split=dataset2_split)
438
 
439
- # # Extract texts from Dataset 1
440
- # status = "Extracting texts from Dataset 1..."
441
- # yield status, ""
442
  # texts1 = [example[dataset1_text_column] for example in ds1]
443
-
444
- # # Extract texts from Dataset 2
445
- # status = "Extracting texts from Dataset 2..."
446
- # yield status, ""
447
  # texts2 = [example[dataset2_text_column] for example in ds2]
448
 
449
- # # Compute embeddings for Dataset 1
450
- # status = "Computing embeddings for Dataset 1..."
451
- # yield status, ""
452
  # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
453
-
454
- # # Compute embeddings for Dataset 2
455
- # status = "Computing embeddings for Dataset 2..."
456
- # yield status, ""
457
  # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
458
 
459
- # # Deduplicate across datasets
460
- # status = "Deduplicating embeddings across datasets..."
461
- # yield status, ""
462
- # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
463
- # embedding_matrix1, embedding_matrix2, threshold, progress=progress
464
- # )
465
 
466
  # num_duplicates = len(duplicate_indices_in_ds2)
467
  # num_total_ds2 = len(texts2)
@@ -471,7 +514,6 @@ demo.launch()
471
  # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
472
  # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
473
 
474
- # # Show deduplicated examples
475
  # if num_duplicates > 0:
476
  # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
477
  # num_examples = min(5, num_duplicates)
@@ -487,15 +529,13 @@ demo.launch()
487
  # else:
488
  # result_text += "No duplicates found."
489
 
490
- # # Final status
491
- # status = "Deduplication completed."
492
- # yield status, result_text
493
 
494
  # except Exception as e:
495
  # yield f"An error occurred: {e}", ""
496
- # raise e
497
 
498
- # with gr.Blocks() as demo:
 
499
  # gr.Markdown("# Semantic Deduplication")
500
 
501
  # deduplication_type = gr.Radio(
@@ -526,10 +566,9 @@ demo.launch()
526
 
527
  # compute_button = gr.Button("Compute")
528
 
529
- # status_output = gr.Markdown()
530
- # result_output = gr.Markdown()
531
 
532
- # # Function to update the visibility of dataset2_inputs
533
  # def update_visibility(deduplication_type_value):
534
  # if deduplication_type_value == "Cross-dataset":
535
  # return gr.update(visible=True)
@@ -559,178 +598,322 @@ demo.launch()
559
 
560
  # demo.launch()
561
 
 
 
 
 
 
 
 
562
 
563
- # import gradio as gr
564
- # from datasets import load_dataset
565
- # import numpy as np
566
- # import model2vec
567
- # from reach import Reach
568
- # from difflib import ndiff
569
 
570
- # # Load the model at startup
571
- # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
 
 
 
 
572
 
573
- # # Default dataset parameters
574
- # default_dataset1_name = "sst2"
575
- # default_dataset1_split = "train"
576
- # default_dataset2_name = "sst2"
577
- # default_dataset2_split = "validation"
578
- # default_text_column = "sentence"
579
- # default_threshold = 0.9
580
 
581
- # # Load the default datasets at startup
582
- # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
583
- # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
 
584
 
585
- # def batch_iterable(iterable, batch_size):
586
- # """Helper function to create batches from an iterable."""
587
- # for i in range(0, len(iterable), batch_size):
588
- # yield iterable[i:i + batch_size]
 
 
589
 
590
- # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
591
- # embeddings = []
592
- # total_batches = (len(texts) + batch_size - 1) // batch_size
593
- # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
594
- # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
595
- # embeddings.append(batch_embeddings)
596
- # progress((i + 1) / total_batches, desc=desc)
597
- # return np.concatenate(embeddings, axis=0)
598
 
599
- # def deduplicate(
600
- # embedding_matrix: np.ndarray,
601
- # threshold: float,
602
- # batch_size: int = 1024,
603
- # progress=None
604
- # ) -> tuple[np.ndarray, dict[int, int]]:
605
- # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
606
 
607
- # deduplicated_indices = set(range(len(embedding_matrix)))
608
- # duplicate_to_original_mapping = {}
 
 
 
 
 
 
609
 
610
- # results = reach.nearest_neighbor_threshold(
611
- # embedding_matrix,
612
- # threshold=threshold,
613
- # batch_size=batch_size,
614
- # show_progressbar=False,
615
- # )
616
 
617
- # total_items = len(embedding_matrix)
618
- # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
619
- # if i not in deduplicated_indices:
620
- # continue
621
 
622
- # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
623
- # for sim_idx in similar_indices:
624
- # if sim_idx in deduplicated_indices:
625
- # deduplicated_indices.remove(sim_idx)
626
- # duplicate_to_original_mapping[sim_idx] = i
627
 
628
- # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
629
 
630
- # def display_word_differences(x: str, y: str) -> str:
631
- # diff = ndiff(x.split(), y.split())
632
- # return " ".join([word for word in diff if word.startswith(("+", "-"))])
 
 
 
 
633
 
634
- # def perform_deduplication(
635
- # deduplication_type,
636
- # dataset1_name,
637
- # dataset1_split,
638
- # dataset1_text_column,
639
- # dataset2_name="",
640
- # dataset2_split="",
641
- # dataset2_text_column="",
642
- # threshold=default_threshold,
643
- # progress=gr.Progress(track_tqdm=True),
644
- # ):
645
- # try:
646
- # threshold = float(threshold)
647
 
648
- # if deduplication_type == "Single dataset":
649
- # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
650
- # texts = [example[dataset1_text_column] for example in ds]
 
 
 
 
 
651
 
652
- # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
653
- # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
 
654
 
655
- # num_duplicates = len(duplicate_to_original_mapping)
656
- # num_total = len(texts)
657
- # num_deduplicated = len(deduplicated_indices)
658
 
659
- # result_text = f"**Total documents:** {num_total}\n"
660
- # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
661
- # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
662
 
663
- # if num_duplicates > 0:
664
- # result_text += "**Examples of duplicates found:**\n\n"
665
- # num_examples = min(5, num_duplicates)
666
- # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
667
- # original_text = texts[original_idx]
668
- # duplicate_text = texts[duplicate_idx]
669
- # differences = display_word_differences(original_text, duplicate_text)
670
- # result_text += f"**Original text:**\n{original_text}\n\n"
671
- # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
672
- # result_text += f"**Differences:**\n{differences}\n"
673
- # result_text += "-" * 50 + "\n\n"
674
- # else:
675
- # result_text += "No duplicates found."
676
 
677
- # yield result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
- # except Exception as e:
680
- # yield f"An error occurred: {e}"
681
 
682
- # # Gradio interface setup
683
- # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
684
- # gr.Markdown("# Semantic Deduplication")
 
 
 
 
 
685
 
686
- # deduplication_type = gr.Radio(
687
- # choices=["Single dataset", "Cross-dataset"],
688
- # label="Deduplication Type",
689
- # value="Single dataset",
690
- # )
691
 
692
- # with gr.Row():
693
- # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
694
- # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
695
- # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
696
 
697
- # dataset2_inputs = gr.Column(visible=False)
698
- # with dataset2_inputs:
699
- # gr.Markdown("### Dataset 2")
700
- # with gr.Row():
701
- # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
702
- # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
703
- # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
704
 
705
- # threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
 
 
 
706
 
707
- # compute_button = gr.Button("Compute")
 
 
708
 
709
- # result_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
- # def update_visibility(deduplication_type_value):
712
- # return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
 
713
 
714
- # deduplication_type.change(
715
- # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
716
- # )
 
 
 
 
 
717
 
718
- # compute_button.click(
719
- # fn=perform_deduplication,
720
- # inputs=[
721
- # deduplication_type,
722
- # dataset1_name,
723
- # dataset1_split,
724
- # dataset1_text_column,
725
- # dataset2_name,
726
- # dataset2_split,
727
- # dataset2_text_column,
728
- # threshold,
729
- # ],
730
- # outputs=[result_output],
731
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
 
733
- # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
 
735
 
736
  # # import gradio as gr
@@ -739,7 +922,6 @@ demo.launch()
739
  # # import model2vec
740
  # # from reach import Reach
741
  # # from difflib import ndiff
742
- # # import time
743
 
744
  # # # Load the model at startup
745
  # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -761,19 +943,7 @@ demo.launch()
761
  # # for i in range(0, len(iterable), batch_size):
762
  # # yield iterable[i:i + batch_size]
763
 
764
- # # def log_time(message, start_time=None, logs=None):
765
- # # """Helper function to log the start and end times."""
766
- # # current_time = time.time()
767
- # # if start_time is not None:
768
- # # elapsed = current_time - start_time
769
- # # log_message = f"{message} - Took {elapsed:.2f} seconds"
770
- # # else:
771
- # # log_message = f"{message} - Started"
772
-
773
- # # if logs is not None:
774
- # # logs.append(log_message)
775
-
776
- # # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
777
  # # embeddings = []
778
  # # total_batches = (len(texts) + batch_size - 1) // batch_size
779
  # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -786,38 +956,26 @@ demo.launch()
786
  # # embedding_matrix: np.ndarray,
787
  # # threshold: float,
788
  # # batch_size: int = 1024,
789
- # # progress=None,
790
- # # logs=None
791
  # # ) -> tuple[np.ndarray, dict[int, int]]:
792
- # # # Building the index
793
- # # log_time("Building search index", logs=logs)
794
- # # reach = Reach(
795
- # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
796
- # # )
797
 
798
  # # deduplicated_indices = set(range(len(embedding_matrix)))
799
  # # duplicate_to_original_mapping = {}
800
 
801
- # # # Finding nearest neighbors
802
- # # log_time("Finding nearest neighbors", logs=logs)
803
  # # results = reach.nearest_neighbor_threshold(
804
  # # embedding_matrix,
805
  # # threshold=threshold,
806
  # # batch_size=batch_size,
807
- # # show_progressbar=False, # Disable internal progress bar
808
  # # )
809
 
810
- # # # Processing duplicates with a progress bar
811
  # # total_items = len(embedding_matrix)
812
- # # log_time("Processing duplicates", logs=logs)
813
- # # for i, similar_items in enumerate(
814
- # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
815
- # # ):
816
  # # if i not in deduplicated_indices:
817
  # # continue
818
 
819
  # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
820
-
821
  # # for sim_idx in similar_indices:
822
  # # if sim_idx in deduplicated_indices:
823
  # # deduplicated_indices.remove(sim_idx)
@@ -829,11 +987,6 @@ demo.launch()
829
  # # diff = ndiff(x.split(), y.split())
830
  # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
831
 
832
- # # def encode_texts(texts, progress=None, logs=None):
833
- # # embedding_matrix = model.encode(texts, show_progressbar=False)
834
- # # log_time("Encoding texts completed", logs=logs)
835
- # # return embedding_matrix
836
-
837
  # # def perform_deduplication(
838
  # # deduplication_type,
839
  # # dataset1_name,
@@ -845,59 +998,24 @@ demo.launch()
845
  # # threshold=default_threshold,
846
  # # progress=gr.Progress(track_tqdm=True),
847
  # # ):
848
- # # logs = [] # To store log messages
849
  # # try:
850
- # # # Convert threshold to float
851
  # # threshold = float(threshold)
852
 
853
- # # # Initialize status message
854
- # # log_time("Deduplication started", logs=logs)
855
-
856
  # # if deduplication_type == "Single dataset":
857
- # # # Load Dataset 1
858
- # # start_time = time.time()
859
- # # log_time("Loading Dataset 1", logs=logs)
860
- # # if (
861
- # # dataset1_name == default_dataset1_name
862
- # # and dataset1_split == default_dataset1_split
863
- # # ):
864
- # # ds = ds_default1
865
- # # else:
866
- # # ds = load_dataset(dataset1_name, split=dataset1_split)
867
- # # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
868
-
869
- # # # Extract texts
870
- # # start_time = time.time()
871
- # # log_time("Extracting texts from Dataset 1", logs=logs)
872
  # # texts = [example[dataset1_text_column] for example in ds]
873
- # # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
874
-
875
- # # # Compute embeddings
876
- # # start_time = time.time()
877
- # # log_time("Computing embeddings for Dataset 1", logs=logs)
878
- # # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
879
- # # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
880
 
881
- # # # Deduplicate
882
- # # start_time = time.time()
883
- # # log_time("Deduplicating embeddings", logs=logs)
884
- # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
885
- # # embedding_matrix, threshold, progress=progress, logs=logs
886
- # # )
887
- # # log_time("Deduplication completed", start_time=start_time, logs=logs)
888
 
889
- # # # Prepare the results
890
  # # num_duplicates = len(duplicate_to_original_mapping)
891
  # # num_total = len(texts)
892
  # # num_deduplicated = len(deduplicated_indices)
893
 
894
  # # result_text = f"**Total documents:** {num_total}\n"
895
  # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
896
- # # result_text += (
897
- # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
898
- # # )
899
 
900
- # # # Show deduplicated examples
901
  # # if num_duplicates > 0:
902
  # # result_text += "**Examples of duplicates found:**\n\n"
903
  # # num_examples = min(5, num_duplicates)
@@ -912,16 +1030,12 @@ demo.launch()
912
  # # else:
913
  # # result_text += "No duplicates found."
914
 
915
- # # log_time("Deduplication process finished", logs=logs)
916
- # # full_log = "\n".join(logs) # Combine all logs into one output
917
- # # yield full_log, result_text
918
 
919
  # # except Exception as e:
920
- # # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
921
- # # yield f"An error occurred: {e}", ""
922
- # # raise e
923
 
924
- # # # Adjust the height of the status_output component using custom CSS
925
  # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
926
  # # gr.Markdown("# Semantic Deduplication")
927
 
@@ -944,22 +1058,14 @@ demo.launch()
944
  # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
945
  # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
946
 
947
- # # threshold = gr.Slider(
948
- # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
949
- # # )
950
 
951
  # # compute_button = gr.Button("Compute")
952
 
953
- # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
954
- # # status_output = gr.Markdown(elem_id="status_output")
955
  # # result_output = gr.Markdown()
956
 
957
- # # # Function to update the visibility of dataset2_inputs
958
  # # def update_visibility(deduplication_type_value):
959
- # # if deduplication_type_value == "Cross-dataset":
960
- # # return gr.update(visible=True)
961
- # # else:
962
- # # return gr.update(visible=False)
963
 
964
  # # deduplication_type.change(
965
  # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
@@ -977,21 +1083,19 @@ demo.launch()
977
  # # dataset2_text_column,
978
  # # threshold,
979
  # # ],
980
- # # outputs=[status_output, result_output],
981
  # # )
982
 
983
  # # demo.launch()
984
 
985
 
986
-
987
  # # # import gradio as gr
988
  # # # from datasets import load_dataset
989
  # # # import numpy as np
990
- # # # #from model2vec import StaticModel
991
  # # # import model2vec
992
  # # # from reach import Reach
993
  # # # from difflib import ndiff
994
-
995
 
996
  # # # # Load the model at startup
997
  # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -1008,13 +1112,24 @@ demo.launch()
1008
  # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
1009
  # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
1010
 
1011
-
1012
  # # # def batch_iterable(iterable, batch_size):
1013
  # # # """Helper function to create batches from an iterable."""
1014
  # # # for i in range(0, len(iterable), batch_size):
1015
  # # # yield iterable[i:i + batch_size]
1016
 
1017
- # # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
1018
  # # # embeddings = []
1019
  # # # total_batches = (len(texts) + batch_size - 1) // batch_size
1020
  # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -1027,10 +1142,11 @@ demo.launch()
1027
  # # # embedding_matrix: np.ndarray,
1028
  # # # threshold: float,
1029
  # # # batch_size: int = 1024,
1030
- # # # progress=None
 
1031
  # # # ) -> tuple[np.ndarray, dict[int, int]]:
1032
  # # # # Building the index
1033
- # # # progress(0, desc="Building search index...")
1034
  # # # reach = Reach(
1035
  # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
1036
  # # # )
@@ -1039,7 +1155,7 @@ demo.launch()
1039
  # # # duplicate_to_original_mapping = {}
1040
 
1041
  # # # # Finding nearest neighbors
1042
- # # # progress(0, desc="Finding nearest neighbors...")
1043
  # # # results = reach.nearest_neighbor_threshold(
1044
  # # # embedding_matrix,
1045
  # # # threshold=threshold,
@@ -1049,6 +1165,7 @@ demo.launch()
1049
 
1050
  # # # # Processing duplicates with a progress bar
1051
  # # # total_items = len(embedding_matrix)
 
1052
  # # # for i, similar_items in enumerate(
1053
  # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
1054
  # # # ):
@@ -1068,9 +1185,9 @@ demo.launch()
1068
  # # # diff = ndiff(x.split(), y.split())
1069
  # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
1070
 
1071
-
1072
- # # # def encode_texts(texts, progress=None):
1073
  # # # embedding_matrix = model.encode(texts, show_progressbar=False)
 
1074
  # # # return embedding_matrix
1075
 
1076
  # # # def perform_deduplication(
@@ -1084,17 +1201,18 @@ demo.launch()
1084
  # # # threshold=default_threshold,
1085
  # # # progress=gr.Progress(track_tqdm=True),
1086
  # # # ):
 
1087
  # # # try:
1088
  # # # # Convert threshold to float
1089
  # # # threshold = float(threshold)
1090
 
1091
  # # # # Initialize status message
1092
- # # # status = ""
1093
 
1094
  # # # if deduplication_type == "Single dataset":
1095
  # # # # Load Dataset 1
1096
- # # # status = "Loading Dataset 1..."
1097
- # # # yield status, ""
1098
  # # # if (
1099
  # # # dataset1_name == default_dataset1_name
1100
  # # # and dataset1_split == default_dataset1_split
@@ -1102,29 +1220,27 @@ demo.launch()
1102
  # # # ds = ds_default1
1103
  # # # else:
1104
  # # # ds = load_dataset(dataset1_name, split=dataset1_split)
 
1105
 
1106
  # # # # Extract texts
1107
- # # # status = "Extracting texts from Dataset 1..."
1108
- # # # yield status, ""
1109
  # # # texts = [example[dataset1_text_column] for example in ds]
 
 
1110
  # # # # Compute embeddings
1111
- # # # status = "Computing embeddings for Dataset 1..."
1112
- # # # yield status, ""
1113
- # # # embedding_matrix = encode_texts(texts, progress=progress)
1114
- # # # #embedding_matrix = model.encode(texts, show_progressbar=True)
1115
- # # # # embedding_matrix = compute_embeddings(
1116
- # # # # texts,
1117
- # # # # batch_size=64,
1118
- # # # # progress=progress,
1119
- # # # # desc="Computing embeddings for Dataset 1",
1120
- # # # # )
1121
 
1122
  # # # # Deduplicate
1123
- # # # status = "Deduplicating embeddings..."
1124
- # # # yield status, ""
1125
  # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
1126
- # # # embedding_matrix, threshold, progress=progress
1127
  # # # )
 
1128
 
1129
  # # # # Prepare the results
1130
  # # # num_duplicates = len(duplicate_to_original_mapping)
@@ -1152,141 +1268,15 @@ demo.launch()
1152
  # # # else:
1153
  # # # result_text += "No duplicates found."
1154
 
1155
- # # # # Final status
1156
- # # # status = "Deduplication completed."
1157
- # # # yield status, result_text
1158
-
1159
- # # # elif deduplication_type == "Cross-dataset":
1160
- # # # # Similar code for cross-dataset deduplication
1161
- # # # # Load Dataset 1
1162
- # # # status = "Loading Dataset 1..."
1163
- # # # yield status, ""
1164
- # # # if (
1165
- # # # dataset1_name == default_dataset1_name
1166
- # # # and dataset1_split == default_dataset1_split
1167
- # # # ):
1168
- # # # ds1 = ds_default1
1169
- # # # else:
1170
- # # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
1171
-
1172
- # # # # Load Dataset 2
1173
- # # # status = "Loading Dataset 2..."
1174
- # # # yield status, ""
1175
- # # # if (
1176
- # # # dataset2_name == default_dataset2_name
1177
- # # # and dataset2_split == default_dataset2_split
1178
- # # # ):
1179
- # # # ds2 = ds_default2
1180
- # # # else:
1181
- # # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
1182
-
1183
- # # # # Extract texts from Dataset 1
1184
- # # # status = "Extracting texts from Dataset 1..."
1185
- # # # yield status, ""
1186
- # # # texts1 = [example[dataset1_text_column] for example in ds1]
1187
-
1188
- # # # # Extract texts from Dataset 2
1189
- # # # status = "Extracting texts from Dataset 2..."
1190
- # # # yield status, ""
1191
- # # # texts2 = [example[dataset2_text_column] for example in ds2]
1192
-
1193
- # # # # Compute embeddings for Dataset 1
1194
- # # # status = "Computing embeddings for Dataset 1..."
1195
- # # # yield status, ""
1196
- # # # embedding_matrix1 = compute_embeddings(
1197
- # # # texts1,
1198
- # # # batch_size=64,
1199
- # # # progress=progress,
1200
- # # # desc="Computing embeddings for Dataset 1",
1201
- # # # )
1202
-
1203
- # # # # Compute embeddings for Dataset 2
1204
- # # # status = "Computing embeddings for Dataset 2..."
1205
- # # # yield status, ""
1206
- # # # embedding_matrix2 = compute_embeddings(
1207
- # # # texts2,
1208
- # # # batch_size=64,
1209
- # # # progress=progress,
1210
- # # # desc="Computing embeddings for Dataset 2",
1211
- # # # )
1212
-
1213
- # # # # Deduplicate across datasets
1214
- # # # status = "Deduplicating embeddings across datasets..."
1215
- # # # yield status, ""
1216
- # # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
1217
- # # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
1218
- # # # )
1219
-
1220
- # # # num_duplicates = len(duplicate_indices_in_ds2)
1221
- # # # num_total_ds2 = len(texts2)
1222
- # # # num_unique_ds2 = num_total_ds2 - num_duplicates
1223
-
1224
- # # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
1225
- # # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
1226
- # # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
1227
-
1228
- # # # # Show deduplicated examples
1229
- # # # if num_duplicates > 0:
1230
- # # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
1231
- # # # num_examples = min(5, num_duplicates)
1232
- # # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
1233
- # # # original_idx = duplicate_to_original_mapping[duplicate_idx]
1234
- # # # original_text = texts1[original_idx]
1235
- # # # duplicate_text = texts2[duplicate_idx]
1236
- # # # differences = display_word_differences(original_text, duplicate_text)
1237
- # # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
1238
- # # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
1239
- # # # result_text += f"**Differences:**\n{differences}\n"
1240
- # # # result_text += "-" * 50 + "\n\n"
1241
- # # # else:
1242
- # # # result_text += "No duplicates found."
1243
-
1244
- # # # # Final status
1245
- # # # status = "Deduplication completed."
1246
- # # # yield status, result_text
1247
 
1248
  # # # except Exception as e:
 
1249
  # # # yield f"An error occurred: {e}", ""
1250
  # # # raise e
1251
 
1252
- # # # def deduplicate_across_datasets(
1253
- # # # embedding_matrix_1: np.ndarray,
1254
- # # # embedding_matrix_2: np.ndarray,
1255
- # # # threshold: float,
1256
- # # # batch_size: int = 1024,
1257
- # # # progress=None
1258
- # # # ) -> tuple[list[int], dict[int, int]]:
1259
- # # # # Building the index from Dataset 1
1260
- # # # progress(0, desc="Building search index from Dataset 1...")
1261
- # # # reach = Reach(
1262
- # # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
1263
- # # # )
1264
-
1265
- # # # duplicate_indices_in_test = []
1266
- # # # duplicate_to_original_mapping = {}
1267
-
1268
- # # # # Finding nearest neighbors between datasets
1269
- # # # progress(0, desc="Finding nearest neighbors between datasets...")
1270
- # # # results = reach.nearest_neighbor_threshold(
1271
- # # # embedding_matrix_2,
1272
- # # # threshold=threshold,
1273
- # # # batch_size=batch_size,
1274
- # # # show_progressbar=False, # Disable internal progress bar
1275
- # # # )
1276
-
1277
- # # # total_items = len(embedding_matrix_2)
1278
- # # # # Processing duplicates with a progress bar
1279
- # # # for i, similar_items in enumerate(
1280
- # # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
1281
- # # # ):
1282
- # # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
1283
-
1284
- # # # if similar_indices:
1285
- # # # duplicate_indices_in_test.append(i)
1286
- # # # duplicate_to_original_mapping[i] = similar_indices[0]
1287
-
1288
- # # # return duplicate_indices_in_test, duplicate_to_original_mapping
1289
-
1290
  # # # # Adjust the height of the status_output component using custom CSS
1291
  # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1292
  # # # gr.Markdown("# Semantic Deduplication")
@@ -1347,3 +1337,369 @@ demo.launch()
1347
  # # # )
1348
 
1349
  # # # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
 
7
 
8
  # Load the model at startup
9
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
10
 
11
+ # Default dataset parameters
12
  default_dataset1_name = "sst2"
13
  default_dataset1_split = "train"
14
  default_dataset2_name = "sst2"
 
27
 
28
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
29
  embeddings = []
30
+ total_batches = (len(texts) + batch_size - 1) // batch_size
31
+ for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
32
+ batch_embeddings = model.encode(batch_texts, show_progressbar=False)
33
  embeddings.append(batch_embeddings)
34
+ progress((i + 1) / total_batches, desc=desc)
35
  return np.concatenate(embeddings, axis=0)
36
 
37
+ def deduplicate(
38
+ embedding_matrix: np.ndarray,
39
+ threshold: float,
40
+ batch_size: int = 1024,
41
+ progress=None
42
+ ) -> tuple[np.ndarray, dict[int, int]]:
43
+ # Building the index
44
+ progress(0, desc="Building search index...")
45
+ reach = Reach(
46
+ vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
47
+ )
48
 
49
  deduplicated_indices = set(range(len(embedding_matrix)))
50
  duplicate_to_original_mapping = {}
51
 
52
+ # Finding nearest neighbors
53
+ progress(0, desc="Finding nearest neighbors...")
54
  results = reach.nearest_neighbor_threshold(
55
  embedding_matrix,
56
  threshold=threshold,
57
  batch_size=batch_size,
58
+ show_progressbar=False, # Disable internal progress bar
59
  )
60
 
61
+ # Processing duplicates with a progress bar
62
  total_items = len(embedding_matrix)
63
+ for i, similar_items in enumerate(
64
+ progress.tqdm(results, desc="Processing duplicates", total=total_items)
65
+ ):
66
  if i not in deduplicated_indices:
67
  continue
68
 
 
75
 
76
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def display_word_differences(x: str, y: str) -> str:
79
  diff = ndiff(x.split(), y.split())
80
+ return " ".join([word for word in diff if word.startswith(("+", "-"))])
81
 
82
  def perform_deduplication(
83
  deduplication_type,
 
88
  dataset2_split="",
89
  dataset2_text_column="",
90
  threshold=default_threshold,
91
+ progress=gr.Progress(track_tqdm=True),
92
  ):
93
  try:
94
+ # Convert threshold to float
95
  threshold = float(threshold)
96
 
97
+ # Initialize status message
98
+ status = ""
99
+
100
  if deduplication_type == "Single dataset":
101
+ # Load Dataset 1
102
+ status = "Loading Dataset 1..."
103
+ yield status, ""
104
+ if (
105
+ dataset1_name == default_dataset1_name
106
+ and dataset1_split == default_dataset1_split
107
+ ):
108
+ ds = ds_default1
109
+ else:
110
+ ds = load_dataset(dataset1_name, split=dataset1_split)
111
 
112
+ # Extract texts
113
+ status = "Extracting texts from Dataset 1..."
114
+ yield status, ""
115
+ texts = [example[dataset1_text_column] for example in ds]
116
 
117
+ # Compute embeddings
118
+ status = "Computing embeddings for Dataset 1..."
119
+ yield status, ""
120
+ embedding_matrix = compute_embeddings(
121
+ texts,
122
+ batch_size=64,
123
+ progress=progress,
124
+ desc="Computing embeddings for Dataset 1",
125
+ )
126
+
127
+ # Deduplicate
128
+ status = "Deduplicating embeddings..."
129
+ yield status, ""
130
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(
131
+ embedding_matrix, threshold, progress=progress
132
+ )
133
+
134
+ # Prepare the results
135
  num_duplicates = len(duplicate_to_original_mapping)
136
  num_total = len(texts)
137
  num_deduplicated = len(deduplicated_indices)
138
 
139
  result_text = f"**Total documents:** {num_total}\n"
140
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
141
+ result_text += (
142
+ f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
143
+ )
144
 
145
+ # Show deduplicated examples
146
  if num_duplicates > 0:
147
  result_text += "**Examples of duplicates found:**\n\n"
148
  num_examples = min(5, num_duplicates)
 
157
  else:
158
  result_text += "No duplicates found."
159
 
160
+ # Final status
161
+ status = "Deduplication completed."
162
+ yield status, result_text
163
 
164
  elif deduplication_type == "Cross-dataset":
165
+ # Similar code for cross-dataset deduplication
166
+ # Load Dataset 1
167
+ status = "Loading Dataset 1..."
168
+ yield status, ""
169
+ if (
170
+ dataset1_name == default_dataset1_name
171
+ and dataset1_split == default_dataset1_split
172
+ ):
173
+ ds1 = ds_default1
174
+ else:
175
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
176
+
177
+ # Load Dataset 2
178
+ status = "Loading Dataset 2..."
179
+ yield status, ""
180
+ if (
181
+ dataset2_name == default_dataset2_name
182
+ and dataset2_split == default_dataset2_split
183
+ ):
184
+ ds2 = ds_default2
185
+ else:
186
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
187
 
188
+ # Extract texts from Dataset 1
189
+ status = "Extracting texts from Dataset 1..."
190
+ yield status, ""
191
  texts1 = [example[dataset1_text_column] for example in ds1]
 
192
 
193
+ # Extract texts from Dataset 2
194
+ status = "Extracting texts from Dataset 2..."
195
+ yield status, ""
196
+ texts2 = [example[dataset2_text_column] for example in ds2]
197
 
198
+ # Compute embeddings for Dataset 1
199
+ status = "Computing embeddings for Dataset 1..."
200
+ yield status, ""
201
+ embedding_matrix1 = compute_embeddings(
202
+ texts1,
203
+ batch_size=64,
204
+ progress=progress,
205
+ desc="Computing embeddings for Dataset 1",
206
+ )
207
+
208
+ # Compute embeddings for Dataset 2
209
+ status = "Computing embeddings for Dataset 2..."
210
+ yield status, ""
211
+ embedding_matrix2 = compute_embeddings(
212
+ texts2,
213
+ batch_size=64,
214
+ progress=progress,
215
+ desc="Computing embeddings for Dataset 2",
216
+ )
217
+
218
+ # Deduplicate across datasets
219
+ status = "Deduplicating embeddings across datasets..."
220
+ yield status, ""
221
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
222
+ embedding_matrix1, embedding_matrix2, threshold, progress=progress
223
+ )
224
 
225
  num_duplicates = len(duplicate_indices_in_ds2)
226
  num_total_ds2 = len(texts2)
 
230
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
231
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
232
 
233
+ # Show deduplicated examples
234
  if num_duplicates > 0:
235
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
236
  num_examples = min(5, num_duplicates)
 
246
  else:
247
  result_text += "No duplicates found."
248
 
249
+ # Final status
250
+ status = "Deduplication completed."
251
+ yield status, result_text
252
 
253
  except Exception as e:
254
  yield f"An error occurred: {e}", ""
255
+ raise e
256
+
257
+ def deduplicate_across_datasets(
258
+ embedding_matrix_1: np.ndarray,
259
+ embedding_matrix_2: np.ndarray,
260
+ threshold: float,
261
+ batch_size: int = 1024,
262
+ progress=None
263
+ ) -> tuple[list[int], dict[int, int]]:
264
+ # Building the index from Dataset 1
265
+ progress(0, desc="Building search index from Dataset 1...")
266
+ reach = Reach(
267
+ vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
268
+ )
269
+
270
+ duplicate_indices_in_test = []
271
+ duplicate_to_original_mapping = {}
272
+
273
+ # Finding nearest neighbors between datasets
274
+ progress(0, desc="Finding nearest neighbors between datasets...")
275
+ results = reach.nearest_neighbor_threshold(
276
+ embedding_matrix_2,
277
+ threshold=threshold,
278
+ batch_size=batch_size,
279
+ show_progressbar=False, # Disable internal progress bar
280
+ )
281
+
282
+ total_items = len(embedding_matrix_2)
283
+ # Processing duplicates with a progress bar
284
+ for i, similar_items in enumerate(
285
+ progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
286
+ ):
287
+ similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
288
 
289
+ if similar_indices:
290
+ duplicate_indices_in_test.append(i)
291
+ duplicate_to_original_mapping[i] = similar_indices[0]
292
+
293
+ return duplicate_indices_in_test, duplicate_to_original_mapping
294
+
295
+ # Adjust the height of the status_output component using custom CSS
296
+ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
297
  gr.Markdown("# Semantic Deduplication")
298
 
299
  deduplication_type = gr.Radio(
300
  choices=["Single dataset", "Cross-dataset"],
301
  label="Deduplication Type",
302
+ value="Single dataset",
303
  )
304
 
305
  with gr.Row():
 
316
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
317
 
318
  threshold = gr.Slider(
319
+ minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
 
 
 
320
  )
321
 
322
  compute_button = gr.Button("Compute")
323
 
324
+ # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
325
  status_output = gr.Markdown(elem_id="status_output")
326
+ result_output = gr.Markdown()
327
 
328
+ # Function to update the visibility of dataset2_inputs
329
  def update_visibility(deduplication_type_value):
330
  if deduplication_type_value == "Cross-dataset":
331
  return gr.update(visible=True)
 
333
  return gr.update(visible=False)
334
 
335
  deduplication_type.change(
336
+ update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
 
337
  )
338
 
339
  compute_button.click(
 
346
  dataset2_name,
347
  dataset2_split,
348
  dataset2_text_column,
349
+ threshold,
350
  ],
351
+ outputs=[status_output, result_output],
352
  )
353
 
354
  demo.launch()
355
 
356
+
357
  # import gradio as gr
358
  # from datasets import load_dataset
359
  # import numpy as np
 
393
  # """
394
  # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
395
  # """
 
 
396
  # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
397
 
398
  # deduplicated_indices = set(range(len(embedding_matrix)))
399
  # duplicate_to_original_mapping = {}
400
 
 
 
401
  # results = reach.nearest_neighbor_threshold(
402
  # embedding_matrix,
403
  # threshold=threshold,
404
  # batch_size=batch_size,
405
+ # show_progressbar=False
406
  # )
407
 
 
408
  # total_items = len(embedding_matrix)
409
  # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
410
  # if i not in deduplicated_indices:
 
423
  # """
424
  # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
425
  # """
 
 
426
  # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
427
 
428
  # duplicate_indices_in_test = []
429
  # duplicate_to_original_mapping = {}
430
 
 
 
431
  # results = reach.nearest_neighbor_threshold(
432
  # embedding_matrix_2,
433
  # threshold=threshold,
434
  # batch_size=batch_size,
435
+ # show_progressbar=False
436
  # )
437
 
438
  # total_items = len(embedding_matrix_2)
 
439
  # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
440
  # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
441
 
 
461
  # progress=gr.Progress(track_tqdm=True)
462
  # ):
463
  # try:
 
464
  # threshold = float(threshold)
465
 
 
 
 
466
  # if deduplication_type == "Single dataset":
467
+ # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
468
  # texts = [example[dataset1_text_column] for example in ds]
469
 
 
 
 
470
  # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
471
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
472
 
 
 
 
 
 
 
 
 
473
  # num_duplicates = len(duplicate_to_original_mapping)
474
  # num_total = len(texts)
475
  # num_deduplicated = len(deduplicated_indices)
 
478
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
479
  # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
480
 
 
481
  # if num_duplicates > 0:
482
  # result_text += "**Examples of duplicates found:**\n\n"
483
  # num_examples = min(5, num_duplicates)
 
492
  # else:
493
  # result_text += "No duplicates found."
494
 
495
+ # yield result_text
 
 
496
 
497
  # elif deduplication_type == "Cross-dataset":
498
+ # ds1 = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
499
+ # ds2 = ds_default2 if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split else load_dataset(dataset2_name, split=dataset2_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
 
 
 
501
  # texts1 = [example[dataset1_text_column] for example in ds1]
 
 
 
 
502
  # texts2 = [example[dataset2_text_column] for example in ds2]
503
 
 
 
 
504
  # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
 
 
 
505
  # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
506
 
507
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
 
 
 
 
508
 
509
  # num_duplicates = len(duplicate_indices_in_ds2)
510
  # num_total_ds2 = len(texts2)
 
514
  # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
515
  # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
516
 
 
517
  # if num_duplicates > 0:
518
  # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
519
  # num_examples = min(5, num_duplicates)
 
529
  # else:
530
  # result_text += "No duplicates found."
531
 
532
+ # yield result_text
 
 
533
 
534
  # except Exception as e:
535
  # yield f"An error occurred: {e}", ""
 
536
 
537
+ # # Adjust the height of the status_output and result_output components
538
+ # with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_output { height: 300px; overflow: auto; }") as demo:
539
  # gr.Markdown("# Semantic Deduplication")
540
 
541
  # deduplication_type = gr.Radio(
 
566
 
567
  # compute_button = gr.Button("Compute")
568
 
569
+ # status_output = gr.Markdown(elem_id="status_output")
570
+ # result_output = gr.Markdown(elem_id="result_output")
571
 
 
572
  # def update_visibility(deduplication_type_value):
573
  # if deduplication_type_value == "Cross-dataset":
574
  # return gr.update(visible=True)
 
598
 
599
  # demo.launch()
600
 
601
+ # # import gradio as gr
602
+ # # from datasets import load_dataset
603
+ # # import numpy as np
604
+ # # from model2vec import StaticModel
605
+ # # from reach import Reach
606
+ # # from difflib import ndiff
607
+ # # import tqdm
608
 
609
+ # # # Load the model at startup
610
+ # # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
 
 
 
 
611
 
612
+ # # # Update default dataset to 'sst2' and set default threshold to 0.9
613
+ # # default_dataset1_name = "sst2"
614
+ # # default_dataset1_split = "train"
615
+ # # default_dataset2_name = "sst2"
616
+ # # default_dataset2_split = "validation"
617
+ # # default_text_column = "sentence"
618
+ # # default_threshold = 0.9
619
 
620
+ # # # Load the default datasets at startup
621
+ # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
622
+ # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
 
 
 
 
623
 
624
+ # # def batch_iterable(iterable, batch_size):
625
+ # # """Helper function to create batches from an iterable."""
626
+ # # for i in range(0, len(iterable), batch_size):
627
+ # # yield iterable[i:i + batch_size]
628
 
629
+ # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
630
+ # # embeddings = []
631
+ # # for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
632
+ # # batch_embeddings = model.encode(batch, show_progressbar=False)
633
+ # # embeddings.append(batch_embeddings)
634
+ # # return np.concatenate(embeddings, axis=0)
635
 
636
+ # # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
637
+ # # """
638
+ # # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
639
+ # # """
640
+ # # # Building the index
641
+ # # progress(0, desc="Building search index...")
642
+ # # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
643
 
644
+ # # deduplicated_indices = set(range(len(embedding_matrix)))
645
+ # # duplicate_to_original_mapping = {}
 
 
 
 
 
646
 
647
+ # # # Finding nearest neighbors
648
+ # # progress(0, desc="Finding nearest neighbors...")
649
+ # # results = reach.nearest_neighbor_threshold(
650
+ # # embedding_matrix,
651
+ # # threshold=threshold,
652
+ # # batch_size=batch_size,
653
+ # # show_progressbar=False # Disable internal progress bar
654
+ # # )
655
 
656
+ # # # Processing duplicates with a progress bar
657
+ # # total_items = len(embedding_matrix)
658
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
659
+ # # if i not in deduplicated_indices:
660
+ # # continue
 
661
 
662
+ # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
 
 
 
663
 
664
+ # # for sim_idx in similar_indices:
665
+ # # if sim_idx in deduplicated_indices:
666
+ # # deduplicated_indices.remove(sim_idx)
667
+ # # duplicate_to_original_mapping[sim_idx] = i
 
668
 
669
+ # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
670
 
671
+ # # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
672
+ # # """
673
+ # # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
674
+ # # """
675
+ # # # Building the index from Dataset 1
676
+ # # progress(0, desc="Building search index from Dataset 1...")
677
+ # # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
678
 
679
+ # # duplicate_indices_in_test = []
680
+ # # duplicate_to_original_mapping = {}
 
 
 
 
 
 
 
 
 
 
 
681
 
682
+ # # # Finding nearest neighbors between datasets
683
+ # # progress(0, desc="Finding nearest neighbors between datasets...")
684
+ # # results = reach.nearest_neighbor_threshold(
685
+ # # embedding_matrix_2,
686
+ # # threshold=threshold,
687
+ # # batch_size=batch_size,
688
+ # # show_progressbar=False # Disable internal progress bar
689
+ # # )
690
 
691
+ # # total_items = len(embedding_matrix_2)
692
+ # # # Processing duplicates with a progress bar
693
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
694
+ # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
695
 
696
+ # # if similar_indices:
697
+ # # duplicate_indices_in_test.append(i)
698
+ # # duplicate_to_original_mapping[i] = similar_indices[0]
699
 
700
+ # # return duplicate_indices_in_test, duplicate_to_original_mapping
 
 
701
 
702
+ # # def display_word_differences(x: str, y: str) -> str:
703
+ # # diff = ndiff(x.split(), y.split())
704
+ # # return " ".join([word for word in diff if word.startswith(('+', '-'))])
 
 
 
 
 
 
 
 
 
 
705
 
706
+ # # def perform_deduplication(
707
+ # # deduplication_type,
708
+ # # dataset1_name,
709
+ # # dataset1_split,
710
+ # # dataset1_text_column,
711
+ # # dataset2_name="",
712
+ # # dataset2_split="",
713
+ # # dataset2_text_column="",
714
+ # # threshold=default_threshold,
715
+ # # progress=gr.Progress(track_tqdm=True)
716
+ # # ):
717
+ # # try:
718
+ # # # Convert threshold to float
719
+ # # threshold = float(threshold)
720
 
721
+ # # # Initialize status message
722
+ # # status = ""
723
 
724
+ # # if deduplication_type == "Single dataset":
725
+ # # # Load Dataset 1
726
+ # # status = "Loading Dataset 1..."
727
+ # # yield status, ""
728
+ # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
729
+ # # ds = ds_default1
730
+ # # else:
731
+ # # ds = load_dataset(dataset1_name, split=dataset1_split)
732
 
733
+ # # # Extract texts
734
+ # # status = "Extracting texts from Dataset 1..."
735
+ # # yield status, ""
736
+ # # texts = [example[dataset1_text_column] for example in ds]
 
737
 
738
+ # # # Compute embeddings
739
+ # # status = "Computing embeddings for Dataset 1..."
740
+ # # yield status, ""
741
+ # # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
742
 
743
+ # # # Deduplicate
744
+ # # status = "Deduplicating embeddings..."
745
+ # # yield status, ""
746
+ # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
747
+ # # embedding_matrix, threshold, progress=progress
748
+ # # )
 
749
 
750
+ # # # Prepare the results
751
+ # # num_duplicates = len(duplicate_to_original_mapping)
752
+ # # num_total = len(texts)
753
+ # # num_deduplicated = len(deduplicated_indices)
754
 
755
+ # # result_text = f"**Total documents:** {num_total}\n"
756
+ # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
757
+ # # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
758
 
759
+ # # # Show deduplicated examples
760
+ # # if num_duplicates > 0:
761
+ # # result_text += "**Examples of duplicates found:**\n\n"
762
+ # # num_examples = min(5, num_duplicates)
763
+ # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
764
+ # # original_text = texts[original_idx]
765
+ # # duplicate_text = texts[duplicate_idx]
766
+ # # differences = display_word_differences(original_text, duplicate_text)
767
+ # # result_text += f"**Original text:**\n{original_text}\n\n"
768
+ # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
769
+ # # result_text += f"**Differences:**\n{differences}\n"
770
+ # # result_text += "-" * 50 + "\n\n"
771
+ # # else:
772
+ # # result_text += "No duplicates found."
773
 
774
+ # # # Final status
775
+ # # status = "Deduplication completed."
776
+ # # yield status, result_text
777
 
778
+ # # elif deduplication_type == "Cross-dataset":
779
+ # # # Load Dataset 1
780
+ # # status = "Loading Dataset 1..."
781
+ # # yield status, ""
782
+ # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
783
+ # # ds1 = ds_default1
784
+ # # else:
785
+ # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
786
 
787
+ # # # Load Dataset 2
788
+ # # status = "Loading Dataset 2..."
789
+ # # yield status, ""
790
+ # # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
791
+ # # ds2 = ds_default2
792
+ # # else:
793
+ # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
794
+
795
+ # # # Extract texts from Dataset 1
796
+ # # status = "Extracting texts from Dataset 1..."
797
+ # # yield status, ""
798
+ # # texts1 = [example[dataset1_text_column] for example in ds1]
799
+
800
+ # # # Extract texts from Dataset 2
801
+ # # status = "Extracting texts from Dataset 2..."
802
+ # # yield status, ""
803
+ # # texts2 = [example[dataset2_text_column] for example in ds2]
804
+
805
+ # # # Compute embeddings for Dataset 1
806
+ # # status = "Computing embeddings for Dataset 1..."
807
+ # # yield status, ""
808
+ # # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
809
+
810
+ # # # Compute embeddings for Dataset 2
811
+ # # status = "Computing embeddings for Dataset 2..."
812
+ # # yield status, ""
813
+ # # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
814
+
815
+ # # # Deduplicate across datasets
816
+ # # status = "Deduplicating embeddings across datasets..."
817
+ # # yield status, ""
818
+ # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
819
+ # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
820
+ # # )
821
 
822
+ # # num_duplicates = len(duplicate_indices_in_ds2)
823
+ # # num_total_ds2 = len(texts2)
824
+ # # num_unique_ds2 = num_total_ds2 - num_duplicates
825
+
826
+ # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
827
+ # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
828
+ # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
829
+
830
+ # # # Show deduplicated examples
831
+ # # if num_duplicates > 0:
832
+ # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
833
+ # # num_examples = min(5, num_duplicates)
834
+ # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
835
+ # # original_idx = duplicate_to_original_mapping[duplicate_idx]
836
+ # # original_text = texts1[original_idx]
837
+ # # duplicate_text = texts2[duplicate_idx]
838
+ # # differences = display_word_differences(original_text, duplicate_text)
839
+ # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
840
+ # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
841
+ # # result_text += f"**Differences:**\n{differences}\n"
842
+ # # result_text += "-" * 50 + "\n\n"
843
+ # # else:
844
+ # # result_text += "No duplicates found."
845
+
846
+ # # # Final status
847
+ # # status = "Deduplication completed."
848
+ # # yield status, result_text
849
+
850
+ # # except Exception as e:
851
+ # # yield f"An error occurred: {e}", ""
852
+ # # raise e
853
+
854
+ # # with gr.Blocks() as demo:
855
+ # # gr.Markdown("# Semantic Deduplication")
856
+
857
+ # # deduplication_type = gr.Radio(
858
+ # # choices=["Single dataset", "Cross-dataset"],
859
+ # # label="Deduplication Type",
860
+ # # value="Single dataset"
861
+ # # )
862
+
863
+ # # with gr.Row():
864
+ # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
865
+ # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
866
+ # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
867
+
868
+ # # dataset2_inputs = gr.Column(visible=False)
869
+ # # with dataset2_inputs:
870
+ # # gr.Markdown("### Dataset 2")
871
+ # # with gr.Row():
872
+ # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
873
+ # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
874
+ # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
875
+
876
+ # # threshold = gr.Slider(
877
+ # # minimum=0.0,
878
+ # # maximum=1.0,
879
+ # # value=default_threshold,
880
+ # # label="Similarity Threshold"
881
+ # # )
882
+
883
+ # # compute_button = gr.Button("Compute")
884
+
885
+ # # status_output = gr.Markdown()
886
+ # # result_output = gr.Markdown()
887
+
888
+ # # # Function to update the visibility of dataset2_inputs
889
+ # # def update_visibility(deduplication_type_value):
890
+ # # if deduplication_type_value == "Cross-dataset":
891
+ # # return gr.update(visible=True)
892
+ # # else:
893
+ # # return gr.update(visible=False)
894
+
895
+ # # deduplication_type.change(
896
+ # # update_visibility,
897
+ # # inputs=deduplication_type,
898
+ # # outputs=dataset2_inputs
899
+ # # )
900
+
901
+ # # compute_button.click(
902
+ # # fn=perform_deduplication,
903
+ # # inputs=[
904
+ # # deduplication_type,
905
+ # # dataset1_name,
906
+ # # dataset1_split,
907
+ # # dataset1_text_column,
908
+ # # dataset2_name,
909
+ # # dataset2_split,
910
+ # # dataset2_text_column,
911
+ # # threshold
912
+ # # ],
913
+ # # outputs=[status_output, result_output]
914
+ # # )
915
+
916
+ # # demo.launch()
917
 
918
 
919
  # # import gradio as gr
 
922
  # # import model2vec
923
  # # from reach import Reach
924
  # # from difflib import ndiff
 
925
 
926
  # # # Load the model at startup
927
  # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
943
  # # for i in range(0, len(iterable), batch_size):
944
  # # yield iterable[i:i + batch_size]
945
 
946
+ # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
947
  # # embeddings = []
948
  # # total_batches = (len(texts) + batch_size - 1) // batch_size
949
  # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
956
  # # embedding_matrix: np.ndarray,
957
  # # threshold: float,
958
  # # batch_size: int = 1024,
959
+ # # progress=None
 
960
  # # ) -> tuple[np.ndarray, dict[int, int]]:
961
+ # # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
962
 
963
  # # deduplicated_indices = set(range(len(embedding_matrix)))
964
  # # duplicate_to_original_mapping = {}
965
 
 
 
966
  # # results = reach.nearest_neighbor_threshold(
967
  # # embedding_matrix,
968
  # # threshold=threshold,
969
  # # batch_size=batch_size,
970
+ # # show_progressbar=False,
971
  # # )
972
 
 
973
  # # total_items = len(embedding_matrix)
974
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
 
 
 
975
  # # if i not in deduplicated_indices:
976
  # # continue
977
 
978
  # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
 
979
  # # for sim_idx in similar_indices:
980
  # # if sim_idx in deduplicated_indices:
981
  # # deduplicated_indices.remove(sim_idx)
 
987
  # # diff = ndiff(x.split(), y.split())
988
  # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
989
 
 
 
 
 
 
990
  # # def perform_deduplication(
991
  # # deduplication_type,
992
  # # dataset1_name,
 
998
  # # threshold=default_threshold,
999
  # # progress=gr.Progress(track_tqdm=True),
1000
  # # ):
 
1001
  # # try:
 
1002
  # # threshold = float(threshold)
1003
 
 
 
 
1004
  # # if deduplication_type == "Single dataset":
1005
+ # # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  # # texts = [example[dataset1_text_column] for example in ds]
 
 
 
 
 
 
 
1007
 
1008
+ # # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
1009
+ # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
 
 
 
 
1010
 
 
1011
  # # num_duplicates = len(duplicate_to_original_mapping)
1012
  # # num_total = len(texts)
1013
  # # num_deduplicated = len(deduplicated_indices)
1014
 
1015
  # # result_text = f"**Total documents:** {num_total}\n"
1016
  # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
1017
+ # # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
1018
 
 
1019
  # # if num_duplicates > 0:
1020
  # # result_text += "**Examples of duplicates found:**\n\n"
1021
  # # num_examples = min(5, num_duplicates)
 
1030
  # # else:
1031
  # # result_text += "No duplicates found."
1032
 
1033
+ # # yield result_text
 
 
1034
 
1035
  # # except Exception as e:
1036
+ # # yield f"An error occurred: {e}"
 
 
1037
 
1038
+ # # # Gradio interface setup
1039
  # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1040
  # # gr.Markdown("# Semantic Deduplication")
1041
 
 
1058
  # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1059
  # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1060
 
1061
+ # # threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
 
 
1062
 
1063
  # # compute_button = gr.Button("Compute")
1064
 
 
 
1065
  # # result_output = gr.Markdown()
1066
 
 
1067
  # # def update_visibility(deduplication_type_value):
1068
+ # # return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
 
 
 
1069
 
1070
  # # deduplication_type.change(
1071
  # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
1083
  # # dataset2_text_column,
1084
  # # threshold,
1085
  # # ],
1086
+ # # outputs=[result_output],
1087
  # # )
1088
 
1089
  # # demo.launch()
1090
 
1091
 
 
1092
  # # # import gradio as gr
1093
  # # # from datasets import load_dataset
1094
  # # # import numpy as np
 
1095
  # # # import model2vec
1096
  # # # from reach import Reach
1097
  # # # from difflib import ndiff
1098
+ # # # import time
1099
 
1100
  # # # # Load the model at startup
1101
  # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
1112
  # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
1113
  # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
1114
 
 
1115
  # # # def batch_iterable(iterable, batch_size):
1116
  # # # """Helper function to create batches from an iterable."""
1117
  # # # for i in range(0, len(iterable), batch_size):
1118
  # # # yield iterable[i:i + batch_size]
1119
 
1120
+ # # # def log_time(message, start_time=None, logs=None):
1121
+ # # # """Helper function to log the start and end times."""
1122
+ # # # current_time = time.time()
1123
+ # # # if start_time is not None:
1124
+ # # # elapsed = current_time - start_time
1125
+ # # # log_message = f"{message} - Took {elapsed:.2f} seconds"
1126
+ # # # else:
1127
+ # # # log_message = f"{message} - Started"
1128
+
1129
+ # # # if logs is not None:
1130
+ # # # logs.append(log_message)
1131
+
1132
+ # # # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
1133
  # # # embeddings = []
1134
  # # # total_batches = (len(texts) + batch_size - 1) // batch_size
1135
  # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
1142
  # # # embedding_matrix: np.ndarray,
1143
  # # # threshold: float,
1144
  # # # batch_size: int = 1024,
1145
+ # # # progress=None,
1146
+ # # # logs=None
1147
  # # # ) -> tuple[np.ndarray, dict[int, int]]:
1148
  # # # # Building the index
1149
+ # # # log_time("Building search index", logs=logs)
1150
  # # # reach = Reach(
1151
  # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
1152
  # # # )
 
1155
  # # # duplicate_to_original_mapping = {}
1156
 
1157
  # # # # Finding nearest neighbors
1158
+ # # # log_time("Finding nearest neighbors", logs=logs)
1159
  # # # results = reach.nearest_neighbor_threshold(
1160
  # # # embedding_matrix,
1161
  # # # threshold=threshold,
 
1165
 
1166
  # # # # Processing duplicates with a progress bar
1167
  # # # total_items = len(embedding_matrix)
1168
+ # # # log_time("Processing duplicates", logs=logs)
1169
  # # # for i, similar_items in enumerate(
1170
  # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
1171
  # # # ):
 
1185
  # # # diff = ndiff(x.split(), y.split())
1186
  # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
1187
 
1188
+ # # # def encode_texts(texts, progress=None, logs=None):
 
1189
  # # # embedding_matrix = model.encode(texts, show_progressbar=False)
1190
+ # # # log_time("Encoding texts completed", logs=logs)
1191
  # # # return embedding_matrix
1192
 
1193
  # # # def perform_deduplication(
 
1201
  # # # threshold=default_threshold,
1202
  # # # progress=gr.Progress(track_tqdm=True),
1203
  # # # ):
1204
+ # # # logs = [] # To store log messages
1205
  # # # try:
1206
  # # # # Convert threshold to float
1207
  # # # threshold = float(threshold)
1208
 
1209
  # # # # Initialize status message
1210
+ # # # log_time("Deduplication started", logs=logs)
1211
 
1212
  # # # if deduplication_type == "Single dataset":
1213
  # # # # Load Dataset 1
1214
+ # # # start_time = time.time()
1215
+ # # # log_time("Loading Dataset 1", logs=logs)
1216
  # # # if (
1217
  # # # dataset1_name == default_dataset1_name
1218
  # # # and dataset1_split == default_dataset1_split
 
1220
  # # # ds = ds_default1
1221
  # # # else:
1222
  # # # ds = load_dataset(dataset1_name, split=dataset1_split)
1223
+ # # # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
1224
 
1225
  # # # # Extract texts
1226
+ # # # start_time = time.time()
1227
+ # # # log_time("Extracting texts from Dataset 1", logs=logs)
1228
  # # # texts = [example[dataset1_text_column] for example in ds]
1229
+ # # # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
1230
+
1231
  # # # # Compute embeddings
1232
+ # # # start_time = time.time()
1233
+ # # # log_time("Computing embeddings for Dataset 1", logs=logs)
1234
+ # # # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
1235
+ # # # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
 
 
 
 
 
 
1236
 
1237
  # # # # Deduplicate
1238
+ # # # start_time = time.time()
1239
+ # # # log_time("Deduplicating embeddings", logs=logs)
1240
  # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
1241
+ # # # embedding_matrix, threshold, progress=progress, logs=logs
1242
  # # # )
1243
+ # # # log_time("Deduplication completed", start_time=start_time, logs=logs)
1244
 
1245
  # # # # Prepare the results
1246
  # # # num_duplicates = len(duplicate_to_original_mapping)
 
1268
  # # # else:
1269
  # # # result_text += "No duplicates found."
1270
 
1271
+ # # # log_time("Deduplication process finished", logs=logs)
1272
+ # # # full_log = "\n".join(logs) # Combine all logs into one output
1273
+ # # # yield full_log, result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274
 
1275
  # # # except Exception as e:
1276
+ # # # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
1277
  # # # yield f"An error occurred: {e}", ""
1278
  # # # raise e
1279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1280
  # # # # Adjust the height of the status_output component using custom CSS
1281
  # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1282
  # # # gr.Markdown("# Semantic Deduplication")
 
1337
  # # # )
1338
 
1339
  # # # demo.launch()
1340
+
1341
+
1342
+
1343
+ # # # # import gradio as gr
1344
+ # # # # from datasets import load_dataset
1345
+ # # # # import numpy as np
1346
+ # # # # #from model2vec import StaticModel
1347
+ # # # # import model2vec
1348
+ # # # # from reach import Reach
1349
+ # # # # from difflib import ndiff
1350
+
1351
+
1352
+ # # # # # Load the model at startup
1353
+ # # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
1354
+
1355
+ # # # # # Default dataset parameters
1356
+ # # # # default_dataset1_name = "sst2"
1357
+ # # # # default_dataset1_split = "train"
1358
+ # # # # default_dataset2_name = "sst2"
1359
+ # # # # default_dataset2_split = "validation"
1360
+ # # # # default_text_column = "sentence"
1361
+ # # # # default_threshold = 0.9
1362
+
1363
+ # # # # # Load the default datasets at startup
1364
+ # # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
1365
+ # # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
1366
+
1367
+
1368
+ # # # # def batch_iterable(iterable, batch_size):
1369
+ # # # # """Helper function to create batches from an iterable."""
1370
+ # # # # for i in range(0, len(iterable), batch_size):
1371
+ # # # # yield iterable[i:i + batch_size]
1372
+
1373
+ # # # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
1374
+ # # # # embeddings = []
1375
+ # # # # total_batches = (len(texts) + batch_size - 1) // batch_size
1376
+ # # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
1377
+ # # # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
1378
+ # # # # embeddings.append(batch_embeddings)
1379
+ # # # # progress((i + 1) / total_batches, desc=desc)
1380
+ # # # # return np.concatenate(embeddings, axis=0)
1381
+
1382
+ # # # # def deduplicate(
1383
+ # # # # embedding_matrix: np.ndarray,
1384
+ # # # # threshold: float,
1385
+ # # # # batch_size: int = 1024,
1386
+ # # # # progress=None
1387
+ # # # # ) -> tuple[np.ndarray, dict[int, int]]:
1388
+ # # # # # Building the index
1389
+ # # # # progress(0, desc="Building search index...")
1390
+ # # # # reach = Reach(
1391
+ # # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
1392
+ # # # # )
1393
+
1394
+ # # # # deduplicated_indices = set(range(len(embedding_matrix)))
1395
+ # # # # duplicate_to_original_mapping = {}
1396
+
1397
+ # # # # # Finding nearest neighbors
1398
+ # # # # progress(0, desc="Finding nearest neighbors...")
1399
+ # # # # results = reach.nearest_neighbor_threshold(
1400
+ # # # # embedding_matrix,
1401
+ # # # # threshold=threshold,
1402
+ # # # # batch_size=batch_size,
1403
+ # # # # show_progressbar=False, # Disable internal progress bar
1404
+ # # # # )
1405
+
1406
+ # # # # # Processing duplicates with a progress bar
1407
+ # # # # total_items = len(embedding_matrix)
1408
+ # # # # for i, similar_items in enumerate(
1409
+ # # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
1410
+ # # # # ):
1411
+ # # # # if i not in deduplicated_indices:
1412
+ # # # # continue
1413
+
1414
+ # # # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
1415
+
1416
+ # # # # for sim_idx in similar_indices:
1417
+ # # # # if sim_idx in deduplicated_indices:
1418
+ # # # # deduplicated_indices.remove(sim_idx)
1419
+ # # # # duplicate_to_original_mapping[sim_idx] = i
1420
+
1421
+ # # # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
1422
+
1423
+ # # # # def display_word_differences(x: str, y: str) -> str:
1424
+ # # # # diff = ndiff(x.split(), y.split())
1425
+ # # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
1426
+
1427
+
1428
+ # # # # def encode_texts(texts, progress=None):
1429
+ # # # # embedding_matrix = model.encode(texts, show_progressbar=False)
1430
+ # # # # return embedding_matrix
1431
+
1432
+ # # # # def perform_deduplication(
1433
+ # # # # deduplication_type,
1434
+ # # # # dataset1_name,
1435
+ # # # # dataset1_split,
1436
+ # # # # dataset1_text_column,
1437
+ # # # # dataset2_name="",
1438
+ # # # # dataset2_split="",
1439
+ # # # # dataset2_text_column="",
1440
+ # # # # threshold=default_threshold,
1441
+ # # # # progress=gr.Progress(track_tqdm=True),
1442
+ # # # # ):
1443
+ # # # # try:
1444
+ # # # # # Convert threshold to float
1445
+ # # # # threshold = float(threshold)
1446
+
1447
+ # # # # # Initialize status message
1448
+ # # # # status = ""
1449
+
1450
+ # # # # if deduplication_type == "Single dataset":
1451
+ # # # # # Load Dataset 1
1452
+ # # # # status = "Loading Dataset 1..."
1453
+ # # # # yield status, ""
1454
+ # # # # if (
1455
+ # # # # dataset1_name == default_dataset1_name
1456
+ # # # # and dataset1_split == default_dataset1_split
1457
+ # # # # ):
1458
+ # # # # ds = ds_default1
1459
+ # # # # else:
1460
+ # # # # ds = load_dataset(dataset1_name, split=dataset1_split)
1461
+
1462
+ # # # # # Extract texts
1463
+ # # # # status = "Extracting texts from Dataset 1..."
1464
+ # # # # yield status, ""
1465
+ # # # # texts = [example[dataset1_text_column] for example in ds]
1466
+ # # # # # Compute embeddings
1467
+ # # # # status = "Computing embeddings for Dataset 1..."
1468
+ # # # # yield status, ""
1469
+ # # # # embedding_matrix = encode_texts(texts, progress=progress)
1470
+ # # # # #embedding_matrix = model.encode(texts, show_progressbar=True)
1471
+ # # # # # embedding_matrix = compute_embeddings(
1472
+ # # # # # texts,
1473
+ # # # # # batch_size=64,
1474
+ # # # # # progress=progress,
1475
+ # # # # # desc="Computing embeddings for Dataset 1",
1476
+ # # # # # )
1477
+
1478
+ # # # # # Deduplicate
1479
+ # # # # status = "Deduplicating embeddings..."
1480
+ # # # # yield status, ""
1481
+ # # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
1482
+ # # # # embedding_matrix, threshold, progress=progress
1483
+ # # # # )
1484
+
1485
+ # # # # # Prepare the results
1486
+ # # # # num_duplicates = len(duplicate_to_original_mapping)
1487
+ # # # # num_total = len(texts)
1488
+ # # # # num_deduplicated = len(deduplicated_indices)
1489
+
1490
+ # # # # result_text = f"**Total documents:** {num_total}\n"
1491
+ # # # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
1492
+ # # # # result_text += (
1493
+ # # # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
1494
+ # # # # )
1495
+
1496
+ # # # # # Show deduplicated examples
1497
+ # # # # if num_duplicates > 0:
1498
+ # # # # result_text += "**Examples of duplicates found:**\n\n"
1499
+ # # # # num_examples = min(5, num_duplicates)
1500
+ # # # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
1501
+ # # # # original_text = texts[original_idx]
1502
+ # # # # duplicate_text = texts[duplicate_idx]
1503
+ # # # # differences = display_word_differences(original_text, duplicate_text)
1504
+ # # # # result_text += f"**Original text:**\n{original_text}\n\n"
1505
+ # # # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
1506
+ # # # # result_text += f"**Differences:**\n{differences}\n"
1507
+ # # # # result_text += "-" * 50 + "\n\n"
1508
+ # # # # else:
1509
+ # # # # result_text += "No duplicates found."
1510
+
1511
+ # # # # # Final status
1512
+ # # # # status = "Deduplication completed."
1513
+ # # # # yield status, result_text
1514
+
1515
+ # # # # elif deduplication_type == "Cross-dataset":
1516
+ # # # # # Similar code for cross-dataset deduplication
1517
+ # # # # # Load Dataset 1
1518
+ # # # # status = "Loading Dataset 1..."
1519
+ # # # # yield status, ""
1520
+ # # # # if (
1521
+ # # # # dataset1_name == default_dataset1_name
1522
+ # # # # and dataset1_split == default_dataset1_split
1523
+ # # # # ):
1524
+ # # # # ds1 = ds_default1
1525
+ # # # # else:
1526
+ # # # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
1527
+
1528
+ # # # # # Load Dataset 2
1529
+ # # # # status = "Loading Dataset 2..."
1530
+ # # # # yield status, ""
1531
+ # # # # if (
1532
+ # # # # dataset2_name == default_dataset2_name
1533
+ # # # # and dataset2_split == default_dataset2_split
1534
+ # # # # ):
1535
+ # # # # ds2 = ds_default2
1536
+ # # # # else:
1537
+ # # # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
1538
+
1539
+ # # # # # Extract texts from Dataset 1
1540
+ # # # # status = "Extracting texts from Dataset 1..."
1541
+ # # # # yield status, ""
1542
+ # # # # texts1 = [example[dataset1_text_column] for example in ds1]
1543
+
1544
+ # # # # # Extract texts from Dataset 2
1545
+ # # # # status = "Extracting texts from Dataset 2..."
1546
+ # # # # yield status, ""
1547
+ # # # # texts2 = [example[dataset2_text_column] for example in ds2]
1548
+
1549
+ # # # # # Compute embeddings for Dataset 1
1550
+ # # # # status = "Computing embeddings for Dataset 1..."
1551
+ # # # # yield status, ""
1552
+ # # # # embedding_matrix1 = compute_embeddings(
1553
+ # # # # texts1,
1554
+ # # # # batch_size=64,
1555
+ # # # # progress=progress,
1556
+ # # # # desc="Computing embeddings for Dataset 1",
1557
+ # # # # )
1558
+
1559
+ # # # # # Compute embeddings for Dataset 2
1560
+ # # # # status = "Computing embeddings for Dataset 2..."
1561
+ # # # # yield status, ""
1562
+ # # # # embedding_matrix2 = compute_embeddings(
1563
+ # # # # texts2,
1564
+ # # # # batch_size=64,
1565
+ # # # # progress=progress,
1566
+ # # # # desc="Computing embeddings for Dataset 2",
1567
+ # # # # )
1568
+
1569
+ # # # # # Deduplicate across datasets
1570
+ # # # # status = "Deduplicating embeddings across datasets..."
1571
+ # # # # yield status, ""
1572
+ # # # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
1573
+ # # # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
1574
+ # # # # )
1575
+
1576
+ # # # # num_duplicates = len(duplicate_indices_in_ds2)
1577
+ # # # # num_total_ds2 = len(texts2)
1578
+ # # # # num_unique_ds2 = num_total_ds2 - num_duplicates
1579
+
1580
+ # # # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
1581
+ # # # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
1582
+ # # # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
1583
+
1584
+ # # # # # Show deduplicated examples
1585
+ # # # # if num_duplicates > 0:
1586
+ # # # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
1587
+ # # # # num_examples = min(5, num_duplicates)
1588
+ # # # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
1589
+ # # # # original_idx = duplicate_to_original_mapping[duplicate_idx]
1590
+ # # # # original_text = texts1[original_idx]
1591
+ # # # # duplicate_text = texts2[duplicate_idx]
1592
+ # # # # differences = display_word_differences(original_text, duplicate_text)
1593
+ # # # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
1594
+ # # # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
1595
+ # # # # result_text += f"**Differences:**\n{differences}\n"
1596
+ # # # # result_text += "-" * 50 + "\n\n"
1597
+ # # # # else:
1598
+ # # # # result_text += "No duplicates found."
1599
+
1600
+ # # # # # Final status
1601
+ # # # # status = "Deduplication completed."
1602
+ # # # # yield status, result_text
1603
+
1604
+ # # # # except Exception as e:
1605
+ # # # # yield f"An error occurred: {e}", ""
1606
+ # # # # raise e
1607
+
1608
+ # # # # def deduplicate_across_datasets(
1609
+ # # # # embedding_matrix_1: np.ndarray,
1610
+ # # # # embedding_matrix_2: np.ndarray,
1611
+ # # # # threshold: float,
1612
+ # # # # batch_size: int = 1024,
1613
+ # # # # progress=None
1614
+ # # # # ) -> tuple[list[int], dict[int, int]]:
1615
+ # # # # # Building the index from Dataset 1
1616
+ # # # # progress(0, desc="Building search index from Dataset 1...")
1617
+ # # # # reach = Reach(
1618
+ # # # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
1619
+ # # # # )
1620
+
1621
+ # # # # duplicate_indices_in_test = []
1622
+ # # # # duplicate_to_original_mapping = {}
1623
+
1624
+ # # # # # Finding nearest neighbors between datasets
1625
+ # # # # progress(0, desc="Finding nearest neighbors between datasets...")
1626
+ # # # # results = reach.nearest_neighbor_threshold(
1627
+ # # # # embedding_matrix_2,
1628
+ # # # # threshold=threshold,
1629
+ # # # # batch_size=batch_size,
1630
+ # # # # show_progressbar=False, # Disable internal progress bar
1631
+ # # # # )
1632
+
1633
+ # # # # total_items = len(embedding_matrix_2)
1634
+ # # # # # Processing duplicates with a progress bar
1635
+ # # # # for i, similar_items in enumerate(
1636
+ # # # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
1637
+ # # # # ):
1638
+ # # # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
1639
+
1640
+ # # # # if similar_indices:
1641
+ # # # # duplicate_indices_in_test.append(i)
1642
+ # # # # duplicate_to_original_mapping[i] = similar_indices[0]
1643
+
1644
+ # # # # return duplicate_indices_in_test, duplicate_to_original_mapping
1645
+
1646
+ # # # # # Adjust the height of the status_output component using custom CSS
1647
+ # # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1648
+ # # # # gr.Markdown("# Semantic Deduplication")
1649
+
1650
+ # # # # deduplication_type = gr.Radio(
1651
+ # # # # choices=["Single dataset", "Cross-dataset"],
1652
+ # # # # label="Deduplication Type",
1653
+ # # # # value="Single dataset",
1654
+ # # # # )
1655
+
1656
+ # # # # with gr.Row():
1657
+ # # # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
1658
+ # # # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
1659
+ # # # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1660
+
1661
+ # # # # dataset2_inputs = gr.Column(visible=False)
1662
+ # # # # with dataset2_inputs:
1663
+ # # # # gr.Markdown("### Dataset 2")
1664
+ # # # # with gr.Row():
1665
+ # # # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
1666
+ # # # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1667
+ # # # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1668
+
1669
+ # # # # threshold = gr.Slider(
1670
+ # # # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
1671
+ # # # # )
1672
+
1673
+ # # # # compute_button = gr.Button("Compute")
1674
+
1675
+ # # # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
1676
+ # # # # status_output = gr.Markdown(elem_id="status_output")
1677
+ # # # # result_output = gr.Markdown()
1678
+
1679
+ # # # # # Function to update the visibility of dataset2_inputs
1680
+ # # # # def update_visibility(deduplication_type_value):
1681
+ # # # # if deduplication_type_value == "Cross-dataset":
1682
+ # # # # return gr.update(visible=True)
1683
+ # # # # else:
1684
+ # # # # return gr.update(visible=False)
1685
+
1686
+ # # # # deduplication_type.change(
1687
+ # # # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
1688
+ # # # # )
1689
+
1690
+ # # # # compute_button.click(
1691
+ # # # # fn=perform_deduplication,
1692
+ # # # # inputs=[
1693
+ # # # # deduplication_type,
1694
+ # # # # dataset1_name,
1695
+ # # # # dataset1_split,
1696
+ # # # # dataset1_text_column,
1697
+ # # # # dataset2_name,
1698
+ # # # # dataset2_split,
1699
+ # # # # dataset2_text_column,
1700
+ # # # # threshold,
1701
+ # # # # ],
1702
+ # # # # outputs=[status_output, result_output],
1703
+ # # # # )
1704
+
1705
+ # # # # demo.launch()