JoeArmani commited on
Commit
4aec49f
·
1 Parent(s): d7fc7a7
.gitignore CHANGED
@@ -182,10 +182,5 @@ training_data/*
182
  !training_data/.gitkeep
183
  augmented_dialogues.json
184
 
185
- checkpoints_old_REMOVE/*
186
- new_iteration/cache/*
187
- new_iteration/data_prep_iterative_models/*
188
- new_iteration/training_data/*
189
- new_iteration/processed_outputs/*
190
  raw_datasets/*
191
 
 
182
  !training_data/.gitkeep
183
  augmented_dialogues.json
184
 
 
 
 
 
 
185
  raw_datasets/*
186
 
data_augmentation_code/augmentation_processing_pipeline.py DELETED
@@ -1,321 +0,0 @@
1
- from datetime import datetime
2
- from pathlib import Path
3
- from typing import List, Dict, Optional
4
- import json
5
- import re
6
- import hashlib
7
- import spacy
8
- import torch
9
- from tqdm import tqdm
10
- from data_augmentation.pipeline_config import PipelineConfig
11
- from data_augmentation.dialogue_augmenter import DialogueAugmenter
12
- from sklearn.feature_extraction.text import TfidfVectorizer
13
- from sklearn.metrics.pairwise import cosine_similarity
14
- from typing import Set
15
-
16
- class AugmentationProcessingPipeline:
17
- """
18
- Complete pipeline combining validation, optimization, and augmentation.
19
- """
20
-
21
- def __init__(self, config: Optional[PipelineConfig] = None):
22
- self.config = config or PipelineConfig()
23
- self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])
24
- self.augmenter = DialogueAugmenter(self.nlp, self.config)
25
- self.num_threads = self.config.batch_size
26
- self.cache_dir = Path("./cache")
27
- self.cache_dir.mkdir(exist_ok=True)
28
- self.output_dir = Path("processed_outputs")
29
- self.output_dir.mkdir(exist_ok=True)
30
- self.checkpoint_file = self.output_dir / "processing_checkpoint.json"
31
- self.batch_size = self.config.batch_size
32
- self.use_gpu = torch.cuda.is_available()
33
- self.batch_size = 32 if self.use_gpu else 8
34
- self.use_multiprocessing = not self.use_gpu
35
-
36
- # Counters for grouping batches
37
- self.batch_counter = 0 # Count batches since last group combine
38
- self.batch_group_number = 0 # How many groups have been created
39
-
40
- if self.config.debug:
41
- print(f"ProcessingPipeline initialized with:")
42
- print(f"- GPU available: {self.use_gpu}")
43
- print(f"- Batch size: {self.batch_size}")
44
- print(f"- Using multiprocessing: {self.use_multiprocessing}")
45
-
46
- def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path:
47
- """Save a batch of results to a separate JSON file"""
48
- batch_file = self.output_dir / f"batch_{batch_num:04d}.json"
49
- with open(batch_file, 'w') as f:
50
- json.dump(batch_results, f)
51
- return batch_file
52
-
53
- def _load_checkpoint(self) -> set:
54
- """Load set of processed dialogue IDs from checkpoint"""
55
- if self.checkpoint_file.exists():
56
- with open(self.checkpoint_file, 'r') as f:
57
- return set(json.load(f))
58
- return set()
59
-
60
- def _update_checkpoint(self, processed_ids: set):
61
- """Update checkpoint with newly processed IDs"""
62
- with open(self.checkpoint_file, 'w') as f:
63
- json.dump(list(processed_ids), f)
64
-
65
- def _process_batch(self, batch: List[Dict]) -> List[Dict]:
66
- """Process batch with optimized model calls"""
67
- results = []
68
- try:
69
- if self.use_gpu:
70
- results = self.augmenter.process_batch(batch)
71
- else:
72
- # Collect all texts that need processing
73
- all_texts = []
74
- text_to_dialogue_map = {}
75
- for dialogue in batch:
76
- for turn in dialogue['turns']:
77
- all_texts.append(turn['text'])
78
- text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
79
-
80
- # Batch process embeddings
81
- self.augmenter._compute_batch_embeddings(all_texts)
82
-
83
- # Process dialogues with cached embeddings
84
- for dialogue in batch:
85
- try:
86
- augmented = self.augmenter.augment_dialogue(dialogue)
87
- results.extend(augmented)
88
- except Exception as e:
89
- print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
90
- continue
91
- except Exception as e:
92
- print(f"Error processing batch: {str(e)}")
93
- return results
94
-
95
- def _combine_intermediate_batches(self):
96
- """
97
- Combine all current batch_*.json files into a single batch_group_XXXX.json file,
98
- then remove the batch_*.json files.
99
- """
100
- batch_files = sorted(self.output_dir.glob("batch_*.json"))
101
- if not batch_files:
102
- return None # No files to combine
103
-
104
- combined_data = []
105
- for bf in batch_files:
106
- with open(bf, 'r') as f:
107
- combined_data.extend(json.load(f))
108
- bf.unlink() # Remove the individual batch file after reading
109
-
110
- self.batch_group_number += 1
111
- group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json"
112
- with open(group_file, 'w') as f:
113
- json.dump(combined_data, f)
114
- return group_file
115
-
116
- def combine_results(self) -> Path:
117
- """Combine all batch_group_*.json files into final output"""
118
- all_results = []
119
- group_files = sorted(self.output_dir.glob("batch_group_*.json"))
120
-
121
- print(f"Combining {len(group_files)} group files...")
122
- for group_file in tqdm(group_files):
123
- with open(group_file, 'r') as f:
124
- group_data = json.load(f)
125
- all_results.extend(group_data)
126
-
127
- # Save combined results
128
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
129
- final_output = self.output_dir / f"augmented_dataset_{timestamp}.json"
130
- with open(final_output, 'w') as f:
131
- json.dump(all_results, f)
132
-
133
- if self.config.debug:
134
- print(f"Combined {len(all_results)} dialogues into {final_output}")
135
-
136
- return final_output
137
-
138
- def process_dataset(self, dialogues: List[Dict]) -> Path:
139
- """Process dataset with hardware-appropriate optimizations and progress tracking"""
140
- processed_ids = self._load_checkpoint()
141
-
142
- # Filter out already processed dialogues
143
- remaining_dialogues = [d for d in dialogues
144
- if d['dialogue_id'] not in processed_ids]
145
-
146
- total_dialogues = len(dialogues)
147
- remaining_count = len(remaining_dialogues)
148
- processed_count = total_dialogues - remaining_count
149
-
150
- print("\nDataset Processing Status:")
151
- print(f"Total dialogues in dataset: {total_dialogues}")
152
- print(f"Previously processed: {processed_count}")
153
- print(f"Remaining to process: {remaining_count}")
154
- print("-" * 50)
155
-
156
- # Process in batches with progress bar
157
- for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size),
158
- desc="Processing batches",
159
- total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size):
160
- batch = remaining_dialogues[batch_num:batch_num + self.batch_size]
161
- current_position = processed_count + batch_num + len(batch)
162
-
163
- total_progress = (current_position / total_dialogues) * 100
164
-
165
- print('\033[K', end='')
166
- print(f"Processing: {current_position}/{total_dialogues} dialogues "
167
- f"({total_progress:.1f}% complete)")
168
- print(f"Current batch: {batch_num//self.batch_size + 1} of "
169
- f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}")
170
- print("-" * 50)
171
-
172
- # Process batch
173
- batch_results = self._process_batch(batch)
174
-
175
- if batch_results:
176
- self._save_batch(batch_results, batch_num)
177
- batch_ids = {d['dialogue_id'] for d in batch}
178
- processed_ids.update(batch_ids)
179
- self._update_checkpoint(processed_ids)
180
-
181
- # Increment batch counter and combine if needed
182
- self.batch_counter += 1
183
- if self.batch_counter == 25:
184
- # Combine these 25 batches into a group file
185
- self._combine_intermediate_batches()
186
- self.batch_counter = 0 # Reset counter after grouping
187
-
188
- # If there are leftover batches less than 25
189
- # combine them into one final group file
190
- if self.batch_counter > 0:
191
- self._combine_intermediate_batches()
192
- self.batch_counter = 0
193
-
194
- print("\n" + "-" * 50)
195
- print("Processing complete. Combining results...")
196
- return self.combine_results()
197
-
198
- def cleanup(self):
199
- """Clean up intermediate files after successful processing"""
200
- # Clean up any leftover batch files (should not exist if logic is correct)
201
- batch_files = list(self.output_dir.glob("batch_*.json"))
202
- for file in batch_files:
203
- try:
204
- file.unlink()
205
- except Exception as e:
206
- print(f"Error deleting {file}: {e}")
207
-
208
- # We can also remove batch_group_*.json if desired after final combine
209
- # but that might not be necessary if we want to keep them.
210
-
211
- if self.checkpoint_file.exists():
212
- try:
213
- self.checkpoint_file.unlink()
214
- except Exception as e:
215
- print(f"Error deleting checkpoint file: {e}")
216
-
217
- def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
218
- """
219
- Deduplicate dialogues based on text similarity.
220
- """
221
- print("Deduplicating dialogues...")
222
- if not dialogues:
223
- print("No dialogues provided for deduplication.")
224
- return []
225
-
226
- # Combine turns into single text for similarity comparison
227
- texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues]
228
- tfidf = TfidfVectorizer().fit_transform(texts)
229
- sim_matrix = cosine_similarity(tfidf)
230
-
231
- unique_indices = set()
232
- for i, row in enumerate(sim_matrix):
233
- if i not in unique_indices:
234
- similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i]
235
- unique_indices.add(i)
236
- unique_indices.difference_update(similar_indices)
237
-
238
- deduplicated_dialogues = [dialogues[i] for i in unique_indices]
239
-
240
- print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.")
241
- return deduplicated_dialogues
242
-
243
- def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]:
244
- """
245
- Validate and clean a single dialogue.
246
- """
247
- try:
248
- # Check required fields
249
- if not all(field in dialogue for field in self.config.required_fields):
250
- return None
251
-
252
- # Process turns
253
- cleaned_turns = []
254
- for turn in dialogue['turns']:
255
- if self._validate_turn(turn):
256
- cleaned_turn = {
257
- 'speaker': turn['speaker'],
258
- 'text': self._clean_text(turn['text'])
259
- }
260
- cleaned_turns.append(cleaned_turn)
261
-
262
- if cleaned_turns:
263
- return {
264
- 'dialogue_id': dialogue['dialogue_id'],
265
- 'turns': cleaned_turns
266
- }
267
-
268
- return None
269
-
270
- except Exception as e:
271
- print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
272
- return None
273
-
274
- def _validate_turn(self, turn: Dict) -> bool:
275
- """
276
- Validate a single speaking turn.
277
- """
278
- return (
279
- turn['speaker'] in self.config.allowed_speakers and
280
- self.config.min_length <= len(turn['text']) <= self.config.max_length
281
- )
282
-
283
- def _clean_text(self, text: str) -> str:
284
- """
285
- Clean and normalize text.
286
- """
287
- # Remove excessive whitespace
288
- text = re.sub(r'\s+', ' ', text.strip())
289
-
290
- # Normalize quotes and apostrophes
291
- text = re.sub(r'[’´`]', "'", text)
292
- text = re.sub(r'[“”]', '"', text)
293
-
294
- # Remove control characters
295
- text = "".join(char for char in text if ord(char) >= 32 or char == '\n')
296
-
297
- return text
298
-
299
- def _process_validation(self, items: List, func, description: str) -> List:
300
- """
301
- Process items sequentially with a progress bar.
302
- """
303
- results = []
304
- print(f"Starting {description}")
305
- for item in tqdm(items, desc=description):
306
- try:
307
- result = func(item)
308
- if result is not None:
309
- results.append(result)
310
- except Exception as e:
311
- print(f"Error processing item: {str(e)}")
312
- print(f"Completed {description}. Processed {len(results)} items successfully")
313
- return results
314
-
315
- def _get_cache_path(self, data: List[Dict]) -> Path:
316
- """
317
- Generate cache file path based on data hash.
318
- """
319
- data_str = json.dumps(data, sort_keys=True)
320
- hash_value = hashlib.md5(data_str.encode()).hexdigest()
321
- return self.cache_dir / f"cache_{hash_value}.pkl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/back_translator.py DELETED
@@ -1,87 +0,0 @@
1
- from transformers import (
2
- MarianMTModel,
3
- MarianTokenizer,
4
- )
5
-
6
- # Retained for reference but removed from the final code.
7
- # This method did not seem helpful for this retrieval-based chatbot.
8
- class BackTranslator:
9
- """
10
- Perform Back-translation with pivot language. English -> German -> Spanish -> English
11
- Args:
12
- source_lang: Source language (default: 'en')
13
- pivot_lang: Pivot language (default: 'de')
14
- target_lang: Target language (default: 'es')
15
- Examples:
16
- back_translator = BackTranslator()
17
- back_translator.back_translate("Hello, how are you?")
18
- """
19
- def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
20
- # Forward (English to German)
21
- pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
22
- self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
23
- self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
24
-
25
- # Pivot translation (German to Spanish)
26
- pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
27
- self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
28
- self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
29
-
30
- # Backward (Spanish to English)
31
- backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
32
- self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
33
- self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
34
-
35
- # Set models to eval mode
36
- self.model_pivot_forward.eval()
37
- self.model_pivot_backward.eval()
38
- self.model_backward.eval()
39
-
40
- def back_translate(self, text, device=None):
41
- try:
42
- # Move models to device if specified
43
- if device is not None:
44
- self.model_pivot_forward = self.model_pivot_forward.to(device)
45
- self.model_pivot_backward = self.model_pivot_backward.to(device)
46
- self.model_backward = self.model_backward.to(device)
47
-
48
- # Forward translation (English to German)
49
- encoded_pivot = self.tokenizer_pivot_forward([text], padding=True,
50
- truncation=True, return_tensors='pt')
51
- if device is not None:
52
- encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
53
-
54
- generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
55
- if device is not None:
56
- generated_pivot = generated_pivot.cpu()
57
- pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot,
58
- skip_special_tokens=True)[0]
59
-
60
- # Pivot translation (German to Spanish)
61
- encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True,
62
- truncation=True, return_tensors='pt')
63
- if device is not None:
64
- encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
65
-
66
- retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
67
- if device is not None:
68
- retranslated_pivot = retranslated_pivot.cpu()
69
- tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot,
70
- skip_special_tokens=True)[0]
71
-
72
- # Backward translation (Spanish to English)
73
- encoded_back = self.tokenizer_backward([tgt_text_back], padding=True,
74
- truncation=True, return_tensors='pt')
75
- if device is not None:
76
- encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
77
-
78
- retranslated = self.model_backward.generate(**encoded_back)
79
- if device is not None:
80
- retranslated = retranslated.cpu()
81
- src_text = self.tokenizer_backward.batch_decode(retranslated,
82
- skip_special_tokens=True)[0]
83
-
84
- return src_text
85
- except Exception as e:
86
- print(f"Error in back translation: {e}")
87
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/dialogue_augmenter.py DELETED
@@ -1,710 +0,0 @@
1
- from typing import Dict, List
2
- import numpy as np
3
- import torch
4
- import tensorflow as tf
5
- import tensorflow_hub as hub
6
- from data_augmentation.pipeline_config import PipelineConfig
7
- from data_augmentation.quality_metrics import QualityMetrics
8
- from data_augmentation.paraphraser import Paraphraser
9
- import nlpaug.augmenter.word as naw
10
- from functools import lru_cache
11
- from sklearn.metrics.pairwise import cosine_similarity
12
-
13
- class DialogueAugmenter:
14
- """
15
- Optimized dialogue augmentation with quality control and complexity management.
16
- """
17
- def __init__(self, nlp, config: PipelineConfig):
18
- self.nlp = nlp
19
- self.config = config
20
-
21
- # Detect hardware and set appropriate batch sizes and optimization strategy
22
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
- self.use_gpu = torch.cuda.is_available()
24
-
25
- if self.config.debug:
26
- print(f"Using device: {self.device}")
27
- if self.use_gpu:
28
- print(f"GPU Device: {torch.cuda.get_device_name(0)}")
29
-
30
-
31
- self.quality_metrics = QualityMetrics(config)
32
- self.semantic_similarity_threshold = 0.75
33
-
34
- # Load model
35
- self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
36
-
37
- # Initialize augmentation models based on hardware
38
- self._initialize_augmentation_models()
39
-
40
- # Initialize caches
41
- self.embedding_cache = {}
42
-
43
- # GPU memory management if available
44
- if self.use_gpu:
45
- gpus = tf.config.list_physical_devices('GPU')
46
- if gpus:
47
- try:
48
- for gpu in gpus:
49
- tf.config.experimental.set_memory_growth(gpu, True)
50
- except RuntimeError as e:
51
- print(e)
52
-
53
- def _initialize_augmentation_models(self):
54
- """Initialize augmentation models with appropriate device settings"""
55
- # Advanced augmentation techniques
56
- self.paraphraser = Paraphraser()
57
- if self.use_gpu:
58
- # Move model to GPU if available
59
- self.paraphraser.model = self.paraphraser.model.to(self.device)
60
-
61
- # Basic augmentation techniques
62
- self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
63
-
64
- self.augmenters = {
65
- 'advanced': [
66
- self.paraphraser,
67
- ],
68
- 'basic': [
69
- ('synonym', self.word_augmenter),
70
- ]
71
- }
72
-
73
- @lru_cache(maxsize=1024)
74
- def _compute_embedding(self, text: str) -> np.ndarray:
75
- """Cached computation of text embedding"""
76
- if text in self.embedding_cache:
77
- return self.embedding_cache[text]
78
- embedding = self.use_model([text])[0].numpy()
79
- self.embedding_cache[text] = embedding
80
- return embedding
81
-
82
- def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
83
- """Compute embeddings for multiple texts at once with hardware optimization"""
84
- # Check cache first
85
- uncached_texts = [t for t in texts if t not in self.embedding_cache]
86
- if uncached_texts:
87
- embeddings = self.use_model(uncached_texts).numpy()
88
- # Update cache
89
- for text, embedding in zip(uncached_texts, embeddings):
90
- self.embedding_cache[text] = embedding
91
-
92
- # Return all embeddings (from cache or newly computed)
93
- return np.array([self.embedding_cache[t] for t in texts])
94
-
95
- def _quick_quality_check(self, variation: str, original: str) -> bool:
96
- """
97
- Preliminary quality check while maintaining reasonable pass rates
98
- """
99
- if self.config.debug:
100
- print(f"\nQuick check for variation: {variation}")
101
-
102
- orig_len = len(original.split())
103
- var_len = len(variation.split())
104
-
105
- # For very short texts (<= 3 words), still allow more variation
106
- if orig_len <= 3:
107
- if var_len > orig_len * 3:
108
- if self.config.debug:
109
- print(f"Failed length check (short text): {var_len} vs {orig_len}")
110
- return False
111
- else:
112
- if var_len > orig_len * 2:
113
- if self.config.debug:
114
- print(f"Failed length check (long text): {var_len} vs {orig_len}")
115
- return False
116
-
117
- # Adjust content overlap check based on length
118
- stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
119
- orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
120
- var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
121
-
122
- # If very short turn (less than 5 words), skip the content overlap check
123
- if orig_len >= 5:
124
- content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
125
- if content_overlap < 0.2:
126
- if self.config.debug:
127
- print(f"Failed content check: overlap {content_overlap:.2f}")
128
- return False
129
- else:
130
- if self.config.debug:
131
- print("Short turn detected (<5 words), skipping content overlap check")
132
-
133
- if self.config.debug:
134
- print("Passed all quick checks")
135
- return True
136
-
137
- def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
138
- """
139
- Filter variations using batched computations with detailed logging
140
- """
141
- if not variations:
142
- return []
143
-
144
- if self.config.debug:
145
- print(f"\nStarting filtration of {len(variations)} variations")
146
- print(f"Context length: {len(context)}")
147
- print(f"Original turn: {original_turn}")
148
-
149
- words = original_turn.split()
150
- orig_len = len(words)
151
-
152
- # If very short text, consider adjusting thresholds
153
- is_very_short = orig_len < 5
154
-
155
- if len(words) < 3:
156
- if self.config.debug:
157
- print("Short text detected, using predefined variations")
158
- short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
159
- return [var['text'] for var in short_text_variations]
160
-
161
- # If this is the first turn (no context), be more lenient
162
- if not context:
163
- preliminary_filtered = variations
164
- if self.config.debug:
165
- print("First turn - skipping preliminary filtering")
166
- else:
167
- # Quick preliminary filtering against original turn
168
- preliminary_filtered = []
169
- for var in variations:
170
- passed = self._quick_quality_check(var, original_turn)
171
- if self.config.debug:
172
- print(f"\nVariation: {var}")
173
- print(f"Passed quick check: {passed}")
174
- if passed:
175
- preliminary_filtered.append(var)
176
-
177
- if self.config.debug:
178
- print(f"Variations after quick check: {len(preliminary_filtered)}")
179
-
180
- if not preliminary_filtered:
181
- return []
182
-
183
- # Compute embeddings for original and variations
184
- original_embedding = self._compute_embedding(original_turn)
185
- variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)
186
-
187
- # Compute similarities
188
- sims = cosine_similarity([original_embedding], variation_embeddings)[0]
189
-
190
- # If very short turn, slightly lower the semantic similarity threshold
191
- dynamic_sem_threshold = self.semantic_similarity_threshold
192
- if is_very_short:
193
- dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)
194
-
195
- # Filter by semantic similarity threshold
196
- refined_filtered = []
197
- for var, sim in zip(preliminary_filtered, sims):
198
- if sim >= dynamic_sem_threshold:
199
- refined_filtered.append(var)
200
- else:
201
- if self.config.debug:
202
- print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")
203
-
204
- if not refined_filtered:
205
- return []
206
-
207
- # Relax context coherence thresholds further if desired
208
- # We already have min_similarity = 0.1, min_coherence = 0.05
209
- # Let's lower them slightly more if the turn is very short:
210
- if is_very_short:
211
- min_similarity = 0.05
212
- min_coherence = 0.02
213
- else:
214
- min_similarity = 0.1
215
- min_coherence = 0.05
216
-
217
- # Only use last turn for coherence
218
- recent_context = [context[-1]] if context else []
219
- context_text = ' '.join(recent_context) if recent_context else ''
220
-
221
- if context_text:
222
- if self.config.debug:
223
- print(f"\nContext text: {context_text}")
224
-
225
- all_texts = [context_text] + refined_filtered
226
- all_embeddings = self._compute_batch_embeddings(all_texts)
227
-
228
- context_embedding = all_embeddings[0]
229
- variation_embeddings = all_embeddings[1:]
230
-
231
- # Vectorized similarity computation
232
- context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
233
-
234
- # Response coherence check
235
- if recent_context:
236
- prev_embedding = self._compute_embedding(recent_context[-1])
237
- response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
238
- else:
239
- response_coherence = np.ones_like(context_similarities)
240
-
241
- filtered_variations = []
242
- for i, (variation, sim, coh) in enumerate(zip(
243
- refined_filtered, context_similarities, response_coherence)):
244
- combined_score = (
245
- self.config.context_similarity_weight * abs(sim) +
246
- self.config.response_coherence_weight * abs(coh)
247
- )
248
-
249
- if self.config.debug:
250
- print(f"\nVariation: {variation}")
251
- print(f"Context similarity: {sim:.3f}")
252
- print(f"Response coherence: {coh:.3f}")
253
- print(f"Combined score: {combined_score:.3f}")
254
-
255
- # Accept if EITHER score is good enough
256
- if (combined_score >= min_similarity or abs(coh) >= min_coherence):
257
- filtered_variations.append(variation)
258
- if self.config.debug:
259
- print("ACCEPTED")
260
- else:
261
- if self.config.debug:
262
- print("REJECTED")
263
-
264
- # If we have enough variations, stop
265
- if len(filtered_variations) >= self.config.max_variations_per_turn:
266
- break
267
- else:
268
- filtered_variations = refined_filtered[:self.config.max_variations_per_turn]
269
-
270
- if self.config.debug:
271
- print(f"\nFinal filtered variations: {len(filtered_variations)}")
272
-
273
- return filtered_variations
274
-
275
- def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
276
- """
277
- Generate variations progressively until we have enough good ones.
278
- Adjust paraphraser parameters for closer paraphrases as needed.
279
- """
280
- variations = set()
281
-
282
- if self.config.debug:
283
- print(f"\nAttempting to generate {needed} variations for text: {text}")
284
-
285
- # Fine-tune paraphraser here if needed: fewer beams, less diversity already done
286
- for augmenter in self.augmenters['advanced']:
287
- if len(variations) >= needed:
288
- break
289
-
290
- try:
291
- if isinstance(augmenter, Paraphraser):
292
- if self.config.debug:
293
- print("Trying paraphrase augmentation...")
294
- new_vars = augmenter.paraphrase(
295
- text,
296
- num_return_sequences=needed-len(variations),
297
- device=self.device if self.use_gpu else None,
298
- num_beams=4, # even fewer beams for more faithful paraphrases
299
- num_beam_groups=1,
300
- diversity_penalty=0.0
301
- )
302
- if self.config.debug:
303
- print(f"Paraphraser generated {len(new_vars)} variations")
304
-
305
- valid_vars = [v for v in new_vars if v.strip() and v != text]
306
- variations.update(valid_vars)
307
-
308
- if self.config.debug:
309
- print(f"Current unique variations: {len(variations)}")
310
-
311
- except Exception as e:
312
- print(f"Error in advanced augmentation: {str(e)}")
313
- continue
314
-
315
- # Try basic augmenters if needed
316
- if len(variations) < needed:
317
- if self.config.debug:
318
- print("Not enough variations, trying basic augmenters...")
319
-
320
- for aug_type, augmenter in self.augmenters['basic']:
321
- if len(variations) >= needed:
322
- break
323
-
324
- try:
325
- if self.config.debug:
326
- print(f"Trying {aug_type} augmentation...")
327
-
328
- new_vars = augmenter.augment(text, n=2)
329
- if isinstance(new_vars, list):
330
- valid_vars = [v for v in new_vars if v.strip() and v != text]
331
- variations.update(valid_vars)
332
- else:
333
- if new_vars.strip() and new_vars != text:
334
- variations.add(new_vars)
335
-
336
- if self.config.debug:
337
- print(f"After {aug_type}, total variations: {len(variations)}")
338
-
339
- except Exception as e:
340
- print(f"Error in {aug_type} augmentation: {str(e)}")
341
- continue
342
-
343
- variations_list = list(variations)
344
-
345
- if self.config.debug:
346
- print(f"Final number of variations generated: {len(variations_list)}")
347
- if not variations_list:
348
- print("WARNING: No variations were generated!")
349
-
350
- return variations_list
351
-
352
- def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
353
- """
354
- Create augmented versions of the dialogue with optimized processing
355
- """
356
- # Early dialogue length check
357
- original_length = len(dialogue['turns'])
358
- if original_length > self.config.max_turns_per_dialogue:
359
- if self.config.debug:
360
- print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
361
- dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
362
-
363
- turn_variations = []
364
- context = []
365
-
366
- # Process each turn with progressive generation
367
- for turn in dialogue['turns']:
368
- original_text = turn['text'] # Store original turn text
369
- variations = self._generate_variations_progressive(
370
- original_text,
371
- self.config.max_variations_per_turn
372
- )
373
-
374
- # Batch filter variations with original text
375
- filtered_variations = self._filter_variations_batch(
376
- variations,
377
- context,
378
- original_text # Pass the original turn text
379
- )
380
-
381
- # Create turn variations with speaker info
382
- turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
383
-
384
- if self.config.debug:
385
- print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
386
-
387
- turn_variations.append(turn_vars)
388
- context.append(original_text)
389
-
390
- # Generate combinations with sampling
391
- augmented_dialogues = self._generate_dialogue_combinations(
392
- dialogue['dialogue_id'],
393
- turn_variations,
394
- dialogue
395
- )
396
-
397
- # Add original dialogue
398
- result = [{
399
- 'dialogue_id': f"{dialogue['dialogue_id']}_original",
400
- 'turns': dialogue['turns']
401
- }]
402
-
403
- # Add unique augmentations
404
- result.extend(augmented_dialogues[:self.config.augmentation_factor])
405
-
406
- if self.config.debug:
407
- print(f"Generated {len(result)-1} unique augmented dialogues")
408
-
409
- return result
410
-
411
- def _variation_score(self, original: str, variation: str) -> float:
412
- """
413
- Compute a single numeric score for a variation to guide selection.
414
- You could use semantic similarity, content preservation, etc.
415
- Higher is better.
416
- """
417
- metrics = self.quality_metrics.compute_metrics(original, variation)
418
- # Example: Primarily semantic similarity, with a slight boost for content preservation
419
- # Adjust as needed.
420
- score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
421
- return score
422
-
423
- def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
424
- """
425
- Compute a quality score for the entire augmented dialogue.
426
- For example, average semantic similarity of turns to the original turns.
427
- This is done after the dialogue is formed.
428
- """
429
- original_texts = [t['text'] for t in original_dialogue['turns']]
430
- aug_texts = [t['text'] for t in dialogue['turns']]
431
-
432
- # Compute semantic similarity turn-by-turn and average it
433
- scores = []
434
- for orig, aug in zip(original_texts, aug_texts):
435
- # Simple semantic similarity for scoring
436
- emb_orig = self._compute_embedding(orig)
437
- emb_aug = self._compute_embedding(aug)
438
- sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
439
- scores.append(sim)
440
-
441
- # Could also incorporate diversity checks, content overlap, etc.
442
- return float(np.mean(scores)) if scores else 0.0
443
-
444
- def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
445
- """
446
- Generate dialogue combinations using a more controlled approach:
447
- - Include the original turn as a fallback variation for each turn.
448
- - Sort variations by a quality score.
449
- - Ensure a balanced augmentation by requiring at least some turns to be augmented.
450
- - Over-generate and then select top dialogues by quality.
451
- """
452
- # Over-generate factor: create more candidates than needed
453
- over_generate_factor = self.config.augmentation_factor * 2
454
-
455
- # Add the original turn as a fallback variation for each turn if not present
456
- for i, turn_variants in enumerate(turn_variations):
457
- original_turn_text = None
458
- # Check if we previously stored original turn text with a marker or just use the original dialogue
459
- # If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
460
- original_turn_text = original_dialogue['turns'][i]['text']
461
-
462
- # Add the original turn as a variation if not already included
463
- if not any(v['text'] == original_turn_text for v in turn_variants):
464
- turn_variants.append({
465
- 'speaker': original_dialogue['turns'][i]['speaker'],
466
- 'text': original_turn_text
467
- })
468
-
469
- # Sort variations by score
470
- original_text = original_dialogue['turns'][i]['text']
471
- turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)
472
-
473
- augmented_dialogues = []
474
- used_combinations = set()
475
-
476
- def generate_candidates(current_turns=None, turn_index=0):
477
- if current_turns is None:
478
- current_turns = []
479
-
480
- if len(augmented_dialogues) >= over_generate_factor:
481
- return
482
-
483
- if turn_index == len(turn_variations):
484
- # Completed a candidate dialogue
485
- dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
486
- if dialogue_fingerprint not in used_combinations:
487
- used_combinations.add(dialogue_fingerprint)
488
- # Check if we have enough augmented turns
489
- aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns)
490
- if orig['text'] != curr['text'])
491
- # Require at least half the turns to be augmented, for example
492
- if aug_count >= max(1, len(turn_variations)//2):
493
- augmented_dialogues.append({
494
- 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
495
- 'turns': current_turns.copy()
496
- })
497
- return
498
-
499
- turn_candidates = turn_variations[turn_index]
500
-
501
- # If no variations are available for this turn, let's just return without error.
502
- # Normally, this shouldn't happen since we always add the original turn above.
503
- if not turn_candidates:
504
- # If you want to at least have the original turn, add it now:
505
- original_text = original_dialogue['turns'][turn_index]['text']
506
- turn_candidates.append({
507
- 'speaker': original_dialogue['turns'][turn_index]['speaker'],
508
- 'text': original_text
509
- })
510
-
511
- # After the fallback, if still empty for some reason, just return.
512
- if not turn_candidates:
513
- return
514
-
515
- # Example strategy:
516
- # 1. Always try the top variation (most semantically similar).
517
- # 2. If available and allowed, pick a mid-ranked variation for diversity.
518
- # 3. Include the original turn if not selected yet.
519
-
520
- num_vars = min(self.config.max_sampled_variations, len(turn_candidates))
521
-
522
- # Always include top variation
523
- candidates_to_pick = [turn_candidates[0]]
524
-
525
- # If we have more than 2 variations and can pick more, add a middle variation for diversity
526
- if len(turn_candidates) > 2 and num_vars > 1:
527
- mid_index = len(turn_candidates)//2
528
- candidates_to_pick.append(turn_candidates[mid_index])
529
-
530
- # If we still have room for another variation, try adding the original turn if not included
531
- if num_vars > len(candidates_to_pick):
532
- original_turn_text = original_dialogue['turns'][turn_index]['text']
533
- orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
534
- if orig_candidate and orig_candidate not in candidates_to_pick:
535
- candidates_to_pick.append(orig_candidate)
536
-
537
- # Shuffle candidates to produce different dialogues
538
- np.random.shuffle(candidates_to_pick)
539
-
540
- for variation in candidates_to_pick:
541
- if len(augmented_dialogues) >= over_generate_factor:
542
- return
543
- current_turns.append(variation)
544
- generate_candidates(current_turns, turn_index + 1)
545
- current_turns.pop()
546
-
547
- try:
548
- generate_candidates()
549
- except Exception as e:
550
- print(f"Error in dialogue generation: {str(e)}")
551
- return []
552
-
553
- # Over-generated set of augmented dialogues is now available
554
- # Let's score them and pick the top ones
555
- scored_dialogues = []
556
- for d in augmented_dialogues:
557
- score = self._dialogue_quality_score(d, original_dialogue)
558
- scored_dialogues.append((score, d))
559
-
560
- scored_dialogues.sort(key=lambda x: x[0], reverse=True)
561
- # Pick top `augmentation_factor` dialogues
562
- final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
563
-
564
- return final_dialogues
565
- # def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
566
- # """
567
- # Generate dialogue combinations using sampling
568
- # """
569
- # augmented_dialogues = []
570
- # used_combinations = set()
571
-
572
- # def generate_dialogues(current_turns=None, turn_index=0):
573
- # if current_turns is None:
574
- # current_turns = []
575
-
576
- # if len(augmented_dialogues) >= self.config.augmentation_factor:
577
- # return
578
-
579
- # if turn_index == len(turn_variations):
580
- # dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
581
- # if dialogue_fingerprint not in used_combinations:
582
- # used_combinations.add(dialogue_fingerprint)
583
- # augmented_dialogues.append({
584
- # 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
585
- # 'turns': current_turns.copy()
586
- # })
587
- # return
588
-
589
- # variations = list(turn_variations[turn_index])
590
- # np.random.shuffle(variations)
591
-
592
- # for variation in variations[:self.config.max_sampled_variations]:
593
- # if len(augmented_dialogues) >= self.config.augmentation_factor:
594
- # return
595
- # current_turns.append(variation)
596
- # generate_dialogues(current_turns, turn_index + 1)
597
- # current_turns.pop()
598
-
599
- # try:
600
- # generate_dialogues()
601
- # except Exception as e:
602
- # print(f"Error in dialogue generation: {str(e)}")
603
- # return []
604
-
605
- # return augmented_dialogues
606
-
607
- def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
608
- """
609
- Check if two dialogues are duplicates.
610
- """
611
- text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
612
- text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
613
- return text1 == text2
614
-
615
- def _augment_short_text(self, turn: Dict) -> List[Dict]:
616
- """
617
- Special handling for very short texts with predefined variations.
618
- If predefined variations are found, return them directly.
619
- Otherwise, produce simple punctuation and capitalization variants.
620
- Skip heavy quality checks for efficiency. These variations are safe and minimal.
621
- """
622
- text = turn['text']
623
- common_variations = {
624
- 'goodbye': [
625
- 'Bye!', 'Farewell!', 'See you!', 'Take care!',
626
- 'Goodbye!', 'Bye for now!', 'Until next time!'
627
- ],
628
- 'hello': [
629
- 'Hi!', 'Hey!', 'Hello!', 'Greetings!',
630
- 'Good day!', 'Hi there!', 'Hello there!'
631
- ],
632
- 'yes': [
633
- 'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
634
- 'That\'s right!', 'Definitely!', 'Sure!'
635
- ],
636
- 'no': [
637
- 'No!', 'Nope!', 'Not at all!', 'Negative!',
638
- 'Unfortunately not!', 'I\'m afraid not!'
639
- ],
640
- 'thanks': [
641
- 'Thank you!', 'Thanks a lot!', 'Many thanks!',
642
- 'I appreciate it!', 'Thank you so much!'
643
- ],
644
- 'ok': [
645
- 'Okay!', 'Alright!', 'Sure!', 'Got it!',
646
- 'Understood!', 'Fine!', 'Great!', 'Perfect!',
647
- 'That works!', 'Sounds good!'
648
- ],
649
- 'good': [
650
- 'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
651
- 'Fantastic!', 'Amazing!', 'Terrific!'
652
- ]
653
- }
654
-
655
- text_lower = text.lower().rstrip('!.,?')
656
- # Check if text matches any predefined category
657
- variations = []
658
- for key, predefined_vars in common_variations.items():
659
- if key in text_lower or text_lower in key:
660
- variations.extend(predefined_vars)
661
-
662
- if not variations:
663
- # Generate simple punctuation and capitalization variations if no predefined match
664
- base = text.rstrip('!.,?')
665
- variations = [
666
- base + '!',
667
- base + '.',
668
- base
669
- ]
670
-
671
- # Add capitalization variations
672
- capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
673
- variations.extend(capitalized)
674
-
675
- # Ensure uniqueness
676
- unique_variations = list(set(variations))
677
-
678
- # Directly return these variations, as they are minimal and trusted
679
- # No further quality checks are needed
680
- result_variations = unique_variations[:self.config.augmentation_factor]
681
- return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
682
-
683
- def process_batch(self, batch: List[Dict]) -> List[Dict]:
684
- """Process multiple dialogues at once to maximize GPU utilization"""
685
- results = []
686
-
687
- # Pre-compute embeddings for all texts in batch
688
- all_texts = []
689
- text_to_embedding = {}
690
-
691
- for dialogue in batch:
692
- for turn in dialogue['turns']:
693
- all_texts.append(turn['text'])
694
-
695
- # Batch compute embeddings
696
- if all_texts:
697
- embeddings = self._compute_batch_embeddings(all_texts)
698
- for text, embedding in zip(all_texts, embeddings):
699
- self.embedding_cache[text] = embedding
700
-
701
- # Process each dialogue using cached embeddings
702
- for dialogue in batch:
703
- try:
704
- augmented = self.augment_dialogue(dialogue)
705
- results.extend(augmented)
706
- except Exception as e:
707
- print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
708
- continue
709
-
710
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/main.py DELETED
@@ -1,112 +0,0 @@
1
- """
2
- CSC525 - Module 8 Option 2 - Joseph Armani
3
- Description and References in the README.md file.
4
- """
5
- import json
6
- import tensorflow as tf
7
- from typing import List, Dict
8
- from data_augmentation.pipeline_config import PipelineConfig
9
- from data_augmentation.augmentation_processing_pipeline import AugmentationProcessingPipeline
10
- from data_augmentation.taskmaster_processor import TaskmasterProcessor
11
- from data_augmentation.schema_guided_dialogue_processor import SchemaGuidedProcessor
12
-
13
- def combine_datasets(taskmaster_dialogues: List[Dict],
14
- schema_guided_dialogues: List[Dict]) -> List[Dict]:
15
- """
16
- Combine dialogues from both datasets into a single list
17
-
18
- Args:
19
- taskmaster_dialogues: List of dialogues in pipeline format from Taskmaster
20
- schema_guided_dialogues: List of dialogues in pipeline format from Schema-Guided
21
-
22
- Returns:
23
- List[Dict]: Combined list of dialogues
24
- """
25
- # Ensure unique dialogue IDs
26
- combined_dialogues = []
27
- seen_ids = set()
28
- duplicate_count = 0 # Track duplicates for reporting
29
-
30
- for dialogue in taskmaster_dialogues:
31
- dialogue_copy = dialogue.copy()
32
- dialogue_id = dialogue_copy['dialogue_id']
33
- if dialogue_id in seen_ids:
34
- duplicate_count += 1
35
- dialogue_id = f"taskmaster_{dialogue_id}"
36
- seen_ids.add(dialogue_id)
37
- dialogue_copy['dialogue_id'] = dialogue_id
38
- combined_dialogues.append(dialogue_copy)
39
-
40
- for dialogue in schema_guided_dialogues:
41
- dialogue_copy = dialogue.copy()
42
- dialogue_id = dialogue_copy['dialogue_id']
43
- if dialogue_id in seen_ids:
44
- duplicate_count += 1
45
- dialogue_id = f"schema_guided_{dialogue_id}"
46
- seen_ids.add(dialogue_id)
47
- dialogue_copy['dialogue_id'] = dialogue_id
48
- combined_dialogues.append(dialogue_copy)
49
-
50
- # Log the results
51
- print(f"Combine Datasets: Found and resolved {duplicate_count} duplicate dialogue IDs.")
52
- print(f"Combine Datasets: Total dialogues combined: {len(combined_dialogues)}")
53
-
54
- return combined_dialogues
55
-
56
- def main():
57
- # Configuration
58
- config = PipelineConfig(
59
- min_length=1,
60
- max_length=512,
61
- batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
62
- max_turns_per_dialogue=12,
63
- max_variations_per_turn=4,
64
- max_sampled_variations=2,
65
- context_window_size=4,
66
- max_complexity_threshold=100,
67
- use_cache=False,
68
- debug=True,
69
- allowed_speakers=['user', 'assistant'],
70
- required_fields=['dialogue_id', 'turns']
71
- )
72
-
73
- try:
74
- # Set max_examples (Optional[int]) for testing
75
- max_examples = 5
76
-
77
- # Initialize and load Taskmaster dataset
78
- print("Loading Taskmaster dataset")
79
- taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
80
- taskmaster_dialogues = taskmaster_processor.load_dataset('./datasets/taskmaster', max_examples=max_examples)
81
- taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
82
- print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
83
-
84
- # Initialize and load Schema-Guided dataset
85
- print("Loading Schema-Guided dataset")
86
- schema_dialogue_processor = SchemaGuidedProcessor(config)
87
- schema_dialogues = schema_dialogue_processor.load_dataset('./datasets/schema_guided', max_examples=max_examples)
88
- schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
89
- print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
90
-
91
- # Combine datasets
92
- print("Combining datasets")
93
- combined_dialogues = combine_datasets(taskmaster_pipeline_dialogues, schema_pipeline_dialogues)
94
- print(f"Combined Dialogues: {len(combined_dialogues)}")
95
-
96
- if not combined_dialogues:
97
- print("Combined dialogues are empty. Exiting.")
98
- return
99
-
100
- # Process through augmentation pipeline
101
- print("Processing combined dataset")
102
- pipeline = AugmentationProcessingPipeline(config)
103
- output_path = pipeline.process_dataset(combined_dialogues)
104
- print(f"Processing complete. Results saved to {output_path}")
105
- pipeline.cleanup()
106
-
107
- except Exception as e:
108
- print(f"Processing failed: {str(e)}")
109
- raise
110
-
111
- if __name__ == "__main__":
112
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/paraphraser.py DELETED
@@ -1,42 +0,0 @@
1
- from transformers import (
2
- AutoTokenizer,
3
- AutoModelForSeq2SeqLM,
4
- )
5
-
6
- class Paraphraser:
7
- def __init__(self, model_name='humarin/chatgpt_paraphraser_on_T5_base'):
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
- self.model.eval()
11
-
12
- def paraphrase(self, text, num_return_sequences=5, num_beams=5,
13
- num_beam_groups=1, diversity_penalty=0.0, device=None):
14
- try:
15
- input_text = "paraphrase: " + text + " </s>"
16
- encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
17
-
18
- # Move input tensors to specified device if provided
19
- if device is not None:
20
- input_ids = encoding["input_ids"].to(device)
21
- self.model = self.model.to(device)
22
- else:
23
- input_ids = encoding["input_ids"]
24
-
25
- outputs = self.model.generate(
26
- input_ids=input_ids,
27
- max_length=256,
28
- num_beams=num_beams,
29
- num_beam_groups=num_beam_groups,
30
- num_return_sequences=num_return_sequences,
31
- diversity_penalty=diversity_penalty,
32
- early_stopping=True
33
- )
34
-
35
- # Move outputs back to CPU for tokenizer decoding
36
- outputs = outputs.cpu() if device is not None else outputs
37
- paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
38
- for output in outputs]
39
- return paraphrases
40
- except Exception as e:
41
- print(f"Error in paraphrasing: {e}")
42
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/pipeline_config.py DELETED
@@ -1,57 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List
3
-
4
- @dataclass
5
- class PipelineConfig:
6
- """
7
- Config for the pipeline
8
- """
9
- # Validation settings
10
- min_length: int = 1
11
- max_length: int = 512
12
- min_tokens: int = 1
13
- max_tokens: int = 128
14
-
15
- allowed_speakers: List[str] = None
16
- required_fields: List[str] = None
17
-
18
- # Text augmentation settings
19
- augmentation_factor: int = 4
20
- augmentation_techniques: List[str] = None
21
-
22
- max_turns_per_dialogue: int = 6
23
- max_variations_per_turn: int = 3
24
- max_sampled_variations: int = 2
25
- max_complexity_threshold: int = 100
26
- complexity_reduction_turns: int = 4
27
-
28
- # Quality thresholds
29
- semantic_similarity_threshold: float = 0.45
30
- grammar_error_threshold: int = 2
31
- rouge1_f1_threshold: float = 0.30
32
- rouge2_f1_threshold: float = 0.15
33
-
34
- # Response coherence thresholds
35
- min_response_coherence: float = 0.3
36
- context_similarity_weight: float = 0.35
37
- response_coherence_weight: float = 0.65
38
-
39
- # Performance settings
40
- batch_size: int = 32
41
- use_cache: bool = True
42
- debug: bool = False
43
-
44
- context_window_size: int = 4
45
-
46
- def __post_init__(self):
47
- if self.allowed_speakers is None:
48
- self.allowed_speakers = ['user', 'assistant']
49
- if self.required_fields is None:
50
- self.required_fields = ['dialogue_id', 'turns']
51
- if self.augmentation_techniques is None:
52
- self.augmentation_techniques = ['paraphrase', 'back_translation']
53
-
54
- # Validate weights sum to 1.0
55
- if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
56
- raise ValueError("Context similarity and response coherence weights must sum to 1.0")
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/quality_metrics.py DELETED
@@ -1,47 +0,0 @@
1
- import tensorflow_hub as hub
2
- import spacy
3
- from sklearn.metrics.pairwise import cosine_similarity
4
- from typing import Dict
5
- from data_augmentation.pipeline_config import PipelineConfig
6
-
7
- class QualityMetrics:
8
- """
9
- Quality metrics focusing on semantic similarity and basic lexical stats.
10
- """
11
- def __init__(self, config: PipelineConfig):
12
- self.config = config
13
- self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
14
- self.nlp = spacy.load('en_core_web_md')
15
-
16
- def compute_semantic_similarity(self, text1: str, text2: str) -> float:
17
- embeddings = self.use_model([text1, text2])
18
- emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
19
- return cosine_similarity([emb1], [emb2])[0][0]
20
-
21
- def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
22
- metrics = {}
23
- # Semantic similarity
24
- embeddings = self.use_model([original, augmented])
25
- emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
26
- metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
27
-
28
- # Lexical diversity & content preservation
29
- doc_orig = self.nlp(original)
30
- doc_aug = self.nlp(augmented)
31
-
32
- aug_tokens = [token.text.lower() for token in doc_aug]
33
- metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
34
-
35
- orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
36
- aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
37
- if len(orig_content) == 0:
38
- metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
39
- else:
40
- metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
41
-
42
- # Length ratio
43
- orig_words = len(original.split())
44
- aug_words = len(augmented.split())
45
- metrics['length_ratio'] = aug_words / max(orig_words, 1)
46
-
47
- return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/schema_guided_dialogue_processor.py DELETED
@@ -1,192 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import List, Dict, Optional, Any
3
- import json
4
- import glob
5
- from pathlib import Path
6
- from data_augmentation.pipeline_config import PipelineConfig
7
-
8
- @dataclass
9
- class SchemaGuidedDialogue:
10
- """
11
- Structured representation of a Schema-Guided dialogue
12
- """
13
- dialogue_id: str
14
- service_name: str
15
- service_description: Optional[str]
16
- schema: Dict[str, Any]
17
- turns: List[Dict[str, Any]]
18
- original_metadata: Dict[str, Any] = field(default_factory=dict)
19
-
20
- class SchemaGuidedProcessor:
21
- """
22
- Handles processing and preparation of Schema-Guided dataset dialogues
23
- """
24
- def __init__(self, config: PipelineConfig):
25
- self.config = config
26
- self.services = set()
27
- self.domains = set()
28
- self.schemas = {}
29
-
30
- def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]:
31
- """
32
- Load and parse Schema-Guided Dialogue dataset
33
-
34
- Args:
35
- dialogue_path: Path to the dialogue JSON file
36
- schema_path: Path to the schema JSON file
37
- """
38
- # Define schema and dialogue file patterns
39
- schema_file = Path(base_dir, "schema.json")
40
- dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json"))
41
-
42
- # Check for schema file
43
- if not schema_file.exists():
44
- raise FileNotFoundError(f"Schema file not found at {schema_file}")
45
-
46
- # Load schema
47
- self.schemas = self._load_schemas(schema_file)
48
-
49
- # Find and validate dialogue files
50
- dialogue_files = glob.glob(dialogue_files_pattern)
51
- if not dialogue_files:
52
- raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}")
53
-
54
- print(f"Found {len(dialogue_files)} dialogue files to process.")
55
-
56
- # Process all dialogues
57
- processed_dialogues = []
58
- for file_path in dialogue_files:
59
- with open(file_path, 'r', encoding='utf-8') as f:
60
- raw_dialogues = json.load(f)
61
-
62
- for dialogue in raw_dialogues:
63
- processed_dialogues.append(self._process_single_dialogue(dialogue))
64
-
65
- if max_examples and len(processed_dialogues) >= max_examples:
66
- break
67
-
68
- return processed_dialogues
69
-
70
- def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue:
71
- """
72
- Process a single dialogue JSON object into a SchemaGuidedDialogue object.
73
- """
74
- dialogue_id = str(dialogue.get("dialogue_id", ""))
75
- services = dialogue.get("services", [])
76
- service_name = services[0] if services else None
77
- schema = self.schemas.get(service_name, {})
78
- service_description = schema.get("description", "")
79
-
80
- # Process turns
81
- turns = self._process_turns(dialogue.get("turns", []))
82
-
83
- # Store metadata
84
- metadata = {
85
- "services": services,
86
- "original_id": dialogue_id,
87
- }
88
-
89
- return SchemaGuidedDialogue(
90
- dialogue_id=f"schema_guided_{dialogue_id}",
91
- service_name=service_name,
92
- service_description=service_description,
93
- schema=schema,
94
- turns=turns,
95
- original_metadata=metadata,
96
- )
97
-
98
- def _validate_schema(self, schema: Dict[str, Any]) -> bool:
99
- """
100
- Validate a schema
101
- """
102
- required_keys = {"service_name", "description", "slots", "intents"}
103
- missing_keys = required_keys - schema.keys()
104
- if missing_keys:
105
- print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}")
106
- return False
107
- return True
108
-
109
- def _load_schemas(self, schema_path: str) -> Dict[str, Any]:
110
- """
111
- Load and process service schemas
112
- """
113
- with open(schema_path, 'r', encoding='utf-8') as f:
114
- schemas = json.load(f)
115
-
116
- # Validate and index schemas
117
- return {
118
- schema["service_name"]: schema for schema in schemas if self._validate_schema(schema)
119
- }
120
-
121
- def _process_turns(self, turns: List[Dict]) -> List[Dict]:
122
- """
123
- Process dialogue turns into standardized format
124
- """
125
- processed_turns = []
126
-
127
- for turn in turns:
128
- try:
129
- # Map speakers to standard format
130
- speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user'
131
-
132
- # Extract utterance and clean it
133
- text = turn.get('utterance', '').strip()
134
-
135
- # Extract frames and dialogue acts
136
- frames = turn.get('frames', [])
137
- acts = []
138
- slots = []
139
-
140
- for frame in frames:
141
- if 'actions' in frame:
142
- acts.extend(frame['actions'])
143
- if 'slots' in frame:
144
- slots.extend(frame['slots'])
145
-
146
- # Create the processed turn
147
- processed_turn = {
148
- 'speaker': speaker,
149
- 'text': text,
150
- 'original_speaker': turn.get('speaker', ''),
151
- 'dialogue_acts': acts,
152
- 'slots': slots,
153
- 'metadata': {k: v for k, v in turn.items()
154
- if k not in {'speaker', 'utterance', 'frames'}}
155
- }
156
-
157
- processed_turns.append(processed_turn)
158
- except Exception as e:
159
- print(f"Error processing turn: {str(e)}")
160
- continue
161
-
162
- return processed_turns
163
-
164
- def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]:
165
- """
166
- Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline
167
- """
168
- pipeline_dialogues = []
169
-
170
- for dialogue in schema_dialogues:
171
- # Convert turns to the expected format
172
- processed_turns = [
173
- {"speaker": turn["speaker"], "text": turn["text"]}
174
- for turn in dialogue.turns if turn["text"].strip()
175
- ]
176
-
177
- # Create dialogue in pipeline format
178
- pipeline_dialogue = {
179
- 'dialogue_id': dialogue.dialogue_id,
180
- 'turns': processed_turns,
181
- 'metadata': {
182
- 'service_name': dialogue.service_name,
183
- 'service_description': dialogue.service_description,
184
- 'schema': dialogue.schema,
185
- **dialogue.original_metadata
186
- }
187
- }
188
-
189
- pipeline_dialogues.append(pipeline_dialogue)
190
-
191
- return pipeline_dialogues
192
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_augmentation_code/taskmaster_processor.py DELETED
@@ -1,192 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import List, Dict, Optional, Any
3
- import json
4
- import re
5
- from pathlib import Path
6
- from data_augmentation.pipeline_config import PipelineConfig
7
-
8
- @dataclass
9
- class TaskmasterDialogue:
10
- """
11
- Structured representation of a Taskmaster dialogue
12
- """
13
- conversation_id: str
14
- instruction_id: Optional[str]
15
- scenario: Optional[str]
16
- domain: Optional[str]
17
- turns: List[Dict[str, Any]]
18
- original_metadata: Dict[str, Any] = field(default_factory=dict)
19
-
20
- def __str__(self):
21
- return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
22
-
23
- def validate(self) -> bool:
24
- return bool(self.conversation_id and isinstance(self.turns, list))
25
-
26
- class TaskmasterProcessor:
27
- """
28
- Handles processing and preparation of Taskmaster dataset dialogues
29
- """
30
- config: PipelineConfig
31
- use_ontology: bool = False # Whether to load and use ontology
32
- ontology: Optional[Dict[str, Any]] = None # Holds ontology data if loaded
33
- domains: set = field(default_factory=set) # Tracks unique domains
34
- scenarios: set = field(default_factory=set) # Tracks unique scenarios
35
-
36
- def __init__(self, config: PipelineConfig, use_ontology: bool = False):
37
- self.config = config
38
- self.use_ontology = use_ontology
39
- self.ontology = None
40
- self.domains = set()
41
- self.scenarios = set()
42
-
43
- def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
44
- """
45
- Load and parse Taskmaster JSON dataset.
46
- Handles self-dialogs, woz-dialogs, and ontology files.
47
- """
48
- required_files = {
49
- "self-dialogs": "self-dialogs.json",
50
- "woz-dialogs": "woz-dialogs.json",
51
- "ontology": "ontology.json",
52
- }
53
-
54
- # Check for required files
55
- missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()]
56
- if missing_files:
57
- raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}")
58
-
59
- # load ontology
60
- ontology_path = Path(base_dir, required_files['ontology'])
61
- with open(ontology_path, 'r', encoding='utf-8') as f:
62
- self.ontology = json.load(f)
63
-
64
- processed_dialogues = []
65
- for file_key in ["self-dialogs", "woz-dialogs"]:
66
- file_path = Path(base_dir, required_files[file_key])
67
- with open(file_path, 'r', encoding='utf-8') as f:
68
- raw_data = json.load(f)
69
-
70
- for dialogue in raw_data:
71
- # Extract core dialogue components
72
- conversation_id = dialogue.get('conversation_id', '')
73
- instruction_id = dialogue.get('instruction_id', None)
74
-
75
- if 'utterances' in dialogue:
76
- turns = self._process_utterances(dialogue['utterances'])
77
- scenario = dialogue.get('scenario', '')
78
- domain = self._extract_domain(scenario)
79
- else:
80
- turns = []
81
- scenario = ''
82
- domain = ''
83
-
84
- # Store metadata
85
- metadata = {k: v for k, v in dialogue.items()
86
- if k not in {'conversation_id', 'instruction_id', 'utterances'}}
87
-
88
- # Create structured dialogue object
89
- processed_dialogue = TaskmasterDialogue(
90
- conversation_id=conversation_id,
91
- instruction_id=instruction_id,
92
- scenario=scenario,
93
- domain=domain,
94
- turns=turns,
95
- original_metadata=metadata
96
- )
97
-
98
- processed_dialogues.append(processed_dialogue)
99
-
100
- # Update domain and scenario tracking
101
- if domain:
102
- self.domains.add(domain)
103
- if scenario:
104
- self.scenarios.add(scenario)
105
-
106
- if max_examples and len(processed_dialogues) >= max_examples:
107
- break
108
-
109
- return processed_dialogues
110
-
111
- def _process_utterances(self, utterances: List[Dict]) -> List[Dict]:
112
- """
113
- Process utterances into a standardized format
114
- """
115
- processed_turns = []
116
-
117
- for utterance in utterances:
118
- # Map Taskmaster speaker roles to your expected format
119
- speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user'
120
-
121
- # Extract and clean the text
122
- text = utterance.get('text', '').strip()
123
-
124
- # Extract any segments or annotations if present
125
- segments = utterance.get('segments', [])
126
-
127
- # Create the processed turn
128
- turn = {
129
- 'speaker': speaker,
130
- 'text': text,
131
- 'original_speaker': utterance.get('speaker', ''),
132
- 'segments': segments,
133
- 'metadata': {k: v for k, v in utterance.items()
134
- if k not in {'speaker', 'text', 'segments'}}
135
- }
136
-
137
- processed_turns.append(turn)
138
-
139
- return processed_turns
140
-
141
- def _extract_domain(self, scenario: str) -> str:
142
- """
143
- Extract domain from scenario description
144
- """
145
- domain_patterns = {
146
- 'restaurant': r'\b(restaurant|dining|food|reservation)\b',
147
- 'movie': r'\b(movie|cinema|film|ticket)\b',
148
- 'ride_share': r'\b(ride|taxi|uber|lyft)\b',
149
- 'coffee': r'\b(coffee|café|cafe|starbucks)\b',
150
- 'pizza': r'\b(pizza|delivery|order food)\b',
151
- 'auto': r'\b(car|vehicle|repair|maintenance)\b',
152
- }
153
-
154
- scenario_lower = scenario.lower()
155
-
156
- for domain, pattern in domain_patterns.items():
157
- if re.search(pattern, scenario_lower):
158
- return domain
159
-
160
- return 'other'
161
-
162
- def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]:
163
- """
164
- Convert TaskmasterDialogues to the format expected by the ProcessingPipeline
165
- """
166
- pipeline_dialogues = []
167
-
168
- for dialogue in taskmaster_dialogues:
169
- # Convert turns to the expected format
170
- processed_turns = []
171
- for turn in dialogue.turns:
172
- if turn['text'].strip(): # Skip empty turns
173
- processed_turns.append({
174
- 'speaker': turn['speaker'],
175
- 'text': turn['text']
176
- })
177
-
178
- # Create dialogue in pipeline format
179
- pipeline_dialogue = {
180
- 'dialogue_id': dialogue.conversation_id,
181
- 'turns': processed_turns,
182
- 'metadata': {
183
- 'instruction_id': dialogue.instruction_id,
184
- 'scenario': dialogue.scenario,
185
- 'domain': dialogue.domain,
186
- **dialogue.original_metadata
187
- }
188
- }
189
-
190
- pipeline_dialogues.append(pipeline_dialogue)
191
-
192
- return pipeline_dialogues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prepare_data.py CHANGED
@@ -16,12 +16,12 @@ logger = config_logger(__name__)
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
 
18
  def main():
19
- MODELS_DIR = 'new_iteration/data_prep_iterative_models'
20
- PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
21
- CACHE_DIR = 'new_iteration/cache'
22
  TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
23
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
24
- TF_RECORD_DIR = 'new_iteration/training_data'
25
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
26
  JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
27
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
 
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
 
18
  def main():
19
+ MODELS_DIR = 'models'
20
+ PROCESSED_DATA_DIR = 'processed_outputs'
21
+ CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache')
22
  TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
23
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
24
+ TF_RECORD_DIR = 'training_data'
25
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
26
  JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
27
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
run_chatbot_validation.py CHANGED
@@ -44,7 +44,7 @@ def run_chatbot_validation():
44
  env = EnvironmentSetup()
45
  env.initialize()
46
 
47
- MODEL_DIR = "new_iteration/data_prep_iterative_models"
48
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
49
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
50
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
 
44
  env = EnvironmentSetup()
45
  env.initialize()
46
 
47
+ MODEL_DIR = "models"
48
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
49
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
50
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
new_iteration/run_taskmaster_processor.py → run_taskmaster_processor.py RENAMED
File without changes
new_iteration/taskmaster_processor.py → taskmaster_processor.py RENAMED
File without changes
tf_data_pipeline.py CHANGED
@@ -29,7 +29,7 @@ class TFDataPipeline:
29
  max_length: int = 512,
30
  neg_samples: int = 10,
31
  index_type: str = 'IndexFlatIP',
32
- faiss_index_file_path: str = 'new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
33
  nlist: int = 100,
34
  max_retries: int = 3
35
  ):
 
29
  max_length: int = 512,
30
  neg_samples: int = 10,
31
  index_type: str = 'IndexFlatIP',
32
+ faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
33
  nlist: int = 100,
34
  max_retries: int = 3
35
  ):
unused/build_faiss_index.py DELETED
@@ -1,160 +0,0 @@
1
- # import os
2
- # import json
3
- # from pathlib import Path
4
-
5
- # import faiss
6
- # import numpy as np
7
- # import tensorflow as tf
8
- # from transformers import AutoTokenizer, TFAutoModel
9
- # from tqdm.auto import tqdm
10
-
11
- # from chatbot_model import ChatbotConfig, EncoderModel
12
- # from tf_data_pipeline import TFDataPipeline
13
- # from logger_config import config_logger
14
-
15
- # logger = config_logger(__name__)
16
- # os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
-
18
- # def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
19
- # """
20
- # Perform a quick sanity check to ensure the model is loaded correctly.
21
- # """
22
- # sample_response = "This is a test response."
23
- # encoded_sample = tokenizer(
24
- # [sample_response],
25
- # padding=True,
26
- # truncation=True,
27
- # max_length=config.max_context_token_limit,
28
- # return_tensors='tf'
29
- # )
30
-
31
- # # Get embedding
32
- # sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
33
-
34
- # # Check shape
35
- # if sample_embedding.shape[1] != config.embedding_dim:
36
- # logger.error(
37
- # f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
38
- # f"got {sample_embedding.shape[1]}"
39
- # )
40
- # raise ValueError("Embedding dimension mismatch.")
41
- # else:
42
- # logger.info("Embedding dimension matches the configuration.")
43
-
44
- # # Check normalization
45
- # embedding_norm = np.linalg.norm(sample_embedding, axis=1)
46
- # if not np.allclose(embedding_norm, 1.0, atol=1e-5):
47
- # logger.error("Embeddings are not properly normalized.")
48
- # raise ValueError("Embeddings are not normalized.")
49
- # else:
50
- # logger.info("Embeddings are properly normalized.")
51
-
52
- # logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
53
-
54
- # def build_faiss_index():
55
- # """
56
- # Rebuild the FAISS index by:
57
- # 1) Loading your config.json
58
- # 2) Initializing encoder + loading submodule & custom weights
59
- # 3) Loading tokenizer from disk
60
- # 4) Creating a TFDataPipeline
61
- # 5) Setting the pipeline's response_pool from a JSON file
62
- # 6) Using pipeline.compute_and_index_response_embeddings()
63
- # 7) Saving the FAISS index
64
- # """
65
- # # Directories
66
- # MODELS_DIR = Path("models")
67
- # FAISS_DIR = MODELS_DIR / "faiss_indices"
68
- # FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
69
- # RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
70
- # TOKENIZER_DIR = MODELS_DIR / "tokenizer"
71
- # SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
72
- # CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
73
-
74
- # # 1) Load ChatbotConfig
75
- # config_path = MODELS_DIR / "config.json"
76
- # if config_path.exists():
77
- # with open(config_path, "r", encoding="utf-8") as f:
78
- # config_dict = json.load(f)
79
- # config = ChatbotConfig.from_dict(config_dict)
80
- # logger.info(f"Loaded ChatbotConfig from {config_path}")
81
- # else:
82
- # config = ChatbotConfig()
83
- # logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
84
-
85
- # # 2) Initialize the EncoderModel
86
- # encoder = EncoderModel(config=config)
87
- # logger.info("EncoderModel instantiated (empty).")
88
-
89
- # # Overwrite the submodule from 'shared_encoder' directory
90
- # if SHARED_ENCODER_DIR.exists():
91
- # logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
92
- # encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
93
- # logger.info("Loaded HF submodule into encoder.pretrained.")
94
- # else:
95
- # logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
96
-
97
- # # Build model once, then load custom weights (projection, etc.)
98
- # dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
99
- # _ = encoder(dummy_input, training=False) # builds the layers
100
-
101
- # if CUSTOM_WEIGHTS_PATH.exists():
102
- # logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
103
- # encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
104
- # logger.info("Custom top-level weights loaded successfully.")
105
- # else:
106
- # logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
107
-
108
- # # 3) Load tokenizer
109
- # if TOKENIZER_DIR.exists():
110
- # logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
111
- # tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
112
- # else:
113
- # logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
114
- # tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
115
-
116
- # # 4) Quick sanity check
117
- # sanity_check(encoder, tokenizer, config)
118
-
119
- # # 5) Prepare a TFDataPipeline
120
- # pipeline = TFDataPipeline(
121
- # config=config,
122
- # tokenizer=tokenizer,
123
- # encoder=encoder,
124
- # index_file_path=str(FAISS_INDEX_PATH),
125
- # response_pool=[],
126
- # max_length=config.max_context_token_limit,
127
- # query_embeddings_cache={},
128
- # neg_samples=config.neg_samples,
129
- # index_type='IndexFlatIP',
130
- # nlist=100,
131
- # max_retries=config.max_retries
132
- # )
133
-
134
- # # 6) Load the existing response pool
135
- # if not RESPONSES_PATH.exists():
136
- # logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
137
- # raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
138
-
139
- # with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
140
- # response_pool = json.load(f)
141
- # logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
142
-
143
- # pipeline.response_pool = response_pool # assign to pipeline
144
-
145
- # # 7) Build (or rebuild) the FAISS index from pipeline method
146
- # # This does all the compute-embeddings + index.add in one place
147
- # logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
148
- # pipeline.compute_and_index_response_embeddings()
149
-
150
- # # 8) Save the rebuilt FAISS index
151
- # pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
152
-
153
- # # Verify
154
- # loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
155
- # logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
156
-
157
- # return loaded_index, pipeline.response_pool
158
-
159
- # if __name__ == "__main__":
160
- # build_faiss_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unused/gpu_monitor.py DELETED
@@ -1,59 +0,0 @@
1
- import tensorflow as tf
2
- from typing import List, Dict, Optional
3
- from dataclasses import dataclass
4
-
5
- from tqdm.auto import tqdm
6
-
7
- @dataclass
8
- class GPUMemoryStats:
9
- total: int
10
- used: int
11
- free: int
12
-
13
- class GPUMemoryMonitor:
14
- """Monitor GPU memory usage with safe CPU fallback."""
15
- def __init__(self):
16
- self.has_gpu = False
17
- try:
18
- gpus = tf.config.list_physical_devices('GPU')
19
- self.has_gpu = len(gpus) > 0
20
- except:
21
- pass
22
-
23
- def get_memory_stats(self) -> Optional[GPUMemoryStats]:
24
- """Get current GPU memory statistics."""
25
- if not self.has_gpu:
26
- return None
27
-
28
- try:
29
- memory_info = tf.config.experimental.get_memory_info('GPU:0')
30
- return GPUMemoryStats(
31
- total=memory_info['peak'],
32
- used=memory_info['current'],
33
- free=memory_info['peak'] - memory_info['current']
34
- )
35
- except:
36
- return None
37
-
38
- def get_memory_usage(self) -> float:
39
- """Get current GPU memory usage as a percentage."""
40
- if not self.has_gpu:
41
- return 0.0
42
- stats = self.get_memory_stats()
43
- if stats is None or stats.total == 0:
44
- return 0.0
45
- return stats.used / stats.total
46
-
47
- def should_reduce_batch_size(self) -> bool:
48
- """Check if batch size should be reduced based on memory usage."""
49
- if not self.has_gpu:
50
- return False
51
- usage = self.get_memory_usage()
52
- return usage > 0.90
53
-
54
- def can_increase_batch_size(self) -> bool:
55
- """Check if batch size can be increased based on memory usage."""
56
- if not self.has_gpu:
57
- return True
58
- usage = self.get_memory_usage()
59
- return usage < 0.70