Pringled commited on
Commit
c2eeff5
·
1 Parent(s): b9fcd2c
Files changed (1) hide show
  1. app.py +33 -61
app.py CHANGED
@@ -26,15 +26,19 @@ def batch_iterable(iterable, batch_size):
26
  for i in range(0, len(iterable), batch_size):
27
  yield iterable[i:i + batch_size]
28
 
29
- def log_time(message, start_time=None):
30
  """Helper function to log the start and end times."""
31
  current_time = time.time()
32
  if start_time is not None:
33
  elapsed = current_time - start_time
34
- return f"{message} - Took {elapsed:.2f} seconds"
35
- return f"{message} - Started"
36
-
37
- def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
 
 
 
38
  embeddings = []
39
  total_batches = (len(texts) + batch_size - 1) // batch_size
40
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
@@ -47,10 +51,11 @@ def deduplicate(
47
  embedding_matrix: np.ndarray,
48
  threshold: float,
49
  batch_size: int = 1024,
50
- progress=None
 
51
  ) -> tuple[np.ndarray, dict[int, int]]:
52
  # Building the index
53
- progress(0, desc="Building search index...")
54
  reach = Reach(
55
  vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
56
  )
@@ -59,7 +64,7 @@ def deduplicate(
59
  duplicate_to_original_mapping = {}
60
 
61
  # Finding nearest neighbors
62
- progress(0, desc="Finding nearest neighbors...")
63
  results = reach.nearest_neighbor_threshold(
64
  embedding_matrix,
65
  threshold=threshold,
@@ -69,6 +74,7 @@ def deduplicate(
69
 
70
  # Processing duplicates with a progress bar
71
  total_items = len(embedding_matrix)
 
72
  for i, similar_items in enumerate(
73
  progress.tqdm(results, desc="Processing duplicates", total=total_items)
74
  ):
@@ -88,8 +94,9 @@ def display_word_differences(x: str, y: str) -> str:
88
  diff = ndiff(x.split(), y.split())
89
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
90
 
91
- def encode_texts(texts, progress=None):
92
  embedding_matrix = model.encode(texts, show_progressbar=False)
 
93
  return embedding_matrix
94
 
95
  def perform_deduplication(
@@ -103,18 +110,18 @@ def perform_deduplication(
103
  threshold=default_threshold,
104
  progress=gr.Progress(track_tqdm=True),
105
  ):
 
106
  try:
107
  # Convert threshold to float
108
  threshold = float(threshold)
109
 
110
  # Initialize status message
111
- status = ""
112
 
113
  if deduplication_type == "Single dataset":
114
  # Load Dataset 1
115
  start_time = time.time()
116
- status = log_time("Loading Dataset 1")
117
- yield status, ""
118
  if (
119
  dataset1_name == default_dataset1_name
120
  and dataset1_split == default_dataset1_split
@@ -122,34 +129,27 @@ def perform_deduplication(
122
  ds = ds_default1
123
  else:
124
  ds = load_dataset(dataset1_name, split=dataset1_split)
125
- status = log_time("Loading Dataset 1 completed", start_time)
126
- yield status, ""
127
 
128
  # Extract texts
129
  start_time = time.time()
130
- status = log_time("Extracting texts from Dataset 1")
131
- yield status, ""
132
  texts = [example[dataset1_text_column] for example in ds]
133
- status = log_time("Extracting texts from Dataset 1 completed", start_time)
134
- yield status, ""
135
 
136
  # Compute embeddings
137
  start_time = time.time()
138
- status = log_time("Computing embeddings for Dataset 1")
139
- yield status, ""
140
- embedding_matrix = encode_texts(texts, progress=progress)
141
- status = log_time("Computing embeddings for Dataset 1 completed", start_time)
142
- yield status, ""
143
 
144
  # Deduplicate
145
  start_time = time.time()
146
- status = log_time("Deduplicating embeddings")
147
- yield status, ""
148
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
149
- embedding_matrix, threshold, progress=progress
150
  )
151
- status = log_time("Deduplication completed", start_time)
152
- yield status, ""
153
 
154
  # Prepare the results
155
  num_duplicates = len(duplicate_to_original_mapping)
@@ -177,41 +177,12 @@ def perform_deduplication(
177
  else:
178
  result_text += "No duplicates found."
179
 
180
- # Final status
181
- status = log_time("Deduplication process finished")
182
- yield status, result_text
183
-
184
- elif deduplication_type == "Cross-dataset":
185
- # Similar code for cross-dataset deduplication with time logging
186
- start_time = time.time()
187
- status = log_time("Loading Dataset 1")
188
- yield status, ""
189
- if (
190
- dataset1_name == default_dataset1_name
191
- and dataset1_split == default_dataset1_split
192
- ):
193
- ds1 = ds_default1
194
- else:
195
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
196
- status = log_time("Loading Dataset 1 completed", start_time)
197
- yield status, ""
198
-
199
- start_time = time.time()
200
- status = log_time("Loading Dataset 2")
201
- yield status, ""
202
- if (
203
- dataset2_name == default_dataset2_name
204
- and dataset2_split == default_dataset2_split
205
- ):
206
- ds2 = ds_default2
207
- else:
208
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
209
- status = log_time("Loading Dataset 2 completed", start_time)
210
- yield status, ""
211
-
212
- # Similar time logging for embedding computations and deduplication steps
213
 
214
  except Exception as e:
 
215
  yield f"An error occurred: {e}", ""
216
  raise e
217
 
@@ -276,6 +247,7 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
276
 
277
  demo.launch()
278
 
 
279
  # import gradio as gr
280
  # from datasets import load_dataset
281
  # import numpy as np
 
26
  for i in range(0, len(iterable), batch_size):
27
  yield iterable[i:i + batch_size]
28
 
29
+ def log_time(message, start_time=None, logs=None):
30
  """Helper function to log the start and end times."""
31
  current_time = time.time()
32
  if start_time is not None:
33
  elapsed = current_time - start_time
34
+ log_message = f"{message} - Took {elapsed:.2f} seconds"
35
+ else:
36
+ log_message = f"{message} - Started"
37
+
38
+ if logs is not None:
39
+ logs.append(log_message)
40
+
41
+ def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
42
  embeddings = []
43
  total_batches = (len(texts) + batch_size - 1) // batch_size
44
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
 
51
  embedding_matrix: np.ndarray,
52
  threshold: float,
53
  batch_size: int = 1024,
54
+ progress=None,
55
+ logs=None
56
  ) -> tuple[np.ndarray, dict[int, int]]:
57
  # Building the index
58
+ log_time("Building search index", logs=logs)
59
  reach = Reach(
60
  vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
61
  )
 
64
  duplicate_to_original_mapping = {}
65
 
66
  # Finding nearest neighbors
67
+ log_time("Finding nearest neighbors", logs=logs)
68
  results = reach.nearest_neighbor_threshold(
69
  embedding_matrix,
70
  threshold=threshold,
 
74
 
75
  # Processing duplicates with a progress bar
76
  total_items = len(embedding_matrix)
77
+ log_time("Processing duplicates", logs=logs)
78
  for i, similar_items in enumerate(
79
  progress.tqdm(results, desc="Processing duplicates", total=total_items)
80
  ):
 
94
  diff = ndiff(x.split(), y.split())
95
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
96
 
97
+ def encode_texts(texts, progress=None, logs=None):
98
  embedding_matrix = model.encode(texts, show_progressbar=False)
99
+ log_time("Encoding texts completed", logs=logs)
100
  return embedding_matrix
101
 
102
  def perform_deduplication(
 
110
  threshold=default_threshold,
111
  progress=gr.Progress(track_tqdm=True),
112
  ):
113
+ logs = [] # To store log messages
114
  try:
115
  # Convert threshold to float
116
  threshold = float(threshold)
117
 
118
  # Initialize status message
119
+ log_time("Deduplication started", logs=logs)
120
 
121
  if deduplication_type == "Single dataset":
122
  # Load Dataset 1
123
  start_time = time.time()
124
+ log_time("Loading Dataset 1", logs=logs)
 
125
  if (
126
  dataset1_name == default_dataset1_name
127
  and dataset1_split == default_dataset1_split
 
129
  ds = ds_default1
130
  else:
131
  ds = load_dataset(dataset1_name, split=dataset1_split)
132
+ log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
 
133
 
134
  # Extract texts
135
  start_time = time.time()
136
+ log_time("Extracting texts from Dataset 1", logs=logs)
 
137
  texts = [example[dataset1_text_column] for example in ds]
138
+ log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
 
139
 
140
  # Compute embeddings
141
  start_time = time.time()
142
+ log_time("Computing embeddings for Dataset 1", logs=logs)
143
+ embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
144
+ log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
 
 
145
 
146
  # Deduplicate
147
  start_time = time.time()
148
+ log_time("Deduplicating embeddings", logs=logs)
 
149
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
150
+ embedding_matrix, threshold, progress=progress, logs=logs
151
  )
152
+ log_time("Deduplication completed", start_time=start_time, logs=logs)
 
153
 
154
  # Prepare the results
155
  num_duplicates = len(duplicate_to_original_mapping)
 
177
  else:
178
  result_text += "No duplicates found."
179
 
180
+ log_time("Deduplication process finished", logs=logs)
181
+ full_log = "\n".join(logs) # Combine all logs into one output
182
+ yield full_log, result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  except Exception as e:
185
+ full_log = "\n".join(logs) # Combine all logs into one output in case of an error
186
  yield f"An error occurred: {e}", ""
187
  raise e
188
 
 
247
 
248
  demo.launch()
249
 
250
+
251
  # import gradio as gr
252
  # from datasets import load_dataset
253
  # import numpy as np