Pringled commited on
Commit
38ed48e
·
1 Parent(s): 75ff340
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
  from model2vec import StaticModel
 
5
  from reach import Reach
6
  from difflib import ndiff
7
 
@@ -24,25 +25,26 @@ ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
24
  # Patch tqdm to use Gradio's progress bar
25
  from tqdm import tqdm as original_tqdm
26
 
 
27
  # Patch tqdm to use Gradio's progress bar
28
  def patch_tqdm_for_gradio(progress):
29
  class GradioTqdm(original_tqdm):
30
  def __init__(self, *args, **kwargs):
31
  super().__init__(*args, **kwargs)
32
  self.progress = progress
33
- # Set smaller step sizes or update more frequently based on total items
34
  self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
35
- self.update_interval = max(1, self.total_batches // 100) # Update every 1% of progress
36
-
37
  def update(self, n=1):
38
  super().update(n)
39
- # Only update Gradio's progress every `update_interval` steps
40
  if self.n % self.update_interval == 0 or self.n == self.total_batches:
41
  self.progress(self.n / self.total_batches)
42
 
43
  return GradioTqdm
44
 
45
-
 
 
46
 
47
  # Function to patch the original encode function with our Gradio tqdm
48
  def original_encode_with_tqdm(original_encode_func, patched_tqdm):
@@ -153,8 +155,9 @@ def perform_deduplication(
153
  yield status, ""
154
  texts = [example[dataset1_text_column] for example in ds]
155
 
156
- patched_tqdm = patch_tqdm_for_gradio(progress)
157
- model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
 
158
  # Compute embeddings
159
  status = "Computing embeddings for Dataset 1..."
160
  yield status, ""
 
2
  from datasets import load_dataset
3
  import numpy as np
4
  from model2vec import StaticModel
5
+ import model2vec
6
  from reach import Reach
7
  from difflib import ndiff
8
 
 
25
  # Patch tqdm to use Gradio's progress bar
26
  from tqdm import tqdm as original_tqdm
27
 
28
+ # Patch tqdm to use Gradio's progress bar
29
  # Patch tqdm to use Gradio's progress bar
30
  def patch_tqdm_for_gradio(progress):
31
  class GradioTqdm(original_tqdm):
32
  def __init__(self, *args, **kwargs):
33
  super().__init__(*args, **kwargs)
34
  self.progress = progress
 
35
  self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
36
+ self.update_interval = max(1, self.total_batches // 100) # Update every 1%
37
+
38
  def update(self, n=1):
39
  super().update(n)
 
40
  if self.n % self.update_interval == 0 or self.n == self.total_batches:
41
  self.progress(self.n / self.total_batches)
42
 
43
  return GradioTqdm
44
 
45
+ def patch_model2vec_tqdm(progress):
46
+ patched_tqdm = patch_tqdm_for_gradio(progress)
47
+ model2vec.tqdm = patched_tqdm # Replace tqdm in the StaticModel's module
48
 
49
  # Function to patch the original encode function with our Gradio tqdm
50
  def original_encode_with_tqdm(original_encode_func, patched_tqdm):
 
155
  yield status, ""
156
  texts = [example[dataset1_text_column] for example in ds]
157
 
158
+ #patched_tqdm = patch_tqdm_for_gradio(progress)
159
+ patch_model2vec_tqdm(progress)
160
+ #model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
161
  # Compute embeddings
162
  status = "Computing embeddings for Dataset 1..."
163
  yield status, ""