from typing import Dict, List import numpy as np import tensorflow as tf import tensorflow_hub as hub import re from pipeline_config import PipelineConfig from quality_metrics import QualityMetrics from paraphraser import Paraphraser from back_translator import BackTranslator import nlpaug.augmenter.word as naw from concurrent.futures import ThreadPoolExecutor from functools import lru_cache from sklearn.metrics.pairwise import cosine_similarity class DialogueAugmenter: """ Optimized dialogue augmentation with quality control and complexity management. """ def __init__(self, nlp, config: PipelineConfig): self.nlp = nlp self.config = config self.quality_metrics = QualityMetrics(config) self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') # Advanced augmentation techniques self.paraphraser = Paraphraser() self.back_translator = BackTranslator() # Basic augmentation techniques self.word_augmenter = naw.SynonymAug(aug_src='wordnet') self.spelling_augmenter = naw.SpellingAug() self.augmenters = { 'advanced': [self.paraphraser, self.back_translator], 'basic': [ ('synonym', self.word_augmenter), ('spelling', self.spelling_augmenter) ] } # Initialize cache self.embedding_cache = {} self.perplexity_cache = {} # Compile regex patterns self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}') # GPU memory management gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) @lru_cache(maxsize=1024) def _compute_embedding(self, text: str) -> np.ndarray: """Cached computation of text embedding""" return self.use_model([text])[0].numpy() def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray: """Compute embeddings for multiple texts at once""" return self.use_model(texts).numpy() def _quick_quality_check(self, variation: str, original: str) -> bool: """ Simplified preliminary quality check with minimal standards """ if self.config.debug: print(f"\nQuick check for variation: {variation}") # Only reject if length is extremely different orig_len = len(original.split()) var_len = len(variation.split()) # For very short texts (1-3 words), allow more variation if orig_len <= 3: if var_len > orig_len * 4: # Allow up to 4x length for short texts if self.config.debug: print(f"Failed length check (short text): {var_len} vs {orig_len}") return False else: if var_len > orig_len * 3: # Allow up to 3x length for longer texts if self.config.debug: print(f"Failed length check (long text): {var_len} vs {orig_len}") return False # Basic content check - at least one word in common (excluding stop words) stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'} orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words) var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words) if not orig_words.intersection(var_words): if self.config.debug: print("Failed content check: no content words in common") return False if self.config.debug: print("Passed all quick checks") return True def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]: """Compute quality metrics for multiple candidates in parallel""" with ThreadPoolExecutor(max_workers=4) as executor: futures = [ executor.submit(self.quality_metrics.compute_metrics, original, candidate) for candidate in candidates ] return [future.result() for future in futures] def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]: """ Filter variations using batched computations with detailed logging """ if not variations: return [] if self.config.debug: print(f"\nStarting filtration of {len(variations)} variations") print(f"Context length: {len(context)}") print(f"Original turn: {original_turn}") words = original_turn.split() if len(words) < 3: if self.config.debug: print("Short text detected, using predefined variations") short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''}) return [var['text'] for var in short_text_variations] # If this is the first turn (no context), be more lenient if not context: preliminary_filtered = variations if self.config.debug: print("First turn - skipping preliminary filtering") else: # Quick preliminary filtering against original turn preliminary_filtered = [] for var in variations: passed = self._quick_quality_check(var, original_turn) if self.config.debug: print(f"\nVariation: {var}") print(f"Passed quick check: {passed}") if passed: preliminary_filtered.append(var) if self.config.debug: print(f"Variations after quick check: {len(preliminary_filtered)}") if not preliminary_filtered: return [] # Only use last turn for coherence recent_context = [context[-1]] if context else [] context_text = ' '.join(recent_context) if recent_context else '' # Even more lenient thresholds min_similarity = 0.1 # Further reduced min_coherence = 0.05 # Further reduced if context_text: if self.config.debug: print(f"\nContext text: {context_text}") all_texts = [context_text] + preliminary_filtered all_embeddings = self._compute_batch_embeddings(all_texts) context_embedding = all_embeddings[0] variation_embeddings = all_embeddings[1:] # Vectorized similarity computation context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0] # Response coherence check if recent_context: prev_embedding = self._compute_embedding(recent_context[-1]) response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0] else: response_coherence = np.ones_like(context_similarities) # Combined scoring with detailed logging filtered_variations = [] for i, (variation, sim, coh) in enumerate(zip( preliminary_filtered, context_similarities, response_coherence)): # Use absolute values for scoring combined_score = ( self.config.context_similarity_weight * abs(sim) + self.config.response_coherence_weight * abs(coh) ) if self.config.debug: print(f"\nVariation: {variation}") print(f"Context similarity: {sim:.3f}") print(f"Response coherence: {coh:.3f}") print(f"Combined score: {combined_score:.3f}") # Accept if EITHER score is good enough if (combined_score >= min_similarity or abs(coh) >= min_coherence): filtered_variations.append(variation) if self.config.debug: print("ACCEPTED") else: if self.config.debug: print("REJECTED") # If we have enough variations, stop if len(filtered_variations) >= self.config.max_variations_per_turn: break else: filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn] if self.config.debug: print(f"\nFinal filtered variations: {len(filtered_variations)}") return filtered_variations def _generate_variations_progressive(self, text: str, needed: int) -> List[str]: """ Generate variations progressively until we have enough good ones """ variations = set() if self.config.debug: print(f"\nAttempting to generate {needed} variations for text: {text}") # Try advanced augmenters first for augmenter in self.augmenters['advanced']: if len(variations) >= needed: break try: if isinstance(augmenter, Paraphraser): if self.config.debug: print("Trying paraphrase augmentation...") new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations)) if self.config.debug: print(f"Paraphraser generated {len(new_vars)} variations") else: if self.config.debug: print("Trying back translation...") new_vars = [augmenter.back_translate(text)] if self.config.debug: print(f"Back translator generated {len(new_vars)} variations") valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) if self.config.debug: print(f"Current unique variations: {len(variations)}") except Exception as e: print(f"Error in advanced augmentation: {str(e)}") continue # Try basic augmenters if needed if len(variations) < needed: if self.config.debug: print("Not enough variations, trying basic augmenters...") for aug_type, augmenter in self.augmenters['basic']: if len(variations) >= needed: break try: if aug_type == 'spelling' and self._is_technical_or_formal_text(text): if self.config.debug: print("Skipping spelling augmentation for technical text") continue if self.config.debug: print(f"Trying {aug_type} augmentation...") new_vars = augmenter.augment(text, n=2) if isinstance(new_vars, list): valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) else: if new_vars.strip() and new_vars != text: variations.add(new_vars) if self.config.debug: print(f"After {aug_type}, total variations: {len(variations)}") except Exception as e: print(f"Error in {aug_type} augmentation: {str(e)}") continue variations_list = list(variations) if self.config.debug: print(f"Final number of variations generated: {len(variations_list)}") if not variations_list: print("WARNING: No variations were generated!") return variations_list def augment_dialogue(self, dialogue: Dict) -> List[Dict]: """ Create augmented versions of the dialogue with optimized processing """ # Early dialogue length check original_length = len(dialogue['turns']) if original_length > self.config.max_turns_per_dialogue: if self.config.debug: print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns") dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue] turn_variations = [] context = [] # Process each turn with progressive generation for turn in dialogue['turns']: original_text = turn['text'] # Store original turn text variations = self._generate_variations_progressive( original_text, self.config.max_variations_per_turn ) # Batch filter variations with original text filtered_variations = self._filter_variations_batch( variations, context, original_text # Pass the original turn text ) # Create turn variations with speaker info turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations] if self.config.debug: print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations") turn_variations.append(turn_vars) context.append(original_text) # Generate combinations with sampling augmented_dialogues = self._generate_dialogue_combinations( dialogue['dialogue_id'], turn_variations ) # Add original dialogue result = [{ 'dialogue_id': f"{dialogue['dialogue_id']}_original", 'turns': dialogue['turns'] }] # Add unique augmentations result.extend(augmented_dialogues[:self.config.augmentation_factor]) if self.config.debug: print(f"Generated {len(result)-1} unique augmented dialogues") return result def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]: """ Generate dialogue combinations using sampling """ augmented_dialogues = [] used_combinations = set() def generate_dialogues(current_turns=None, turn_index=0): if current_turns is None: current_turns = [] if len(augmented_dialogues) >= self.config.augmentation_factor: return if turn_index == len(turn_variations): dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) if dialogue_fingerprint not in used_combinations: used_combinations.add(dialogue_fingerprint) augmented_dialogues.append({ 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", 'turns': current_turns.copy() }) return variations = list(turn_variations[turn_index]) np.random.shuffle(variations) for variation in variations[:self.config.max_sampled_variations]: if len(augmented_dialogues) >= self.config.augmentation_factor: return current_turns.append(variation) generate_dialogues(current_turns, turn_index + 1) current_turns.pop() try: generate_dialogues() except Exception as e: print(f"Error in dialogue generation: {str(e)}") return [] return augmented_dialogues def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool: """ Check if two dialogues are duplicates. """ text1 = " ".join(turn['text'] for turn in dialogue1['turns']) text2 = " ".join(turn['text'] for turn in dialogue2['turns']) return text1 == text2 # def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]: # """ # Generate augmented versions of the turn using multiple strategies. # """ # text = turn['text'] # words = text.split() # # Special handling for very short texts # if len(words) < 3: # return self._augment_short_text(turn) # all_variations = set() # # Advanced augmentations (paraphrase and back-translation) # for augmenter in self.augmenters['advanced']: # try: # if isinstance(augmenter, Paraphraser): # variations = augmenter.paraphrase(text) # all_variations.update(variations) # elif isinstance(augmenter, BackTranslator): # aug_text = augmenter.back_translate(text) # if aug_text: # all_variations.add(aug_text) # except Exception as e: # print(f"Error in advanced augmentation: {str(e)}") # continue # # Basic nlpaug augmentations # for aug_type, augmenter in self.augmenters['basic']: # try: # if aug_type == 'spelling' and self._is_technical_or_formal_text(text): # continue # aug_texts = augmenter.augment(text, n=2) # if isinstance(aug_texts, list): # all_variations.update(aug_texts) # else: # all_variations.add(aug_texts) # except Exception as e: # print(f"Error in {aug_type} augmentation: {str(e)}") # continue # # Remove exact duplicates and empty strings # augmented_texts = [t for t in list(all_variations) if t.strip()] # # Apply context filtering # if context: # augmented_texts = self._filter_by_context(augmented_texts, context) # print(f"After context filtering: {len(augmented_texts)} variations") # # Select best variations # best_variations = self._select_best_augmentations( # text, # augmented_texts, # num_to_select=self.config.augmentation_factor, # min_quality_score=0.7 # ) # # Create variations with speaker info # variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations] # return variations def _augment_short_text(self, turn: Dict) -> List[Dict]: """ Special handling for very short texts with predefined variations. Args: turn (Dict): Original dialogue turn Returns: List[Dict]: List of variations for the short text """ text = turn['text'] common_variations = { 'goodbye': [ 'Bye!', 'Farewell!', 'See you!', 'Take care!', 'Goodbye!', 'Bye for now!', 'Until next time!' ], 'hello': [ 'Hi!', 'Hey!', 'Hello!', 'Greetings!', 'Good day!', 'Hi there!', 'Hello there!' ], 'yes': [ 'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', 'That\'s right!', 'Definitely!', 'Sure!' ], 'no': [ 'No!', 'Nope!', 'Not at all!', 'Negative!', 'Unfortunately not!', 'I\'m afraid not!' ], 'thanks': [ 'Thank you!', 'Thanks a lot!', 'Many thanks!', 'I appreciate it!', 'Thank you so much!' ], 'ok': [ 'Okay!', 'Alright!', 'Sure!', 'Got it!', 'Understood!', 'Fine!', 'Great!', 'Perfect!', 'That works!', 'Sounds good!' ], 'good': [ 'Great!', 'Excellent!', 'Perfect!', 'Wonderful!', 'Fantastic!', 'Amazing!', 'Terrific!' ] } # Try to find matching variations text_lower = text.lower().rstrip('!.,?') variations = [] # Check if text matches any of our predefined categories for key, predefined_vars in common_variations.items(): if key in text_lower or text_lower in key: variations.extend(predefined_vars) # If no predefined variations found, generate simple variants if not variations: # Add punctuation variations variations = [ text.rstrip('!.,?') + '!', text.rstrip('!.,?') + '.', text.rstrip('!.,?') ] # Add capitalization variations variations.extend([ v.capitalize() for v in variations if v.capitalize() not in variations ]) # Filter variations for uniqueness and quality unique_variations = list(set(variations)) quality_variations = [] for var in unique_variations: metrics = self.quality_metrics.compute_metrics(text, var) quality_score = ( 0.35 * metrics['semantic_similarity'] + 0.30 * (1.0 - metrics['perplexity'] / 100) + 0.15 * (1.0 - metrics['grammar_errors'] / 10) + 0.15 * metrics['content_preservation'] + 0.10 * metrics['type_token_ratio'] ) # More lenient quality threshold for short texts if quality_score >= 0.5: # Lower threshold for short texts quality_variations.append(var) # Ensure we have at least some variations if not quality_variations: quality_variations = [text] # Return the variations with original speaker return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]] def _is_technical_or_formal_text(self, text: str) -> bool: """ Check if text is formal/technical and shouldn't have spelling variations. """ formal_indicators = { 'technical_terms': {'api', 'config', 'database', 'server', 'system'}, 'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'}, 'professional_context': {'meeting', 'conference', 'project', 'deadline'} } text_lower = text.lower() words = set(text_lower.split()) for category in formal_indicators.values(): if words.intersection(category): return True return False # def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]: # """ # Filter variations based on conversation context using config parameters. # """ # # Manage context window using config # recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context # filtered_variations = [] # context_embedding = self.use_model([' '.join(recent_context)])[0].numpy() # prev_turn = recent_context[-1] if recent_context else '' # for variation in variations: # var_embedding = self.use_model([variation])[0].numpy() # # Overall context similarity # context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0] # # Direct response coherence # response_coherence = 1.0 # if prev_turn: # prev_embedding = self.use_model([prev_turn])[0].numpy() # response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0] # # Use weights from config # combined_similarity = ( # self.config.context_similarity_weight * context_similarity + # self.config.response_coherence_weight * response_coherence # ) # if (combined_similarity >= self.config.semantic_similarity_threshold and # response_coherence >= self.config.min_response_coherence): # filtered_variations.append(variation) # if self.config.debug: # print(f"Accepted variation: {variation}") # print(f"Context similarity: {context_similarity:.3f}") # print(f"Response coherence: {response_coherence:.3f}") # print(f"Combined score: {combined_similarity:.3f}\n") # else: # if self.config.debug: # print(f"Rejected variation: {variation}") # print(f"Combined score {combined_similarity:.3f} below threshold " # f"{self.config.semantic_similarity_threshold}") # print(f"Response coherence {response_coherence:.3f} below threshold " # f"{self.config.min_response_coherence}\n") # return filtered_variations or variations # Fallback to original # def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None, # num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]: # """ # Select the best augmentations using a quality score. # Args: # original (str): The original text # candidates (List[str]): List of candidate augmented texts # used_variations (set): Set of already used variations # num_to_select (int): Number of variations to select # min_quality_score (float): Minimum quality score threshold # """ # if used_variations is None: # used_variations = set() # candidates = [c for c in candidates if c.strip()] # # Skip short text # if len(original.split()) < 3: # print(f"Text too short for augmentation: {original}") # return [original] # scored_candidates = [] # for candidate in candidates: # if candidate in used_variations: # continue # metrics = self.quality_metrics.compute_metrics(original, candidate) # # Add contextual penalty for inappropriate audience terms # audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'} # has_audience_term = any(term in candidate.lower() for term in audience_terms) # audience_penalty = 0.2 if has_audience_term else 0.0 # # Weighted quality score # quality_score = ( # 0.40 * metrics['semantic_similarity'] + # Semantic preservation # 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency # 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar # 0.15 * metrics['content_preservation'] + # Content preservation # 0.05 * metrics['type_token_ratio'] # Lexical diversity # ) # quality_score -= audience_penalty # if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss # metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap # continue # # Bonus points for: # # Length similarity to original # if 0.75 <= metrics['length_ratio'] <= 1.25: # quality_score += 0.05 # # Correct grammar # if metrics['grammar_errors'] == 0: # quality_score += 0.025 # print(f"Candidate: {candidate}") # print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}") # # Consider the augmentationif meets basic quality threshold # if quality_score >= min_quality_score: # print('Candidate accepted\n') # scored_candidates.append((candidate, quality_score, metrics)) # else: # print('Candidate rejected\n') # # Sort by quality score with small random factor for diversity # scored_candidates.sort(key=lambda x: x[1], reverse=True) # selected = [] # for candidate, score, metrics in scored_candidates: # # Check diversity against already selected # if len(selected) == 0: # selected.append(candidate) # continue # # Compute average similarity to already selected # avg_similarity = np.mean([ # self.quality_metrics.compute_semantic_similarity(candidate, prev) # for prev in selected # ]) # # Add if sufficiently different (similarity < 0.98) # if avg_similarity < 0.98: # selected.append(candidate) # if len(selected) >= num_to_select: # break # return selected