from typing import Dict, List import numpy as np import torch import tensorflow as tf import tensorflow_hub as hub from pipeline_config import PipelineConfig from quality_metrics import QualityMetrics from paraphraser import Paraphraser import nlpaug.augmenter.word as naw from functools import lru_cache from sklearn.metrics.pairwise import cosine_similarity class DialogueAugmenter: """ Optimized dialogue augmentation with quality control and complexity management. """ def __init__(self, nlp, config: PipelineConfig): self.nlp = nlp self.config = config # Detect hardware and set appropriate batch sizes and optimization strategy self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.use_gpu = torch.cuda.is_available() if self.config.debug: print(f"Using device: {self.device}") if self.use_gpu: print(f"GPU Device: {torch.cuda.get_device_name(0)}") self.quality_metrics = QualityMetrics(config) self.semantic_similarity_threshold = 0.75 # Load model self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') # Initialize augmentation models based on hardware self._initialize_augmentation_models() # Initialize caches self.embedding_cache = {} # GPU memory management if available if self.use_gpu: gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) def _initialize_augmentation_models(self): """Initialize augmentation models with appropriate device settings""" # Advanced augmentation techniques self.paraphraser = Paraphraser() if self.use_gpu: # Move model to GPU if available self.paraphraser.model = self.paraphraser.model.to(self.device) # Basic augmentation techniques self.word_augmenter = naw.SynonymAug(aug_src='wordnet') self.augmenters = { 'advanced': [ self.paraphraser, ], 'basic': [ ('synonym', self.word_augmenter), ] } @lru_cache(maxsize=1024) def _compute_embedding(self, text: str) -> np.ndarray: """Cached computation of text embedding""" if text in self.embedding_cache: return self.embedding_cache[text] embedding = self.use_model([text])[0].numpy() self.embedding_cache[text] = embedding return embedding def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray: """Compute embeddings for multiple texts at once with hardware optimization""" # Check cache first uncached_texts = [t for t in texts if t not in self.embedding_cache] if uncached_texts: embeddings = self.use_model(uncached_texts).numpy() # Update cache for text, embedding in zip(uncached_texts, embeddings): self.embedding_cache[text] = embedding # Return all embeddings (from cache or newly computed) return np.array([self.embedding_cache[t] for t in texts]) def _quick_quality_check(self, variation: str, original: str) -> bool: """ Preliminary quality check while maintaining reasonable pass rates """ if self.config.debug: print(f"\nQuick check for variation: {variation}") orig_len = len(original.split()) var_len = len(variation.split()) # For very short texts (<= 3 words), still allow more variation if orig_len <= 3: if var_len > orig_len * 3: if self.config.debug: print(f"Failed length check (short text): {var_len} vs {orig_len}") return False else: if var_len > orig_len * 2: if self.config.debug: print(f"Failed length check (long text): {var_len} vs {orig_len}") return False # Adjust content overlap check based on length stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'} orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words) var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words) # If very short turn (less than 5 words), skip the content overlap check if orig_len >= 5: content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0 if content_overlap < 0.2: if self.config.debug: print(f"Failed content check: overlap {content_overlap:.2f}") return False else: if self.config.debug: print("Short turn detected (<5 words), skipping content overlap check") if self.config.debug: print("Passed all quick checks") return True def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]: """ Filter variations using batched computations with detailed logging """ if not variations: return [] if self.config.debug: print(f"\nStarting filtration of {len(variations)} variations") print(f"Context length: {len(context)}") print(f"Original turn: {original_turn}") words = original_turn.split() orig_len = len(words) # If very short text, consider adjusting thresholds is_very_short = orig_len < 5 if len(words) < 3: if self.config.debug: print("Short text detected, using predefined variations") short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''}) return [var['text'] for var in short_text_variations] # If this is the first turn (no context), be more lenient if not context: preliminary_filtered = variations if self.config.debug: print("First turn - skipping preliminary filtering") else: # Quick preliminary filtering against original turn preliminary_filtered = [] for var in variations: passed = self._quick_quality_check(var, original_turn) if self.config.debug: print(f"\nVariation: {var}") print(f"Passed quick check: {passed}") if passed: preliminary_filtered.append(var) if self.config.debug: print(f"Variations after quick check: {len(preliminary_filtered)}") if not preliminary_filtered: return [] # Compute embeddings for original and variations original_embedding = self._compute_embedding(original_turn) variation_embeddings = self._compute_batch_embeddings(preliminary_filtered) # Compute similarities sims = cosine_similarity([original_embedding], variation_embeddings)[0] # If very short turn, slightly lower the semantic similarity threshold dynamic_sem_threshold = self.semantic_similarity_threshold if is_very_short: dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05) # Filter by semantic similarity threshold refined_filtered = [] for var, sim in zip(preliminary_filtered, sims): if sim >= dynamic_sem_threshold: refined_filtered.append(var) else: if self.config.debug: print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}") if not refined_filtered: return [] # Relax context coherence thresholds further if desired # We already have min_similarity = 0.1, min_coherence = 0.05 # Let's lower them slightly more if the turn is very short: if is_very_short: min_similarity = 0.05 min_coherence = 0.02 else: min_similarity = 0.1 min_coherence = 0.05 # Only use last turn for coherence recent_context = [context[-1]] if context else [] context_text = ' '.join(recent_context) if recent_context else '' if context_text: if self.config.debug: print(f"\nContext text: {context_text}") all_texts = [context_text] + refined_filtered all_embeddings = self._compute_batch_embeddings(all_texts) context_embedding = all_embeddings[0] variation_embeddings = all_embeddings[1:] # Vectorized similarity computation context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0] # Response coherence check if recent_context: prev_embedding = self._compute_embedding(recent_context[-1]) response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0] else: response_coherence = np.ones_like(context_similarities) filtered_variations = [] for i, (variation, sim, coh) in enumerate(zip( refined_filtered, context_similarities, response_coherence)): combined_score = ( self.config.context_similarity_weight * abs(sim) + self.config.response_coherence_weight * abs(coh) ) if self.config.debug: print(f"\nVariation: {variation}") print(f"Context similarity: {sim:.3f}") print(f"Response coherence: {coh:.3f}") print(f"Combined score: {combined_score:.3f}") # Accept if EITHER score is good enough if (combined_score >= min_similarity or abs(coh) >= min_coherence): filtered_variations.append(variation) if self.config.debug: print("ACCEPTED") else: if self.config.debug: print("REJECTED") # If we have enough variations, stop if len(filtered_variations) >= self.config.max_variations_per_turn: break else: filtered_variations = refined_filtered[:self.config.max_variations_per_turn] if self.config.debug: print(f"\nFinal filtered variations: {len(filtered_variations)}") return filtered_variations def _generate_variations_progressive(self, text: str, needed: int) -> List[str]: """ Generate variations progressively until we have enough good ones. Adjust paraphraser parameters for closer paraphrases as needed. """ variations = set() if self.config.debug: print(f"\nAttempting to generate {needed} variations for text: {text}") # Fine-tune paraphraser here if needed: fewer beams, less diversity already done for augmenter in self.augmenters['advanced']: if len(variations) >= needed: break try: if isinstance(augmenter, Paraphraser): if self.config.debug: print("Trying paraphrase augmentation...") new_vars = augmenter.paraphrase( text, num_return_sequences=needed-len(variations), device=self.device if self.use_gpu else None, num_beams=4, # even fewer beams for more faithful paraphrases num_beam_groups=1, diversity_penalty=0.0 ) if self.config.debug: print(f"Paraphraser generated {len(new_vars)} variations") valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) if self.config.debug: print(f"Current unique variations: {len(variations)}") except Exception as e: print(f"Error in advanced augmentation: {str(e)}") continue # Try basic augmenters if needed if len(variations) < needed: if self.config.debug: print("Not enough variations, trying basic augmenters...") for aug_type, augmenter in self.augmenters['basic']: if len(variations) >= needed: break try: if self.config.debug: print(f"Trying {aug_type} augmentation...") new_vars = augmenter.augment(text, n=2) if isinstance(new_vars, list): valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) else: if new_vars.strip() and new_vars != text: variations.add(new_vars) if self.config.debug: print(f"After {aug_type}, total variations: {len(variations)}") except Exception as e: print(f"Error in {aug_type} augmentation: {str(e)}") continue variations_list = list(variations) if self.config.debug: print(f"Final number of variations generated: {len(variations_list)}") if not variations_list: print("WARNING: No variations were generated!") return variations_list def augment_dialogue(self, dialogue: Dict) -> List[Dict]: """ Create augmented versions of the dialogue with optimized processing """ # Early dialogue length check original_length = len(dialogue['turns']) if original_length > self.config.max_turns_per_dialogue: if self.config.debug: print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns") dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue] turn_variations = [] context = [] # Process each turn with progressive generation for turn in dialogue['turns']: original_text = turn['text'] # Store original turn text variations = self._generate_variations_progressive( original_text, self.config.max_variations_per_turn ) # Batch filter variations with original text filtered_variations = self._filter_variations_batch( variations, context, original_text # Pass the original turn text ) # Create turn variations with speaker info turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations] if self.config.debug: print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations") turn_variations.append(turn_vars) context.append(original_text) # Generate combinations with sampling augmented_dialogues = self._generate_dialogue_combinations( dialogue['dialogue_id'], turn_variations, dialogue ) # Add original dialogue result = [{ 'dialogue_id': f"{dialogue['dialogue_id']}_original", 'turns': dialogue['turns'] }] # Add unique augmentations result.extend(augmented_dialogues[:self.config.augmentation_factor]) if self.config.debug: print(f"Generated {len(result)-1} unique augmented dialogues") return result def _variation_score(self, original: str, variation: str) -> float: """ Compute a single numeric score for a variation to guide selection. You could use semantic similarity, content preservation, etc. Higher is better. """ metrics = self.quality_metrics.compute_metrics(original, variation) # Example: Primarily semantic similarity, with a slight boost for content preservation # Adjust as needed. score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3 return score def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float: """ Compute a quality score for the entire augmented dialogue. For example, average semantic similarity of turns to the original turns. This is done after the dialogue is formed. """ original_texts = [t['text'] for t in original_dialogue['turns']] aug_texts = [t['text'] for t in dialogue['turns']] # Compute semantic similarity turn-by-turn and average it scores = [] for orig, aug in zip(original_texts, aug_texts): # Simple semantic similarity for scoring emb_orig = self._compute_embedding(orig) emb_aug = self._compute_embedding(aug) sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug)) scores.append(sim) # Could also incorporate diversity checks, content overlap, etc. return float(np.mean(scores)) if scores else 0.0 def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]: """ Generate dialogue combinations using a more controlled approach: - Include the original turn as a fallback variation for each turn. - Sort variations by a quality score. - Ensure a balanced augmentation by requiring at least some turns to be augmented. - Over-generate and then select top dialogues by quality. """ # Over-generate factor: create more candidates than needed over_generate_factor = self.config.augmentation_factor * 2 # Add the original turn as a fallback variation for each turn if not present for i, turn_variants in enumerate(turn_variations): original_turn_text = None # Check if we previously stored original turn text with a marker or just use the original dialogue # If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue. original_turn_text = original_dialogue['turns'][i]['text'] # Add the original turn as a variation if not already included if not any(v['text'] == original_turn_text for v in turn_variants): turn_variants.append({ 'speaker': original_dialogue['turns'][i]['speaker'], 'text': original_turn_text }) # Sort variations by score original_text = original_dialogue['turns'][i]['text'] turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True) augmented_dialogues = [] used_combinations = set() def generate_candidates(current_turns=None, turn_index=0): if current_turns is None: current_turns = [] if len(augmented_dialogues) >= over_generate_factor: return if turn_index == len(turn_variations): # Completed a candidate dialogue dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) if dialogue_fingerprint not in used_combinations: used_combinations.add(dialogue_fingerprint) # Check if we have enough augmented turns aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns) if orig['text'] != curr['text']) # Require at least half the turns to be augmented, for example if aug_count >= max(1, len(turn_variations)//2): augmented_dialogues.append({ 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", 'turns': current_turns.copy() }) return turn_candidates = turn_variations[turn_index] # If no variations are available for this turn, let's just return without error. # Normally, this shouldn't happen since we always add the original turn above. if not turn_candidates: # If you want to at least have the original turn, add it now: original_text = original_dialogue['turns'][turn_index]['text'] turn_candidates.append({ 'speaker': original_dialogue['turns'][turn_index]['speaker'], 'text': original_text }) # After the fallback, if still empty for some reason, just return. if not turn_candidates: return # Example strategy: # 1. Always try the top variation (most semantically similar). # 2. If available and allowed, pick a mid-ranked variation for diversity. # 3. Include the original turn if not selected yet. num_vars = min(self.config.max_sampled_variations, len(turn_candidates)) # Always include top variation candidates_to_pick = [turn_candidates[0]] # If we have more than 2 variations and can pick more, add a middle variation for diversity if len(turn_candidates) > 2 and num_vars > 1: mid_index = len(turn_candidates)//2 candidates_to_pick.append(turn_candidates[mid_index]) # If we still have room for another variation, try adding the original turn if not included if num_vars > len(candidates_to_pick): original_turn_text = original_dialogue['turns'][turn_index]['text'] orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None) if orig_candidate and orig_candidate not in candidates_to_pick: candidates_to_pick.append(orig_candidate) # Shuffle candidates to produce different dialogues np.random.shuffle(candidates_to_pick) for variation in candidates_to_pick: if len(augmented_dialogues) >= over_generate_factor: return current_turns.append(variation) generate_candidates(current_turns, turn_index + 1) current_turns.pop() try: generate_candidates() except Exception as e: print(f"Error in dialogue generation: {str(e)}") return [] # Over-generated set of augmented dialogues is now available # Let's score them and pick the top ones scored_dialogues = [] for d in augmented_dialogues: score = self._dialogue_quality_score(d, original_dialogue) scored_dialogues.append((score, d)) scored_dialogues.sort(key=lambda x: x[0], reverse=True) # Pick top `augmentation_factor` dialogues final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]] return final_dialogues # def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]: # """ # Generate dialogue combinations using sampling # """ # augmented_dialogues = [] # used_combinations = set() # def generate_dialogues(current_turns=None, turn_index=0): # if current_turns is None: # current_turns = [] # if len(augmented_dialogues) >= self.config.augmentation_factor: # return # if turn_index == len(turn_variations): # dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) # if dialogue_fingerprint not in used_combinations: # used_combinations.add(dialogue_fingerprint) # augmented_dialogues.append({ # 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", # 'turns': current_turns.copy() # }) # return # variations = list(turn_variations[turn_index]) # np.random.shuffle(variations) # for variation in variations[:self.config.max_sampled_variations]: # if len(augmented_dialogues) >= self.config.augmentation_factor: # return # current_turns.append(variation) # generate_dialogues(current_turns, turn_index + 1) # current_turns.pop() # try: # generate_dialogues() # except Exception as e: # print(f"Error in dialogue generation: {str(e)}") # return [] # return augmented_dialogues def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool: """ Check if two dialogues are duplicates. """ text1 = " ".join(turn['text'] for turn in dialogue1['turns']) text2 = " ".join(turn['text'] for turn in dialogue2['turns']) return text1 == text2 def _augment_short_text(self, turn: Dict) -> List[Dict]: """ Special handling for very short texts with predefined variations. If predefined variations are found, return them directly. Otherwise, produce simple punctuation and capitalization variants. Skip heavy quality checks for efficiency. These variations are safe and minimal. """ text = turn['text'] common_variations = { 'goodbye': [ 'Bye!', 'Farewell!', 'See you!', 'Take care!', 'Goodbye!', 'Bye for now!', 'Until next time!' ], 'hello': [ 'Hi!', 'Hey!', 'Hello!', 'Greetings!', 'Good day!', 'Hi there!', 'Hello there!' ], 'yes': [ 'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', 'That\'s right!', 'Definitely!', 'Sure!' ], 'no': [ 'No!', 'Nope!', 'Not at all!', 'Negative!', 'Unfortunately not!', 'I\'m afraid not!' ], 'thanks': [ 'Thank you!', 'Thanks a lot!', 'Many thanks!', 'I appreciate it!', 'Thank you so much!' ], 'ok': [ 'Okay!', 'Alright!', 'Sure!', 'Got it!', 'Understood!', 'Fine!', 'Great!', 'Perfect!', 'That works!', 'Sounds good!' ], 'good': [ 'Great!', 'Excellent!', 'Perfect!', 'Wonderful!', 'Fantastic!', 'Amazing!', 'Terrific!' ] } text_lower = text.lower().rstrip('!.,?') # Check if text matches any predefined category variations = [] for key, predefined_vars in common_variations.items(): if key in text_lower or text_lower in key: variations.extend(predefined_vars) if not variations: # Generate simple punctuation and capitalization variations if no predefined match base = text.rstrip('!.,?') variations = [ base + '!', base + '.', base ] # Add capitalization variations capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations] variations.extend(capitalized) # Ensure uniqueness unique_variations = list(set(variations)) # Directly return these variations, as they are minimal and trusted # No further quality checks are needed result_variations = unique_variations[:self.config.augmentation_factor] return [{'speaker': turn['speaker'], 'text': v} for v in result_variations] def process_batch(self, batch: List[Dict]) -> List[Dict]: """Process multiple dialogues at once to maximize GPU utilization""" results = [] # Pre-compute embeddings for all texts in batch all_texts = [] text_to_embedding = {} for dialogue in batch: for turn in dialogue['turns']: all_texts.append(turn['text']) # Batch compute embeddings if all_texts: embeddings = self._compute_batch_embeddings(all_texts) for text, embedding in zip(all_texts, embeddings): self.embedding_cache[text] = embedding # Process each dialogue using cached embeddings for dialogue in batch: try: augmented = self.augment_dialogue(dialogue) results.extend(augmented) except Exception as e: print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}") continue return results