|
from typing import Dict, List |
|
import numpy as np |
|
import torch |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
from pipeline_config import PipelineConfig |
|
from quality_metrics import QualityMetrics |
|
from paraphraser import Paraphraser |
|
import nlpaug.augmenter.word as naw |
|
from functools import lru_cache |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class DialogueAugmenter: |
|
""" |
|
Optimized dialogue augmentation with quality control and complexity management. |
|
""" |
|
def __init__(self, nlp, config: PipelineConfig): |
|
self.nlp = nlp |
|
self.config = config |
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.use_gpu = torch.cuda.is_available() |
|
|
|
if self.config.debug: |
|
print(f"Using device: {self.device}") |
|
if self.use_gpu: |
|
print(f"GPU Device: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
self.quality_metrics = QualityMetrics(config) |
|
self.semantic_similarity_threshold = 0.75 |
|
|
|
|
|
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
|
|
|
|
|
self._initialize_augmentation_models() |
|
|
|
|
|
self.embedding_cache = {} |
|
|
|
|
|
if self.use_gpu: |
|
gpus = tf.config.list_physical_devices('GPU') |
|
if gpus: |
|
try: |
|
for gpu in gpus: |
|
tf.config.experimental.set_memory_growth(gpu, True) |
|
except RuntimeError as e: |
|
print(e) |
|
|
|
def _initialize_augmentation_models(self): |
|
"""Initialize augmentation models with appropriate device settings""" |
|
|
|
self.paraphraser = Paraphraser() |
|
if self.use_gpu: |
|
|
|
self.paraphraser.model = self.paraphraser.model.to(self.device) |
|
|
|
|
|
self.word_augmenter = naw.SynonymAug(aug_src='wordnet') |
|
|
|
self.augmenters = { |
|
'advanced': [ |
|
self.paraphraser, |
|
], |
|
'basic': [ |
|
('synonym', self.word_augmenter), |
|
] |
|
} |
|
|
|
@lru_cache(maxsize=1024) |
|
def _compute_embedding(self, text: str) -> np.ndarray: |
|
"""Cached computation of text embedding""" |
|
if text in self.embedding_cache: |
|
return self.embedding_cache[text] |
|
embedding = self.use_model([text])[0].numpy() |
|
self.embedding_cache[text] = embedding |
|
return embedding |
|
|
|
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray: |
|
"""Compute embeddings for multiple texts at once with hardware optimization""" |
|
|
|
uncached_texts = [t for t in texts if t not in self.embedding_cache] |
|
if uncached_texts: |
|
embeddings = self.use_model(uncached_texts).numpy() |
|
|
|
for text, embedding in zip(uncached_texts, embeddings): |
|
self.embedding_cache[text] = embedding |
|
|
|
|
|
return np.array([self.embedding_cache[t] for t in texts]) |
|
|
|
def _quick_quality_check(self, variation: str, original: str) -> bool: |
|
""" |
|
Preliminary quality check while maintaining reasonable pass rates |
|
""" |
|
if self.config.debug: |
|
print(f"\nQuick check for variation: {variation}") |
|
|
|
orig_len = len(original.split()) |
|
var_len = len(variation.split()) |
|
|
|
|
|
if orig_len <= 3: |
|
if var_len > orig_len * 3: |
|
if self.config.debug: |
|
print(f"Failed length check (short text): {var_len} vs {orig_len}") |
|
return False |
|
else: |
|
if var_len > orig_len * 2: |
|
if self.config.debug: |
|
print(f"Failed length check (long text): {var_len} vs {orig_len}") |
|
return False |
|
|
|
|
|
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'} |
|
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words) |
|
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words) |
|
|
|
|
|
if orig_len >= 5: |
|
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0 |
|
if content_overlap < 0.2: |
|
if self.config.debug: |
|
print(f"Failed content check: overlap {content_overlap:.2f}") |
|
return False |
|
else: |
|
if self.config.debug: |
|
print("Short turn detected (<5 words), skipping content overlap check") |
|
|
|
if self.config.debug: |
|
print("Passed all quick checks") |
|
return True |
|
|
|
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]: |
|
""" |
|
Filter variations using batched computations with detailed logging |
|
""" |
|
if not variations: |
|
return [] |
|
|
|
if self.config.debug: |
|
print(f"\nStarting filtration of {len(variations)} variations") |
|
print(f"Context length: {len(context)}") |
|
print(f"Original turn: {original_turn}") |
|
|
|
words = original_turn.split() |
|
orig_len = len(words) |
|
|
|
|
|
is_very_short = orig_len < 5 |
|
|
|
if len(words) < 3: |
|
if self.config.debug: |
|
print("Short text detected, using predefined variations") |
|
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''}) |
|
return [var['text'] for var in short_text_variations] |
|
|
|
|
|
if not context: |
|
preliminary_filtered = variations |
|
if self.config.debug: |
|
print("First turn - skipping preliminary filtering") |
|
else: |
|
|
|
preliminary_filtered = [] |
|
for var in variations: |
|
passed = self._quick_quality_check(var, original_turn) |
|
if self.config.debug: |
|
print(f"\nVariation: {var}") |
|
print(f"Passed quick check: {passed}") |
|
if passed: |
|
preliminary_filtered.append(var) |
|
|
|
if self.config.debug: |
|
print(f"Variations after quick check: {len(preliminary_filtered)}") |
|
|
|
if not preliminary_filtered: |
|
return [] |
|
|
|
|
|
original_embedding = self._compute_embedding(original_turn) |
|
variation_embeddings = self._compute_batch_embeddings(preliminary_filtered) |
|
|
|
|
|
sims = cosine_similarity([original_embedding], variation_embeddings)[0] |
|
|
|
|
|
dynamic_sem_threshold = self.semantic_similarity_threshold |
|
if is_very_short: |
|
dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05) |
|
|
|
|
|
refined_filtered = [] |
|
for var, sim in zip(preliminary_filtered, sims): |
|
if sim >= dynamic_sem_threshold: |
|
refined_filtered.append(var) |
|
else: |
|
if self.config.debug: |
|
print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}") |
|
|
|
if not refined_filtered: |
|
return [] |
|
|
|
|
|
|
|
|
|
if is_very_short: |
|
min_similarity = 0.05 |
|
min_coherence = 0.02 |
|
else: |
|
min_similarity = 0.1 |
|
min_coherence = 0.05 |
|
|
|
|
|
recent_context = [context[-1]] if context else [] |
|
context_text = ' '.join(recent_context) if recent_context else '' |
|
|
|
if context_text: |
|
if self.config.debug: |
|
print(f"\nContext text: {context_text}") |
|
|
|
all_texts = [context_text] + refined_filtered |
|
all_embeddings = self._compute_batch_embeddings(all_texts) |
|
|
|
context_embedding = all_embeddings[0] |
|
variation_embeddings = all_embeddings[1:] |
|
|
|
|
|
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0] |
|
|
|
|
|
if recent_context: |
|
prev_embedding = self._compute_embedding(recent_context[-1]) |
|
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0] |
|
else: |
|
response_coherence = np.ones_like(context_similarities) |
|
|
|
filtered_variations = [] |
|
for i, (variation, sim, coh) in enumerate(zip( |
|
refined_filtered, context_similarities, response_coherence)): |
|
combined_score = ( |
|
self.config.context_similarity_weight * abs(sim) + |
|
self.config.response_coherence_weight * abs(coh) |
|
) |
|
|
|
if self.config.debug: |
|
print(f"\nVariation: {variation}") |
|
print(f"Context similarity: {sim:.3f}") |
|
print(f"Response coherence: {coh:.3f}") |
|
print(f"Combined score: {combined_score:.3f}") |
|
|
|
|
|
if (combined_score >= min_similarity or abs(coh) >= min_coherence): |
|
filtered_variations.append(variation) |
|
if self.config.debug: |
|
print("ACCEPTED") |
|
else: |
|
if self.config.debug: |
|
print("REJECTED") |
|
|
|
|
|
if len(filtered_variations) >= self.config.max_variations_per_turn: |
|
break |
|
else: |
|
filtered_variations = refined_filtered[:self.config.max_variations_per_turn] |
|
|
|
if self.config.debug: |
|
print(f"\nFinal filtered variations: {len(filtered_variations)}") |
|
|
|
return filtered_variations |
|
|
|
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]: |
|
""" |
|
Generate variations progressively until we have enough good ones. |
|
Adjust paraphraser parameters for closer paraphrases as needed. |
|
""" |
|
variations = set() |
|
|
|
if self.config.debug: |
|
print(f"\nAttempting to generate {needed} variations for text: {text}") |
|
|
|
|
|
for augmenter in self.augmenters['advanced']: |
|
if len(variations) >= needed: |
|
break |
|
|
|
try: |
|
if isinstance(augmenter, Paraphraser): |
|
if self.config.debug: |
|
print("Trying paraphrase augmentation...") |
|
new_vars = augmenter.paraphrase( |
|
text, |
|
num_return_sequences=needed-len(variations), |
|
device=self.device if self.use_gpu else None, |
|
num_beams=4, |
|
num_beam_groups=1, |
|
diversity_penalty=0.0 |
|
) |
|
if self.config.debug: |
|
print(f"Paraphraser generated {len(new_vars)} variations") |
|
|
|
valid_vars = [v for v in new_vars if v.strip() and v != text] |
|
variations.update(valid_vars) |
|
|
|
if self.config.debug: |
|
print(f"Current unique variations: {len(variations)}") |
|
|
|
except Exception as e: |
|
print(f"Error in advanced augmentation: {str(e)}") |
|
continue |
|
|
|
|
|
if len(variations) < needed: |
|
if self.config.debug: |
|
print("Not enough variations, trying basic augmenters...") |
|
|
|
for aug_type, augmenter in self.augmenters['basic']: |
|
if len(variations) >= needed: |
|
break |
|
|
|
try: |
|
if self.config.debug: |
|
print(f"Trying {aug_type} augmentation...") |
|
|
|
new_vars = augmenter.augment(text, n=2) |
|
if isinstance(new_vars, list): |
|
valid_vars = [v for v in new_vars if v.strip() and v != text] |
|
variations.update(valid_vars) |
|
else: |
|
if new_vars.strip() and new_vars != text: |
|
variations.add(new_vars) |
|
|
|
if self.config.debug: |
|
print(f"After {aug_type}, total variations: {len(variations)}") |
|
|
|
except Exception as e: |
|
print(f"Error in {aug_type} augmentation: {str(e)}") |
|
continue |
|
|
|
variations_list = list(variations) |
|
|
|
if self.config.debug: |
|
print(f"Final number of variations generated: {len(variations_list)}") |
|
if not variations_list: |
|
print("WARNING: No variations were generated!") |
|
|
|
return variations_list |
|
|
|
def augment_dialogue(self, dialogue: Dict) -> List[Dict]: |
|
""" |
|
Create augmented versions of the dialogue with optimized processing |
|
""" |
|
|
|
original_length = len(dialogue['turns']) |
|
if original_length > self.config.max_turns_per_dialogue: |
|
if self.config.debug: |
|
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns") |
|
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue] |
|
|
|
turn_variations = [] |
|
context = [] |
|
|
|
|
|
for turn in dialogue['turns']: |
|
original_text = turn['text'] |
|
variations = self._generate_variations_progressive( |
|
original_text, |
|
self.config.max_variations_per_turn |
|
) |
|
|
|
|
|
filtered_variations = self._filter_variations_batch( |
|
variations, |
|
context, |
|
original_text |
|
) |
|
|
|
|
|
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations] |
|
|
|
if self.config.debug: |
|
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations") |
|
|
|
turn_variations.append(turn_vars) |
|
context.append(original_text) |
|
|
|
|
|
augmented_dialogues = self._generate_dialogue_combinations( |
|
dialogue['dialogue_id'], |
|
turn_variations, |
|
dialogue |
|
) |
|
|
|
|
|
result = [{ |
|
'dialogue_id': f"{dialogue['dialogue_id']}_original", |
|
'turns': dialogue['turns'] |
|
}] |
|
|
|
|
|
result.extend(augmented_dialogues[:self.config.augmentation_factor]) |
|
|
|
if self.config.debug: |
|
print(f"Generated {len(result)-1} unique augmented dialogues") |
|
|
|
return result |
|
|
|
def _variation_score(self, original: str, variation: str) -> float: |
|
""" |
|
Compute a single numeric score for a variation to guide selection. |
|
You could use semantic similarity, content preservation, etc. |
|
Higher is better. |
|
""" |
|
metrics = self.quality_metrics.compute_metrics(original, variation) |
|
|
|
|
|
score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3 |
|
return score |
|
|
|
def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float: |
|
""" |
|
Compute a quality score for the entire augmented dialogue. |
|
For example, average semantic similarity of turns to the original turns. |
|
This is done after the dialogue is formed. |
|
""" |
|
original_texts = [t['text'] for t in original_dialogue['turns']] |
|
aug_texts = [t['text'] for t in dialogue['turns']] |
|
|
|
|
|
scores = [] |
|
for orig, aug in zip(original_texts, aug_texts): |
|
|
|
emb_orig = self._compute_embedding(orig) |
|
emb_aug = self._compute_embedding(aug) |
|
sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug)) |
|
scores.append(sim) |
|
|
|
|
|
return float(np.mean(scores)) if scores else 0.0 |
|
|
|
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]: |
|
""" |
|
Generate dialogue combinations using a more controlled approach: |
|
- Include the original turn as a fallback variation for each turn. |
|
- Sort variations by a quality score. |
|
- Ensure a balanced augmentation by requiring at least some turns to be augmented. |
|
- Over-generate and then select top dialogues by quality. |
|
""" |
|
|
|
over_generate_factor = self.config.augmentation_factor * 2 |
|
|
|
|
|
for i, turn_variants in enumerate(turn_variations): |
|
original_turn_text = None |
|
|
|
|
|
original_turn_text = original_dialogue['turns'][i]['text'] |
|
|
|
|
|
if not any(v['text'] == original_turn_text for v in turn_variants): |
|
turn_variants.append({ |
|
'speaker': original_dialogue['turns'][i]['speaker'], |
|
'text': original_turn_text |
|
}) |
|
|
|
|
|
original_text = original_dialogue['turns'][i]['text'] |
|
turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True) |
|
|
|
augmented_dialogues = [] |
|
used_combinations = set() |
|
|
|
def generate_candidates(current_turns=None, turn_index=0): |
|
if current_turns is None: |
|
current_turns = [] |
|
|
|
if len(augmented_dialogues) >= over_generate_factor: |
|
return |
|
|
|
if turn_index == len(turn_variations): |
|
|
|
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) |
|
if dialogue_fingerprint not in used_combinations: |
|
used_combinations.add(dialogue_fingerprint) |
|
|
|
aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns) |
|
if orig['text'] != curr['text']) |
|
|
|
if aug_count >= max(1, len(turn_variations)//2): |
|
augmented_dialogues.append({ |
|
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", |
|
'turns': current_turns.copy() |
|
}) |
|
return |
|
|
|
turn_candidates = turn_variations[turn_index] |
|
|
|
|
|
|
|
if not turn_candidates: |
|
|
|
original_text = original_dialogue['turns'][turn_index]['text'] |
|
turn_candidates.append({ |
|
'speaker': original_dialogue['turns'][turn_index]['speaker'], |
|
'text': original_text |
|
}) |
|
|
|
|
|
if not turn_candidates: |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_vars = min(self.config.max_sampled_variations, len(turn_candidates)) |
|
|
|
|
|
candidates_to_pick = [turn_candidates[0]] |
|
|
|
|
|
if len(turn_candidates) > 2 and num_vars > 1: |
|
mid_index = len(turn_candidates)//2 |
|
candidates_to_pick.append(turn_candidates[mid_index]) |
|
|
|
|
|
if num_vars > len(candidates_to_pick): |
|
original_turn_text = original_dialogue['turns'][turn_index]['text'] |
|
orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None) |
|
if orig_candidate and orig_candidate not in candidates_to_pick: |
|
candidates_to_pick.append(orig_candidate) |
|
|
|
|
|
np.random.shuffle(candidates_to_pick) |
|
|
|
for variation in candidates_to_pick: |
|
if len(augmented_dialogues) >= over_generate_factor: |
|
return |
|
current_turns.append(variation) |
|
generate_candidates(current_turns, turn_index + 1) |
|
current_turns.pop() |
|
|
|
try: |
|
generate_candidates() |
|
except Exception as e: |
|
print(f"Error in dialogue generation: {str(e)}") |
|
return [] |
|
|
|
|
|
|
|
scored_dialogues = [] |
|
for d in augmented_dialogues: |
|
score = self._dialogue_quality_score(d, original_dialogue) |
|
scored_dialogues.append((score, d)) |
|
|
|
scored_dialogues.sort(key=lambda x: x[0], reverse=True) |
|
|
|
final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]] |
|
|
|
return final_dialogues |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool: |
|
""" |
|
Check if two dialogues are duplicates. |
|
""" |
|
text1 = " ".join(turn['text'] for turn in dialogue1['turns']) |
|
text2 = " ".join(turn['text'] for turn in dialogue2['turns']) |
|
return text1 == text2 |
|
|
|
def _augment_short_text(self, turn: Dict) -> List[Dict]: |
|
""" |
|
Special handling for very short texts with predefined variations. |
|
If predefined variations are found, return them directly. |
|
Otherwise, produce simple punctuation and capitalization variants. |
|
Skip heavy quality checks for efficiency. These variations are safe and minimal. |
|
""" |
|
text = turn['text'] |
|
common_variations = { |
|
'goodbye': [ |
|
'Bye!', 'Farewell!', 'See you!', 'Take care!', |
|
'Goodbye!', 'Bye for now!', 'Until next time!' |
|
], |
|
'hello': [ |
|
'Hi!', 'Hey!', 'Hello!', 'Greetings!', |
|
'Good day!', 'Hi there!', 'Hello there!' |
|
], |
|
'yes': [ |
|
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', |
|
'That\'s right!', 'Definitely!', 'Sure!' |
|
], |
|
'no': [ |
|
'No!', 'Nope!', 'Not at all!', 'Negative!', |
|
'Unfortunately not!', 'I\'m afraid not!' |
|
], |
|
'thanks': [ |
|
'Thank you!', 'Thanks a lot!', 'Many thanks!', |
|
'I appreciate it!', 'Thank you so much!' |
|
], |
|
'ok': [ |
|
'Okay!', 'Alright!', 'Sure!', 'Got it!', |
|
'Understood!', 'Fine!', 'Great!', 'Perfect!', |
|
'That works!', 'Sounds good!' |
|
], |
|
'good': [ |
|
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!', |
|
'Fantastic!', 'Amazing!', 'Terrific!' |
|
] |
|
} |
|
|
|
text_lower = text.lower().rstrip('!.,?') |
|
|
|
variations = [] |
|
for key, predefined_vars in common_variations.items(): |
|
if key in text_lower or text_lower in key: |
|
variations.extend(predefined_vars) |
|
|
|
if not variations: |
|
|
|
base = text.rstrip('!.,?') |
|
variations = [ |
|
base + '!', |
|
base + '.', |
|
base |
|
] |
|
|
|
|
|
capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations] |
|
variations.extend(capitalized) |
|
|
|
|
|
unique_variations = list(set(variations)) |
|
|
|
|
|
|
|
result_variations = unique_variations[:self.config.augmentation_factor] |
|
return [{'speaker': turn['speaker'], 'text': v} for v in result_variations] |
|
|
|
def process_batch(self, batch: List[Dict]) -> List[Dict]: |
|
"""Process multiple dialogues at once to maximize GPU utilization""" |
|
results = [] |
|
|
|
|
|
all_texts = [] |
|
text_to_embedding = {} |
|
|
|
for dialogue in batch: |
|
for turn in dialogue['turns']: |
|
all_texts.append(turn['text']) |
|
|
|
|
|
if all_texts: |
|
embeddings = self._compute_batch_embeddings(all_texts) |
|
for text, embedding in zip(all_texts, embeddings): |
|
self.embedding_cache[text] = embedding |
|
|
|
|
|
for dialogue in batch: |
|
try: |
|
augmented = self.augment_dialogue(dialogue) |
|
results.extend(augmented) |
|
except Exception as e: |
|
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}") |
|
continue |
|
|
|
return results |
|
|