JoeArmani
commited on
Commit
·
4aec49f
1
Parent(s):
d7fc7a7
cleanup
Browse files- .gitignore +0 -5
- data_augmentation_code/augmentation_processing_pipeline.py +0 -321
- data_augmentation_code/back_translator.py +0 -87
- data_augmentation_code/dialogue_augmenter.py +0 -710
- data_augmentation_code/main.py +0 -112
- data_augmentation_code/paraphraser.py +0 -42
- data_augmentation_code/pipeline_config.py +0 -57
- data_augmentation_code/quality_metrics.py +0 -47
- data_augmentation_code/schema_guided_dialogue_processor.py +0 -192
- data_augmentation_code/taskmaster_processor.py +0 -192
- prepare_data.py +4 -4
- run_chatbot_validation.py +1 -1
- new_iteration/run_taskmaster_processor.py → run_taskmaster_processor.py +0 -0
- new_iteration/taskmaster_processor.py → taskmaster_processor.py +0 -0
- tf_data_pipeline.py +1 -1
- unused/build_faiss_index.py +0 -160
- unused/gpu_monitor.py +0 -59
.gitignore
CHANGED
@@ -182,10 +182,5 @@ training_data/*
|
|
182 |
!training_data/.gitkeep
|
183 |
augmented_dialogues.json
|
184 |
|
185 |
-
checkpoints_old_REMOVE/*
|
186 |
-
new_iteration/cache/*
|
187 |
-
new_iteration/data_prep_iterative_models/*
|
188 |
-
new_iteration/training_data/*
|
189 |
-
new_iteration/processed_outputs/*
|
190 |
raw_datasets/*
|
191 |
|
|
|
182 |
!training_data/.gitkeep
|
183 |
augmented_dialogues.json
|
184 |
|
|
|
|
|
|
|
|
|
|
|
185 |
raw_datasets/*
|
186 |
|
data_augmentation_code/augmentation_processing_pipeline.py
DELETED
@@ -1,321 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import List, Dict, Optional
|
4 |
-
import json
|
5 |
-
import re
|
6 |
-
import hashlib
|
7 |
-
import spacy
|
8 |
-
import torch
|
9 |
-
from tqdm import tqdm
|
10 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
11 |
-
from data_augmentation.dialogue_augmenter import DialogueAugmenter
|
12 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
-
from typing import Set
|
15 |
-
|
16 |
-
class AugmentationProcessingPipeline:
|
17 |
-
"""
|
18 |
-
Complete pipeline combining validation, optimization, and augmentation.
|
19 |
-
"""
|
20 |
-
|
21 |
-
def __init__(self, config: Optional[PipelineConfig] = None):
|
22 |
-
self.config = config or PipelineConfig()
|
23 |
-
self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])
|
24 |
-
self.augmenter = DialogueAugmenter(self.nlp, self.config)
|
25 |
-
self.num_threads = self.config.batch_size
|
26 |
-
self.cache_dir = Path("./cache")
|
27 |
-
self.cache_dir.mkdir(exist_ok=True)
|
28 |
-
self.output_dir = Path("processed_outputs")
|
29 |
-
self.output_dir.mkdir(exist_ok=True)
|
30 |
-
self.checkpoint_file = self.output_dir / "processing_checkpoint.json"
|
31 |
-
self.batch_size = self.config.batch_size
|
32 |
-
self.use_gpu = torch.cuda.is_available()
|
33 |
-
self.batch_size = 32 if self.use_gpu else 8
|
34 |
-
self.use_multiprocessing = not self.use_gpu
|
35 |
-
|
36 |
-
# Counters for grouping batches
|
37 |
-
self.batch_counter = 0 # Count batches since last group combine
|
38 |
-
self.batch_group_number = 0 # How many groups have been created
|
39 |
-
|
40 |
-
if self.config.debug:
|
41 |
-
print(f"ProcessingPipeline initialized with:")
|
42 |
-
print(f"- GPU available: {self.use_gpu}")
|
43 |
-
print(f"- Batch size: {self.batch_size}")
|
44 |
-
print(f"- Using multiprocessing: {self.use_multiprocessing}")
|
45 |
-
|
46 |
-
def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path:
|
47 |
-
"""Save a batch of results to a separate JSON file"""
|
48 |
-
batch_file = self.output_dir / f"batch_{batch_num:04d}.json"
|
49 |
-
with open(batch_file, 'w') as f:
|
50 |
-
json.dump(batch_results, f)
|
51 |
-
return batch_file
|
52 |
-
|
53 |
-
def _load_checkpoint(self) -> set:
|
54 |
-
"""Load set of processed dialogue IDs from checkpoint"""
|
55 |
-
if self.checkpoint_file.exists():
|
56 |
-
with open(self.checkpoint_file, 'r') as f:
|
57 |
-
return set(json.load(f))
|
58 |
-
return set()
|
59 |
-
|
60 |
-
def _update_checkpoint(self, processed_ids: set):
|
61 |
-
"""Update checkpoint with newly processed IDs"""
|
62 |
-
with open(self.checkpoint_file, 'w') as f:
|
63 |
-
json.dump(list(processed_ids), f)
|
64 |
-
|
65 |
-
def _process_batch(self, batch: List[Dict]) -> List[Dict]:
|
66 |
-
"""Process batch with optimized model calls"""
|
67 |
-
results = []
|
68 |
-
try:
|
69 |
-
if self.use_gpu:
|
70 |
-
results = self.augmenter.process_batch(batch)
|
71 |
-
else:
|
72 |
-
# Collect all texts that need processing
|
73 |
-
all_texts = []
|
74 |
-
text_to_dialogue_map = {}
|
75 |
-
for dialogue in batch:
|
76 |
-
for turn in dialogue['turns']:
|
77 |
-
all_texts.append(turn['text'])
|
78 |
-
text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
|
79 |
-
|
80 |
-
# Batch process embeddings
|
81 |
-
self.augmenter._compute_batch_embeddings(all_texts)
|
82 |
-
|
83 |
-
# Process dialogues with cached embeddings
|
84 |
-
for dialogue in batch:
|
85 |
-
try:
|
86 |
-
augmented = self.augmenter.augment_dialogue(dialogue)
|
87 |
-
results.extend(augmented)
|
88 |
-
except Exception as e:
|
89 |
-
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
|
90 |
-
continue
|
91 |
-
except Exception as e:
|
92 |
-
print(f"Error processing batch: {str(e)}")
|
93 |
-
return results
|
94 |
-
|
95 |
-
def _combine_intermediate_batches(self):
|
96 |
-
"""
|
97 |
-
Combine all current batch_*.json files into a single batch_group_XXXX.json file,
|
98 |
-
then remove the batch_*.json files.
|
99 |
-
"""
|
100 |
-
batch_files = sorted(self.output_dir.glob("batch_*.json"))
|
101 |
-
if not batch_files:
|
102 |
-
return None # No files to combine
|
103 |
-
|
104 |
-
combined_data = []
|
105 |
-
for bf in batch_files:
|
106 |
-
with open(bf, 'r') as f:
|
107 |
-
combined_data.extend(json.load(f))
|
108 |
-
bf.unlink() # Remove the individual batch file after reading
|
109 |
-
|
110 |
-
self.batch_group_number += 1
|
111 |
-
group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json"
|
112 |
-
with open(group_file, 'w') as f:
|
113 |
-
json.dump(combined_data, f)
|
114 |
-
return group_file
|
115 |
-
|
116 |
-
def combine_results(self) -> Path:
|
117 |
-
"""Combine all batch_group_*.json files into final output"""
|
118 |
-
all_results = []
|
119 |
-
group_files = sorted(self.output_dir.glob("batch_group_*.json"))
|
120 |
-
|
121 |
-
print(f"Combining {len(group_files)} group files...")
|
122 |
-
for group_file in tqdm(group_files):
|
123 |
-
with open(group_file, 'r') as f:
|
124 |
-
group_data = json.load(f)
|
125 |
-
all_results.extend(group_data)
|
126 |
-
|
127 |
-
# Save combined results
|
128 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
129 |
-
final_output = self.output_dir / f"augmented_dataset_{timestamp}.json"
|
130 |
-
with open(final_output, 'w') as f:
|
131 |
-
json.dump(all_results, f)
|
132 |
-
|
133 |
-
if self.config.debug:
|
134 |
-
print(f"Combined {len(all_results)} dialogues into {final_output}")
|
135 |
-
|
136 |
-
return final_output
|
137 |
-
|
138 |
-
def process_dataset(self, dialogues: List[Dict]) -> Path:
|
139 |
-
"""Process dataset with hardware-appropriate optimizations and progress tracking"""
|
140 |
-
processed_ids = self._load_checkpoint()
|
141 |
-
|
142 |
-
# Filter out already processed dialogues
|
143 |
-
remaining_dialogues = [d for d in dialogues
|
144 |
-
if d['dialogue_id'] not in processed_ids]
|
145 |
-
|
146 |
-
total_dialogues = len(dialogues)
|
147 |
-
remaining_count = len(remaining_dialogues)
|
148 |
-
processed_count = total_dialogues - remaining_count
|
149 |
-
|
150 |
-
print("\nDataset Processing Status:")
|
151 |
-
print(f"Total dialogues in dataset: {total_dialogues}")
|
152 |
-
print(f"Previously processed: {processed_count}")
|
153 |
-
print(f"Remaining to process: {remaining_count}")
|
154 |
-
print("-" * 50)
|
155 |
-
|
156 |
-
# Process in batches with progress bar
|
157 |
-
for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size),
|
158 |
-
desc="Processing batches",
|
159 |
-
total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size):
|
160 |
-
batch = remaining_dialogues[batch_num:batch_num + self.batch_size]
|
161 |
-
current_position = processed_count + batch_num + len(batch)
|
162 |
-
|
163 |
-
total_progress = (current_position / total_dialogues) * 100
|
164 |
-
|
165 |
-
print('\033[K', end='')
|
166 |
-
print(f"Processing: {current_position}/{total_dialogues} dialogues "
|
167 |
-
f"({total_progress:.1f}% complete)")
|
168 |
-
print(f"Current batch: {batch_num//self.batch_size + 1} of "
|
169 |
-
f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}")
|
170 |
-
print("-" * 50)
|
171 |
-
|
172 |
-
# Process batch
|
173 |
-
batch_results = self._process_batch(batch)
|
174 |
-
|
175 |
-
if batch_results:
|
176 |
-
self._save_batch(batch_results, batch_num)
|
177 |
-
batch_ids = {d['dialogue_id'] for d in batch}
|
178 |
-
processed_ids.update(batch_ids)
|
179 |
-
self._update_checkpoint(processed_ids)
|
180 |
-
|
181 |
-
# Increment batch counter and combine if needed
|
182 |
-
self.batch_counter += 1
|
183 |
-
if self.batch_counter == 25:
|
184 |
-
# Combine these 25 batches into a group file
|
185 |
-
self._combine_intermediate_batches()
|
186 |
-
self.batch_counter = 0 # Reset counter after grouping
|
187 |
-
|
188 |
-
# If there are leftover batches less than 25
|
189 |
-
# combine them into one final group file
|
190 |
-
if self.batch_counter > 0:
|
191 |
-
self._combine_intermediate_batches()
|
192 |
-
self.batch_counter = 0
|
193 |
-
|
194 |
-
print("\n" + "-" * 50)
|
195 |
-
print("Processing complete. Combining results...")
|
196 |
-
return self.combine_results()
|
197 |
-
|
198 |
-
def cleanup(self):
|
199 |
-
"""Clean up intermediate files after successful processing"""
|
200 |
-
# Clean up any leftover batch files (should not exist if logic is correct)
|
201 |
-
batch_files = list(self.output_dir.glob("batch_*.json"))
|
202 |
-
for file in batch_files:
|
203 |
-
try:
|
204 |
-
file.unlink()
|
205 |
-
except Exception as e:
|
206 |
-
print(f"Error deleting {file}: {e}")
|
207 |
-
|
208 |
-
# We can also remove batch_group_*.json if desired after final combine
|
209 |
-
# but that might not be necessary if we want to keep them.
|
210 |
-
|
211 |
-
if self.checkpoint_file.exists():
|
212 |
-
try:
|
213 |
-
self.checkpoint_file.unlink()
|
214 |
-
except Exception as e:
|
215 |
-
print(f"Error deleting checkpoint file: {e}")
|
216 |
-
|
217 |
-
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
|
218 |
-
"""
|
219 |
-
Deduplicate dialogues based on text similarity.
|
220 |
-
"""
|
221 |
-
print("Deduplicating dialogues...")
|
222 |
-
if not dialogues:
|
223 |
-
print("No dialogues provided for deduplication.")
|
224 |
-
return []
|
225 |
-
|
226 |
-
# Combine turns into single text for similarity comparison
|
227 |
-
texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues]
|
228 |
-
tfidf = TfidfVectorizer().fit_transform(texts)
|
229 |
-
sim_matrix = cosine_similarity(tfidf)
|
230 |
-
|
231 |
-
unique_indices = set()
|
232 |
-
for i, row in enumerate(sim_matrix):
|
233 |
-
if i not in unique_indices:
|
234 |
-
similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i]
|
235 |
-
unique_indices.add(i)
|
236 |
-
unique_indices.difference_update(similar_indices)
|
237 |
-
|
238 |
-
deduplicated_dialogues = [dialogues[i] for i in unique_indices]
|
239 |
-
|
240 |
-
print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.")
|
241 |
-
return deduplicated_dialogues
|
242 |
-
|
243 |
-
def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]:
|
244 |
-
"""
|
245 |
-
Validate and clean a single dialogue.
|
246 |
-
"""
|
247 |
-
try:
|
248 |
-
# Check required fields
|
249 |
-
if not all(field in dialogue for field in self.config.required_fields):
|
250 |
-
return None
|
251 |
-
|
252 |
-
# Process turns
|
253 |
-
cleaned_turns = []
|
254 |
-
for turn in dialogue['turns']:
|
255 |
-
if self._validate_turn(turn):
|
256 |
-
cleaned_turn = {
|
257 |
-
'speaker': turn['speaker'],
|
258 |
-
'text': self._clean_text(turn['text'])
|
259 |
-
}
|
260 |
-
cleaned_turns.append(cleaned_turn)
|
261 |
-
|
262 |
-
if cleaned_turns:
|
263 |
-
return {
|
264 |
-
'dialogue_id': dialogue['dialogue_id'],
|
265 |
-
'turns': cleaned_turns
|
266 |
-
}
|
267 |
-
|
268 |
-
return None
|
269 |
-
|
270 |
-
except Exception as e:
|
271 |
-
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
|
272 |
-
return None
|
273 |
-
|
274 |
-
def _validate_turn(self, turn: Dict) -> bool:
|
275 |
-
"""
|
276 |
-
Validate a single speaking turn.
|
277 |
-
"""
|
278 |
-
return (
|
279 |
-
turn['speaker'] in self.config.allowed_speakers and
|
280 |
-
self.config.min_length <= len(turn['text']) <= self.config.max_length
|
281 |
-
)
|
282 |
-
|
283 |
-
def _clean_text(self, text: str) -> str:
|
284 |
-
"""
|
285 |
-
Clean and normalize text.
|
286 |
-
"""
|
287 |
-
# Remove excessive whitespace
|
288 |
-
text = re.sub(r'\s+', ' ', text.strip())
|
289 |
-
|
290 |
-
# Normalize quotes and apostrophes
|
291 |
-
text = re.sub(r'[’´`]', "'", text)
|
292 |
-
text = re.sub(r'[“”]', '"', text)
|
293 |
-
|
294 |
-
# Remove control characters
|
295 |
-
text = "".join(char for char in text if ord(char) >= 32 or char == '\n')
|
296 |
-
|
297 |
-
return text
|
298 |
-
|
299 |
-
def _process_validation(self, items: List, func, description: str) -> List:
|
300 |
-
"""
|
301 |
-
Process items sequentially with a progress bar.
|
302 |
-
"""
|
303 |
-
results = []
|
304 |
-
print(f"Starting {description}")
|
305 |
-
for item in tqdm(items, desc=description):
|
306 |
-
try:
|
307 |
-
result = func(item)
|
308 |
-
if result is not None:
|
309 |
-
results.append(result)
|
310 |
-
except Exception as e:
|
311 |
-
print(f"Error processing item: {str(e)}")
|
312 |
-
print(f"Completed {description}. Processed {len(results)} items successfully")
|
313 |
-
return results
|
314 |
-
|
315 |
-
def _get_cache_path(self, data: List[Dict]) -> Path:
|
316 |
-
"""
|
317 |
-
Generate cache file path based on data hash.
|
318 |
-
"""
|
319 |
-
data_str = json.dumps(data, sort_keys=True)
|
320 |
-
hash_value = hashlib.md5(data_str.encode()).hexdigest()
|
321 |
-
return self.cache_dir / f"cache_{hash_value}.pkl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/back_translator.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
from transformers import (
|
2 |
-
MarianMTModel,
|
3 |
-
MarianTokenizer,
|
4 |
-
)
|
5 |
-
|
6 |
-
# Retained for reference but removed from the final code.
|
7 |
-
# This method did not seem helpful for this retrieval-based chatbot.
|
8 |
-
class BackTranslator:
|
9 |
-
"""
|
10 |
-
Perform Back-translation with pivot language. English -> German -> Spanish -> English
|
11 |
-
Args:
|
12 |
-
source_lang: Source language (default: 'en')
|
13 |
-
pivot_lang: Pivot language (default: 'de')
|
14 |
-
target_lang: Target language (default: 'es')
|
15 |
-
Examples:
|
16 |
-
back_translator = BackTranslator()
|
17 |
-
back_translator.back_translate("Hello, how are you?")
|
18 |
-
"""
|
19 |
-
def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
|
20 |
-
# Forward (English to German)
|
21 |
-
pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
|
22 |
-
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
|
23 |
-
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
|
24 |
-
|
25 |
-
# Pivot translation (German to Spanish)
|
26 |
-
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
|
27 |
-
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
|
28 |
-
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
|
29 |
-
|
30 |
-
# Backward (Spanish to English)
|
31 |
-
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
|
32 |
-
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
|
33 |
-
self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
|
34 |
-
|
35 |
-
# Set models to eval mode
|
36 |
-
self.model_pivot_forward.eval()
|
37 |
-
self.model_pivot_backward.eval()
|
38 |
-
self.model_backward.eval()
|
39 |
-
|
40 |
-
def back_translate(self, text, device=None):
|
41 |
-
try:
|
42 |
-
# Move models to device if specified
|
43 |
-
if device is not None:
|
44 |
-
self.model_pivot_forward = self.model_pivot_forward.to(device)
|
45 |
-
self.model_pivot_backward = self.model_pivot_backward.to(device)
|
46 |
-
self.model_backward = self.model_backward.to(device)
|
47 |
-
|
48 |
-
# Forward translation (English to German)
|
49 |
-
encoded_pivot = self.tokenizer_pivot_forward([text], padding=True,
|
50 |
-
truncation=True, return_tensors='pt')
|
51 |
-
if device is not None:
|
52 |
-
encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
|
53 |
-
|
54 |
-
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
|
55 |
-
if device is not None:
|
56 |
-
generated_pivot = generated_pivot.cpu()
|
57 |
-
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot,
|
58 |
-
skip_special_tokens=True)[0]
|
59 |
-
|
60 |
-
# Pivot translation (German to Spanish)
|
61 |
-
encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True,
|
62 |
-
truncation=True, return_tensors='pt')
|
63 |
-
if device is not None:
|
64 |
-
encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
|
65 |
-
|
66 |
-
retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
|
67 |
-
if device is not None:
|
68 |
-
retranslated_pivot = retranslated_pivot.cpu()
|
69 |
-
tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot,
|
70 |
-
skip_special_tokens=True)[0]
|
71 |
-
|
72 |
-
# Backward translation (Spanish to English)
|
73 |
-
encoded_back = self.tokenizer_backward([tgt_text_back], padding=True,
|
74 |
-
truncation=True, return_tensors='pt')
|
75 |
-
if device is not None:
|
76 |
-
encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
|
77 |
-
|
78 |
-
retranslated = self.model_backward.generate(**encoded_back)
|
79 |
-
if device is not None:
|
80 |
-
retranslated = retranslated.cpu()
|
81 |
-
src_text = self.tokenizer_backward.batch_decode(retranslated,
|
82 |
-
skip_special_tokens=True)[0]
|
83 |
-
|
84 |
-
return src_text
|
85 |
-
except Exception as e:
|
86 |
-
print(f"Error in back translation: {e}")
|
87 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/dialogue_augmenter.py
DELETED
@@ -1,710 +0,0 @@
|
|
1 |
-
from typing import Dict, List
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import tensorflow as tf
|
5 |
-
import tensorflow_hub as hub
|
6 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
-
from data_augmentation.quality_metrics import QualityMetrics
|
8 |
-
from data_augmentation.paraphraser import Paraphraser
|
9 |
-
import nlpaug.augmenter.word as naw
|
10 |
-
from functools import lru_cache
|
11 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
-
|
13 |
-
class DialogueAugmenter:
|
14 |
-
"""
|
15 |
-
Optimized dialogue augmentation with quality control and complexity management.
|
16 |
-
"""
|
17 |
-
def __init__(self, nlp, config: PipelineConfig):
|
18 |
-
self.nlp = nlp
|
19 |
-
self.config = config
|
20 |
-
|
21 |
-
# Detect hardware and set appropriate batch sizes and optimization strategy
|
22 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
-
self.use_gpu = torch.cuda.is_available()
|
24 |
-
|
25 |
-
if self.config.debug:
|
26 |
-
print(f"Using device: {self.device}")
|
27 |
-
if self.use_gpu:
|
28 |
-
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
29 |
-
|
30 |
-
|
31 |
-
self.quality_metrics = QualityMetrics(config)
|
32 |
-
self.semantic_similarity_threshold = 0.75
|
33 |
-
|
34 |
-
# Load model
|
35 |
-
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
36 |
-
|
37 |
-
# Initialize augmentation models based on hardware
|
38 |
-
self._initialize_augmentation_models()
|
39 |
-
|
40 |
-
# Initialize caches
|
41 |
-
self.embedding_cache = {}
|
42 |
-
|
43 |
-
# GPU memory management if available
|
44 |
-
if self.use_gpu:
|
45 |
-
gpus = tf.config.list_physical_devices('GPU')
|
46 |
-
if gpus:
|
47 |
-
try:
|
48 |
-
for gpu in gpus:
|
49 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
50 |
-
except RuntimeError as e:
|
51 |
-
print(e)
|
52 |
-
|
53 |
-
def _initialize_augmentation_models(self):
|
54 |
-
"""Initialize augmentation models with appropriate device settings"""
|
55 |
-
# Advanced augmentation techniques
|
56 |
-
self.paraphraser = Paraphraser()
|
57 |
-
if self.use_gpu:
|
58 |
-
# Move model to GPU if available
|
59 |
-
self.paraphraser.model = self.paraphraser.model.to(self.device)
|
60 |
-
|
61 |
-
# Basic augmentation techniques
|
62 |
-
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
63 |
-
|
64 |
-
self.augmenters = {
|
65 |
-
'advanced': [
|
66 |
-
self.paraphraser,
|
67 |
-
],
|
68 |
-
'basic': [
|
69 |
-
('synonym', self.word_augmenter),
|
70 |
-
]
|
71 |
-
}
|
72 |
-
|
73 |
-
@lru_cache(maxsize=1024)
|
74 |
-
def _compute_embedding(self, text: str) -> np.ndarray:
|
75 |
-
"""Cached computation of text embedding"""
|
76 |
-
if text in self.embedding_cache:
|
77 |
-
return self.embedding_cache[text]
|
78 |
-
embedding = self.use_model([text])[0].numpy()
|
79 |
-
self.embedding_cache[text] = embedding
|
80 |
-
return embedding
|
81 |
-
|
82 |
-
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
|
83 |
-
"""Compute embeddings for multiple texts at once with hardware optimization"""
|
84 |
-
# Check cache first
|
85 |
-
uncached_texts = [t for t in texts if t not in self.embedding_cache]
|
86 |
-
if uncached_texts:
|
87 |
-
embeddings = self.use_model(uncached_texts).numpy()
|
88 |
-
# Update cache
|
89 |
-
for text, embedding in zip(uncached_texts, embeddings):
|
90 |
-
self.embedding_cache[text] = embedding
|
91 |
-
|
92 |
-
# Return all embeddings (from cache or newly computed)
|
93 |
-
return np.array([self.embedding_cache[t] for t in texts])
|
94 |
-
|
95 |
-
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
96 |
-
"""
|
97 |
-
Preliminary quality check while maintaining reasonable pass rates
|
98 |
-
"""
|
99 |
-
if self.config.debug:
|
100 |
-
print(f"\nQuick check for variation: {variation}")
|
101 |
-
|
102 |
-
orig_len = len(original.split())
|
103 |
-
var_len = len(variation.split())
|
104 |
-
|
105 |
-
# For very short texts (<= 3 words), still allow more variation
|
106 |
-
if orig_len <= 3:
|
107 |
-
if var_len > orig_len * 3:
|
108 |
-
if self.config.debug:
|
109 |
-
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
110 |
-
return False
|
111 |
-
else:
|
112 |
-
if var_len > orig_len * 2:
|
113 |
-
if self.config.debug:
|
114 |
-
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
115 |
-
return False
|
116 |
-
|
117 |
-
# Adjust content overlap check based on length
|
118 |
-
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
|
119 |
-
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
120 |
-
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
121 |
-
|
122 |
-
# If very short turn (less than 5 words), skip the content overlap check
|
123 |
-
if orig_len >= 5:
|
124 |
-
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
|
125 |
-
if content_overlap < 0.2:
|
126 |
-
if self.config.debug:
|
127 |
-
print(f"Failed content check: overlap {content_overlap:.2f}")
|
128 |
-
return False
|
129 |
-
else:
|
130 |
-
if self.config.debug:
|
131 |
-
print("Short turn detected (<5 words), skipping content overlap check")
|
132 |
-
|
133 |
-
if self.config.debug:
|
134 |
-
print("Passed all quick checks")
|
135 |
-
return True
|
136 |
-
|
137 |
-
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
|
138 |
-
"""
|
139 |
-
Filter variations using batched computations with detailed logging
|
140 |
-
"""
|
141 |
-
if not variations:
|
142 |
-
return []
|
143 |
-
|
144 |
-
if self.config.debug:
|
145 |
-
print(f"\nStarting filtration of {len(variations)} variations")
|
146 |
-
print(f"Context length: {len(context)}")
|
147 |
-
print(f"Original turn: {original_turn}")
|
148 |
-
|
149 |
-
words = original_turn.split()
|
150 |
-
orig_len = len(words)
|
151 |
-
|
152 |
-
# If very short text, consider adjusting thresholds
|
153 |
-
is_very_short = orig_len < 5
|
154 |
-
|
155 |
-
if len(words) < 3:
|
156 |
-
if self.config.debug:
|
157 |
-
print("Short text detected, using predefined variations")
|
158 |
-
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
|
159 |
-
return [var['text'] for var in short_text_variations]
|
160 |
-
|
161 |
-
# If this is the first turn (no context), be more lenient
|
162 |
-
if not context:
|
163 |
-
preliminary_filtered = variations
|
164 |
-
if self.config.debug:
|
165 |
-
print("First turn - skipping preliminary filtering")
|
166 |
-
else:
|
167 |
-
# Quick preliminary filtering against original turn
|
168 |
-
preliminary_filtered = []
|
169 |
-
for var in variations:
|
170 |
-
passed = self._quick_quality_check(var, original_turn)
|
171 |
-
if self.config.debug:
|
172 |
-
print(f"\nVariation: {var}")
|
173 |
-
print(f"Passed quick check: {passed}")
|
174 |
-
if passed:
|
175 |
-
preliminary_filtered.append(var)
|
176 |
-
|
177 |
-
if self.config.debug:
|
178 |
-
print(f"Variations after quick check: {len(preliminary_filtered)}")
|
179 |
-
|
180 |
-
if not preliminary_filtered:
|
181 |
-
return []
|
182 |
-
|
183 |
-
# Compute embeddings for original and variations
|
184 |
-
original_embedding = self._compute_embedding(original_turn)
|
185 |
-
variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)
|
186 |
-
|
187 |
-
# Compute similarities
|
188 |
-
sims = cosine_similarity([original_embedding], variation_embeddings)[0]
|
189 |
-
|
190 |
-
# If very short turn, slightly lower the semantic similarity threshold
|
191 |
-
dynamic_sem_threshold = self.semantic_similarity_threshold
|
192 |
-
if is_very_short:
|
193 |
-
dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)
|
194 |
-
|
195 |
-
# Filter by semantic similarity threshold
|
196 |
-
refined_filtered = []
|
197 |
-
for var, sim in zip(preliminary_filtered, sims):
|
198 |
-
if sim >= dynamic_sem_threshold:
|
199 |
-
refined_filtered.append(var)
|
200 |
-
else:
|
201 |
-
if self.config.debug:
|
202 |
-
print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")
|
203 |
-
|
204 |
-
if not refined_filtered:
|
205 |
-
return []
|
206 |
-
|
207 |
-
# Relax context coherence thresholds further if desired
|
208 |
-
# We already have min_similarity = 0.1, min_coherence = 0.05
|
209 |
-
# Let's lower them slightly more if the turn is very short:
|
210 |
-
if is_very_short:
|
211 |
-
min_similarity = 0.05
|
212 |
-
min_coherence = 0.02
|
213 |
-
else:
|
214 |
-
min_similarity = 0.1
|
215 |
-
min_coherence = 0.05
|
216 |
-
|
217 |
-
# Only use last turn for coherence
|
218 |
-
recent_context = [context[-1]] if context else []
|
219 |
-
context_text = ' '.join(recent_context) if recent_context else ''
|
220 |
-
|
221 |
-
if context_text:
|
222 |
-
if self.config.debug:
|
223 |
-
print(f"\nContext text: {context_text}")
|
224 |
-
|
225 |
-
all_texts = [context_text] + refined_filtered
|
226 |
-
all_embeddings = self._compute_batch_embeddings(all_texts)
|
227 |
-
|
228 |
-
context_embedding = all_embeddings[0]
|
229 |
-
variation_embeddings = all_embeddings[1:]
|
230 |
-
|
231 |
-
# Vectorized similarity computation
|
232 |
-
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
|
233 |
-
|
234 |
-
# Response coherence check
|
235 |
-
if recent_context:
|
236 |
-
prev_embedding = self._compute_embedding(recent_context[-1])
|
237 |
-
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
|
238 |
-
else:
|
239 |
-
response_coherence = np.ones_like(context_similarities)
|
240 |
-
|
241 |
-
filtered_variations = []
|
242 |
-
for i, (variation, sim, coh) in enumerate(zip(
|
243 |
-
refined_filtered, context_similarities, response_coherence)):
|
244 |
-
combined_score = (
|
245 |
-
self.config.context_similarity_weight * abs(sim) +
|
246 |
-
self.config.response_coherence_weight * abs(coh)
|
247 |
-
)
|
248 |
-
|
249 |
-
if self.config.debug:
|
250 |
-
print(f"\nVariation: {variation}")
|
251 |
-
print(f"Context similarity: {sim:.3f}")
|
252 |
-
print(f"Response coherence: {coh:.3f}")
|
253 |
-
print(f"Combined score: {combined_score:.3f}")
|
254 |
-
|
255 |
-
# Accept if EITHER score is good enough
|
256 |
-
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
|
257 |
-
filtered_variations.append(variation)
|
258 |
-
if self.config.debug:
|
259 |
-
print("ACCEPTED")
|
260 |
-
else:
|
261 |
-
if self.config.debug:
|
262 |
-
print("REJECTED")
|
263 |
-
|
264 |
-
# If we have enough variations, stop
|
265 |
-
if len(filtered_variations) >= self.config.max_variations_per_turn:
|
266 |
-
break
|
267 |
-
else:
|
268 |
-
filtered_variations = refined_filtered[:self.config.max_variations_per_turn]
|
269 |
-
|
270 |
-
if self.config.debug:
|
271 |
-
print(f"\nFinal filtered variations: {len(filtered_variations)}")
|
272 |
-
|
273 |
-
return filtered_variations
|
274 |
-
|
275 |
-
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
|
276 |
-
"""
|
277 |
-
Generate variations progressively until we have enough good ones.
|
278 |
-
Adjust paraphraser parameters for closer paraphrases as needed.
|
279 |
-
"""
|
280 |
-
variations = set()
|
281 |
-
|
282 |
-
if self.config.debug:
|
283 |
-
print(f"\nAttempting to generate {needed} variations for text: {text}")
|
284 |
-
|
285 |
-
# Fine-tune paraphraser here if needed: fewer beams, less diversity already done
|
286 |
-
for augmenter in self.augmenters['advanced']:
|
287 |
-
if len(variations) >= needed:
|
288 |
-
break
|
289 |
-
|
290 |
-
try:
|
291 |
-
if isinstance(augmenter, Paraphraser):
|
292 |
-
if self.config.debug:
|
293 |
-
print("Trying paraphrase augmentation...")
|
294 |
-
new_vars = augmenter.paraphrase(
|
295 |
-
text,
|
296 |
-
num_return_sequences=needed-len(variations),
|
297 |
-
device=self.device if self.use_gpu else None,
|
298 |
-
num_beams=4, # even fewer beams for more faithful paraphrases
|
299 |
-
num_beam_groups=1,
|
300 |
-
diversity_penalty=0.0
|
301 |
-
)
|
302 |
-
if self.config.debug:
|
303 |
-
print(f"Paraphraser generated {len(new_vars)} variations")
|
304 |
-
|
305 |
-
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
306 |
-
variations.update(valid_vars)
|
307 |
-
|
308 |
-
if self.config.debug:
|
309 |
-
print(f"Current unique variations: {len(variations)}")
|
310 |
-
|
311 |
-
except Exception as e:
|
312 |
-
print(f"Error in advanced augmentation: {str(e)}")
|
313 |
-
continue
|
314 |
-
|
315 |
-
# Try basic augmenters if needed
|
316 |
-
if len(variations) < needed:
|
317 |
-
if self.config.debug:
|
318 |
-
print("Not enough variations, trying basic augmenters...")
|
319 |
-
|
320 |
-
for aug_type, augmenter in self.augmenters['basic']:
|
321 |
-
if len(variations) >= needed:
|
322 |
-
break
|
323 |
-
|
324 |
-
try:
|
325 |
-
if self.config.debug:
|
326 |
-
print(f"Trying {aug_type} augmentation...")
|
327 |
-
|
328 |
-
new_vars = augmenter.augment(text, n=2)
|
329 |
-
if isinstance(new_vars, list):
|
330 |
-
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
331 |
-
variations.update(valid_vars)
|
332 |
-
else:
|
333 |
-
if new_vars.strip() and new_vars != text:
|
334 |
-
variations.add(new_vars)
|
335 |
-
|
336 |
-
if self.config.debug:
|
337 |
-
print(f"After {aug_type}, total variations: {len(variations)}")
|
338 |
-
|
339 |
-
except Exception as e:
|
340 |
-
print(f"Error in {aug_type} augmentation: {str(e)}")
|
341 |
-
continue
|
342 |
-
|
343 |
-
variations_list = list(variations)
|
344 |
-
|
345 |
-
if self.config.debug:
|
346 |
-
print(f"Final number of variations generated: {len(variations_list)}")
|
347 |
-
if not variations_list:
|
348 |
-
print("WARNING: No variations were generated!")
|
349 |
-
|
350 |
-
return variations_list
|
351 |
-
|
352 |
-
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
|
353 |
-
"""
|
354 |
-
Create augmented versions of the dialogue with optimized processing
|
355 |
-
"""
|
356 |
-
# Early dialogue length check
|
357 |
-
original_length = len(dialogue['turns'])
|
358 |
-
if original_length > self.config.max_turns_per_dialogue:
|
359 |
-
if self.config.debug:
|
360 |
-
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
|
361 |
-
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
|
362 |
-
|
363 |
-
turn_variations = []
|
364 |
-
context = []
|
365 |
-
|
366 |
-
# Process each turn with progressive generation
|
367 |
-
for turn in dialogue['turns']:
|
368 |
-
original_text = turn['text'] # Store original turn text
|
369 |
-
variations = self._generate_variations_progressive(
|
370 |
-
original_text,
|
371 |
-
self.config.max_variations_per_turn
|
372 |
-
)
|
373 |
-
|
374 |
-
# Batch filter variations with original text
|
375 |
-
filtered_variations = self._filter_variations_batch(
|
376 |
-
variations,
|
377 |
-
context,
|
378 |
-
original_text # Pass the original turn text
|
379 |
-
)
|
380 |
-
|
381 |
-
# Create turn variations with speaker info
|
382 |
-
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
|
383 |
-
|
384 |
-
if self.config.debug:
|
385 |
-
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
|
386 |
-
|
387 |
-
turn_variations.append(turn_vars)
|
388 |
-
context.append(original_text)
|
389 |
-
|
390 |
-
# Generate combinations with sampling
|
391 |
-
augmented_dialogues = self._generate_dialogue_combinations(
|
392 |
-
dialogue['dialogue_id'],
|
393 |
-
turn_variations,
|
394 |
-
dialogue
|
395 |
-
)
|
396 |
-
|
397 |
-
# Add original dialogue
|
398 |
-
result = [{
|
399 |
-
'dialogue_id': f"{dialogue['dialogue_id']}_original",
|
400 |
-
'turns': dialogue['turns']
|
401 |
-
}]
|
402 |
-
|
403 |
-
# Add unique augmentations
|
404 |
-
result.extend(augmented_dialogues[:self.config.augmentation_factor])
|
405 |
-
|
406 |
-
if self.config.debug:
|
407 |
-
print(f"Generated {len(result)-1} unique augmented dialogues")
|
408 |
-
|
409 |
-
return result
|
410 |
-
|
411 |
-
def _variation_score(self, original: str, variation: str) -> float:
|
412 |
-
"""
|
413 |
-
Compute a single numeric score for a variation to guide selection.
|
414 |
-
You could use semantic similarity, content preservation, etc.
|
415 |
-
Higher is better.
|
416 |
-
"""
|
417 |
-
metrics = self.quality_metrics.compute_metrics(original, variation)
|
418 |
-
# Example: Primarily semantic similarity, with a slight boost for content preservation
|
419 |
-
# Adjust as needed.
|
420 |
-
score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
|
421 |
-
return score
|
422 |
-
|
423 |
-
def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
|
424 |
-
"""
|
425 |
-
Compute a quality score for the entire augmented dialogue.
|
426 |
-
For example, average semantic similarity of turns to the original turns.
|
427 |
-
This is done after the dialogue is formed.
|
428 |
-
"""
|
429 |
-
original_texts = [t['text'] for t in original_dialogue['turns']]
|
430 |
-
aug_texts = [t['text'] for t in dialogue['turns']]
|
431 |
-
|
432 |
-
# Compute semantic similarity turn-by-turn and average it
|
433 |
-
scores = []
|
434 |
-
for orig, aug in zip(original_texts, aug_texts):
|
435 |
-
# Simple semantic similarity for scoring
|
436 |
-
emb_orig = self._compute_embedding(orig)
|
437 |
-
emb_aug = self._compute_embedding(aug)
|
438 |
-
sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
|
439 |
-
scores.append(sim)
|
440 |
-
|
441 |
-
# Could also incorporate diversity checks, content overlap, etc.
|
442 |
-
return float(np.mean(scores)) if scores else 0.0
|
443 |
-
|
444 |
-
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
|
445 |
-
"""
|
446 |
-
Generate dialogue combinations using a more controlled approach:
|
447 |
-
- Include the original turn as a fallback variation for each turn.
|
448 |
-
- Sort variations by a quality score.
|
449 |
-
- Ensure a balanced augmentation by requiring at least some turns to be augmented.
|
450 |
-
- Over-generate and then select top dialogues by quality.
|
451 |
-
"""
|
452 |
-
# Over-generate factor: create more candidates than needed
|
453 |
-
over_generate_factor = self.config.augmentation_factor * 2
|
454 |
-
|
455 |
-
# Add the original turn as a fallback variation for each turn if not present
|
456 |
-
for i, turn_variants in enumerate(turn_variations):
|
457 |
-
original_turn_text = None
|
458 |
-
# Check if we previously stored original turn text with a marker or just use the original dialogue
|
459 |
-
# If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
|
460 |
-
original_turn_text = original_dialogue['turns'][i]['text']
|
461 |
-
|
462 |
-
# Add the original turn as a variation if not already included
|
463 |
-
if not any(v['text'] == original_turn_text for v in turn_variants):
|
464 |
-
turn_variants.append({
|
465 |
-
'speaker': original_dialogue['turns'][i]['speaker'],
|
466 |
-
'text': original_turn_text
|
467 |
-
})
|
468 |
-
|
469 |
-
# Sort variations by score
|
470 |
-
original_text = original_dialogue['turns'][i]['text']
|
471 |
-
turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)
|
472 |
-
|
473 |
-
augmented_dialogues = []
|
474 |
-
used_combinations = set()
|
475 |
-
|
476 |
-
def generate_candidates(current_turns=None, turn_index=0):
|
477 |
-
if current_turns is None:
|
478 |
-
current_turns = []
|
479 |
-
|
480 |
-
if len(augmented_dialogues) >= over_generate_factor:
|
481 |
-
return
|
482 |
-
|
483 |
-
if turn_index == len(turn_variations):
|
484 |
-
# Completed a candidate dialogue
|
485 |
-
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
486 |
-
if dialogue_fingerprint not in used_combinations:
|
487 |
-
used_combinations.add(dialogue_fingerprint)
|
488 |
-
# Check if we have enough augmented turns
|
489 |
-
aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns)
|
490 |
-
if orig['text'] != curr['text'])
|
491 |
-
# Require at least half the turns to be augmented, for example
|
492 |
-
if aug_count >= max(1, len(turn_variations)//2):
|
493 |
-
augmented_dialogues.append({
|
494 |
-
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
|
495 |
-
'turns': current_turns.copy()
|
496 |
-
})
|
497 |
-
return
|
498 |
-
|
499 |
-
turn_candidates = turn_variations[turn_index]
|
500 |
-
|
501 |
-
# If no variations are available for this turn, let's just return without error.
|
502 |
-
# Normally, this shouldn't happen since we always add the original turn above.
|
503 |
-
if not turn_candidates:
|
504 |
-
# If you want to at least have the original turn, add it now:
|
505 |
-
original_text = original_dialogue['turns'][turn_index]['text']
|
506 |
-
turn_candidates.append({
|
507 |
-
'speaker': original_dialogue['turns'][turn_index]['speaker'],
|
508 |
-
'text': original_text
|
509 |
-
})
|
510 |
-
|
511 |
-
# After the fallback, if still empty for some reason, just return.
|
512 |
-
if not turn_candidates:
|
513 |
-
return
|
514 |
-
|
515 |
-
# Example strategy:
|
516 |
-
# 1. Always try the top variation (most semantically similar).
|
517 |
-
# 2. If available and allowed, pick a mid-ranked variation for diversity.
|
518 |
-
# 3. Include the original turn if not selected yet.
|
519 |
-
|
520 |
-
num_vars = min(self.config.max_sampled_variations, len(turn_candidates))
|
521 |
-
|
522 |
-
# Always include top variation
|
523 |
-
candidates_to_pick = [turn_candidates[0]]
|
524 |
-
|
525 |
-
# If we have more than 2 variations and can pick more, add a middle variation for diversity
|
526 |
-
if len(turn_candidates) > 2 and num_vars > 1:
|
527 |
-
mid_index = len(turn_candidates)//2
|
528 |
-
candidates_to_pick.append(turn_candidates[mid_index])
|
529 |
-
|
530 |
-
# If we still have room for another variation, try adding the original turn if not included
|
531 |
-
if num_vars > len(candidates_to_pick):
|
532 |
-
original_turn_text = original_dialogue['turns'][turn_index]['text']
|
533 |
-
orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
|
534 |
-
if orig_candidate and orig_candidate not in candidates_to_pick:
|
535 |
-
candidates_to_pick.append(orig_candidate)
|
536 |
-
|
537 |
-
# Shuffle candidates to produce different dialogues
|
538 |
-
np.random.shuffle(candidates_to_pick)
|
539 |
-
|
540 |
-
for variation in candidates_to_pick:
|
541 |
-
if len(augmented_dialogues) >= over_generate_factor:
|
542 |
-
return
|
543 |
-
current_turns.append(variation)
|
544 |
-
generate_candidates(current_turns, turn_index + 1)
|
545 |
-
current_turns.pop()
|
546 |
-
|
547 |
-
try:
|
548 |
-
generate_candidates()
|
549 |
-
except Exception as e:
|
550 |
-
print(f"Error in dialogue generation: {str(e)}")
|
551 |
-
return []
|
552 |
-
|
553 |
-
# Over-generated set of augmented dialogues is now available
|
554 |
-
# Let's score them and pick the top ones
|
555 |
-
scored_dialogues = []
|
556 |
-
for d in augmented_dialogues:
|
557 |
-
score = self._dialogue_quality_score(d, original_dialogue)
|
558 |
-
scored_dialogues.append((score, d))
|
559 |
-
|
560 |
-
scored_dialogues.sort(key=lambda x: x[0], reverse=True)
|
561 |
-
# Pick top `augmentation_factor` dialogues
|
562 |
-
final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
|
563 |
-
|
564 |
-
return final_dialogues
|
565 |
-
# def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
|
566 |
-
# """
|
567 |
-
# Generate dialogue combinations using sampling
|
568 |
-
# """
|
569 |
-
# augmented_dialogues = []
|
570 |
-
# used_combinations = set()
|
571 |
-
|
572 |
-
# def generate_dialogues(current_turns=None, turn_index=0):
|
573 |
-
# if current_turns is None:
|
574 |
-
# current_turns = []
|
575 |
-
|
576 |
-
# if len(augmented_dialogues) >= self.config.augmentation_factor:
|
577 |
-
# return
|
578 |
-
|
579 |
-
# if turn_index == len(turn_variations):
|
580 |
-
# dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
581 |
-
# if dialogue_fingerprint not in used_combinations:
|
582 |
-
# used_combinations.add(dialogue_fingerprint)
|
583 |
-
# augmented_dialogues.append({
|
584 |
-
# 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
|
585 |
-
# 'turns': current_turns.copy()
|
586 |
-
# })
|
587 |
-
# return
|
588 |
-
|
589 |
-
# variations = list(turn_variations[turn_index])
|
590 |
-
# np.random.shuffle(variations)
|
591 |
-
|
592 |
-
# for variation in variations[:self.config.max_sampled_variations]:
|
593 |
-
# if len(augmented_dialogues) >= self.config.augmentation_factor:
|
594 |
-
# return
|
595 |
-
# current_turns.append(variation)
|
596 |
-
# generate_dialogues(current_turns, turn_index + 1)
|
597 |
-
# current_turns.pop()
|
598 |
-
|
599 |
-
# try:
|
600 |
-
# generate_dialogues()
|
601 |
-
# except Exception as e:
|
602 |
-
# print(f"Error in dialogue generation: {str(e)}")
|
603 |
-
# return []
|
604 |
-
|
605 |
-
# return augmented_dialogues
|
606 |
-
|
607 |
-
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
|
608 |
-
"""
|
609 |
-
Check if two dialogues are duplicates.
|
610 |
-
"""
|
611 |
-
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
|
612 |
-
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
|
613 |
-
return text1 == text2
|
614 |
-
|
615 |
-
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
616 |
-
"""
|
617 |
-
Special handling for very short texts with predefined variations.
|
618 |
-
If predefined variations are found, return them directly.
|
619 |
-
Otherwise, produce simple punctuation and capitalization variants.
|
620 |
-
Skip heavy quality checks for efficiency. These variations are safe and minimal.
|
621 |
-
"""
|
622 |
-
text = turn['text']
|
623 |
-
common_variations = {
|
624 |
-
'goodbye': [
|
625 |
-
'Bye!', 'Farewell!', 'See you!', 'Take care!',
|
626 |
-
'Goodbye!', 'Bye for now!', 'Until next time!'
|
627 |
-
],
|
628 |
-
'hello': [
|
629 |
-
'Hi!', 'Hey!', 'Hello!', 'Greetings!',
|
630 |
-
'Good day!', 'Hi there!', 'Hello there!'
|
631 |
-
],
|
632 |
-
'yes': [
|
633 |
-
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
|
634 |
-
'That\'s right!', 'Definitely!', 'Sure!'
|
635 |
-
],
|
636 |
-
'no': [
|
637 |
-
'No!', 'Nope!', 'Not at all!', 'Negative!',
|
638 |
-
'Unfortunately not!', 'I\'m afraid not!'
|
639 |
-
],
|
640 |
-
'thanks': [
|
641 |
-
'Thank you!', 'Thanks a lot!', 'Many thanks!',
|
642 |
-
'I appreciate it!', 'Thank you so much!'
|
643 |
-
],
|
644 |
-
'ok': [
|
645 |
-
'Okay!', 'Alright!', 'Sure!', 'Got it!',
|
646 |
-
'Understood!', 'Fine!', 'Great!', 'Perfect!',
|
647 |
-
'That works!', 'Sounds good!'
|
648 |
-
],
|
649 |
-
'good': [
|
650 |
-
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
|
651 |
-
'Fantastic!', 'Amazing!', 'Terrific!'
|
652 |
-
]
|
653 |
-
}
|
654 |
-
|
655 |
-
text_lower = text.lower().rstrip('!.,?')
|
656 |
-
# Check if text matches any predefined category
|
657 |
-
variations = []
|
658 |
-
for key, predefined_vars in common_variations.items():
|
659 |
-
if key in text_lower or text_lower in key:
|
660 |
-
variations.extend(predefined_vars)
|
661 |
-
|
662 |
-
if not variations:
|
663 |
-
# Generate simple punctuation and capitalization variations if no predefined match
|
664 |
-
base = text.rstrip('!.,?')
|
665 |
-
variations = [
|
666 |
-
base + '!',
|
667 |
-
base + '.',
|
668 |
-
base
|
669 |
-
]
|
670 |
-
|
671 |
-
# Add capitalization variations
|
672 |
-
capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
|
673 |
-
variations.extend(capitalized)
|
674 |
-
|
675 |
-
# Ensure uniqueness
|
676 |
-
unique_variations = list(set(variations))
|
677 |
-
|
678 |
-
# Directly return these variations, as they are minimal and trusted
|
679 |
-
# No further quality checks are needed
|
680 |
-
result_variations = unique_variations[:self.config.augmentation_factor]
|
681 |
-
return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
|
682 |
-
|
683 |
-
def process_batch(self, batch: List[Dict]) -> List[Dict]:
|
684 |
-
"""Process multiple dialogues at once to maximize GPU utilization"""
|
685 |
-
results = []
|
686 |
-
|
687 |
-
# Pre-compute embeddings for all texts in batch
|
688 |
-
all_texts = []
|
689 |
-
text_to_embedding = {}
|
690 |
-
|
691 |
-
for dialogue in batch:
|
692 |
-
for turn in dialogue['turns']:
|
693 |
-
all_texts.append(turn['text'])
|
694 |
-
|
695 |
-
# Batch compute embeddings
|
696 |
-
if all_texts:
|
697 |
-
embeddings = self._compute_batch_embeddings(all_texts)
|
698 |
-
for text, embedding in zip(all_texts, embeddings):
|
699 |
-
self.embedding_cache[text] = embedding
|
700 |
-
|
701 |
-
# Process each dialogue using cached embeddings
|
702 |
-
for dialogue in batch:
|
703 |
-
try:
|
704 |
-
augmented = self.augment_dialogue(dialogue)
|
705 |
-
results.extend(augmented)
|
706 |
-
except Exception as e:
|
707 |
-
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
|
708 |
-
continue
|
709 |
-
|
710 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/main.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
CSC525 - Module 8 Option 2 - Joseph Armani
|
3 |
-
Description and References in the README.md file.
|
4 |
-
"""
|
5 |
-
import json
|
6 |
-
import tensorflow as tf
|
7 |
-
from typing import List, Dict
|
8 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
9 |
-
from data_augmentation.augmentation_processing_pipeline import AugmentationProcessingPipeline
|
10 |
-
from data_augmentation.taskmaster_processor import TaskmasterProcessor
|
11 |
-
from data_augmentation.schema_guided_dialogue_processor import SchemaGuidedProcessor
|
12 |
-
|
13 |
-
def combine_datasets(taskmaster_dialogues: List[Dict],
|
14 |
-
schema_guided_dialogues: List[Dict]) -> List[Dict]:
|
15 |
-
"""
|
16 |
-
Combine dialogues from both datasets into a single list
|
17 |
-
|
18 |
-
Args:
|
19 |
-
taskmaster_dialogues: List of dialogues in pipeline format from Taskmaster
|
20 |
-
schema_guided_dialogues: List of dialogues in pipeline format from Schema-Guided
|
21 |
-
|
22 |
-
Returns:
|
23 |
-
List[Dict]: Combined list of dialogues
|
24 |
-
"""
|
25 |
-
# Ensure unique dialogue IDs
|
26 |
-
combined_dialogues = []
|
27 |
-
seen_ids = set()
|
28 |
-
duplicate_count = 0 # Track duplicates for reporting
|
29 |
-
|
30 |
-
for dialogue in taskmaster_dialogues:
|
31 |
-
dialogue_copy = dialogue.copy()
|
32 |
-
dialogue_id = dialogue_copy['dialogue_id']
|
33 |
-
if dialogue_id in seen_ids:
|
34 |
-
duplicate_count += 1
|
35 |
-
dialogue_id = f"taskmaster_{dialogue_id}"
|
36 |
-
seen_ids.add(dialogue_id)
|
37 |
-
dialogue_copy['dialogue_id'] = dialogue_id
|
38 |
-
combined_dialogues.append(dialogue_copy)
|
39 |
-
|
40 |
-
for dialogue in schema_guided_dialogues:
|
41 |
-
dialogue_copy = dialogue.copy()
|
42 |
-
dialogue_id = dialogue_copy['dialogue_id']
|
43 |
-
if dialogue_id in seen_ids:
|
44 |
-
duplicate_count += 1
|
45 |
-
dialogue_id = f"schema_guided_{dialogue_id}"
|
46 |
-
seen_ids.add(dialogue_id)
|
47 |
-
dialogue_copy['dialogue_id'] = dialogue_id
|
48 |
-
combined_dialogues.append(dialogue_copy)
|
49 |
-
|
50 |
-
# Log the results
|
51 |
-
print(f"Combine Datasets: Found and resolved {duplicate_count} duplicate dialogue IDs.")
|
52 |
-
print(f"Combine Datasets: Total dialogues combined: {len(combined_dialogues)}")
|
53 |
-
|
54 |
-
return combined_dialogues
|
55 |
-
|
56 |
-
def main():
|
57 |
-
# Configuration
|
58 |
-
config = PipelineConfig(
|
59 |
-
min_length=1,
|
60 |
-
max_length=512,
|
61 |
-
batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
|
62 |
-
max_turns_per_dialogue=12,
|
63 |
-
max_variations_per_turn=4,
|
64 |
-
max_sampled_variations=2,
|
65 |
-
context_window_size=4,
|
66 |
-
max_complexity_threshold=100,
|
67 |
-
use_cache=False,
|
68 |
-
debug=True,
|
69 |
-
allowed_speakers=['user', 'assistant'],
|
70 |
-
required_fields=['dialogue_id', 'turns']
|
71 |
-
)
|
72 |
-
|
73 |
-
try:
|
74 |
-
# Set max_examples (Optional[int]) for testing
|
75 |
-
max_examples = 5
|
76 |
-
|
77 |
-
# Initialize and load Taskmaster dataset
|
78 |
-
print("Loading Taskmaster dataset")
|
79 |
-
taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
|
80 |
-
taskmaster_dialogues = taskmaster_processor.load_dataset('./datasets/taskmaster', max_examples=max_examples)
|
81 |
-
taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
|
82 |
-
print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
|
83 |
-
|
84 |
-
# Initialize and load Schema-Guided dataset
|
85 |
-
print("Loading Schema-Guided dataset")
|
86 |
-
schema_dialogue_processor = SchemaGuidedProcessor(config)
|
87 |
-
schema_dialogues = schema_dialogue_processor.load_dataset('./datasets/schema_guided', max_examples=max_examples)
|
88 |
-
schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
|
89 |
-
print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
|
90 |
-
|
91 |
-
# Combine datasets
|
92 |
-
print("Combining datasets")
|
93 |
-
combined_dialogues = combine_datasets(taskmaster_pipeline_dialogues, schema_pipeline_dialogues)
|
94 |
-
print(f"Combined Dialogues: {len(combined_dialogues)}")
|
95 |
-
|
96 |
-
if not combined_dialogues:
|
97 |
-
print("Combined dialogues are empty. Exiting.")
|
98 |
-
return
|
99 |
-
|
100 |
-
# Process through augmentation pipeline
|
101 |
-
print("Processing combined dataset")
|
102 |
-
pipeline = AugmentationProcessingPipeline(config)
|
103 |
-
output_path = pipeline.process_dataset(combined_dialogues)
|
104 |
-
print(f"Processing complete. Results saved to {output_path}")
|
105 |
-
pipeline.cleanup()
|
106 |
-
|
107 |
-
except Exception as e:
|
108 |
-
print(f"Processing failed: {str(e)}")
|
109 |
-
raise
|
110 |
-
|
111 |
-
if __name__ == "__main__":
|
112 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/paraphraser.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
from transformers import (
|
2 |
-
AutoTokenizer,
|
3 |
-
AutoModelForSeq2SeqLM,
|
4 |
-
)
|
5 |
-
|
6 |
-
class Paraphraser:
|
7 |
-
def __init__(self, model_name='humarin/chatgpt_paraphraser_on_T5_base'):
|
8 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
-
self.model.eval()
|
11 |
-
|
12 |
-
def paraphrase(self, text, num_return_sequences=5, num_beams=5,
|
13 |
-
num_beam_groups=1, diversity_penalty=0.0, device=None):
|
14 |
-
try:
|
15 |
-
input_text = "paraphrase: " + text + " </s>"
|
16 |
-
encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
|
17 |
-
|
18 |
-
# Move input tensors to specified device if provided
|
19 |
-
if device is not None:
|
20 |
-
input_ids = encoding["input_ids"].to(device)
|
21 |
-
self.model = self.model.to(device)
|
22 |
-
else:
|
23 |
-
input_ids = encoding["input_ids"]
|
24 |
-
|
25 |
-
outputs = self.model.generate(
|
26 |
-
input_ids=input_ids,
|
27 |
-
max_length=256,
|
28 |
-
num_beams=num_beams,
|
29 |
-
num_beam_groups=num_beam_groups,
|
30 |
-
num_return_sequences=num_return_sequences,
|
31 |
-
diversity_penalty=diversity_penalty,
|
32 |
-
early_stopping=True
|
33 |
-
)
|
34 |
-
|
35 |
-
# Move outputs back to CPU for tokenizer decoding
|
36 |
-
outputs = outputs.cpu() if device is not None else outputs
|
37 |
-
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
|
38 |
-
for output in outputs]
|
39 |
-
return paraphrases
|
40 |
-
except Exception as e:
|
41 |
-
print(f"Error in paraphrasing: {e}")
|
42 |
-
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/pipeline_config.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import List
|
3 |
-
|
4 |
-
@dataclass
|
5 |
-
class PipelineConfig:
|
6 |
-
"""
|
7 |
-
Config for the pipeline
|
8 |
-
"""
|
9 |
-
# Validation settings
|
10 |
-
min_length: int = 1
|
11 |
-
max_length: int = 512
|
12 |
-
min_tokens: int = 1
|
13 |
-
max_tokens: int = 128
|
14 |
-
|
15 |
-
allowed_speakers: List[str] = None
|
16 |
-
required_fields: List[str] = None
|
17 |
-
|
18 |
-
# Text augmentation settings
|
19 |
-
augmentation_factor: int = 4
|
20 |
-
augmentation_techniques: List[str] = None
|
21 |
-
|
22 |
-
max_turns_per_dialogue: int = 6
|
23 |
-
max_variations_per_turn: int = 3
|
24 |
-
max_sampled_variations: int = 2
|
25 |
-
max_complexity_threshold: int = 100
|
26 |
-
complexity_reduction_turns: int = 4
|
27 |
-
|
28 |
-
# Quality thresholds
|
29 |
-
semantic_similarity_threshold: float = 0.45
|
30 |
-
grammar_error_threshold: int = 2
|
31 |
-
rouge1_f1_threshold: float = 0.30
|
32 |
-
rouge2_f1_threshold: float = 0.15
|
33 |
-
|
34 |
-
# Response coherence thresholds
|
35 |
-
min_response_coherence: float = 0.3
|
36 |
-
context_similarity_weight: float = 0.35
|
37 |
-
response_coherence_weight: float = 0.65
|
38 |
-
|
39 |
-
# Performance settings
|
40 |
-
batch_size: int = 32
|
41 |
-
use_cache: bool = True
|
42 |
-
debug: bool = False
|
43 |
-
|
44 |
-
context_window_size: int = 4
|
45 |
-
|
46 |
-
def __post_init__(self):
|
47 |
-
if self.allowed_speakers is None:
|
48 |
-
self.allowed_speakers = ['user', 'assistant']
|
49 |
-
if self.required_fields is None:
|
50 |
-
self.required_fields = ['dialogue_id', 'turns']
|
51 |
-
if self.augmentation_techniques is None:
|
52 |
-
self.augmentation_techniques = ['paraphrase', 'back_translation']
|
53 |
-
|
54 |
-
# Validate weights sum to 1.0
|
55 |
-
if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
|
56 |
-
raise ValueError("Context similarity and response coherence weights must sum to 1.0")
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/quality_metrics.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import tensorflow_hub as hub
|
2 |
-
import spacy
|
3 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
-
from typing import Dict
|
5 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
6 |
-
|
7 |
-
class QualityMetrics:
|
8 |
-
"""
|
9 |
-
Quality metrics focusing on semantic similarity and basic lexical stats.
|
10 |
-
"""
|
11 |
-
def __init__(self, config: PipelineConfig):
|
12 |
-
self.config = config
|
13 |
-
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
14 |
-
self.nlp = spacy.load('en_core_web_md')
|
15 |
-
|
16 |
-
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
|
17 |
-
embeddings = self.use_model([text1, text2])
|
18 |
-
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
|
19 |
-
return cosine_similarity([emb1], [emb2])[0][0]
|
20 |
-
|
21 |
-
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
|
22 |
-
metrics = {}
|
23 |
-
# Semantic similarity
|
24 |
-
embeddings = self.use_model([original, augmented])
|
25 |
-
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
|
26 |
-
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
|
27 |
-
|
28 |
-
# Lexical diversity & content preservation
|
29 |
-
doc_orig = self.nlp(original)
|
30 |
-
doc_aug = self.nlp(augmented)
|
31 |
-
|
32 |
-
aug_tokens = [token.text.lower() for token in doc_aug]
|
33 |
-
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
|
34 |
-
|
35 |
-
orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
|
36 |
-
aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
|
37 |
-
if len(orig_content) == 0:
|
38 |
-
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
|
39 |
-
else:
|
40 |
-
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
|
41 |
-
|
42 |
-
# Length ratio
|
43 |
-
orig_words = len(original.split())
|
44 |
-
aug_words = len(augmented.split())
|
45 |
-
metrics['length_ratio'] = aug_words / max(orig_words, 1)
|
46 |
-
|
47 |
-
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/schema_guided_dialogue_processor.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, field
|
2 |
-
from typing import List, Dict, Optional, Any
|
3 |
-
import json
|
4 |
-
import glob
|
5 |
-
from pathlib import Path
|
6 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
-
|
8 |
-
@dataclass
|
9 |
-
class SchemaGuidedDialogue:
|
10 |
-
"""
|
11 |
-
Structured representation of a Schema-Guided dialogue
|
12 |
-
"""
|
13 |
-
dialogue_id: str
|
14 |
-
service_name: str
|
15 |
-
service_description: Optional[str]
|
16 |
-
schema: Dict[str, Any]
|
17 |
-
turns: List[Dict[str, Any]]
|
18 |
-
original_metadata: Dict[str, Any] = field(default_factory=dict)
|
19 |
-
|
20 |
-
class SchemaGuidedProcessor:
|
21 |
-
"""
|
22 |
-
Handles processing and preparation of Schema-Guided dataset dialogues
|
23 |
-
"""
|
24 |
-
def __init__(self, config: PipelineConfig):
|
25 |
-
self.config = config
|
26 |
-
self.services = set()
|
27 |
-
self.domains = set()
|
28 |
-
self.schemas = {}
|
29 |
-
|
30 |
-
def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]:
|
31 |
-
"""
|
32 |
-
Load and parse Schema-Guided Dialogue dataset
|
33 |
-
|
34 |
-
Args:
|
35 |
-
dialogue_path: Path to the dialogue JSON file
|
36 |
-
schema_path: Path to the schema JSON file
|
37 |
-
"""
|
38 |
-
# Define schema and dialogue file patterns
|
39 |
-
schema_file = Path(base_dir, "schema.json")
|
40 |
-
dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json"))
|
41 |
-
|
42 |
-
# Check for schema file
|
43 |
-
if not schema_file.exists():
|
44 |
-
raise FileNotFoundError(f"Schema file not found at {schema_file}")
|
45 |
-
|
46 |
-
# Load schema
|
47 |
-
self.schemas = self._load_schemas(schema_file)
|
48 |
-
|
49 |
-
# Find and validate dialogue files
|
50 |
-
dialogue_files = glob.glob(dialogue_files_pattern)
|
51 |
-
if not dialogue_files:
|
52 |
-
raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}")
|
53 |
-
|
54 |
-
print(f"Found {len(dialogue_files)} dialogue files to process.")
|
55 |
-
|
56 |
-
# Process all dialogues
|
57 |
-
processed_dialogues = []
|
58 |
-
for file_path in dialogue_files:
|
59 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
60 |
-
raw_dialogues = json.load(f)
|
61 |
-
|
62 |
-
for dialogue in raw_dialogues:
|
63 |
-
processed_dialogues.append(self._process_single_dialogue(dialogue))
|
64 |
-
|
65 |
-
if max_examples and len(processed_dialogues) >= max_examples:
|
66 |
-
break
|
67 |
-
|
68 |
-
return processed_dialogues
|
69 |
-
|
70 |
-
def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue:
|
71 |
-
"""
|
72 |
-
Process a single dialogue JSON object into a SchemaGuidedDialogue object.
|
73 |
-
"""
|
74 |
-
dialogue_id = str(dialogue.get("dialogue_id", ""))
|
75 |
-
services = dialogue.get("services", [])
|
76 |
-
service_name = services[0] if services else None
|
77 |
-
schema = self.schemas.get(service_name, {})
|
78 |
-
service_description = schema.get("description", "")
|
79 |
-
|
80 |
-
# Process turns
|
81 |
-
turns = self._process_turns(dialogue.get("turns", []))
|
82 |
-
|
83 |
-
# Store metadata
|
84 |
-
metadata = {
|
85 |
-
"services": services,
|
86 |
-
"original_id": dialogue_id,
|
87 |
-
}
|
88 |
-
|
89 |
-
return SchemaGuidedDialogue(
|
90 |
-
dialogue_id=f"schema_guided_{dialogue_id}",
|
91 |
-
service_name=service_name,
|
92 |
-
service_description=service_description,
|
93 |
-
schema=schema,
|
94 |
-
turns=turns,
|
95 |
-
original_metadata=metadata,
|
96 |
-
)
|
97 |
-
|
98 |
-
def _validate_schema(self, schema: Dict[str, Any]) -> bool:
|
99 |
-
"""
|
100 |
-
Validate a schema
|
101 |
-
"""
|
102 |
-
required_keys = {"service_name", "description", "slots", "intents"}
|
103 |
-
missing_keys = required_keys - schema.keys()
|
104 |
-
if missing_keys:
|
105 |
-
print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}")
|
106 |
-
return False
|
107 |
-
return True
|
108 |
-
|
109 |
-
def _load_schemas(self, schema_path: str) -> Dict[str, Any]:
|
110 |
-
"""
|
111 |
-
Load and process service schemas
|
112 |
-
"""
|
113 |
-
with open(schema_path, 'r', encoding='utf-8') as f:
|
114 |
-
schemas = json.load(f)
|
115 |
-
|
116 |
-
# Validate and index schemas
|
117 |
-
return {
|
118 |
-
schema["service_name"]: schema for schema in schemas if self._validate_schema(schema)
|
119 |
-
}
|
120 |
-
|
121 |
-
def _process_turns(self, turns: List[Dict]) -> List[Dict]:
|
122 |
-
"""
|
123 |
-
Process dialogue turns into standardized format
|
124 |
-
"""
|
125 |
-
processed_turns = []
|
126 |
-
|
127 |
-
for turn in turns:
|
128 |
-
try:
|
129 |
-
# Map speakers to standard format
|
130 |
-
speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user'
|
131 |
-
|
132 |
-
# Extract utterance and clean it
|
133 |
-
text = turn.get('utterance', '').strip()
|
134 |
-
|
135 |
-
# Extract frames and dialogue acts
|
136 |
-
frames = turn.get('frames', [])
|
137 |
-
acts = []
|
138 |
-
slots = []
|
139 |
-
|
140 |
-
for frame in frames:
|
141 |
-
if 'actions' in frame:
|
142 |
-
acts.extend(frame['actions'])
|
143 |
-
if 'slots' in frame:
|
144 |
-
slots.extend(frame['slots'])
|
145 |
-
|
146 |
-
# Create the processed turn
|
147 |
-
processed_turn = {
|
148 |
-
'speaker': speaker,
|
149 |
-
'text': text,
|
150 |
-
'original_speaker': turn.get('speaker', ''),
|
151 |
-
'dialogue_acts': acts,
|
152 |
-
'slots': slots,
|
153 |
-
'metadata': {k: v for k, v in turn.items()
|
154 |
-
if k not in {'speaker', 'utterance', 'frames'}}
|
155 |
-
}
|
156 |
-
|
157 |
-
processed_turns.append(processed_turn)
|
158 |
-
except Exception as e:
|
159 |
-
print(f"Error processing turn: {str(e)}")
|
160 |
-
continue
|
161 |
-
|
162 |
-
return processed_turns
|
163 |
-
|
164 |
-
def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]:
|
165 |
-
"""
|
166 |
-
Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline
|
167 |
-
"""
|
168 |
-
pipeline_dialogues = []
|
169 |
-
|
170 |
-
for dialogue in schema_dialogues:
|
171 |
-
# Convert turns to the expected format
|
172 |
-
processed_turns = [
|
173 |
-
{"speaker": turn["speaker"], "text": turn["text"]}
|
174 |
-
for turn in dialogue.turns if turn["text"].strip()
|
175 |
-
]
|
176 |
-
|
177 |
-
# Create dialogue in pipeline format
|
178 |
-
pipeline_dialogue = {
|
179 |
-
'dialogue_id': dialogue.dialogue_id,
|
180 |
-
'turns': processed_turns,
|
181 |
-
'metadata': {
|
182 |
-
'service_name': dialogue.service_name,
|
183 |
-
'service_description': dialogue.service_description,
|
184 |
-
'schema': dialogue.schema,
|
185 |
-
**dialogue.original_metadata
|
186 |
-
}
|
187 |
-
}
|
188 |
-
|
189 |
-
pipeline_dialogues.append(pipeline_dialogue)
|
190 |
-
|
191 |
-
return pipeline_dialogues
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_augmentation_code/taskmaster_processor.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, field
|
2 |
-
from typing import List, Dict, Optional, Any
|
3 |
-
import json
|
4 |
-
import re
|
5 |
-
from pathlib import Path
|
6 |
-
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
-
|
8 |
-
@dataclass
|
9 |
-
class TaskmasterDialogue:
|
10 |
-
"""
|
11 |
-
Structured representation of a Taskmaster dialogue
|
12 |
-
"""
|
13 |
-
conversation_id: str
|
14 |
-
instruction_id: Optional[str]
|
15 |
-
scenario: Optional[str]
|
16 |
-
domain: Optional[str]
|
17 |
-
turns: List[Dict[str, Any]]
|
18 |
-
original_metadata: Dict[str, Any] = field(default_factory=dict)
|
19 |
-
|
20 |
-
def __str__(self):
|
21 |
-
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
|
22 |
-
|
23 |
-
def validate(self) -> bool:
|
24 |
-
return bool(self.conversation_id and isinstance(self.turns, list))
|
25 |
-
|
26 |
-
class TaskmasterProcessor:
|
27 |
-
"""
|
28 |
-
Handles processing and preparation of Taskmaster dataset dialogues
|
29 |
-
"""
|
30 |
-
config: PipelineConfig
|
31 |
-
use_ontology: bool = False # Whether to load and use ontology
|
32 |
-
ontology: Optional[Dict[str, Any]] = None # Holds ontology data if loaded
|
33 |
-
domains: set = field(default_factory=set) # Tracks unique domains
|
34 |
-
scenarios: set = field(default_factory=set) # Tracks unique scenarios
|
35 |
-
|
36 |
-
def __init__(self, config: PipelineConfig, use_ontology: bool = False):
|
37 |
-
self.config = config
|
38 |
-
self.use_ontology = use_ontology
|
39 |
-
self.ontology = None
|
40 |
-
self.domains = set()
|
41 |
-
self.scenarios = set()
|
42 |
-
|
43 |
-
def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
|
44 |
-
"""
|
45 |
-
Load and parse Taskmaster JSON dataset.
|
46 |
-
Handles self-dialogs, woz-dialogs, and ontology files.
|
47 |
-
"""
|
48 |
-
required_files = {
|
49 |
-
"self-dialogs": "self-dialogs.json",
|
50 |
-
"woz-dialogs": "woz-dialogs.json",
|
51 |
-
"ontology": "ontology.json",
|
52 |
-
}
|
53 |
-
|
54 |
-
# Check for required files
|
55 |
-
missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()]
|
56 |
-
if missing_files:
|
57 |
-
raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}")
|
58 |
-
|
59 |
-
# load ontology
|
60 |
-
ontology_path = Path(base_dir, required_files['ontology'])
|
61 |
-
with open(ontology_path, 'r', encoding='utf-8') as f:
|
62 |
-
self.ontology = json.load(f)
|
63 |
-
|
64 |
-
processed_dialogues = []
|
65 |
-
for file_key in ["self-dialogs", "woz-dialogs"]:
|
66 |
-
file_path = Path(base_dir, required_files[file_key])
|
67 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
68 |
-
raw_data = json.load(f)
|
69 |
-
|
70 |
-
for dialogue in raw_data:
|
71 |
-
# Extract core dialogue components
|
72 |
-
conversation_id = dialogue.get('conversation_id', '')
|
73 |
-
instruction_id = dialogue.get('instruction_id', None)
|
74 |
-
|
75 |
-
if 'utterances' in dialogue:
|
76 |
-
turns = self._process_utterances(dialogue['utterances'])
|
77 |
-
scenario = dialogue.get('scenario', '')
|
78 |
-
domain = self._extract_domain(scenario)
|
79 |
-
else:
|
80 |
-
turns = []
|
81 |
-
scenario = ''
|
82 |
-
domain = ''
|
83 |
-
|
84 |
-
# Store metadata
|
85 |
-
metadata = {k: v for k, v in dialogue.items()
|
86 |
-
if k not in {'conversation_id', 'instruction_id', 'utterances'}}
|
87 |
-
|
88 |
-
# Create structured dialogue object
|
89 |
-
processed_dialogue = TaskmasterDialogue(
|
90 |
-
conversation_id=conversation_id,
|
91 |
-
instruction_id=instruction_id,
|
92 |
-
scenario=scenario,
|
93 |
-
domain=domain,
|
94 |
-
turns=turns,
|
95 |
-
original_metadata=metadata
|
96 |
-
)
|
97 |
-
|
98 |
-
processed_dialogues.append(processed_dialogue)
|
99 |
-
|
100 |
-
# Update domain and scenario tracking
|
101 |
-
if domain:
|
102 |
-
self.domains.add(domain)
|
103 |
-
if scenario:
|
104 |
-
self.scenarios.add(scenario)
|
105 |
-
|
106 |
-
if max_examples and len(processed_dialogues) >= max_examples:
|
107 |
-
break
|
108 |
-
|
109 |
-
return processed_dialogues
|
110 |
-
|
111 |
-
def _process_utterances(self, utterances: List[Dict]) -> List[Dict]:
|
112 |
-
"""
|
113 |
-
Process utterances into a standardized format
|
114 |
-
"""
|
115 |
-
processed_turns = []
|
116 |
-
|
117 |
-
for utterance in utterances:
|
118 |
-
# Map Taskmaster speaker roles to your expected format
|
119 |
-
speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user'
|
120 |
-
|
121 |
-
# Extract and clean the text
|
122 |
-
text = utterance.get('text', '').strip()
|
123 |
-
|
124 |
-
# Extract any segments or annotations if present
|
125 |
-
segments = utterance.get('segments', [])
|
126 |
-
|
127 |
-
# Create the processed turn
|
128 |
-
turn = {
|
129 |
-
'speaker': speaker,
|
130 |
-
'text': text,
|
131 |
-
'original_speaker': utterance.get('speaker', ''),
|
132 |
-
'segments': segments,
|
133 |
-
'metadata': {k: v for k, v in utterance.items()
|
134 |
-
if k not in {'speaker', 'text', 'segments'}}
|
135 |
-
}
|
136 |
-
|
137 |
-
processed_turns.append(turn)
|
138 |
-
|
139 |
-
return processed_turns
|
140 |
-
|
141 |
-
def _extract_domain(self, scenario: str) -> str:
|
142 |
-
"""
|
143 |
-
Extract domain from scenario description
|
144 |
-
"""
|
145 |
-
domain_patterns = {
|
146 |
-
'restaurant': r'\b(restaurant|dining|food|reservation)\b',
|
147 |
-
'movie': r'\b(movie|cinema|film|ticket)\b',
|
148 |
-
'ride_share': r'\b(ride|taxi|uber|lyft)\b',
|
149 |
-
'coffee': r'\b(coffee|café|cafe|starbucks)\b',
|
150 |
-
'pizza': r'\b(pizza|delivery|order food)\b',
|
151 |
-
'auto': r'\b(car|vehicle|repair|maintenance)\b',
|
152 |
-
}
|
153 |
-
|
154 |
-
scenario_lower = scenario.lower()
|
155 |
-
|
156 |
-
for domain, pattern in domain_patterns.items():
|
157 |
-
if re.search(pattern, scenario_lower):
|
158 |
-
return domain
|
159 |
-
|
160 |
-
return 'other'
|
161 |
-
|
162 |
-
def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]:
|
163 |
-
"""
|
164 |
-
Convert TaskmasterDialogues to the format expected by the ProcessingPipeline
|
165 |
-
"""
|
166 |
-
pipeline_dialogues = []
|
167 |
-
|
168 |
-
for dialogue in taskmaster_dialogues:
|
169 |
-
# Convert turns to the expected format
|
170 |
-
processed_turns = []
|
171 |
-
for turn in dialogue.turns:
|
172 |
-
if turn['text'].strip(): # Skip empty turns
|
173 |
-
processed_turns.append({
|
174 |
-
'speaker': turn['speaker'],
|
175 |
-
'text': turn['text']
|
176 |
-
})
|
177 |
-
|
178 |
-
# Create dialogue in pipeline format
|
179 |
-
pipeline_dialogue = {
|
180 |
-
'dialogue_id': dialogue.conversation_id,
|
181 |
-
'turns': processed_turns,
|
182 |
-
'metadata': {
|
183 |
-
'instruction_id': dialogue.instruction_id,
|
184 |
-
'scenario': dialogue.scenario,
|
185 |
-
'domain': dialogue.domain,
|
186 |
-
**dialogue.original_metadata
|
187 |
-
}
|
188 |
-
}
|
189 |
-
|
190 |
-
pipeline_dialogues.append(pipeline_dialogue)
|
191 |
-
|
192 |
-
return pipeline_dialogues
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_data.py
CHANGED
@@ -16,12 +16,12 @@ logger = config_logger(__name__)
|
|
16 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
|
18 |
def main():
|
19 |
-
MODELS_DIR = '
|
20 |
-
PROCESSED_DATA_DIR = '
|
21 |
-
CACHE_DIR = '
|
22 |
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
23 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
24 |
-
TF_RECORD_DIR = '
|
25 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
26 |
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
|
27 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
|
|
16 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
|
18 |
def main():
|
19 |
+
MODELS_DIR = 'models'
|
20 |
+
PROCESSED_DATA_DIR = 'processed_outputs'
|
21 |
+
CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache')
|
22 |
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
23 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
24 |
+
TF_RECORD_DIR = 'training_data'
|
25 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
26 |
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
|
27 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
run_chatbot_validation.py
CHANGED
@@ -44,7 +44,7 @@ def run_chatbot_validation():
|
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
46 |
|
47 |
-
MODEL_DIR = "
|
48 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
49 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
50 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
|
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
46 |
|
47 |
+
MODEL_DIR = "models"
|
48 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
49 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
50 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
new_iteration/run_taskmaster_processor.py → run_taskmaster_processor.py
RENAMED
File without changes
|
new_iteration/taskmaster_processor.py → taskmaster_processor.py
RENAMED
File without changes
|
tf_data_pipeline.py
CHANGED
@@ -29,7 +29,7 @@ class TFDataPipeline:
|
|
29 |
max_length: int = 512,
|
30 |
neg_samples: int = 10,
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
-
faiss_index_file_path: str = '
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
35 |
):
|
|
|
29 |
max_length: int = 512,
|
30 |
neg_samples: int = 10,
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
+
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
35 |
):
|
unused/build_faiss_index.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# import os
|
2 |
-
# import json
|
3 |
-
# from pathlib import Path
|
4 |
-
|
5 |
-
# import faiss
|
6 |
-
# import numpy as np
|
7 |
-
# import tensorflow as tf
|
8 |
-
# from transformers import AutoTokenizer, TFAutoModel
|
9 |
-
# from tqdm.auto import tqdm
|
10 |
-
|
11 |
-
# from chatbot_model import ChatbotConfig, EncoderModel
|
12 |
-
# from tf_data_pipeline import TFDataPipeline
|
13 |
-
# from logger_config import config_logger
|
14 |
-
|
15 |
-
# logger = config_logger(__name__)
|
16 |
-
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
-
|
18 |
-
# def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
|
19 |
-
# """
|
20 |
-
# Perform a quick sanity check to ensure the model is loaded correctly.
|
21 |
-
# """
|
22 |
-
# sample_response = "This is a test response."
|
23 |
-
# encoded_sample = tokenizer(
|
24 |
-
# [sample_response],
|
25 |
-
# padding=True,
|
26 |
-
# truncation=True,
|
27 |
-
# max_length=config.max_context_token_limit,
|
28 |
-
# return_tensors='tf'
|
29 |
-
# )
|
30 |
-
|
31 |
-
# # Get embedding
|
32 |
-
# sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
|
33 |
-
|
34 |
-
# # Check shape
|
35 |
-
# if sample_embedding.shape[1] != config.embedding_dim:
|
36 |
-
# logger.error(
|
37 |
-
# f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
|
38 |
-
# f"got {sample_embedding.shape[1]}"
|
39 |
-
# )
|
40 |
-
# raise ValueError("Embedding dimension mismatch.")
|
41 |
-
# else:
|
42 |
-
# logger.info("Embedding dimension matches the configuration.")
|
43 |
-
|
44 |
-
# # Check normalization
|
45 |
-
# embedding_norm = np.linalg.norm(sample_embedding, axis=1)
|
46 |
-
# if not np.allclose(embedding_norm, 1.0, atol=1e-5):
|
47 |
-
# logger.error("Embeddings are not properly normalized.")
|
48 |
-
# raise ValueError("Embeddings are not normalized.")
|
49 |
-
# else:
|
50 |
-
# logger.info("Embeddings are properly normalized.")
|
51 |
-
|
52 |
-
# logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
|
53 |
-
|
54 |
-
# def build_faiss_index():
|
55 |
-
# """
|
56 |
-
# Rebuild the FAISS index by:
|
57 |
-
# 1) Loading your config.json
|
58 |
-
# 2) Initializing encoder + loading submodule & custom weights
|
59 |
-
# 3) Loading tokenizer from disk
|
60 |
-
# 4) Creating a TFDataPipeline
|
61 |
-
# 5) Setting the pipeline's response_pool from a JSON file
|
62 |
-
# 6) Using pipeline.compute_and_index_response_embeddings()
|
63 |
-
# 7) Saving the FAISS index
|
64 |
-
# """
|
65 |
-
# # Directories
|
66 |
-
# MODELS_DIR = Path("models")
|
67 |
-
# FAISS_DIR = MODELS_DIR / "faiss_indices"
|
68 |
-
# FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
|
69 |
-
# RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
|
70 |
-
# TOKENIZER_DIR = MODELS_DIR / "tokenizer"
|
71 |
-
# SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
|
72 |
-
# CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
|
73 |
-
|
74 |
-
# # 1) Load ChatbotConfig
|
75 |
-
# config_path = MODELS_DIR / "config.json"
|
76 |
-
# if config_path.exists():
|
77 |
-
# with open(config_path, "r", encoding="utf-8") as f:
|
78 |
-
# config_dict = json.load(f)
|
79 |
-
# config = ChatbotConfig.from_dict(config_dict)
|
80 |
-
# logger.info(f"Loaded ChatbotConfig from {config_path}")
|
81 |
-
# else:
|
82 |
-
# config = ChatbotConfig()
|
83 |
-
# logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
|
84 |
-
|
85 |
-
# # 2) Initialize the EncoderModel
|
86 |
-
# encoder = EncoderModel(config=config)
|
87 |
-
# logger.info("EncoderModel instantiated (empty).")
|
88 |
-
|
89 |
-
# # Overwrite the submodule from 'shared_encoder' directory
|
90 |
-
# if SHARED_ENCODER_DIR.exists():
|
91 |
-
# logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
|
92 |
-
# encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
|
93 |
-
# logger.info("Loaded HF submodule into encoder.pretrained.")
|
94 |
-
# else:
|
95 |
-
# logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
|
96 |
-
|
97 |
-
# # Build model once, then load custom weights (projection, etc.)
|
98 |
-
# dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
99 |
-
# _ = encoder(dummy_input, training=False) # builds the layers
|
100 |
-
|
101 |
-
# if CUSTOM_WEIGHTS_PATH.exists():
|
102 |
-
# logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
|
103 |
-
# encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
|
104 |
-
# logger.info("Custom top-level weights loaded successfully.")
|
105 |
-
# else:
|
106 |
-
# logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
|
107 |
-
|
108 |
-
# # 3) Load tokenizer
|
109 |
-
# if TOKENIZER_DIR.exists():
|
110 |
-
# logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
111 |
-
# tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
|
112 |
-
# else:
|
113 |
-
# logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
|
114 |
-
# tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
115 |
-
|
116 |
-
# # 4) Quick sanity check
|
117 |
-
# sanity_check(encoder, tokenizer, config)
|
118 |
-
|
119 |
-
# # 5) Prepare a TFDataPipeline
|
120 |
-
# pipeline = TFDataPipeline(
|
121 |
-
# config=config,
|
122 |
-
# tokenizer=tokenizer,
|
123 |
-
# encoder=encoder,
|
124 |
-
# index_file_path=str(FAISS_INDEX_PATH),
|
125 |
-
# response_pool=[],
|
126 |
-
# max_length=config.max_context_token_limit,
|
127 |
-
# query_embeddings_cache={},
|
128 |
-
# neg_samples=config.neg_samples,
|
129 |
-
# index_type='IndexFlatIP',
|
130 |
-
# nlist=100,
|
131 |
-
# max_retries=config.max_retries
|
132 |
-
# )
|
133 |
-
|
134 |
-
# # 6) Load the existing response pool
|
135 |
-
# if not RESPONSES_PATH.exists():
|
136 |
-
# logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
|
137 |
-
# raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
|
138 |
-
|
139 |
-
# with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
|
140 |
-
# response_pool = json.load(f)
|
141 |
-
# logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
|
142 |
-
|
143 |
-
# pipeline.response_pool = response_pool # assign to pipeline
|
144 |
-
|
145 |
-
# # 7) Build (or rebuild) the FAISS index from pipeline method
|
146 |
-
# # This does all the compute-embeddings + index.add in one place
|
147 |
-
# logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
|
148 |
-
# pipeline.compute_and_index_response_embeddings()
|
149 |
-
|
150 |
-
# # 8) Save the rebuilt FAISS index
|
151 |
-
# pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
|
152 |
-
|
153 |
-
# # Verify
|
154 |
-
# loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
155 |
-
# logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
|
156 |
-
|
157 |
-
# return loaded_index, pipeline.response_pool
|
158 |
-
|
159 |
-
# if __name__ == "__main__":
|
160 |
-
# build_faiss_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unused/gpu_monitor.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from typing import List, Dict, Optional
|
3 |
-
from dataclasses import dataclass
|
4 |
-
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
|
7 |
-
@dataclass
|
8 |
-
class GPUMemoryStats:
|
9 |
-
total: int
|
10 |
-
used: int
|
11 |
-
free: int
|
12 |
-
|
13 |
-
class GPUMemoryMonitor:
|
14 |
-
"""Monitor GPU memory usage with safe CPU fallback."""
|
15 |
-
def __init__(self):
|
16 |
-
self.has_gpu = False
|
17 |
-
try:
|
18 |
-
gpus = tf.config.list_physical_devices('GPU')
|
19 |
-
self.has_gpu = len(gpus) > 0
|
20 |
-
except:
|
21 |
-
pass
|
22 |
-
|
23 |
-
def get_memory_stats(self) -> Optional[GPUMemoryStats]:
|
24 |
-
"""Get current GPU memory statistics."""
|
25 |
-
if not self.has_gpu:
|
26 |
-
return None
|
27 |
-
|
28 |
-
try:
|
29 |
-
memory_info = tf.config.experimental.get_memory_info('GPU:0')
|
30 |
-
return GPUMemoryStats(
|
31 |
-
total=memory_info['peak'],
|
32 |
-
used=memory_info['current'],
|
33 |
-
free=memory_info['peak'] - memory_info['current']
|
34 |
-
)
|
35 |
-
except:
|
36 |
-
return None
|
37 |
-
|
38 |
-
def get_memory_usage(self) -> float:
|
39 |
-
"""Get current GPU memory usage as a percentage."""
|
40 |
-
if not self.has_gpu:
|
41 |
-
return 0.0
|
42 |
-
stats = self.get_memory_stats()
|
43 |
-
if stats is None or stats.total == 0:
|
44 |
-
return 0.0
|
45 |
-
return stats.used / stats.total
|
46 |
-
|
47 |
-
def should_reduce_batch_size(self) -> bool:
|
48 |
-
"""Check if batch size should be reduced based on memory usage."""
|
49 |
-
if not self.has_gpu:
|
50 |
-
return False
|
51 |
-
usage = self.get_memory_usage()
|
52 |
-
return usage > 0.90
|
53 |
-
|
54 |
-
def can_increase_batch_size(self) -> bool:
|
55 |
-
"""Check if batch size can be increased based on memory usage."""
|
56 |
-
if not self.has_gpu:
|
57 |
-
return True
|
58 |
-
usage = self.get_memory_usage()
|
59 |
-
return usage < 0.70
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|