Pringled commited on
Commit
39a5b1c
·
1 Parent(s): a847bef
Files changed (1) hide show
  1. app.py +600 -281
app.py CHANGED
@@ -1,14 +1,16 @@
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
- import model2vec
5
  from reach import Reach
6
  from difflib import ndiff
 
7
 
8
  # Load the model at startup
9
- model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
10
 
11
- # Default dataset parameters
12
  default_dataset1_name = "sst2"
13
  default_dataset1_split = "train"
14
  default_dataset2_name = "sst2"
@@ -27,37 +29,39 @@ def batch_iterable(iterable, batch_size):
27
 
28
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
29
  embeddings = []
30
- total_batches = (len(texts) + batch_size - 1) // batch_size
31
- for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
32
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
33
  embeddings.append(batch_embeddings)
34
- progress((i + 1) / total_batches, desc=desc)
35
  return np.concatenate(embeddings, axis=0)
36
 
37
- def deduplicate(
38
- embedding_matrix: np.ndarray,
39
- threshold: float,
40
- batch_size: int = 1024,
41
- progress=None
42
- ) -> tuple[np.ndarray, dict[int, int]]:
43
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
44
 
45
  deduplicated_indices = set(range(len(embedding_matrix)))
46
  duplicate_to_original_mapping = {}
47
 
 
 
48
  results = reach.nearest_neighbor_threshold(
49
  embedding_matrix,
50
  threshold=threshold,
51
  batch_size=batch_size,
52
- show_progressbar=False,
53
  )
54
 
 
55
  total_items = len(embedding_matrix)
56
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
57
  if i not in deduplicated_indices:
58
  continue
59
 
60
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
 
61
  for sim_idx in similar_indices:
62
  if sim_idx in deduplicated_indices:
63
  deduplicated_indices.remove(sim_idx)
