JoeArmani commited on
Commit
2a3cfd8
·
2 Parent(s): cf97979 098eba4

Merge branch 'dev'

Browse files
.gitignore CHANGED
@@ -154,9 +154,5 @@ dmypy.json
154
  # Cython debug symbols
155
  cython_debug/
156
 
157
- # PyCharm
158
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
- # and can be added to the global gitignore or merged into this file. For a more nuclear
161
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
- #.idea/
 
154
  # Cython debug symbols
155
  cython_debug/
156
 
157
+ datasets/*
158
+ !datasets/.gitkeep
 
 
 
 
augmented_combined_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
back_translator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ MarianMTModel,
3
+ MarianTokenizer,
4
+ )
5
+
6
+ class BackTranslator:
7
+ """
8
+ Perform Back-translation with pivot language. English -> German -> Spanish -> English
9
+ Args:
10
+ source_lang: Source language (default: 'en')
11
+ pivot_lang: Pivot language (default: 'de')
12
+ target_lang: Target language (default: 'es')
13
+ Examples:
14
+ back_translator = BackTranslator()
15
+ back_translator.back_translate("Hello, how are you?")
16
+ """
17
+ def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
18
+ # Forward (English to German)
19
+ pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
20
+ self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
21
+ self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
22
+
23
+ # Pivot translation model (German to Spanish)
24
+ pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
25
+ self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
26
+ self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
27
+
28
+ # Backward (Spanish to English)
29
+ backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
30
+ self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
31
+ self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
32
+
33
+ def back_translate(self, text):
34
+ """
35
+ Perform back-translation through German and Spanish to generate text variations.
36
+ Args:
37
+ text (str): The input text to be back-translated
38
+
39
+ Returns:
40
+ str: The back-translated text
41
+ """
42
+ # 1. English to German
43
+ encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, truncation=True, return_tensors='pt')
44
+ generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
45
+ pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, skip_special_tokens=True)[0]
46
+
47
+ # 2. German to Spanish
48
+ encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, truncation=True, return_tensors='pt')
49
+ retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
50
+ tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, skip_special_tokens=True)[0]
51
+
52
+ # 3. Spanish to English
53
+ encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, truncation=True, return_tensors='pt')
54
+ retranslated = self.model_backward.generate(**encoded_back)
55
+ src_text = self.tokenizer_backward.batch_decode(retranslated, skip_special_tokens=True)[0]
56
+ return src_text
dialogue_augmenter.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import tensorflow_hub as hub
5
+ import re
6
+ from pipeline_config import PipelineConfig
7
+ from quality_metrics import QualityMetrics
8
+ from paraphraser import Paraphraser
9
+ from back_translator import BackTranslator
10
+ import nlpaug.augmenter.word as naw
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from functools import lru_cache
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ class DialogueAugmenter:
16
+ """
17
+ Optimized dialogue augmentation with quality control and complexity management.
18
+ """
19
+ def __init__(self, nlp, config: PipelineConfig):
20
+ self.nlp = nlp
21
+ self.config = config
22
+ self.quality_metrics = QualityMetrics(config)
23
+ self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
24
+
25
+ # Advanced augmentation techniques
26
+ self.paraphraser = Paraphraser()
27
+ self.back_translator = BackTranslator()
28
+
29
+ # Basic augmentation techniques
30
+ self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
31
+ self.spelling_augmenter = naw.SpellingAug()
32
+
33
+ self.augmenters = {
34
+ 'advanced': [self.paraphraser, self.back_translator],
35
+ 'basic': [
36
+ ('synonym', self.word_augmenter),
37
+ ('spelling', self.spelling_augmenter)
38
+ ]
39
+ }
40
+
41
+ # Initialize cache
42
+ self.embedding_cache = {}
43
+ self.perplexity_cache = {}
44
+
45
+ # Compile regex patterns
46
+ self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
47
+
48
+ # GPU memory management
49
+ gpus = tf.config.list_physical_devices('GPU')
50
+ if gpus:
51
+ try:
52
+ for gpu in gpus:
53
+ tf.config.experimental.set_memory_growth(gpu, True)
54
+ except RuntimeError as e:
55
+ print(e)
56
+
57
+ @lru_cache(maxsize=1024)
58
+ def _compute_embedding(self, text: str) -> np.ndarray:
59
+ """Cached computation of text embedding"""
60
+ return self.use_model([text])[0].numpy()
61
+
62
+ def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
63
+ """Compute embeddings for multiple texts at once"""
64
+ return self.use_model(texts).numpy()
65
+
66
+ def _quick_quality_check(self, variation: str, original: str) -> bool:
67
+ """
68
+ Simplified preliminary quality check with minimal standards
69
+ """
70
+ if self.config.debug:
71
+ print(f"\nQuick check for variation: {variation}")
72
+
73
+ # Only reject if length is extremely different
74
+ orig_len = len(original.split())
75
+ var_len = len(variation.split())
76
+
77
+ # For very short texts (1-3 words), allow more variation
78
+ if orig_len <= 3:
79
+ if var_len > orig_len * 4: # Allow up to 4x length for short texts
80
+ if self.config.debug:
81
+ print(f"Failed length check (short text): {var_len} vs {orig_len}")
82
+ return False
83
+ else:
84
+ if var_len > orig_len * 3: # Allow up to 3x length for longer texts
85
+ if self.config.debug:
86
+ print(f"Failed length check (long text): {var_len} vs {orig_len}")
87
+ return False
88
+
89
+ # Basic content check - at least one word in common (excluding stop words)
90
+ stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'}
91
+ orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
92
+ var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
93
+
94
+ if not orig_words.intersection(var_words):
95
+ if self.config.debug:
96
+ print("Failed content check: no content words in common")
97
+ return False
98
+
99
+ if self.config.debug:
100
+ print("Passed all quick checks")
101
+ return True
102
+
103
+ def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
104
+ """Compute quality metrics for multiple candidates in parallel"""
105
+ with ThreadPoolExecutor(max_workers=4) as executor:
106
+ futures = [
107
+ executor.submit(self.quality_metrics.compute_metrics, original, candidate)
108
+ for candidate in candidates
109
+ ]
110
+ return [future.result() for future in futures]
111
+
112
+ def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
113
+ """
114
+ Filter variations using batched computations with detailed logging
115
+ """
116
+ if not variations:
117
+ return []
118
+
119
+ if self.config.debug:
120
+ print(f"\nStarting filtration of {len(variations)} variations")
121
+ print(f"Context length: {len(context)}")
122
+ print(f"Original turn: {original_turn}")
123
+
124
+ words = original_turn.split()
125
+ if len(words) < 3:
126
+ if self.config.debug:
127
+ print("Short text detected, using predefined variations")
128
+ short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
129
+ return [var['text'] for var in short_text_variations]
130
+
131
+ # If this is the first turn (no context), be more lenient
132
+ if not context:
133
+ preliminary_filtered = variations
134
+ if self.config.debug:
135
+ print("First turn - skipping preliminary filtering")
136
+ else:
137
+ # Quick preliminary filtering against original turn
138
+ preliminary_filtered = []
139
+ for var in variations:
140
+ passed = self._quick_quality_check(var, original_turn)
141
+ if self.config.debug:
142
+ print(f"\nVariation: {var}")
143
+ print(f"Passed quick check: {passed}")
144
+ if passed:
145
+ preliminary_filtered.append(var)
146
+
147
+ if self.config.debug:
148
+ print(f"Variations after quick check: {len(preliminary_filtered)}")
149
+
150
+ if not preliminary_filtered:
151
+ return []
152
+
153
+ # Only use last turn for coherence
154
+ recent_context = [context[-1]] if context else []
155
+ context_text = ' '.join(recent_context) if recent_context else ''
156
+
157
+ # Even more lenient thresholds
158
+ min_similarity = 0.1 # Further reduced
159
+ min_coherence = 0.05 # Further reduced
160
+
161
+ if context_text:
162
+ if self.config.debug:
163
+ print(f"\nContext text: {context_text}")
164
+
165
+ all_texts = [context_text] + preliminary_filtered
166
+ all_embeddings = self._compute_batch_embeddings(all_texts)
167
+
168
+ context_embedding = all_embeddings[0]
169
+ variation_embeddings = all_embeddings[1:]
170
+
171
+ # Vectorized similarity computation
172
+ context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
173
+
174
+ # Response coherence check
175
+ if recent_context:
176
+ prev_embedding = self._compute_embedding(recent_context[-1])
177
+ response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
178
+ else:
179
+ response_coherence = np.ones_like(context_similarities)
180
+
181
+ # Combined scoring with detailed logging
182
+ filtered_variations = []
183
+ for i, (variation, sim, coh) in enumerate(zip(
184
+ preliminary_filtered, context_similarities, response_coherence)):
185
+ # Use absolute values for scoring
186
+ combined_score = (
187
+ self.config.context_similarity_weight * abs(sim) +
188
+ self.config.response_coherence_weight * abs(coh)
189
+ )
190
+
191
+ if self.config.debug:
192
+ print(f"\nVariation: {variation}")
193
+ print(f"Context similarity: {sim:.3f}")
194
+ print(f"Response coherence: {coh:.3f}")
195
+ print(f"Combined score: {combined_score:.3f}")
196
+
197
+ # Accept if EITHER score is good enough
198
+ if (combined_score >= min_similarity or abs(coh) >= min_coherence):
199
+ filtered_variations.append(variation)
200
+ if self.config.debug:
201
+ print("ACCEPTED")
202
+ else:
203
+ if self.config.debug:
204
+ print("REJECTED")
205
+
206
+ # If we have enough variations, stop
207
+ if len(filtered_variations) >= self.config.max_variations_per_turn:
208
+ break
209
+ else:
210
+ filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn]
211
+
212
+ if self.config.debug:
213
+ print(f"\nFinal filtered variations: {len(filtered_variations)}")
214
+
215
+ return filtered_variations
216
+
217
+ def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
218
+ """
219
+ Generate variations progressively until we have enough good ones
220
+ """
221
+ variations = set()
222
+
223
+ if self.config.debug:
224
+ print(f"\nAttempting to generate {needed} variations for text: {text}")
225
+
226
+ # Try advanced augmenters first
227
+ for augmenter in self.augmenters['advanced']:
228
+ if len(variations) >= needed:
229
+ break
230
+
231
+ try:
232
+ if isinstance(augmenter, Paraphraser):
233
+ if self.config.debug:
234
+ print("Trying paraphrase augmentation...")
235
+ new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations))
236
+ if self.config.debug:
237
+ print(f"Paraphraser generated {len(new_vars)} variations")
238
+ else:
239
+ if self.config.debug:
240
+ print("Trying back translation...")
241
+ new_vars = [augmenter.back_translate(text)]
242
+ if self.config.debug:
243
+ print(f"Back translator generated {len(new_vars)} variations")
244
+
245
+ valid_vars = [v for v in new_vars if v.strip() and v != text]
246
+ variations.update(valid_vars)
247
+
248
+ if self.config.debug:
249
+ print(f"Current unique variations: {len(variations)}")
250
+
251
+ except Exception as e:
252
+ print(f"Error in advanced augmentation: {str(e)}")
253
+ continue
254
+
255
+ # Try basic augmenters if needed
256
+ if len(variations) < needed:
257
+ if self.config.debug:
258
+ print("Not enough variations, trying basic augmenters...")
259
+
260
+ for aug_type, augmenter in self.augmenters['basic']:
261
+ if len(variations) >= needed:
262
+ break
263
+
264
+ try:
265
+ if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
266
+ if self.config.debug:
267
+ print("Skipping spelling augmentation for technical text")
268
+ continue
269
+
270
+ if self.config.debug:
271
+ print(f"Trying {aug_type} augmentation...")
272
+
273
+ new_vars = augmenter.augment(text, n=2)
274
+ if isinstance(new_vars, list):
275
+ valid_vars = [v for v in new_vars if v.strip() and v != text]
276
+ variations.update(valid_vars)
277
+ else:
278
+ if new_vars.strip() and new_vars != text:
279
+ variations.add(new_vars)
280
+
281
+ if self.config.debug:
282
+ print(f"After {aug_type}, total variations: {len(variations)}")
283
+
284
+ except Exception as e:
285
+ print(f"Error in {aug_type} augmentation: {str(e)}")
286
+ continue
287
+
288
+ variations_list = list(variations)
289
+
290
+ if self.config.debug:
291
+ print(f"Final number of variations generated: {len(variations_list)}")
292
+ if not variations_list:
293
+ print("WARNING: No variations were generated!")
294
+
295
+ return variations_list
296
+
297
+ def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
298
+ """
299
+ Create augmented versions of the dialogue with optimized processing
300
+ """
301
+ # Early dialogue length check
302
+ original_length = len(dialogue['turns'])
303
+ if original_length > self.config.max_turns_per_dialogue:
304
+ if self.config.debug:
305
+ print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
306
+ dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
307
+
308
+ turn_variations = []
309
+ context = []
310
+
311
+ # Process each turn with progressive generation
312
+ for turn in dialogue['turns']:
313
+ original_text = turn['text'] # Store original turn text
314
+ variations = self._generate_variations_progressive(
315
+ original_text,
316
+ self.config.max_variations_per_turn
317
+ )
318
+
319
+ # Batch filter variations with original text
320
+ filtered_variations = self._filter_variations_batch(
321
+ variations,
322
+ context,
323
+ original_text # Pass the original turn text
324
+ )
325
+
326
+ # Create turn variations with speaker info
327
+ turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
328
+
329
+ if self.config.debug:
330
+ print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
331
+
332
+ turn_variations.append(turn_vars)
333
+ context.append(original_text)
334
+
335
+ # Generate combinations with sampling
336
+ augmented_dialogues = self._generate_dialogue_combinations(
337
+ dialogue['dialogue_id'],
338
+ turn_variations
339
+ )
340
+
341
+ # Add original dialogue
342
+ result = [{
343
+ 'dialogue_id': f"{dialogue['dialogue_id']}_original",
344
+ 'turns': dialogue['turns']
345
+ }]
346
+
347
+ # Add unique augmentations
348
+ result.extend(augmented_dialogues[:self.config.augmentation_factor])
349
+
350
+ if self.config.debug:
351
+ print(f"Generated {len(result)-1} unique augmented dialogues")
352
+
353
+ return result
354
+
355
+ def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
356
+ """
357
+ Generate dialogue combinations using sampling
358
+ """
359
+ augmented_dialogues = []
360
+ used_combinations = set()
361
+
362
+ def generate_dialogues(current_turns=None, turn_index=0):
363
+ if current_turns is None:
364
+ current_turns = []
365
+
366
+ if len(augmented_dialogues) >= self.config.augmentation_factor:
367
+ return
368
+
369
+ if turn_index == len(turn_variations):
370
+ dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
371
+ if dialogue_fingerprint not in used_combinations:
372
+ used_combinations.add(dialogue_fingerprint)
373
+ augmented_dialogues.append({
374
+ 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
375
+ 'turns': current_turns.copy()
376
+ })
377
+ return
378
+
379
+ variations = list(turn_variations[turn_index])
380
+ np.random.shuffle(variations)
381
+
382
+ for variation in variations[:self.config.max_sampled_variations]:
383
+ if len(augmented_dialogues) >= self.config.augmentation_factor:
384
+ return
385
+ current_turns.append(variation)
386
+ generate_dialogues(current_turns, turn_index + 1)
387
+ current_turns.pop()
388
+
389
+ try:
390
+ generate_dialogues()
391
+ except Exception as e:
392
+ print(f"Error in dialogue generation: {str(e)}")
393
+ return []
394
+
395
+ return augmented_dialogues
396
+
397
+ def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
398
+ """
399
+ Check if two dialogues are duplicates.
400
+ """
401
+ text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
402
+ text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
403
+ return text1 == text2
404
+
405
+ # def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]:
406
+ # """
407
+ # Generate augmented versions of the turn using multiple strategies.
408
+ # """
409
+ # text = turn['text']
410
+ # words = text.split()
411
+
412
+ # # Special handling for very short texts
413
+ # if len(words) < 3:
414
+ # return self._augment_short_text(turn)
415
+
416
+ # all_variations = set()
417
+
418
+ # # Advanced augmentations (paraphrase and back-translation)
419
+ # for augmenter in self.augmenters['advanced']:
420
+ # try:
421
+ # if isinstance(augmenter, Paraphraser):
422
+ # variations = augmenter.paraphrase(text)
423
+ # all_variations.update(variations)
424
+ # elif isinstance(augmenter, BackTranslator):
425
+ # aug_text = augmenter.back_translate(text)
426
+ # if aug_text:
427
+ # all_variations.add(aug_text)
428
+ # except Exception as e:
429
+ # print(f"Error in advanced augmentation: {str(e)}")
430
+ # continue
431
+
432
+ # # Basic nlpaug augmentations
433
+ # for aug_type, augmenter in self.augmenters['basic']:
434
+ # try:
435
+ # if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
436
+ # continue
437
+
438
+ # aug_texts = augmenter.augment(text, n=2)
439
+ # if isinstance(aug_texts, list):
440
+ # all_variations.update(aug_texts)
441
+ # else:
442
+ # all_variations.add(aug_texts)
443
+ # except Exception as e:
444
+ # print(f"Error in {aug_type} augmentation: {str(e)}")
445
+ # continue
446
+
447
+ # # Remove exact duplicates and empty strings
448
+ # augmented_texts = [t for t in list(all_variations) if t.strip()]
449
+
450
+ # # Apply context filtering
451
+ # if context:
452
+ # augmented_texts = self._filter_by_context(augmented_texts, context)
453
+ # print(f"After context filtering: {len(augmented_texts)} variations")
454
+
455
+ # # Select best variations
456
+ # best_variations = self._select_best_augmentations(
457
+ # text,
458
+ # augmented_texts,
459
+ # num_to_select=self.config.augmentation_factor,
460
+ # min_quality_score=0.7
461
+ # )
462
+
463
+ # # Create variations with speaker info
464
+ # variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations]
465
+
466
+ # return variations
467
+
468
+ def _augment_short_text(self, turn: Dict) -> List[Dict]:
469
+ """
470
+ Special handling for very short texts with predefined variations.
471
+ Args:
472
+ turn (Dict): Original dialogue turn
473
+
474
+ Returns:
475
+ List[Dict]: List of variations for the short text
476
+ """
477
+ text = turn['text']
478
+ common_variations = {
479
+ 'goodbye': [
480
+ 'Bye!', 'Farewell!', 'See you!', 'Take care!',
481
+ 'Goodbye!', 'Bye for now!', 'Until next time!'
482
+ ],
483
+ 'hello': [
484
+ 'Hi!', 'Hey!', 'Hello!', 'Greetings!',
485
+ 'Good day!', 'Hi there!', 'Hello there!'
486
+ ],
487
+ 'yes': [
488
+ 'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
489
+ 'That\'s right!', 'Definitely!', 'Sure!'
490
+ ],
491
+ 'no': [
492
+ 'No!', 'Nope!', 'Not at all!', 'Negative!',
493
+ 'Unfortunately not!', 'I\'m afraid not!'
494
+ ],
495
+ 'thanks': [
496
+ 'Thank you!', 'Thanks a lot!', 'Many thanks!',
497
+ 'I appreciate it!', 'Thank you so much!'
498
+ ],
499
+ 'ok': [
500
+ 'Okay!', 'Alright!', 'Sure!', 'Got it!',
501
+ 'Understood!', 'Fine!', 'Great!', 'Perfect!',
502
+ 'That works!', 'Sounds good!'
503
+ ],
504
+ 'good': [
505
+ 'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
506
+ 'Fantastic!', 'Amazing!', 'Terrific!'
507
+ ]
508
+ }
509
+
510
+ # Try to find matching variations
511
+ text_lower = text.lower().rstrip('!.,?')
512
+ variations = []
513
+
514
+ # Check if text matches any of our predefined categories
515
+ for key, predefined_vars in common_variations.items():
516
+ if key in text_lower or text_lower in key:
517
+ variations.extend(predefined_vars)
518
+
519
+ # If no predefined variations found, generate simple variants
520
+ if not variations:
521
+ # Add punctuation variations
522
+ variations = [
523
+ text.rstrip('!.,?') + '!',
524
+ text.rstrip('!.,?') + '.',
525
+ text.rstrip('!.,?')
526
+ ]
527
+
528
+ # Add capitalization variations
529
+ variations.extend([
530
+ v.capitalize() for v in variations
531
+ if v.capitalize() not in variations
532
+ ])
533
+
534
+ # Filter variations for uniqueness and quality
535
+ unique_variations = list(set(variations))
536
+ quality_variations = []
537
+
538
+ for var in unique_variations:
539
+ metrics = self.quality_metrics.compute_metrics(text, var)
540
+ quality_score = (
541
+ 0.35 * metrics['semantic_similarity'] +
542
+ 0.30 * (1.0 - metrics['perplexity'] / 100) +
543
+ 0.15 * (1.0 - metrics['grammar_errors'] / 10) +
544
+ 0.15 * metrics['content_preservation'] +
545
+ 0.10 * metrics['type_token_ratio']
546
+ )
547
+
548
+ # More lenient quality threshold for short texts
549
+ if quality_score >= 0.5: # Lower threshold for short texts
550
+ quality_variations.append(var)
551
+
552
+ # Ensure we have at least some variations
553
+ if not quality_variations:
554
+ quality_variations = [text]
555
+
556
+ # Return the variations with original speaker
557
+ return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]]
558
+
559
+ def _is_technical_or_formal_text(self, text: str) -> bool:
560
+ """
561
+ Check if text is formal/technical and shouldn't have spelling variations.
562
+ """
563
+ formal_indicators = {
564
+ 'technical_terms': {'api', 'config', 'database', 'server', 'system'},
565
+ 'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'},
566
+ 'professional_context': {'meeting', 'conference', 'project', 'deadline'}
567
+ }
568
+
569
+ text_lower = text.lower()
570
+ words = set(text_lower.split())
571
+
572
+ for category in formal_indicators.values():
573
+ if words.intersection(category):
574
+ return True
575
+
576
+ return False
577
+
578
+ # def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]:
579
+ # """
580
+ # Filter variations based on conversation context using config parameters.
581
+ # """
582
+ # # Manage context window using config
583
+ # recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context
584
+
585
+ # filtered_variations = []
586
+ # context_embedding = self.use_model([' '.join(recent_context)])[0].numpy()
587
+
588
+ # prev_turn = recent_context[-1] if recent_context else ''
589
+
590
+ # for variation in variations:
591
+ # var_embedding = self.use_model([variation])[0].numpy()
592
+
593
+ # # Overall context similarity
594
+ # context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0]
595
+
596
+ # # Direct response coherence
597
+ # response_coherence = 1.0
598
+ # if prev_turn:
599
+ # prev_embedding = self.use_model([prev_turn])[0].numpy()
600
+ # response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0]
601
+
602
+ # # Use weights from config
603
+ # combined_similarity = (
604
+ # self.config.context_similarity_weight * context_similarity +
605
+ # self.config.response_coherence_weight * response_coherence
606
+ # )
607
+
608
+ # if (combined_similarity >= self.config.semantic_similarity_threshold and
609
+ # response_coherence >= self.config.min_response_coherence):
610
+ # filtered_variations.append(variation)
611
+ # if self.config.debug:
612
+ # print(f"Accepted variation: {variation}")
613
+ # print(f"Context similarity: {context_similarity:.3f}")
614
+ # print(f"Response coherence: {response_coherence:.3f}")
615
+ # print(f"Combined score: {combined_similarity:.3f}\n")
616
+ # else:
617
+ # if self.config.debug:
618
+ # print(f"Rejected variation: {variation}")
619
+ # print(f"Combined score {combined_similarity:.3f} below threshold "
620
+ # f"{self.config.semantic_similarity_threshold}")
621
+ # print(f"Response coherence {response_coherence:.3f} below threshold "
622
+ # f"{self.config.min_response_coherence}\n")
623
+
624
+ # return filtered_variations or variations # Fallback to original
625
+
626
+ # def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None,
627
+ # num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]:
628
+ # """
629
+ # Select the best augmentations using a quality score.
630
+ # Args:
631
+ # original (str): The original text
632
+ # candidates (List[str]): List of candidate augmented texts
633
+ # used_variations (set): Set of already used variations
634
+ # num_to_select (int): Number of variations to select
635
+ # min_quality_score (float): Minimum quality score threshold
636
+ # """
637
+ # if used_variations is None:
638
+ # used_variations = set()
639
+
640
+ # candidates = [c for c in candidates if c.strip()]
641
+
642
+ # # Skip short text
643
+ # if len(original.split()) < 3:
644
+ # print(f"Text too short for augmentation: {original}")
645
+ # return [original]
646
+
647
+ # scored_candidates = []
648
+ # for candidate in candidates:
649
+ # if candidate in used_variations:
650
+ # continue
651
+
652
+ # metrics = self.quality_metrics.compute_metrics(original, candidate)
653
+
654
+ # # Add contextual penalty for inappropriate audience terms
655
+ # audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'}
656
+ # has_audience_term = any(term in candidate.lower() for term in audience_terms)
657
+ # audience_penalty = 0.2 if has_audience_term else 0.0
658
+
659
+ # # Weighted quality score
660
+ # quality_score = (
661
+ # 0.40 * metrics['semantic_similarity'] + # Semantic preservation
662
+ # 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency
663
+ # 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar
664
+ # 0.15 * metrics['content_preservation'] + # Content preservation
665
+ # 0.05 * metrics['type_token_ratio'] # Lexical diversity
666
+ # )
667
+
668
+ # quality_score -= audience_penalty
669
+
670
+ # if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss
671
+ # metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap
672
+ # continue
673
+
674
+ # # Bonus points for:
675
+ # # Length similarity to original
676
+ # if 0.75 <= metrics['length_ratio'] <= 1.25:
677
+ # quality_score += 0.05
678
+
679
+ # # Correct grammar
680
+ # if metrics['grammar_errors'] == 0:
681
+ # quality_score += 0.025
682
+
683
+ # print(f"Candidate: {candidate}")
684
+ # print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}")
685
+
686
+ # # Consider the augmentationif meets basic quality threshold
687
+ # if quality_score >= min_quality_score:
688
+ # print('Candidate accepted\n')
689
+ # scored_candidates.append((candidate, quality_score, metrics))
690
+ # else:
691
+ # print('Candidate rejected\n')
692
+
693
+ # # Sort by quality score with small random factor for diversity
694
+ # scored_candidates.sort(key=lambda x: x[1], reverse=True)
695
+
696
+ # selected = []
697
+ # for candidate, score, metrics in scored_candidates:
698
+ # # Check diversity against already selected
699
+ # if len(selected) == 0:
700
+ # selected.append(candidate)
701
+ # continue
702
+
703
+ # # Compute average similarity to already selected
704
+ # avg_similarity = np.mean([
705
+ # self.quality_metrics.compute_semantic_similarity(candidate, prev)
706
+ # for prev in selected
707
+ # ])
708
+
709
+ # # Add if sufficiently different (similarity < 0.98)
710
+ # if avg_similarity < 0.98:
711
+ # selected.append(candidate)
712
+
713
+ # if len(selected) >= num_to_select:
714
+ # break
715
+
716
+ # return selected
main.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CSC525 - Module 8 Option 2 - Joseph Armani
3
+ Description and References in the README.md file.
4
+ """
5
+ import json
6
+ import tensorflow as tf
7
+ from typing import List, Dict
8
+ from pipeline_config import PipelineConfig
9
+ from processing_pipeline import ProcessingPipeline
10
+ from taskmaster_processor import TaskmasterProcessor
11
+ from schema_guided_dialogue_processor import SchemaGuidedProcessor
12
+
13
+ def combine_datasets(taskmaster_dialogues: List[Dict],
14
+ schema_guided_dialogues: List[Dict]) -> List[Dict]:
15
+ """
16
+ Combine dialogues from both datasets into a single list
17
+
18
+ Args:
19
+ taskmaster_dialogues: List of dialogues in pipeline format from Taskmaster
20
+ schema_guided_dialogues: List of dialogues in pipeline format from Schema-Guided
21
+
22
+ Returns:
23
+ List[Dict]: Combined list of dialogues
24
+ """
25
+ # Ensure unique dialogue IDs
26
+ combined_dialogues = []
27
+ seen_ids = set()
28
+ duplicate_count = 0 # Track duplicates for reporting
29
+
30
+ for dialogue in taskmaster_dialogues:
31
+ dialogue_copy = dialogue.copy()
32
+ dialogue_id = dialogue_copy['dialogue_id']
33
+ if dialogue_id in seen_ids:
34
+ duplicate_count += 1
35
+ dialogue_id = f"taskmaster_{dialogue_id}"
36
+ seen_ids.add(dialogue_id)
37
+ dialogue_copy['dialogue_id'] = dialogue_id
38
+ combined_dialogues.append(dialogue_copy)
39
+
40
+ for dialogue in schema_guided_dialogues:
41
+ dialogue_copy = dialogue.copy()
42
+ dialogue_id = dialogue_copy['dialogue_id']
43
+ if dialogue_id in seen_ids:
44
+ duplicate_count += 1
45
+ dialogue_id = f"schema_guided_{dialogue_id}"
46
+ seen_ids.add(dialogue_id)
47
+ dialogue_copy['dialogue_id'] = dialogue_id
48
+ combined_dialogues.append(dialogue_copy)
49
+
50
+ # Log the results
51
+ print(f"Combine Datasets: Found and resolved {duplicate_count} duplicate dialogue IDs.")
52
+ print(f"Combine Datasets: Total dialogues combined: {len(combined_dialogues)}")
53
+
54
+ return combined_dialogues
55
+
56
+ def main():
57
+ # Configuration
58
+ config = PipelineConfig(
59
+ min_length=1,
60
+ max_length=512,
61
+ batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
62
+ max_turns_per_dialogue=6,
63
+ max_variations_per_turn=3,
64
+ max_sampled_variations=2,
65
+ context_window_size=4,
66
+ max_complexity_threshold=100,
67
+ use_cache=False,
68
+ debug=True,
69
+ allowed_speakers=['user', 'assistant'],
70
+ required_fields=['dialogue_id', 'turns']
71
+ )
72
+
73
+ try:
74
+ # Set max_examples (Optional[int]) for testing
75
+ max_examples = None
76
+
77
+ # Initialize and load Taskmaster dataset
78
+ print("Loading Taskmaster dataset")
79
+ taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
80
+ taskmaster_dir = './datasets/taskmaster'
81
+ taskmaster_dialogues = taskmaster_processor.load_dataset(taskmaster_dir, max_examples=max_examples)
82
+ taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
83
+ print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
84
+
85
+ # Initialize and load Schema-Guided dataset
86
+ print("Loading Schema-Guided dataset")
87
+ schema_dialogue_processor = SchemaGuidedProcessor(config)
88
+ schema_guided_dir = './datasets/schema_guided'
89
+ schema_dialogues = schema_dialogue_processor.load_dataset(schema_guided_dir, max_examples=max_examples)
90
+ schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
91
+ print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
92
+
93
+ # Combine datasets
94
+ print("Combining datasets")
95
+ combined_dialogues = combine_datasets(taskmaster_pipeline_dialogues, schema_pipeline_dialogues)
96
+ print(f"Combined Dialogues: {len(combined_dialogues)}")
97
+
98
+ if not combined_dialogues:
99
+ print("Combined dialogues are empty. Exiting.")
100
+ return
101
+
102
+ # Process through augmentation pipeline
103
+ print("Processing combined dataset")
104
+ pipeline = ProcessingPipeline(config)
105
+ processed_dialogues = pipeline.process_dataset(combined_dialogues)
106
+
107
+ # Save results
108
+ output_path = 'augmented_combined_dataset.json'
109
+ with open(output_path, 'w', encoding='utf-8') as f:
110
+ json.dump(processed_dialogues, f, indent=2, ensure_ascii=False)
111
+
112
+ # Print statistics
113
+ print(f"\nProcessed Statistics:")
114
+ print(f"Total dialogues: {len(processed_dialogues)}")
115
+ print(f"Taskmaster domains: {len(taskmaster_processor.domains)}")
116
+ print(f"Schema-Guided services: {len(schema_dialogue_processor.services)}")
117
+ print(f"Schema-Guided domains: {len(schema_dialogue_processor.domains)}")
118
+
119
+ except Exception as e:
120
+ print(f"Processing failed: {str(e)}")
121
+ raise
122
+
123
+ if __name__ == "__main__":
124
+ main()
paraphraser.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoTokenizer,
3
+ AutoModelForSeq2SeqLM,
4
+ )
5
+
6
+ class Paraphraser:
7
+ def __init__(self, model_name='humarin/chatgpt_paraphraser_on_T5_base'):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+ self.model.eval()
11
+
12
+ def paraphrase(self, text, num_return_sequences=5, num_beams=10, num_beam_groups=5, diversity_penalty=0.8):
13
+ try:
14
+ input_text = "paraphrase: " + text + " </s>"
15
+ encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
16
+ input_ids = encoding["input_ids"]
17
+
18
+ outputs = self.model.generate(
19
+ input_ids=input_ids,
20
+ max_length=256,
21
+ num_beams=num_beams,
22
+ num_beam_groups=num_beam_groups,
23
+ num_return_sequences=num_return_sequences,
24
+ diversity_penalty=diversity_penalty,
25
+ early_stopping=True
26
+ )
27
+ paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
28
+ return paraphrases
29
+ except Exception as e:
30
+ print(f"Error in paraphrasing: {e}")
31
+ return []
pipeline_config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ @dataclass
5
+ class PipelineConfig:
6
+ """
7
+ Config for the pipeline
8
+ """
9
+ # Validation settings
10
+ min_length: int = 1
11
+ max_length: int = 512
12
+ min_tokens: int = 1
13
+ max_tokens: int = 128
14
+
15
+ allowed_speakers: List[str] = None
16
+ required_fields: List[str] = None
17
+
18
+ # Text augmentation settings
19
+ augmentation_factor: int = 4
20
+ augmentation_techniques: List[str] = None
21
+
22
+ max_turns_per_dialogue: int = 6
23
+ max_variations_per_turn: int = 3
24
+ max_sampled_variations: int = 2
25
+ max_complexity_threshold: int = 100
26
+ complexity_reduction_turns: int = 4
27
+
28
+ # Quality thresholds
29
+ semantic_similarity_threshold: float = 0.45
30
+ grammar_error_threshold: int = 2
31
+ rouge1_f1_threshold: float = 0.30
32
+ rouge2_f1_threshold: float = 0.15
33
+ perplexity_threshold: float = 50.0
34
+
35
+ # Response coherence thresholds
36
+ min_response_coherence: float = 0.3
37
+ context_similarity_weight: float = 0.35
38
+ response_coherence_weight: float = 0.65
39
+
40
+ # Performance settings
41
+ batch_size: int = 32
42
+ use_cache: bool = True
43
+ debug: bool = False
44
+
45
+ context_window_size: int = 4
46
+
47
+ def __post_init__(self):
48
+ if self.allowed_speakers is None:
49
+ self.allowed_speakers = ['user', 'assistant']
50
+ if self.required_fields is None:
51
+ self.required_fields = ['dialogue_id', 'turns']
52
+ if self.augmentation_techniques is None:
53
+ self.augmentation_techniques = ['paraphrase', 'back_translation']
54
+
55
+ # Validate weights sum to 1.0
56
+ if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
57
+ raise ValueError("Context similarity and response coherence weights must sum to 1.0")
58
+
processing_pipeline.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from pathlib import Path
3
+ from typing import List, Dict, Optional
4
+ import json
5
+ import re
6
+ import hashlib
7
+ import pickle
8
+ import spacy
9
+ from tqdm import tqdm
10
+ from pipeline_config import PipelineConfig
11
+ from dialogue_augmenter import DialogueAugmenter
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ class ProcessingPipeline:
16
+ """
17
+ Complete pipeline combining validation, optimization, and augmentation.
18
+ """
19
+
20
+ def __init__(self, config: Optional[PipelineConfig] = None):
21
+ self.config = config or PipelineConfig()
22
+ self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])
23
+ self.augmenter = DialogueAugmenter(self.nlp, self.config)
24
+ self.num_threads = self.config.batch_size
25
+ self.cache_dir = Path("./cache")
26
+ self.cache_dir.mkdir(exist_ok=True)
27
+
28
+ def process_dataset(self, dialogues: List[Dict]) -> List[Dict]:
29
+ """
30
+ Process entire dataset through the pipeline.
31
+ """
32
+ print(f"Processing {len(dialogues)} dialogues")
33
+ start_time = datetime.now()
34
+
35
+ # Check cache
36
+ if self.config.use_cache:
37
+ cache_path = self._get_cache_path(dialogues)
38
+ if cache_path.exists():
39
+ print("Loading from cache...")
40
+ with open(cache_path, 'rb') as f:
41
+ return pickle.load(f)
42
+
43
+ # Validate and clean
44
+ valid_dialogues = self._process_validation(
45
+ dialogues,
46
+ self._validate_and_clean_dialogue,
47
+ "validating and cleaning"
48
+ )
49
+
50
+ if not valid_dialogues:
51
+ raise ValueError("Dialogue validation resulted in an empty dataset.")
52
+
53
+ deduplicated_dialogues = self._deduplicate_dialogues(valid_dialogues)
54
+
55
+ # Augment dialogues
56
+ all_processed_dialogues = []
57
+ for dialogue in deduplicated_dialogues:
58
+ augmented = self.augmenter.augment_dialogue(dialogue)
59
+ all_processed_dialogues.extend(augmented)
60
+
61
+ # Save to cache
62
+ if self.config.use_cache:
63
+ with open(cache_path, 'wb') as f:
64
+ pickle.dump(all_processed_dialogues, f)
65
+
66
+ processing_time = datetime.now() - start_time
67
+ print(f"Processing completed in {processing_time}")
68
+ print(f"Generated {len(all_processed_dialogues)} total dialogues")
69
+
70
+ return all_processed_dialogues
71
+
72
+ def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
73
+ """
74
+ Deduplicate dialogues based on text similarity.
75
+ """
76
+ print("Deduplicating dialogues...")
77
+ if not dialogues:
78
+ print("No dialogues provided for deduplication.")
79
+ return []
80
+
81
+ # Combine turns into single text for similarity comparison
82
+ texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues]
83
+ tfidf = TfidfVectorizer().fit_transform(texts)
84
+ sim_matrix = cosine_similarity(tfidf)
85
+
86
+ unique_indices = set()
87
+ for i, row in enumerate(sim_matrix):
88
+ if i not in unique_indices:
89
+ similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i]
90
+ unique_indices.add(i)
91
+ unique_indices.difference_update(similar_indices)
92
+
93
+ deduplicated_dialogues = [dialogues[i] for i in unique_indices]
94
+
95
+ print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.")
96
+ return deduplicated_dialogues
97
+
98
+ def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]:
99
+ """
100
+ Validate and clean a single dialogue.
101
+ """
102
+ try:
103
+ # Check required fields
104
+ if not all(field in dialogue for field in self.config.required_fields):
105
+ return None
106
+
107
+ # Process turns
108
+ cleaned_turns = []
109
+ for turn in dialogue['turns']:
110
+ if self._validate_turn(turn):
111
+ cleaned_turn = {
112
+ 'speaker': turn['speaker'],
113
+ 'text': self._clean_text(turn['text'])
114
+ }
115
+ cleaned_turns.append(cleaned_turn)
116
+
117
+ if cleaned_turns:
118
+ return {
119
+ 'dialogue_id': dialogue['dialogue_id'],
120
+ 'turns': cleaned_turns
121
+ }
122
+
123
+ return None
124
+
125
+ except Exception as e:
126
+ print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
127
+ return None
128
+
129
+ def _validate_turn(self, turn: Dict) -> bool:
130
+ """
131
+ Validate a single speaking turn.
132
+ """
133
+ return (
134
+ turn['speaker'] in self.config.allowed_speakers and
135
+ self.config.min_length <= len(turn['text']) <= self.config.max_length
136
+ )
137
+
138
+ def _clean_text(self, text: str) -> str:
139
+ """
140
+ Clean and normalize text.
141
+ """
142
+ # Remove excessive whitespace
143
+ text = re.sub(r'\s+', ' ', text.strip())
144
+
145
+ # Normalize quotes and apostrophes
146
+ text = re.sub(r'[’´`]', "'", text)
147
+ text = re.sub(r'[“”]', '"', text)
148
+
149
+ # Remove control characters
150
+ text = "".join(char for char in text if ord(char) >= 32 or char == '\n')
151
+
152
+ return text
153
+
154
+ def _process_validation(self, items: List, func, description: str) -> List:
155
+ """
156
+ Process items sequentially with a progress bar.
157
+ """
158
+ results = []
159
+ print(f"Starting {description}")
160
+ for item in tqdm(items, desc=description):
161
+ try:
162
+ result = func(item)
163
+ if result is not None:
164
+ results.append(result)
165
+ except Exception as e:
166
+ print(f"Error processing item: {str(e)}")
167
+ print(f"Completed {description}. Processed {len(results)} items successfully")
168
+ return results
169
+
170
+ def _get_cache_path(self, data: List[Dict]) -> Path:
171
+ """
172
+ Generate cache file path based on data hash.
173
+ """
174
+ data_str = json.dumps(data, sort_keys=True)
175
+ hash_value = hashlib.md5(data_str.encode()).hexdigest()
176
+ return self.cache_dir / f"cache_{hash_value}.pkl"
quality_metrics.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ import tensorflow_hub as hub
4
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
5
+ import language_tool_python
6
+ from rouge_score import rouge_scorer
7
+ import spacy
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import numpy as np
10
+ from typing import Dict
11
+ from pipeline_config import PipelineConfig
12
+
13
+ class QualityMetrics:
14
+ """
15
+ Measure augmented text quality
16
+ """
17
+ def __init__(self, config: PipelineConfig):
18
+ self.config = config
19
+
20
+ # Semantic similarity
21
+ self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
22
+
23
+ # Fluency metrics
24
+ self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
25
+ self.model = GPT2LMHeadModel.from_pretrained('gpt2')
26
+ self.model.eval()
27
+
28
+ # Grammar
29
+ self.language_tool = language_tool_python.LanguageTool('en-US')
30
+
31
+ # Lexical similarity
32
+ self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
33
+
34
+ # Diversity
35
+ self.nlp = spacy.load('en_core_web_sm')
36
+
37
+ def compute_perplexity(self, text):
38
+ try:
39
+ encodings = self.tokenizer(text, return_tensors='pt')
40
+ input_ids = encodings['input_ids']
41
+ with torch.no_grad():
42
+ outputs = self.model(input_ids, labels=input_ids)
43
+ loss = outputs.loss
44
+ perplexity = torch.exp(loss)
45
+ return perplexity.item()
46
+ except Exception as e:
47
+ print(f"Error computing perplexity for text '{text}': {e}")
48
+ return float('inf') # High perplexity value == poor quality
49
+
50
+ def compute_semantic_similarity(self, text1: str, text2: str) -> float:
51
+ """
52
+ Compute semantic similarity between two texts using the Universal Sentence Encoder.
53
+ Args:
54
+ text1 (str): First text
55
+ text2 (str): Second text
56
+ Returns:
57
+ float: Cosine similarity score between the two texts (0-1)
58
+ """
59
+ embeddings = self.use_model([text1, text2])
60
+ emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
61
+ return cosine_similarity([emb1], [emb2])[0][0]
62
+
63
+ def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
64
+ """
65
+ Compute quality metrics
66
+ """
67
+ metrics = {}
68
+
69
+ # 1. Semantic Preservation
70
+ embeddings = self.use_model([original, augmented])
71
+ emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
72
+ metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
73
+
74
+ # 2. Fluency & Naturalness
75
+ metrics['perplexity'] = self.compute_perplexity(augmented)
76
+ metrics['grammar_errors'] = len(self.language_tool.check(augmented))
77
+
78
+ # 3. Lexical Diversity
79
+ doc_orig = self.nlp(original)
80
+ doc_aug = self.nlp(augmented)
81
+
82
+ # Type-token ratio with safety check
83
+ aug_tokens = [token.text.lower() for token in doc_aug]
84
+ metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
85
+
86
+ # Content word overlap with safety checks
87
+ orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop])
88
+ aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
89
+
90
+ # Safety check for empty content sets
91
+ if len(orig_content) == 0:
92
+ metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
93
+ else:
94
+ metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
95
+
96
+ # 4. Structural Preservation
97
+ rouge_scores = self.rouge.score(original, augmented)
98
+ metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
99
+ metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
100
+ metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
101
+
102
+ # 5. Length Preservation with safety check
103
+ orig_words = len(original.split())
104
+ aug_words = len(augmented.split())
105
+ metrics['length_ratio'] = aug_words / max(orig_words, 1)
106
+
107
+ return metrics
108
+
109
+ def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool:
110
+ """
111
+ Enhanced quality threshold checking
112
+ """
113
+ # Core quality checks
114
+ basic_quality = (
115
+ metrics['perplexity'] <= self.config.perplexity_threshold and
116
+ metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
117
+ metrics['grammar_errors'] <= self.config.grammar_error_threshold
118
+ )
119
+
120
+ # Length preservation check
121
+ length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
122
+
123
+ # Diversity check
124
+ diversity_ok = metrics['type_token_ratio'] >= 0.4
125
+
126
+ # Content preservation check
127
+ content_ok = metrics['content_preservation'] >= 0.6
128
+
129
+ return all([basic_quality, length_ok, diversity_ok, content_ok])
readme.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Retrieval-based learning chatbot
2
+
3
+ CSC525 - Module 8 Option 2 - Retrieval-based Learning Chatbot - Joseph Armani
4
+
5
+ ## TODO
6
+
7
+ A Python tool to generate high-quality dialog variations.
8
+
9
+ This package automatically downloads the following models during installation:
10
+
11
+ - Universal Sentence Encoder v4 (TensorFlow Hub)
12
+ - ChatGPT Paraphraser T5-base
13
+ - Helsinki-NLP translation models (en-de, de-es, es-en)
14
+ - GPT-2 (for perplexity scoring)
15
+ - spaCy en_core_web_sm
16
+ - nltk wordnet and averaged_perceptron_tagger_eng models
17
+
18
+ ## Install package
19
+
20
+ pip install -e .
21
+
22
+ ## Description
23
+
24
+ This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
25
+ It creates high-quality augmented versions of dialogues by applying various text augmentation techniques and quality control checks.
26
+ Two approaches are used for text augmentation: paraphrasing and back-translation. The pipeline also includes quality metrics for evaluating the augmented text.
27
+ Special handling is implemented for very short text such as greetings and farewells, which are predefined and filtered for quality.
28
+ The pipeline is designed to process a dataset of dialogues and generate multiple high-quality augmented versions of each dialogue.
29
+ The pipeline ensures duplicate dialogues are not generated and that the output meets quality thresholds for semantic similarity, grammar, fluency, diversity, and content preservation.
30
+
31
+ ## References
32
+
33
+ Accsany, P. (2024). Working with JSON data in Python. Real Python. <https://realpython.com/python-json/>
34
+ Explosion AI Team. (n.d.). Spacy · industrial-strength natural language processing in python. <https://spacy.io/>
35
+ GeeksforGeeks. (2024). Text augmentation techniques in NLP. GeeksforGeeks. <https://www.geeksforgeeks.org/text-augmentation-techniques-in-nlp/>
36
+ Helsinki-NLP. (2024). Opus-MT [Computer software]. GitHub. <https://github.com/Helsinki-NLP/Opus-MT>
37
+ Hugging Face. (n.d.). Transformers. Hugging Face. <https://huggingface.co/docs/transformers/en/index>
38
+ Humarin. (2023). ChatGPT paraphraser on T5-base [Computer software]. Hugging Face. <https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base>
39
+ Keita, Z. (2022). Data augmentation in NLP using back-translation with MarianMT. Towards Data Science. <https://towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a>
40
+ Memgraph. (2023). Cosine similarity in Python with scikit-learn. Memgraph. <https://memgraph.com/blog/cosine-similarity-python-scikit-learn>
41
+ Morris, J. (n.d.). language-tool-python (Version 2.8.1) [Computer software]. PyPI. <https://pypi.org/project/language-tool-python/>
42
+ TensorFlow. (n.d.). Universal sentence encoder. TensorFlow Hub. <https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder>
43
+ Waheed, A. (2023). How to calculate ROUGE score in Python. Python Code. <https://thepythoncode.com/article/calculate-rouge-score-in-python>
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spacy>=3.0.0 # Text processing and tokenization
2
+ numpy>=1.19.0 # General numerical computation
3
+ tqdm>=4.64.0 # Progress bar
4
+ torch>=1.10.0 # PyTorch, for deep learning
5
+ tensorflow>=2.6.0 # TensorFlow, for deep learning
6
+ tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
7
+ transformers>=4.21.0 # Hugging Face Transformers library
8
+ rouge-score>=0.1.2 # ROUGE metric for evaluation
9
+ language-tool-python>=2.7.1 # Grammar checking and text correction
10
+ scikit-learn>=1.0.0 # Machine learning tools
11
+ nlpaug>=1.1.0 # Data augmentation for NLP
12
+ nltk>=3.6.0 # Natural language toolkit
schema_guided_dialogue_processor.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Dict, Optional, Any
3
+ import json
4
+ import glob
5
+ from pathlib import Path
6
+ from pipeline_config import PipelineConfig
7
+
8
+ @dataclass
9
+ class SchemaGuidedDialogue:
10
+ """
11
+ Structured representation of a Schema-Guided dialogue
12
+ """
13
+ dialogue_id: str
14
+ service_name: str
15
+ service_description: Optional[str]
16
+ schema: Dict[str, Any]
17
+ turns: List[Dict[str, Any]]
18
+ original_metadata: Dict[str, Any] = field(default_factory=dict)
19
+
20
+ class SchemaGuidedProcessor:
21
+ """
22
+ Handles processing and preparation of Schema-Guided dataset dialogues
23
+ """
24
+ def __init__(self, config: PipelineConfig):
25
+ self.config = config
26
+ self.services = set()
27
+ self.domains = set()
28
+ self.schemas = {}
29
+
30
+ def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]:
31
+ """
32
+ Load and parse Schema-Guided Dialogue dataset
33
+
34
+ Args:
35
+ dialogue_path: Path to the dialogue JSON file
36
+ schema_path: Path to the schema JSON file
37
+ """
38
+ # Define schema and dialogue file patterns
39
+ schema_file = Path(base_dir, "schema.json")
40
+ dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json"))
41
+
42
+ # Check for schema file
43
+ if not schema_file.exists():
44
+ raise FileNotFoundError(f"Schema file not found at {schema_file}")
45
+
46
+ # Load schema
47
+ self.schemas = self._load_schemas(schema_file)
48
+
49
+ # Find and validate dialogue files
50
+ dialogue_files = glob.glob(dialogue_files_pattern)
51
+ if not dialogue_files:
52
+ raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}")
53
+
54
+ print(f"Found {len(dialogue_files)} dialogue files to process.")
55
+
56
+ # Process all dialogues
57
+ processed_dialogues = []
58
+ for file_path in dialogue_files:
59
+ with open(file_path, 'r', encoding='utf-8') as f:
60
+ raw_dialogues = json.load(f)
61
+
62
+ for dialogue in raw_dialogues:
63
+ processed_dialogues.append(self._process_single_dialogue(dialogue))
64
+
65
+ if max_examples and len(processed_dialogues) >= max_examples:
66
+ break
67
+
68
+ return processed_dialogues
69
+
70
+ def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue:
71
+ """
72
+ Process a single dialogue JSON object into a SchemaGuidedDialogue object.
73
+ """
74
+ dialogue_id = str(dialogue.get("dialogue_id", ""))
75
+ services = dialogue.get("services", [])
76
+ service_name = services[0] if services else None
77
+ schema = self.schemas.get(service_name, {})
78
+ service_description = schema.get("description", "")
79
+
80
+ # Process turns
81
+ turns = self._process_turns(dialogue.get("turns", []))
82
+
83
+ # Store metadata
84
+ metadata = {
85
+ "services": services,
86
+ "original_id": dialogue_id,
87
+ }
88
+
89
+ return SchemaGuidedDialogue(
90
+ dialogue_id=f"schema_guided_{dialogue_id}",
91
+ service_name=service_name,
92
+ service_description=service_description,
93
+ schema=schema,
94
+ turns=turns,
95
+ original_metadata=metadata,
96
+ )
97
+
98
+ def _validate_schema(self, schema: Dict[str, Any]) -> bool:
99
+ """
100
+ Validate a schema
101
+ """
102
+ required_keys = {"service_name", "description", "slots", "intents"}
103
+ missing_keys = required_keys - schema.keys()
104
+ if missing_keys:
105
+ print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}")
106
+ return False
107
+ return True
108
+
109
+ def _load_schemas(self, schema_path: str) -> Dict[str, Any]:
110
+ """
111
+ Load and process service schemas
112
+ """
113
+ with open(schema_path, 'r', encoding='utf-8') as f:
114
+ schemas = json.load(f)
115
+
116
+ # Validate and index schemas
117
+ return {
118
+ schema["service_name"]: schema for schema in schemas if self._validate_schema(schema)
119
+ }
120
+
121
+ def _process_turns(self, turns: List[Dict]) -> List[Dict]:
122
+ """
123
+ Process dialogue turns into standardized format
124
+ """
125
+ processed_turns = []
126
+
127
+ for turn in turns:
128
+ try:
129
+ # Map speakers to standard format
130
+ speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user'
131
+
132
+ # Extract utterance and clean it
133
+ text = turn.get('utterance', '').strip()
134
+
135
+ # Extract frames and dialogue acts
136
+ frames = turn.get('frames', [])
137
+ acts = []
138
+ slots = []
139
+
140
+ for frame in frames:
141
+ if 'actions' in frame:
142
+ acts.extend(frame['actions'])
143
+ if 'slots' in frame:
144
+ slots.extend(frame['slots'])
145
+
146
+ # Create the processed turn
147
+ processed_turn = {
148
+ 'speaker': speaker,
149
+ 'text': text,
150
+ 'original_speaker': turn.get('speaker', ''),
151
+ 'dialogue_acts': acts,
152
+ 'slots': slots,
153
+ 'metadata': {k: v for k, v in turn.items()
154
+ if k not in {'speaker', 'utterance', 'frames'}}
155
+ }
156
+
157
+ processed_turns.append(processed_turn)
158
+ except Exception as e:
159
+ print(f"Error processing turn: {str(e)}")
160
+ continue
161
+
162
+ return processed_turns
163
+
164
+ def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]:
165
+ """
166
+ Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline
167
+ """
168
+ pipeline_dialogues = []
169
+
170
+ for dialogue in schema_dialogues:
171
+ # Convert turns to the expected format
172
+ processed_turns = [
173
+ {"speaker": turn["speaker"], "text": turn["text"]}
174
+ for turn in dialogue.turns if turn["text"].strip()
175
+ ]
176
+
177
+ # Create dialogue in pipeline format
178
+ pipeline_dialogue = {
179
+ 'dialogue_id': dialogue.dialogue_id,
180
+ 'turns': processed_turns,
181
+ 'metadata': {
182
+ 'service_name': dialogue.service_name,
183
+ 'service_description': dialogue.service_description,
184
+ 'schema': dialogue.schema,
185
+ **dialogue.original_metadata
186
+ }
187
+ }
188
+
189
+ pipeline_dialogues.append(pipeline_dialogue)
190
+
191
+ return pipeline_dialogues
192
+
setup.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ import subprocess
3
+ import sys
4
+
5
+ with open("README.md", "r", encoding="utf-8") as fh:
6
+ long_description = fh.read()
7
+
8
+ with open("requirements.txt", "r", encoding="utf-8") as fh:
9
+ requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
10
+
11
+ def setup_spacy_model():
12
+ """
13
+ Download spaCy model.
14
+ """
15
+ subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
16
+
17
+ def setup_models():
18
+ """
19
+ Download other required models.
20
+ """
21
+ import tensorflow_hub as hub
22
+ from sklearn.feature_extraction.text import TfidfVectorizer
23
+ from transformers import (
24
+ AutoTokenizer,
25
+ GPT2TokenizerFast,
26
+ MarianTokenizer
27
+ )
28
+
29
+ # Download Universal Sentence Encoder
30
+ _ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
31
+
32
+ # Download paraphraser model
33
+ _ = AutoTokenizer.from_pretrained('humarin/chatgpt_paraphraser_on_T5_base')
34
+
35
+ # Download translation models
36
+ source_lang, pivot_lang, target_lang = 'en', 'de', 'es'
37
+ model_names = [
38
+ f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}',
39
+ f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}',
40
+ f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
41
+ ]
42
+ for model_name in model_names:
43
+ _ = MarianTokenizer.from_pretrained(model_name)
44
+
45
+ # Download GPT-2
46
+ _ = GPT2TokenizerFast.from_pretrained('gpt2')
47
+
48
+ def setup_nltk():
49
+ """
50
+ Download required NLTK data.
51
+ """
52
+ import nltk
53
+ required_packages = [
54
+ 'wordnet',
55
+ 'averaged_perceptron_tagger_eng'
56
+ ]
57
+
58
+ for package in required_packages:
59
+ try:
60
+ print(f"Downloading {package}...")
61
+ nltk.download(package)
62
+ print(f"Successfully downloaded {package}")
63
+ except Exception as e:
64
+ print(f"Warning: Could not download {package}: {str(e)}")
65
+
66
+ setup(
67
+ name="text-data-augmenter",
68
+ version="0.1.0",
69
+ author="Joe Armani",
70
+ author_email="[email protected]",
71
+ description="A tool for generating high-quality dialogue variations",
72
+ packages=find_packages(),
73
+ classifiers=[
74
+ "Development Status :: 3 - Alpha",
75
+ "Intended Audience :: Science/Research",
76
+ "License :: OSI Approved :: MIT License",
77
+ "Operating System :: OS Independent",
78
+ "Programming Language :: Python :: 3",
79
+ "Programming Language :: Python :: 3.8",
80
+ "Programming Language :: Python :: 3.9",
81
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
82
+ "Topic :: Text Processing :: Linguistic",
83
+ ],
84
+ python_requires=">=3.8",
85
+ install_requires=requirements,
86
+ entry_points={
87
+ "console_scripts": [
88
+ "dialogue-augment=dialogue_augmenter.main:main",
89
+ ],
90
+ },
91
+ include_package_data=True,
92
+ package_data={
93
+ "dialogue_augmenter": ["data/*.json", "config/*.yaml"],
94
+ },
95
+ )
96
+
97
+ if __name__ == '__main__':
98
+ setup_spacy_model()
99
+ setup_models()
100
+ setup_nltk()
taskmaster_processor.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Dict, Optional, Any
3
+ import json
4
+ import re
5
+ from pathlib import Path
6
+ from pipeline_config import PipelineConfig
7
+
8
+ @dataclass
9
+ class TaskmasterDialogue:
10
+ """
11
+ Structured representation of a Taskmaster dialogue
12
+ """
13
+ conversation_id: str
14
+ instruction_id: Optional[str]
15
+ scenario: Optional[str]
16
+ domain: Optional[str]
17
+ turns: List[Dict[str, Any]]
18
+ original_metadata: Dict[str, Any] = field(default_factory=dict)
19
+
20
+ def __str__(self):
21
+ return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
22
+
23
+ def validate(self) -> bool:
24
+ return bool(self.conversation_id and isinstance(self.turns, list))
25
+
26
+ class TaskmasterProcessor:
27
+ """
28
+ Handles processing and preparation of Taskmaster dataset dialogues
29
+ """
30
+ config: PipelineConfig
31
+ use_ontology: bool = False # Whether to load and use ontology
32
+ ontology: Optional[Dict[str, Any]] = None # Holds ontology data if loaded
33
+ domains: set = field(default_factory=set) # Tracks unique domains
34
+ scenarios: set = field(default_factory=set) # Tracks unique scenarios
35
+
36
+ def __init__(self, config: PipelineConfig, use_ontology: bool = False):
37
+ self.config = config
38
+ self.use_ontology = use_ontology
39
+ self.ontology = None
40
+ self.domains = set()
41
+ self.scenarios = set()
42
+
43
+ def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
44
+ """
45
+ Load and parse Taskmaster JSON dataset.
46
+ Handles self-dialogs, woz-dialogs, and ontology files.
47
+ """
48
+ required_files = {
49
+ "self-dialogs": "self-dialogs.json",
50
+ "woz-dialogs": "woz-dialogs.json",
51
+ "ontology": "ontology.json",
52
+ }
53
+
54
+ # Check for required files
55
+ missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()]
56
+ if missing_files:
57
+ raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}")
58
+
59
+ # load ontology
60
+ ontology_path = Path(base_dir, required_files['ontology'])
61
+ with open(ontology_path, 'r', encoding='utf-8') as f:
62
+ self.ontology = json.load(f)
63
+
64
+ processed_dialogues = []
65
+ for file_key in ["self-dialogs", "woz-dialogs"]:
66
+ file_path = Path(base_dir, required_files[file_key])
67
+ with open(file_path, 'r', encoding='utf-8') as f:
68
+ raw_data = json.load(f)
69
+
70
+ for dialogue in raw_data:
71
+ # Extract core dialogue components
72
+ conversation_id = dialogue.get('conversation_id', '')
73
+ instruction_id = dialogue.get('instruction_id', None)
74
+
75
+ if 'utterances' in dialogue:
76
+ turns = self._process_utterances(dialogue['utterances'])
77
+ scenario = dialogue.get('scenario', '')
78
+ domain = self._extract_domain(scenario)
79
+ else:
80
+ turns = []
81
+ scenario = ''
82
+ domain = ''
83
+
84
+ # Store metadata
85
+ metadata = {k: v for k, v in dialogue.items()
86
+ if k not in {'conversation_id', 'instruction_id', 'utterances'}}
87
+
88
+ # Create structured dialogue object
89
+ processed_dialogue = TaskmasterDialogue(
90
+ conversation_id=conversation_id,
91
+ instruction_id=instruction_id,
92
+ scenario=scenario,
93
+ domain=domain,
94
+ turns=turns,
95
+ original_metadata=metadata
96
+ )
97
+
98
+ processed_dialogues.append(processed_dialogue)
99
+
100
+ # Update domain and scenario tracking
101
+ if domain:
102
+ self.domains.add(domain)
103
+ if scenario:
104
+ self.scenarios.add(scenario)
105
+
106
+ if max_examples and len(processed_dialogues) >= max_examples:
107
+ break
108
+
109
+ return processed_dialogues
110
+
111
+ def _process_utterances(self, utterances: List[Dict]) -> List[Dict]:
112
+ """
113
+ Process utterances into a standardized format
114
+ """
115
+ processed_turns = []
116
+
117
+ for utterance in utterances:
118
+ # Map Taskmaster speaker roles to your expected format
119
+ speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user'
120
+
121
+ # Extract and clean the text
122
+ text = utterance.get('text', '').strip()
123
+
124
+ # Extract any segments or annotations if present
125
+ segments = utterance.get('segments', [])
126
+
127
+ # Create the processed turn
128
+ turn = {
129
+ 'speaker': speaker,
130
+ 'text': text,
131
+ 'original_speaker': utterance.get('speaker', ''),
132
+ 'segments': segments,
133
+ 'metadata': {k: v for k, v in utterance.items()
134
+ if k not in {'speaker', 'text', 'segments'}}
135
+ }
136
+
137
+ processed_turns.append(turn)
138
+
139
+ return processed_turns
140
+
141
+ def _extract_domain(self, scenario: str) -> str:
142
+ """
143
+ Extract domain from scenario description
144
+ """
145
+ domain_patterns = {
146
+ 'restaurant': r'\b(restaurant|dining|food|reservation)\b',
147
+ 'movie': r'\b(movie|cinema|film|ticket)\b',
148
+ 'ride_share': r'\b(ride|taxi|uber|lyft)\b',
149
+ 'coffee': r'\b(coffee|café|cafe|starbucks)\b',
150
+ 'pizza': r'\b(pizza|delivery|order food)\b',
151
+ 'auto': r'\b(car|vehicle|repair|maintenance)\b',
152
+ }
153
+
154
+ scenario_lower = scenario.lower()
155
+
156
+ for domain, pattern in domain_patterns.items():
157
+ if re.search(pattern, scenario_lower):
158
+ return domain
159
+
160
+ return 'other'
161
+
162
+ def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]:
163
+ """
164
+ Convert TaskmasterDialogues to the format expected by the ProcessingPipeline
165
+ """
166
+ pipeline_dialogues = []
167
+
168
+ for dialogue in taskmaster_dialogues:
169
+ # Convert turns to the expected format
170
+ processed_turns = []
171
+ for turn in dialogue.turns:
172
+ if turn['text'].strip(): # Skip empty turns
173
+ processed_turns.append({
174
+ 'speaker': turn['speaker'],
175
+ 'text': turn['text']
176
+ })
177
+
178
+ # Create dialogue in pipeline format
179
+ pipeline_dialogue = {
180
+ 'dialogue_id': dialogue.conversation_id,
181
+ 'turns': processed_turns,
182
+ 'metadata': {
183
+ 'instruction_id': dialogue.instruction_id,
184
+ 'scenario': dialogue.scenario,
185
+ 'domain': dialogue.domain,
186
+ **dialogue.original_metadata
187
+ }
188
+ }
189
+
190
+ pipeline_dialogues.append(pipeline_dialogue)
191
+
192
+ return pipeline_dialogues