@@ -65,9 +69,40 @@ def deduplicate(
65
 
66
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def display_word_differences(x: str, y: str) -> str:
69
  diff = ndiff(x.split(), y.split())
70
- return " ".join([word for word in diff if word.startswith(("+", "-"))])
71
 
72
  def perform_deduplication(
73
  deduplication_type,
@@ -78,18 +113,42 @@ def perform_deduplication(
78
  dataset2_split="",
79
  dataset2_text_column="",
80
  threshold=default_threshold,
81
- progress=gr.Progress(track_tqdm=True),
82
  ):
83
  try:
 
84
  threshold = float(threshold)
85
 
 
 
 
86
  if deduplication_type == "Single dataset":
87
- ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
88
  texts = [example[dataset1_text_column] for example in ds]
89
 
90
- embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
91
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
 
92
 
 
 
 
 
 
 
 
 
93
  num_duplicates = len(duplicate_to_original_mapping)
94
  num_total = len(texts)
95
  num_deduplicated = len(deduplicated_indices)
@@ -98,6 +157,7 @@ def perform_deduplication(
98
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
99
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
100
 
 
101
  if num_duplicates > 0:
102
  result_text += "**Examples of duplicates found:**\n\n"
103
  num_examples = min(5, num_duplicates)
@@ -112,19 +172,93 @@ def perform_deduplication(
112
  else:
113
  result_text += "No duplicates found."
114
 
115
- yield result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  except Exception as e:
118
- yield f"An error occurred: {e}"
 
119
 
120
- # Gradio interface setup
121
- with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
122
  gr.Markdown("# Semantic Deduplication")
123
 
124
  deduplication_type = gr.Radio(
125
  choices=["Single dataset", "Cross-dataset"],
126
  label="Deduplication Type",
127
- value="Single dataset",
128
  )
129
 
130
  with gr.Row():
@@ -140,17 +274,29 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
140
  dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
141
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
142
 
143
- threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
 
 
 
 
 
144
 
145
  compute_button = gr.Button("Compute")
146
 
 
147
  result_output = gr.Markdown()
148
 
 
149
  def update_visibility(deduplication_type_value):
150
- return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
 
 
 
151
 
152
  deduplication_type.change(
153
- update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
 
154
  )
155
 
156
  compute_button.click(
@@ -163,9 +309,9 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
163
  dataset2_name,
164
  dataset2_split,
165
  dataset2_text_column,
166
- threshold,
167
  ],
168
- outputs=[result_output],
169
  )
170
 
171
  demo.launch()
@@ -177,7 +323,6 @@ demo.launch()
177
  # import model2vec
178
  # from reach import Reach
179
  # from difflib import ndiff
180
- # import time
181
 
182
  # # Load the model at startup
183
  # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -199,19 +344,7 @@ demo.launch()
199
  # for i in range(0, len(iterable), batch_size):
200
  # yield iterable[i:i + batch_size]
201
 
202
- # def log_time(message, start_time=None, logs=None):
203
- # """Helper function to log the start and end times."""
204
- # current_time = time.time()
205
- # if start_time is not None:
206
- # elapsed = current_time - start_time
207
- # log_message = f"{message} - Took {elapsed:.2f} seconds"
208
- # else:
209
- # log_message = f"{message} - Started"
210
-
211
- # if logs is not None:
212
- # logs.append(log_message)
213
-
214
- # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
215
  # embeddings = []
216
  # total_batches = (len(texts) + batch_size - 1) // batch_size
217
  # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -224,38 +357,26 @@ demo.launch()
224
  # embedding_matrix: np.ndarray,
225
  # threshold: float,
226
  # batch_size: int = 1024,
227
- # progress=None,
228
- # logs=None
229
  # ) -> tuple[np.ndarray, dict[int, int]]:
230
- # # Building the index
231
- # log_time("Building search index", logs=logs)
232
- # reach = Reach(
233
- # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
234
- # )
235
 
236
  # deduplicated_indices = set(range(len(embedding_matrix)))
237
  # duplicate_to_original_mapping = {}
238
 
239
- # # Finding nearest neighbors
240
- # log_time("Finding nearest neighbors", logs=logs)
241
  # results = reach.nearest_neighbor_threshold(
242
  # embedding_matrix,
243
  # threshold=threshold,
244
  # batch_size=batch_size,
245
- # show_progressbar=False, # Disable internal progress bar
246
  # )
247
 
248
- # # Processing duplicates with a progress bar
249
  # total_items = len(embedding_matrix)
250
- # log_time("Processing duplicates", logs=logs)
251
- # for i, similar_items in enumerate(
252
- # progress.tqdm(results, desc="Processing duplicates", total=total_items)
253
- # ):
254
  # if i not in deduplicated_indices:
255
  # continue
256
 
257
  # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
258
-
259
  # for sim_idx in similar_indices:
260
  # if sim_idx in deduplicated_indices:
261
  # deduplicated_indices.remove(sim_idx)
@@ -267,11 +388,6 @@ demo.launch()
267
  # diff = ndiff(x.split(), y.split())
268
  # return " ".join([word for word in diff if word.startswith(("+", "-"))])
269
 
270
- # def encode_texts(texts, progress=None, logs=None):
271
- # embedding_matrix = model.encode(texts, show_progressbar=False)
272
- # log_time("Encoding texts completed", logs=logs)
273
- # return embedding_matrix
274
-
275
  # def perform_deduplication(
276
  # deduplication_type,
277
  # dataset1_name,
@@ -283,59 +399,24 @@ demo.launch()
283
  # threshold=default_threshold,
284
  # progress=gr.Progress(track_tqdm=True),
285
  # ):
286
- # logs = [] # To store log messages
287
  # try:
288
- # # Convert threshold to float
289
  # threshold = float(threshold)
290
 
291
- # # Initialize status message
292
- # log_time("Deduplication started", logs=logs)
293
-
294
  # if deduplication_type == "Single dataset":
295
- # # Load Dataset 1
296
- # start_time = time.time()
297
- # log_time("Loading Dataset 1", logs=logs)
298
- # if (
299
- # dataset1_name == default_dataset1_name
300
- # and dataset1_split == default_dataset1_split
301
- # ):
302
- # ds = ds_default1
303
- # else:
304
- # ds = load_dataset(dataset1_name, split=dataset1_split)
305
- # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
306
-
307
- # # Extract texts
308
- # start_time = time.time()
309
- # log_time("Extracting texts from Dataset 1", logs=logs)
310
  # texts = [example[dataset1_text_column] for example in ds]
311
- # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
312
-
313
- # # Compute embeddings
314
- # start_time = time.time()
315
- # log_time("Computing embeddings for Dataset 1", logs=logs)
316
- # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
317
- # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
318
-
319
- # # Deduplicate
320
- # start_time = time.time()
321
- # log_time("Deduplicating embeddings", logs=logs)
322
- # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
323
- # embedding_matrix, threshold, progress=progress, logs=logs
324
- # )
325
- # log_time("Deduplication completed", start_time=start_time, logs=logs)
326
-
327
- # # Prepare the results
328
  # num_duplicates = len(duplicate_to_original_mapping)
329
  # num_total = len(texts)
330
  # num_deduplicated = len(deduplicated_indices)
331
 
332
  # result_text = f"**Total documents:** {num_total}\n"
333
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
334
- # result_text += (
335
- # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
336
- # )
337
 
338
- # # Show deduplicated examples
339
  # if num_duplicates > 0:
340
  # result_text += "**Examples of duplicates found:**\n\n"
341
  # num_examples = min(5, num_duplicates)
@@ -350,16 +431,12 @@ demo.launch()
350
  # else:
351
  # result_text += "No duplicates found."
352
 
353
- # log_time("Deduplication process finished", logs=logs)
354
- # full_log = "\n".join(logs) # Combine all logs into one output
355
- # yield full_log, result_text
356
 
357
  # except Exception as e:
358
- # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
359
- # yield f"An error occurred: {e}", ""
360
- # raise e
361
 
362
- # # Adjust the height of the status_output component using custom CSS
363
  # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
364
  # gr.Markdown("# Semantic Deduplication")
365
 
@@ -382,22 +459,14 @@ demo.launch()
382
  # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
383
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
384
 
385
- # threshold = gr.Slider(
386
- # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
387
- # )
388
 
389
  # compute_button = gr.Button("Compute")
390
 
391
- # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
392
- # status_output = gr.Markdown(elem_id="status_output")
393
  # result_output = gr.Markdown()
394
 
395
- # # Function to update the visibility of dataset2_inputs
396
  # def update_visibility(deduplication_type_value):
397
- # if deduplication_type_value == "Cross-dataset":
398
- # return gr.update(visible=True)
399
- # else:
400
- # return gr.update(visible=False)
401
 
402
  # deduplication_type.change(
403
  # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
@@ -415,21 +484,19 @@ demo.launch()
415
  # dataset2_text_column,
416
  # threshold,
417
  # ],
418
- # outputs=[status_output, result_output],
419
  # )
420
 
421
  # demo.launch()
422
 
423
 
424
-
425
  # # import gradio as gr
426
  # # from datasets import load_dataset
427
  # # import numpy as np
428
- # # #from model2vec import StaticModel
429
  # # import model2vec
430
  # # from reach import Reach
431
  # # from difflib import ndiff
432
-
433
 
434
  # # # Load the model at startup
435
  # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -446,13 +513,24 @@ demo.launch()
446
  # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
447
  # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
448
 
449
-
450
  # # def batch_iterable(iterable, batch_size):
451
  # # """Helper function to create batches from an iterable."""
452
  # # for i in range(0, len(iterable), batch_size):
453
  # # yield iterable[i:i + batch_size]
454
 
455
- # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
456
  # # embeddings = []
457
  # # total_batches = (len(texts) + batch_size - 1) // batch_size
458
  # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -465,10 +543,11 @@ demo.launch()
465
  # # embedding_matrix: np.ndarray,
466
  # # threshold: float,
467
  # # batch_size: int = 1024,
468
- # # progress=None
 
469
  # # ) -> tuple[np.ndarray, dict[int, int]]:
470
  # # # Building the index
471
- # # progress(0, desc="Building search index...")
472
  # # reach = Reach(
473
  # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
474
  # # )
@@ -477,7 +556,7 @@ demo.launch()
477
  # # duplicate_to_original_mapping = {}
478
 
479
  # # # Finding nearest neighbors
480
- # # progress(0, desc="Finding nearest neighbors...")
481
  # # results = reach.nearest_neighbor_threshold(
482
  # # embedding_matrix,
483
  # # threshold=threshold,
@@ -487,6 +566,7 @@ demo.launch()
487
 
488
  # # # Processing duplicates with a progress bar
489
  # # total_items = len(embedding_matrix)
 
490
  # # for i, similar_items in enumerate(
491
  # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
492
  # # ):
@@ -506,9 +586,9 @@ demo.launch()
506
  # # diff = ndiff(x.split(), y.split())
507
  # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
508
 
509
-
510
- # # def encode_texts(texts, progress=None):
511
  # # embedding_matrix = model.encode(texts, show_progressbar=False)
 
512
  # # return embedding_matrix
513
 
514
  # # def perform_deduplication(
@@ -522,17 +602,18 @@ demo.launch()
522
  # # threshold=default_threshold,
523
  # # progress=gr.Progress(track_tqdm=True),
524
  # # ):
 
525
  # # try:
526
  # # # Convert threshold to float
527
  # # threshold = float(threshold)
528
 
529
  # # # Initialize status message
530
- # # status = ""
531
 
532
  # # if deduplication_type == "Single dataset":
533
  # # # Load Dataset 1
534
- # # status = "Loading Dataset 1..."
535
- # # yield status, ""
536
  # # if (
537
  # # dataset1_name == default_dataset1_name
538
  # # and dataset1_split == default_dataset1_split
@@ -540,29 +621,27 @@ demo.launch()
540
  # # ds = ds_default1
541
  # # else:
542
  # # ds = load_dataset(dataset1_name, split=dataset1_split)
 
543
 
544
  # # # Extract texts
545
- # # status = "Extracting texts from Dataset 1..."
546
- # # yield status, ""
547
  # # texts = [example[dataset1_text_column] for example in ds]
 
 
548
  # # # Compute embeddings
549
- # # status = "Computing embeddings for Dataset 1..."
550
- # # yield status, ""
551
- # # embedding_matrix = encode_texts(texts, progress=progress)
552
- # # #embedding_matrix = model.encode(texts, show_progressbar=True)
553
- # # # embedding_matrix = compute_embeddings(
554
- # # # texts,
555
- # # # batch_size=64,
556
- # # # progress=progress,
557
- # # # desc="Computing embeddings for Dataset 1",
558
- # # # )
559
 
560
  # # # Deduplicate
561
- # # status = "Deduplicating embeddings..."
562
- # # yield status, ""
563
  # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
564
- # # embedding_matrix, threshold, progress=progress
565
  # # )
 
566
 
567
  # # # Prepare the results
568
  # # num_duplicates = len(duplicate_to_original_mapping)
@@ -590,141 +669,15 @@ demo.launch()
590
  # # else:
591
  # # result_text += "No duplicates found."
592
 
593
- # # # Final status
594
- # # status = "Deduplication completed."
595
- # # yield status, result_text
596
-
597
- # # elif deduplication_type == "Cross-dataset":
598
- # # # Similar code for cross-dataset deduplication
599
- # # # Load Dataset 1
600
- # # status = "Loading Dataset 1..."
601
- # # yield status, ""
602
- # # if (
603
- # # dataset1_name == default_dataset1_name
604
- # # and dataset1_split == default_dataset1_split
605
- # # ):
606
- # # ds1 = ds_default1
607
- # # else:
608
- # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
609
-
610
- # # # Load Dataset 2
611
- # # status = "Loading Dataset 2..."
612
- # # yield status, ""
613
- # # if (
614
- # # dataset2_name == default_dataset2_name
615
- # # and dataset2_split == default_dataset2_split
616
- # # ):
617
- # # ds2 = ds_default2
618
- # # else:
619
- # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
620
-
621
- # # # Extract texts from Dataset 1
622
- # # status = "Extracting texts from Dataset 1..."
623
- # # yield status, ""
624
- # # texts1 = [example[dataset1_text_column] for example in ds1]
625
-
626
- # # # Extract texts from Dataset 2
627
- # # status = "Extracting texts from Dataset 2..."
628
- # # yield status, ""
629
- # # texts2 = [example[dataset2_text_column] for example in ds2]
630
-
631
- # # # Compute embeddings for Dataset 1
632
- # # status = "Computing embeddings for Dataset 1..."
633
- # # yield status, ""
634
- # # embedding_matrix1 = compute_embeddings(
635
- # # texts1,
636
- # # batch_size=64,
637
- # # progress=progress,
638
- # # desc="Computing embeddings for Dataset 1",
639
- # # )
640
-
641
- # # # Compute embeddings for Dataset 2
642
- # # status = "Computing embeddings for Dataset 2..."
643
- # # yield status, ""
644
- # # embedding_matrix2 = compute_embeddings(
645
- # # texts2,
646
- # # batch_size=64,
647
- # # progress=progress,
648
- # # desc="Computing embeddings for Dataset 2",
649
- # # )
650
-
651
- # # # Deduplicate across datasets
652
- # # status = "Deduplicating embeddings across datasets..."
653
- # # yield status, ""
654
- # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
655
- # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
656
- # # )
657
-
658
- # # num_duplicates = len(duplicate_indices_in_ds2)
659
- # # num_total_ds2 = len(texts2)
660
- # # num_unique_ds2 = num_total_ds2 - num_duplicates
661
-
662
- # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
663
- # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
664
- # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
665
-
666
- # # # Show deduplicated examples
667
- # # if num_duplicates > 0:
668
- # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
669
- # # num_examples = min(5, num_duplicates)
670
- # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
671
- # # original_idx = duplicate_to_original_mapping[duplicate_idx]
672
- # # original_text = texts1[original_idx]
673
- # # duplicate_text = texts2[duplicate_idx]
674
- # # differences = display_word_differences(original_text, duplicate_text)
675
- # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
676
- # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
677
- # # result_text += f"**Differences:**\n{differences}\n"
678
- # # result_text += "-" * 50 + "\n\n"
679
- # # else:
680
- # # result_text += "No duplicates found."
681
-
682
- # # # Final status
683
- # # status = "Deduplication completed."
684
- # # yield status, result_text
685
 
686
  # # except Exception as e:
 
687
  # # yield f"An error occurred: {e}", ""
688
  # # raise e
689
 
690
- # # def deduplicate_across_datasets(
691
- # # embedding_matrix_1: np.ndarray,
692
- # # embedding_matrix_2: np.ndarray,
693
- # # threshold: float,
694
- # # batch_size: int = 1024,
695
- # # progress=None
696
- # # ) -> tuple[list[int], dict[int, int]]:
697
- # # # Building the index from Dataset 1
698
- # # progress(0, desc="Building search index from Dataset 1...")
699
- # # reach = Reach(
700
- # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
701
- # # )
702
-
703
- # # duplicate_indices_in_test = []
704
- # # duplicate_to_original_mapping = {}
705
-
706
- # # # Finding nearest neighbors between datasets
707
- # # progress(0, desc="Finding nearest neighbors between datasets...")
708
- # # results = reach.nearest_neighbor_threshold(
709
- # # embedding_matrix_2,
710
- # # threshold=threshold,
711
- # # batch_size=batch_size,
712
- # # show_progressbar=False, # Disable internal progress bar
713
- # # )
714
-
715
- # # total_items = len(embedding_matrix_2)
716
- # # # Processing duplicates with a progress bar
717
- # # for i, similar_items in enumerate(
718
- # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
719
- # # ):
720
- # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
721
-
722
- # # if similar_indices:
723
- # # duplicate_indices_in_test.append(i)
724
- # # duplicate_to_original_mapping[i] = similar_indices[0]
725
-
726
- # # return duplicate_indices_in_test, duplicate_to_original_mapping
727
-
728
  # # # Adjust the height of the status_output component using custom CSS
729
  # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
730
  # # gr.Markdown("# Semantic Deduplication")
@@ -785,3 +738,369 @@ demo.launch()
785
  # # )
786
 
787
  # # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import gradio as gr
3
  from datasets import load_dataset
4
  import numpy as np
5
+ from model2vec import StaticModel
6
  from reach import Reach
7
  from difflib import ndiff
8
+ import tqdm
9
 
10
  # Load the model at startup
11
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
+ # Update default dataset to 'sst2' and set default threshold to 0.9
14
  default_dataset1_name = "sst2"
15
  default_dataset1_split = "train"
16
  default_dataset2_name = "sst2"
 
29
 
30
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
31
  embeddings = []
32
+ for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
33
+ batch_embeddings = model.encode(batch, show_progressbar=False)
 
34
  embeddings.append(batch_embeddings)
 
35
  return np.concatenate(embeddings, axis=0)
36
 
37
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
38
+ """
39
+ Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
40
+ """
41
+ # Building the index
42
+ progress(0, desc="Building search index...")
43
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
44
 
45
  deduplicated_indices = set(range(len(embedding_matrix)))
46
  duplicate_to_original_mapping = {}
47
 
48
+ # Finding nearest neighbors
49
+ progress(0, desc="Finding nearest neighbors...")
50
  results = reach.nearest_neighbor_threshold(
51
  embedding_matrix,
52
  threshold=threshold,
53
  batch_size=batch_size,
54
+ show_progressbar=False # Disable internal progress bar
55
  )
56
 
57
+ # Processing duplicates with a progress bar
58
  total_items = len(embedding_matrix)
59
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
60
  if i not in deduplicated_indices:
61
  continue
62
 
63
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
64
+
65
  for sim_idx in similar_indices:
66
  if sim_idx in deduplicated_indices:
67
  deduplicated_indices.remove(sim_idx)
 
69
 
70
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
71
 
72
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
73
+ """
74
+ Deduplicate embeddings across two datasets and return the indices of duplicates between them.
75
+ """
76
+ # Building the index from Dataset 1
77
+ progress(0, desc="Building search index from Dataset 1...")
78
+ reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
79
+
80
+ duplicate_indices_in_test = []
81
+ duplicate_to_original_mapping = {}
82
+
83
+ # Finding nearest neighbors between datasets
84
+ progress(0, desc="Finding nearest neighbors between datasets...")
85
+ results = reach.nearest_neighbor_threshold(
86
+ embedding_matrix_2,
87
+ threshold=threshold,
88
+ batch_size=batch_size,
89
+ show_progressbar=False # Disable internal progress bar
90
+ )
91
+
92
+ total_items = len(embedding_matrix_2)
93
+ # Processing duplicates with a progress bar
94
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
95
+ similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
96
+
97
+ if similar_indices:
98
+ duplicate_indices_in_test.append(i)
99
+ duplicate_to_original_mapping[i] = similar_indices[0]
100
+
101
+ return duplicate_indices_in_test, duplicate_to_original_mapping
102
+
103
  def display_word_differences(x: str, y: str) -> str:
104
  diff = ndiff(x.split(), y.split())
105
+ return " ".join([word for word in diff if word.startswith(('+', '-'))])
106
 
107
  def perform_deduplication(
108
  deduplication_type,
 
113
  dataset2_split="",
114
  dataset2_text_column="",
115
  threshold=default_threshold,
116
+ progress=gr.Progress(track_tqdm=True)
117
  ):
118
  try:
119
+ # Convert threshold to float
120
  threshold = float(threshold)
121
 
122
+ # Initialize status message
123
+ status = ""
124
+
125
  if deduplication_type == "Single dataset":
126
+ # Load Dataset 1
127
+ status = "Loading Dataset 1..."
128
+ yield status, ""
129
+ if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
130
+ ds = ds_default1
131
+ else:
132
+ ds = load_dataset(dataset1_name, split=dataset1_split)
133
+
134
+ # Extract texts
135
+ status = "Extracting texts from Dataset 1..."
136
+ yield status, ""
137
  texts = [example[dataset1_text_column] for example in ds]
138
 
139
+ # Compute embeddings
140
+ status = "Computing embeddings for Dataset 1..."
141
+ yield status, ""
142
+ embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
143
 
144
+ # Deduplicate
145
+ status = "Deduplicating embeddings..."
146
+ yield status, ""
147
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(
148
+ embedding_matrix, threshold, progress=progress
149
+ )
150
+
151
+ # Prepare the results
152
  num_duplicates = len(duplicate_to_original_mapping)
153
  num_total = len(texts)
154
  num_deduplicated = len(deduplicated_indices)
 
157
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
158
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
159
 
160
+ # Show deduplicated examples
161
  if num_duplicates > 0:
162
  result_text += "**Examples of duplicates found:**\n\n"
163
  num_examples = min(5, num_duplicates)
 
172
  else:
173
  result_text += "No duplicates found."
174
 
175
+ # Final status
176
+ status = "Deduplication completed."
177
+ yield status, result_text
178
+
179
+ elif deduplication_type == "Cross-dataset":
180
+ # Load Dataset 1
181
+ status = "Loading Dataset 1..."
182
+ yield status, ""
183
+ if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
184
+ ds1 = ds_default1
185
+ else:
186
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
187
+
188
+ # Load Dataset 2
189
+ status = "Loading Dataset 2..."
190
+ yield status, ""
191
+ if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
192
+ ds2 = ds_default2
193
+ else:
194
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
195
+
196
+ # Extract texts from Dataset 1
197
+ status = "Extracting texts from Dataset 1..."
198
+ yield status, ""
199
+ texts1 = [example[dataset1_text_column] for example in ds1]
200
+
201
+ # Extract texts from Dataset 2
202
+ status = "Extracting texts from Dataset 2..."
203
+ yield status, ""
204
+ texts2 = [example[dataset2_text_column] for example in ds2]
205
+
206
+ # Compute embeddings for Dataset 1
207
+ status = "Computing embeddings for Dataset 1..."
208
+ yield status, ""
209
+ embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
210
+
211
+ # Compute embeddings for Dataset 2
212
+ status = "Computing embeddings for Dataset 2..."
213
+ yield status, ""
214
+ embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
215
+
216
+ # Deduplicate across datasets
217
+ status = "Deduplicating embeddings across datasets..."
218
+ yield status, ""
219
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
220
+ embedding_matrix1, embedding_matrix2, threshold, progress=progress
221
+ )
222
+
223
+ num_duplicates = len(duplicate_indices_in_ds2)
224
+ num_total_ds2 = len(texts2)
225
+ num_unique_ds2 = num_total_ds2 - num_duplicates
226
+
227
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
228
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
229
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
230
+
231
+ # Show deduplicated examples
232
+ if num_duplicates > 0:
233
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
234
+ num_examples = min(5, num_duplicates)
235
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
236
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
237
+ original_text = texts1[original_idx]
238
+ duplicate_text = texts2[duplicate_idx]
239
+ differences = display_word_differences(original_text, duplicate_text)
240
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
241
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
242
+ result_text += f"**Differences:**\n{differences}\n"
243
+ result_text += "-" * 50 + "\n\n"
244
+ else:
245
+ result_text += "No duplicates found."
246
+
247
+ # Final status
248
+ status = "Deduplication completed."
249
+ yield status, result_text
250
 
251
  except Exception as e:
252
+ yield f"An error occurred: {e}", ""
253
+ raise e
254
 
255
+ with gr.Blocks() as demo:
 
256
  gr.Markdown("# Semantic Deduplication")
257
 
258
  deduplication_type = gr.Radio(
259
  choices=["Single dataset", "Cross-dataset"],
260
  label="Deduplication Type",
261
+ value="Single dataset"
262
  )
263
 
264
  with gr.Row():
 
274
  dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
275
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
276
 
277
+ threshold = gr.Slider(
278
+ minimum=0.0,
279
+ maximum=1.0,
280
+ value=default_threshold,
281
+ label="Similarity Threshold"
282
+ )
283
 
284
  compute_button = gr.Button("Compute")
285
 
286
+ status_output = gr.Markdown()
287
  result_output = gr.Markdown()
288
 
289
+ # Function to update the visibility of dataset2_inputs
290
  def update_visibility(deduplication_type_value):
291
+ if deduplication_type_value == "Cross-dataset":
292
+ return gr.update(visible=True)
293
+ else:
294
+ return gr.update(visible=False)
295
 
296
  deduplication_type.change(
297
+ update_visibility,
298
+ inputs=deduplication_type,
299
+ outputs=dataset2_inputs
300
  )
301
 
302
  compute_button.click(
 
309
  dataset2_name,
310
  dataset2_split,
311
  dataset2_text_column,
312
+ threshold
313
  ],
314
+ outputs=[status_output, result_output]
315
  )
316
 
317
  demo.launch()
 
323
  # import model2vec
324
  # from reach import Reach
325
  # from difflib import ndiff
 
326
 
327
  # # Load the model at startup
328
  # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
344
  # for i in range(0, len(iterable), batch_size):
345
  # yield iterable[i:i + batch_size]
346
 
347
+ # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
 
 
 
 
 
 
 
 
348
  # embeddings = []
349
  # total_batches = (len(texts) + batch_size - 1) // batch_size
350
  # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
357
  # embedding_matrix: np.ndarray,
358
  # threshold: float,
359
  # batch_size: int = 1024,
360
+ # progress=None
 
361
  # ) -> tuple[np.ndarray, dict[int, int]]:
362
+ # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
363
 
364
  # deduplicated_indices = set(range(len(embedding_matrix)))
365
  # duplicate_to_original_mapping = {}
366
 
 
 
367
  # results = reach.nearest_neighbor_threshold(
368
  # embedding_matrix,
369
  # threshold=threshold,
370
  # batch_size=batch_size,
371
+ # show_progressbar=False,
372
  # )
373
 
 
374
  # total_items = len(embedding_matrix)
375
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
 
 
 
376
  # if i not in deduplicated_indices:
377
  # continue
378
 
379
  # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
 
380
  # for sim_idx in similar_indices:
381
  # if sim_idx in deduplicated_indices:
382
  # deduplicated_indices.remove(sim_idx)
 
388
  # diff = ndiff(x.split(), y.split())
389
  # return " ".join([word for word in diff if word.startswith(("+", "-"))])
390
 
 
 
 
 
 
391
  # def perform_deduplication(
392
  # deduplication_type,
393
  # dataset1_name,
 
399
  # threshold=default_threshold,
400
  # progress=gr.Progress(track_tqdm=True),
401
  # ):
 
402
  # try:
 
403
  # threshold = float(threshold)
404
 
 
 
 
405
  # if deduplication_type == "Single dataset":
406
+ # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  # texts = [example[dataset1_text_column] for example in ds]
408
+
409
+ # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
410
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
411
+
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  # num_duplicates = len(duplicate_to_original_mapping)
413
  # num_total = len(texts)
414
  # num_deduplicated = len(deduplicated_indices)
415
 
416
  # result_text = f"**Total documents:** {num_total}\n"
417
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
418
+ # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
419
 
 
420
  # if num_duplicates > 0:
421
  # result_text += "**Examples of duplicates found:**\n\n"
422
  # num_examples = min(5, num_duplicates)
 
431
  # else:
432
  # result_text += "No duplicates found."
433
 
434
+ # yield result_text
 
 
435
 
436
  # except Exception as e:
437
+ # yield f"An error occurred: {e}"
 
 
438
 
439
+ # # Gradio interface setup
440
  # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
441
  # gr.Markdown("# Semantic Deduplication")
442
 
 
459
  # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
460
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
461
 
462
+ # threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
 
 
463
 
464
  # compute_button = gr.Button("Compute")
465
 
 
 
466
  # result_output = gr.Markdown()
467
 
 
468
  # def update_visibility(deduplication_type_value):
469
+ # return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
 
 
 
470
 
471
  # deduplication_type.change(
472
  # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
484
  # dataset2_text_column,
485
  # threshold,
486
  # ],
487
+ # outputs=[result_output],
488
  # )
489
 
490
  # demo.launch()
491
 
492
 
 
493
  # # import gradio as gr
494
  # # from datasets import load_dataset
495
  # # import numpy as np
 
496
  # # import model2vec
497
  # # from reach import Reach
498
  # # from difflib import ndiff
499
+ # # import time
500
 
501
  # # # Load the model at startup
502
  # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
513
  # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
514
  # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
515
 
 
516
  # # def batch_iterable(iterable, batch_size):
517
  # # """Helper function to create batches from an iterable."""
518
  # # for i in range(0, len(iterable), batch_size):
519
  # # yield iterable[i:i + batch_size]
520
 
521
+ # # def log_time(message, start_time=None, logs=None):
522
+ # # """Helper function to log the start and end times."""
523
+ # # current_time = time.time()
524
+ # # if start_time is not None:
525
+ # # elapsed = current_time - start_time
526
+ # # log_message = f"{message} - Took {elapsed:.2f} seconds"
527
+ # # else:
528
+ # # log_message = f"{message} - Started"
529
+
530
+ # # if logs is not None:
531
+ # # logs.append(log_message)
532
+
533
+ # # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
534
  # # embeddings = []
535
  # # total_batches = (len(texts) + batch_size - 1) // batch_size
536
  # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
543
  # # embedding_matrix: np.ndarray,
544
  # # threshold: float,
545
  # # batch_size: int = 1024,
546
+ # # progress=None,
547
+ # # logs=None
548
  # # ) -> tuple[np.ndarray, dict[int, int]]:
549
  # # # Building the index
550
+ # # log_time("Building search index", logs=logs)
551
  # # reach = Reach(
552
  # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
553
  # # )
 
556
  # # duplicate_to_original_mapping = {}
557
 
558
  # # # Finding nearest neighbors
559
+ # # log_time("Finding nearest neighbors", logs=logs)
560
  # # results = reach.nearest_neighbor_threshold(
561
  # # embedding_matrix,
562
  # # threshold=threshold,
 
566
 
567
  # # # Processing duplicates with a progress bar
568
  # # total_items = len(embedding_matrix)
569
+ # # log_time("Processing duplicates", logs=logs)
570
  # # for i, similar_items in enumerate(
571
  # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
572
  # # ):
 
586
  # # diff = ndiff(x.split(), y.split())
587
  # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
588
 
589
+ # # def encode_texts(texts, progress=None, logs=None):
 
590
  # # embedding_matrix = model.encode(texts, show_progressbar=False)
591
+ # # log_time("Encoding texts completed", logs=logs)
592
  # # return embedding_matrix
593
 
594
  # # def perform_deduplication(
 
602
  # # threshold=default_threshold,
603
  # # progress=gr.Progress(track_tqdm=True),
604
  # # ):
605
+ # # logs = [] # To store log messages
606
  # # try:
607
  # # # Convert threshold to float
608
  # # threshold = float(threshold)
609
 
610
  # # # Initialize status message
611
+ # # log_time("Deduplication started", logs=logs)
612
 
613
  # # if deduplication_type == "Single dataset":
614
  # # # Load Dataset 1
615
+ # # start_time = time.time()
616
+ # # log_time("Loading Dataset 1", logs=logs)
617
  # # if (
618
  # # dataset1_name == default_dataset1_name
619
  # # and dataset1_split == default_dataset1_split
 
621
  # # ds = ds_default1
622
  # # else:
623
  # # ds = load_dataset(dataset1_name, split=dataset1_split)
624
+ # # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
625
 
626
  # # # Extract texts
627
+ # # start_time = time.time()
628
+ # # log_time("Extracting texts from Dataset 1", logs=logs)
629
  # # texts = [example[dataset1_text_column] for example in ds]
630
+ # # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
631
+
632
  # # # Compute embeddings
633
+ # # start_time = time.time()
634
+ # # log_time("Computing embeddings for Dataset 1", logs=logs)
635
+ # # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
636
+ # # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
 
 
 
 
 
 
637
 
638
  # # # Deduplicate
639
+ # # start_time = time.time()
640
+ # # log_time("Deduplicating embeddings", logs=logs)
641
  # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
642
+ # # embedding_matrix, threshold, progress=progress, logs=logs
643
  # # )
644
+ # # log_time("Deduplication completed", start_time=start_time, logs=logs)
645
 
646
  # # # Prepare the results
647
  # # num_duplicates = len(duplicate_to_original_mapping)
 
669
  # # else:
670
  # # result_text += "No duplicates found."
671
 
672
+ # # log_time("Deduplication process finished", logs=logs)
673
+ # # full_log = "\n".join(logs) # Combine all logs into one output
674
+ # # yield full_log, result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  # # except Exception as e:
677
+ # # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
678
  # # yield f"An error occurred: {e}", ""
679
  # # raise e
680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  # # # Adjust the height of the status_output component using custom CSS
682
  # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
683
  # # gr.Markdown("# Semantic Deduplication")
 
738
  # # )
739
 
740
  # # demo.launch()
741
+
742
+
743
+
744
+ # # # import gradio as gr
745
+ # # # from datasets import load_dataset
746
+ # # # import numpy as np
747
+ # # # #from model2vec import StaticModel
748
+ # # # import model2vec
749
+ # # # from reach import Reach
750
+ # # # from difflib import ndiff
751
+
752
+
753
+ # # # # Load the model at startup
754
+ # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
755
+
756
+ # # # # Default dataset parameters
757
+ # # # default_dataset1_name = "sst2"
758
+ # # # default_dataset1_split = "train"
759
+ # # # default_dataset2_name = "sst2"
760
+ # # # default_dataset2_split = "validation"
761
+ # # # default_text_column = "sentence"
762
+ # # # default_threshold = 0.9
763
+
764
+ # # # # Load the default datasets at startup
765
+ # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
766
+ # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
767
+
768
+
769
+ # # # def batch_iterable(iterable, batch_size):
770
+ # # # """Helper function to create batches from an iterable."""
771
+ # # # for i in range(0, len(iterable), batch_size):
772
+ # # # yield iterable[i:i + batch_size]
773
+
774
+ # # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
775
+ # # # embeddings = []
776
+ # # # total_batches = (len(texts) + batch_size - 1) // batch_size
777
+ # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
778
+ # # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
779
+ # # # embeddings.append(batch_embeddings)
780
+ # # # progress((i + 1) / total_batches, desc=desc)
781
+ # # # return np.concatenate(embeddings, axis=0)
782
+
783
+ # # # def deduplicate(
784
+ # # # embedding_matrix: np.ndarray,
785
+ # # # threshold: float,
786
+ # # # batch_size: int = 1024,
787
+ # # # progress=None
788
+ # # # ) -> tuple[np.ndarray, dict[int, int]]:
789
+ # # # # Building the index
790
+ # # # progress(0, desc="Building search index...")
791
+ # # # reach = Reach(
792
+ # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
793
+ # # # )
794
+
795
+ # # # deduplicated_indices = set(range(len(embedding_matrix)))
796
+ # # # duplicate_to_original_mapping = {}
797
+
798
+ # # # # Finding nearest neighbors
799
+ # # # progress(0, desc="Finding nearest neighbors...")
800
+ # # # results = reach.nearest_neighbor_threshold(
801
+ # # # embedding_matrix,
802
+ # # # threshold=threshold,
803
+ # # # batch_size=batch_size,
804
+ # # # show_progressbar=False, # Disable internal progress bar
805
+ # # # )
806
+
807
+ # # # # Processing duplicates with a progress bar
808
+ # # # total_items = len(embedding_matrix)
809
+ # # # for i, similar_items in enumerate(
810
+ # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
811
+ # # # ):
812
+ # # # if i not in deduplicated_indices:
813
+ # # # continue
814
+
815
+ # # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
816
+
817
+ # # # for sim_idx in similar_indices:
818
+ # # # if sim_idx in deduplicated_indices:
819
+ # # # deduplicated_indices.remove(sim_idx)
820
+ # # # duplicate_to_original_mapping[sim_idx] = i
821
+
822
+ # # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
823
+
824
+ # # # def display_word_differences(x: str, y: str) -> str:
825
+ # # # diff = ndiff(x.split(), y.split())
826
+ # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
827
+
828
+
829
+ # # # def encode_texts(texts, progress=None):
830
+ # # # embedding_matrix = model.encode(texts, show_progressbar=False)
831
+ # # # return embedding_matrix
832
+
833
+ # # # def perform_deduplication(
834
+ # # # deduplication_type,
835
+ # # # dataset1_name,
836
+ # # # dataset1_split,
837
+ # # # dataset1_text_column,
838
+ # # # dataset2_name="",
839
+ # # # dataset2_split="",
840
+ # # # dataset2_text_column="",
841
+ # # # threshold=default_threshold,
842
+ # # # progress=gr.Progress(track_tqdm=True),
843
+ # # # ):
844
+ # # # try:
845
+ # # # # Convert threshold to float
846
+ # # # threshold = float(threshold)
847
+
848
+ # # # # Initialize status message
849
+ # # # status = ""
850
+
851
+ # # # if deduplication_type == "Single dataset":
852
+ # # # # Load Dataset 1
853
+ # # # status = "Loading Dataset 1..."
854
+ # # # yield status, ""
855
+ # # # if (
856
+ # # # dataset1_name == default_dataset1_name
857
+ # # # and dataset1_split == default_dataset1_split
858
+ # # # ):
859
+ # # # ds = ds_default1
860
+ # # # else:
861
+ # # # ds = load_dataset(dataset1_name, split=dataset1_split)
862
+
863
+ # # # # Extract texts
864
+ # # # status = "Extracting texts from Dataset 1..."
865
+ # # # yield status, ""
866
+ # # # texts = [example[dataset1_text_column] for example in ds]
867
+ # # # # Compute embeddings
868
+ # # # status = "Computing embeddings for Dataset 1..."
869
+ # # # yield status, ""
870
+ # # # embedding_matrix = encode_texts(texts, progress=progress)
871
+ # # # #embedding_matrix = model.encode(texts, show_progressbar=True)
872
+ # # # # embedding_matrix = compute_embeddings(
873
+ # # # # texts,
874
+ # # # # batch_size=64,
875
+ # # # # progress=progress,
876
+ # # # # desc="Computing embeddings for Dataset 1",
877
+ # # # # )
878
+
879
+ # # # # Deduplicate
880
+ # # # status = "Deduplicating embeddings..."
881
+ # # # yield status, ""
882
+ # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
883
+ # # # embedding_matrix, threshold, progress=progress
884
+ # # # )
885
+
886
+ # # # # Prepare the results
887
+ # # # num_duplicates = len(duplicate_to_original_mapping)
888
+ # # # num_total = len(texts)
889
+ # # # num_deduplicated = len(deduplicated_indices)
890
+
891
+ # # # result_text = f"**Total documents:** {num_total}\n"
892
+ # # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
893
+ # # # result_text += (
894
+ # # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
895
+ # # # )
896
+
897
+ # # # # Show deduplicated examples
898
+ # # # if num_duplicates > 0:
899
+ # # # result_text += "**Examples of duplicates found:**\n\n"
900
+ # # # num_examples = min(5, num_duplicates)
901
+ # # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
902
+ # # # original_text = texts[original_idx]
903
+ # # # duplicate_text = texts[duplicate_idx]
904
+ # # # differences = display_word_differences(original_text, duplicate_text)
905
+ # # # result_text += f"**Original text:**\n{original_text}\n\n"
906
+ # # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
907
+ # # # result_text += f"**Differences:**\n{differences}\n"
908
+ # # # result_text += "-" * 50 + "\n\n"
909
+ # # # else:
910
+ # # # result_text += "No duplicates found."
911
+
912
+ # # # # Final status
913
+ # # # status = "Deduplication completed."
914
+ # # # yield status, result_text
915
+
916
+ # # # elif deduplication_type == "Cross-dataset":
917
+ # # # # Similar code for cross-dataset deduplication
918
+ # # # # Load Dataset 1
919
+ # # # status = "Loading Dataset 1..."
920
+ # # # yield status, ""
921
+ # # # if (
922
+ # # # dataset1_name == default_dataset1_name
923
+ # # # and dataset1_split == default_dataset1_split
924
+ # # # ):
925
+ # # # ds1 = ds_default1
926
+ # # # else:
927
+ # # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
928
+
929
+ # # # # Load Dataset 2
930
+ # # # status = "Loading Dataset 2..."
931
+ # # # yield status, ""
932
+ # # # if (
933
+ # # # dataset2_name == default_dataset2_name
934
+ # # # and dataset2_split == default_dataset2_split
935
+ # # # ):
936
+ # # # ds2 = ds_default2
937
+ # # # else:
938
+ # # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
939
+
940
+ # # # # Extract texts from Dataset 1
941
+ # # # status = "Extracting texts from Dataset 1..."
942
+ # # # yield status, ""
943
+ # # # texts1 = [example[dataset1_text_column] for example in ds1]
944
+
945
+ # # # # Extract texts from Dataset 2
946
+ # # # status = "Extracting texts from Dataset 2..."
947
+ # # # yield status, ""
948
+ # # # texts2 = [example[dataset2_text_column] for example in ds2]
949
+
950
+ # # # # Compute embeddings for Dataset 1
951
+ # # # status = "Computing embeddings for Dataset 1..."
952
+ # # # yield status, ""
953
+ # # # embedding_matrix1 = compute_embeddings(
954
+ # # # texts1,
955
+ # # # batch_size=64,
956
+ # # # progress=progress,
957
+ # # # desc="Computing embeddings for Dataset 1",
958
+ # # # )
959
+
960
+ # # # # Compute embeddings for Dataset 2
961
+ # # # status = "Computing embeddings for Dataset 2..."
962
+ # # # yield status, ""
963
+ # # # embedding_matrix2 = compute_embeddings(
964
+ # # # texts2,
965
+ # # # batch_size=64,
966
+ # # # progress=progress,
967
+ # # # desc="Computing embeddings for Dataset 2",
968
+ # # # )
969
+
970
+ # # # # Deduplicate across datasets
971
+ # # # status = "Deduplicating embeddings across datasets..."
972
+ # # # yield status, ""
973
+ # # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
974
+ # # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
975
+ # # # )
976
+
977
+ # # # num_duplicates = len(duplicate_indices_in_ds2)
978
+ # # # num_total_ds2 = len(texts2)
979
+ # # # num_unique_ds2 = num_total_ds2 - num_duplicates
980
+
981
+ # # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
982
+ # # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
983
+ # # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
984
+
985
+ # # # # Show deduplicated examples
986
+ # # # if num_duplicates > 0:
987
+ # # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
988
+ # # # num_examples = min(5, num_duplicates)
989
+ # # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
990
+ # # # original_idx = duplicate_to_original_mapping[duplicate_idx]
991
+ # # # original_text = texts1[original_idx]
992
+ # # # duplicate_text = texts2[duplicate_idx]
993
+ # # # differences = display_word_differences(original_text, duplicate_text)
994
+ # # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
995
+ # # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
996
+ # # # result_text += f"**Differences:**\n{differences}\n"
997
+ # # # result_text += "-" * 50 + "\n\n"
998
+ # # # else:
999
+ # # # result_text += "No duplicates found."
1000
+
1001
+ # # # # Final status
1002
+ # # # status = "Deduplication completed."
1003
+ # # # yield status, result_text
1004
+
1005
+ # # # except Exception as e:
1006
+ # # # yield f"An error occurred: {e}", ""
1007
+ # # # raise e
1008
+
1009
+ # # # def deduplicate_across_datasets(
1010
+ # # # embedding_matrix_1: np.ndarray,
1011
+ # # # embedding_matrix_2: np.ndarray,
1012
+ # # # threshold: float,
1013
+ # # # batch_size: int = 1024,
1014
+ # # # progress=None
1015
+ # # # ) -> tuple[list[int], dict[int, int]]:
1016
+ # # # # Building the index from Dataset 1
1017
+ # # # progress(0, desc="Building search index from Dataset 1...")
1018
+ # # # reach = Reach(
1019
+ # # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
1020
+ # # # )
1021
+
1022
+ # # # duplicate_indices_in_test = []
1023
+ # # # duplicate_to_original_mapping = {}
1024
+
1025
+ # # # # Finding nearest neighbors between datasets
1026
+ # # # progress(0, desc="Finding nearest neighbors between datasets...")
1027
+ # # # results = reach.nearest_neighbor_threshold(
1028
+ # # # embedding_matrix_2,
1029
+ # # # threshold=threshold,
1030
+ # # # batch_size=batch_size,
1031
+ # # # show_progressbar=False, # Disable internal progress bar
1032
+ # # # )
1033
+
1034
+ # # # total_items = len(embedding_matrix_2)
1035
+ # # # # Processing duplicates with a progress bar
1036
+ # # # for i, similar_items in enumerate(
1037
+ # # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
1038
+ # # # ):
1039
+ # # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
1040
+
1041
+ # # # if similar_indices:
1042
+ # # # duplicate_indices_in_test.append(i)
1043
+ # # # duplicate_to_original_mapping[i] = similar_indices[0]
1044
+
1045
+ # # # return duplicate_indices_in_test, duplicate_to_original_mapping
1046
+
1047
+ # # # # Adjust the height of the status_output component using custom CSS
1048
+ # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1049
+ # # # gr.Markdown("# Semantic Deduplication")
1050
+
1051
+ # # # deduplication_type = gr.Radio(
1052
+ # # # choices=["Single dataset", "Cross-dataset"],
1053
+ # # # label="Deduplication Type",
1054
+ # # # value="Single dataset",
1055
+ # # # )
1056
+
1057
+ # # # with gr.Row():
1058
+ # # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
1059
+ # # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
1060
+ # # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1061
+
1062
+ # # # dataset2_inputs = gr.Column(visible=False)
1063
+ # # # with dataset2_inputs:
1064
+ # # # gr.Markdown("### Dataset 2")
1065
+ # # # with gr.Row():
1066
+ # # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
1067
+ # # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1068
+ # # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1069
+
1070
+ # # # threshold = gr.Slider(
1071
+ # # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
1072
+ # # # )
1073
+
1074
+ # # # compute_button = gr.Button("Compute")
1075
+
1076
+ # # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
1077
+ # # # status_output = gr.Markdown(elem_id="status_output")
1078
+ # # # result_output = gr.Markdown()
1079
+
1080
+ # # # # Function to update the visibility of dataset2_inputs
1081
+ # # # def update_visibility(deduplication_type_value):
1082
+ # # # if deduplication_type_value == "Cross-dataset":
1083
+ # # # return gr.update(visible=True)
1084
+ # # # else:
1085
+ # # # return gr.update(visible=False)
1086
+
1087
+ # # # deduplication_type.change(
1088
+ # # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
1089
+ # # # )
1090
+
1091
+ # # # compute_button.click(
1092
+ # # # fn=perform_deduplication,
1093
+ # # # inputs=[
1094
+ # # # deduplication_type,
1095
+ # # # dataset1_name,
1096
+ # # # dataset1_split,
1097
+ # # # dataset1_text_column,
1098
+ # # # dataset2_name,
1099
+ # # # dataset2_split,
1100
+ # # # dataset2_text_column,
1101
+ # # # threshold,
1102
+ # # # ],
1103
+ # # # outputs=[status_output, result_output],
1104
+ # # # )
1105
+
1106
+ # # # demo.launch()