|
from typing import Dict, List |
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
import re |
|
from pipeline_config import PipelineConfig |
|
from quality_metrics import QualityMetrics |
|
from paraphraser import Paraphraser |
|
from back_translator import BackTranslator |
|
import nlpaug.augmenter.word as naw |
|
from concurrent.futures import ThreadPoolExecutor |
|
from functools import lru_cache |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class DialogueAugmenter: |
|
""" |
|
Optimized dialogue augmentation with quality control and complexity management. |
|
""" |
|
def __init__(self, nlp, config: PipelineConfig): |
|
self.nlp = nlp |
|
self.config = config |
|
self.quality_metrics = QualityMetrics(config) |
|
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
|
|
|
|
|
self.paraphraser = Paraphraser() |
|
self.back_translator = BackTranslator() |
|
|
|
|
|
self.word_augmenter = naw.SynonymAug(aug_src='wordnet') |
|
self.spelling_augmenter = naw.SpellingAug() |
|
|
|
self.augmenters = { |
|
'advanced': [self.paraphraser, self.back_translator], |
|
'basic': [ |
|
('synonym', self.word_augmenter), |
|
('spelling', self.spelling_augmenter) |
|
] |
|
} |
|
|
|
|
|
self.embedding_cache = {} |
|
self.perplexity_cache = {} |
|
|
|
|
|
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}') |
|
|
|
|
|
gpus = tf.config.list_physical_devices('GPU') |
|
if gpus: |
|
try: |
|
for gpu in gpus: |
|
tf.config.experimental.set_memory_growth(gpu, True) |
|
except RuntimeError as e: |
|
print(e) |
|
|
|
@lru_cache(maxsize=1024) |
|
def _compute_embedding(self, text: str) -> np.ndarray: |
|
"""Cached computation of text embedding""" |
|
return self.use_model([text])[0].numpy() |
|
|
|
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray: |
|
"""Compute embeddings for multiple texts at once""" |
|
return self.use_model(texts).numpy() |
|
|
|
def _quick_quality_check(self, variation: str, original: str) -> bool: |
|
""" |
|
Simplified preliminary quality check with minimal standards |
|
""" |
|
if self.config.debug: |
|
print(f"\nQuick check for variation: {variation}") |
|
|
|
|
|
orig_len = len(original.split()) |
|
var_len = len(variation.split()) |
|
|
|
|
|
if orig_len <= 3: |
|
if var_len > orig_len * 4: |
|
if self.config.debug: |
|
print(f"Failed length check (short text): {var_len} vs {orig_len}") |
|
return False |
|
else: |
|
if var_len > orig_len * 3: |
|
if self.config.debug: |
|
print(f"Failed length check (long text): {var_len} vs {orig_len}") |
|
return False |
|
|
|
|
|
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'} |
|
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words) |
|
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words) |
|
|
|
if not orig_words.intersection(var_words): |
|
if self.config.debug: |
|
print("Failed content check: no content words in common") |
|
return False |
|
|
|
if self.config.debug: |
|
print("Passed all quick checks") |
|
return True |
|
|
|
def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]: |
|
"""Compute quality metrics for multiple candidates in parallel""" |
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
futures = [ |
|
executor.submit(self.quality_metrics.compute_metrics, original, candidate) |
|
for candidate in candidates |
|
] |
|
return [future.result() for future in futures] |
|
|
|
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]: |
|
""" |
|
Filter variations using batched computations with detailed logging |
|
""" |
|
if not variations: |
|
return [] |
|
|
|
if self.config.debug: |
|
print(f"\nStarting filtration of {len(variations)} variations") |
|
print(f"Context length: {len(context)}") |
|
print(f"Original turn: {original_turn}") |
|
|
|
words = original_turn.split() |
|
if len(words) < 3: |
|
if self.config.debug: |
|
print("Short text detected, using predefined variations") |
|
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''}) |
|
return [var['text'] for var in short_text_variations] |
|
|
|
|
|
if not context: |
|
preliminary_filtered = variations |
|
if self.config.debug: |
|
print("First turn - skipping preliminary filtering") |
|
else: |
|
|
|
preliminary_filtered = [] |
|
for var in variations: |
|
passed = self._quick_quality_check(var, original_turn) |
|
if self.config.debug: |
|
print(f"\nVariation: {var}") |
|
print(f"Passed quick check: {passed}") |
|
if passed: |
|
preliminary_filtered.append(var) |
|
|
|
if self.config.debug: |
|
print(f"Variations after quick check: {len(preliminary_filtered)}") |
|
|
|
if not preliminary_filtered: |
|
return [] |
|
|
|
|
|
recent_context = [context[-1]] if context else [] |
|
context_text = ' '.join(recent_context) if recent_context else '' |
|
|
|
|
|
min_similarity = 0.1 |
|
min_coherence = 0.05 |
|
|
|
if context_text: |
|
if self.config.debug: |
|
print(f"\nContext text: {context_text}") |
|
|
|
all_texts = [context_text] + preliminary_filtered |
|
all_embeddings = self._compute_batch_embeddings(all_texts) |
|
|
|
context_embedding = all_embeddings[0] |
|
variation_embeddings = all_embeddings[1:] |
|
|
|
|
|
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0] |
|
|
|
|
|
if recent_context: |
|
prev_embedding = self._compute_embedding(recent_context[-1]) |
|
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0] |
|
else: |
|
response_coherence = np.ones_like(context_similarities) |
|
|
|
|
|
filtered_variations = [] |
|
for i, (variation, sim, coh) in enumerate(zip( |
|
preliminary_filtered, context_similarities, response_coherence)): |
|
|
|
combined_score = ( |
|
self.config.context_similarity_weight * abs(sim) + |
|
self.config.response_coherence_weight * abs(coh) |
|
) |
|
|
|
if self.config.debug: |
|
print(f"\nVariation: {variation}") |
|
print(f"Context similarity: {sim:.3f}") |
|
print(f"Response coherence: {coh:.3f}") |
|
print(f"Combined score: {combined_score:.3f}") |
|
|
|
|
|
if (combined_score >= min_similarity or abs(coh) >= min_coherence): |
|
filtered_variations.append(variation) |
|
if self.config.debug: |
|
print("ACCEPTED") |
|
else: |
|
if self.config.debug: |
|
print("REJECTED") |
|
|
|
|
|
if len(filtered_variations) >= self.config.max_variations_per_turn: |
|
break |
|
else: |
|
filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn] |
|
|
|
if self.config.debug: |
|
print(f"\nFinal filtered variations: {len(filtered_variations)}") |
|
|
|
return filtered_variations |
|
|
|
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]: |
|
""" |
|
Generate variations progressively until we have enough good ones |
|
""" |
|
variations = set() |
|
|
|
if self.config.debug: |
|
print(f"\nAttempting to generate {needed} variations for text: {text}") |
|
|
|
|
|
for augmenter in self.augmenters['advanced']: |
|
if len(variations) >= needed: |
|
break |
|
|
|
try: |
|
if isinstance(augmenter, Paraphraser): |
|
if self.config.debug: |
|
print("Trying paraphrase augmentation...") |
|
new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations)) |
|
if self.config.debug: |
|
print(f"Paraphraser generated {len(new_vars)} variations") |
|
else: |
|
if self.config.debug: |
|
print("Trying back translation...") |
|
new_vars = [augmenter.back_translate(text)] |
|
if self.config.debug: |
|
print(f"Back translator generated {len(new_vars)} variations") |
|
|
|
valid_vars = [v for v in new_vars if v.strip() and v != text] |
|
variations.update(valid_vars) |
|
|
|
if self.config.debug: |
|
print(f"Current unique variations: {len(variations)}") |
|
|
|
except Exception as e: |
|
print(f"Error in advanced augmentation: {str(e)}") |
|
continue |
|
|
|
|
|
if len(variations) < needed: |
|
if self.config.debug: |
|
print("Not enough variations, trying basic augmenters...") |
|
|
|
for aug_type, augmenter in self.augmenters['basic']: |
|
if len(variations) >= needed: |
|
break |
|
|
|
try: |
|
if aug_type == 'spelling' and self._is_technical_or_formal_text(text): |
|
if self.config.debug: |
|
print("Skipping spelling augmentation for technical text") |
|
continue |
|
|
|
if self.config.debug: |
|
print(f"Trying {aug_type} augmentation...") |
|
|
|
new_vars = augmenter.augment(text, n=2) |
|
if isinstance(new_vars, list): |
|
valid_vars = [v for v in new_vars if v.strip() and v != text] |
|
variations.update(valid_vars) |
|
else: |
|
if new_vars.strip() and new_vars != text: |
|
variations.add(new_vars) |
|
|
|
if self.config.debug: |
|
print(f"After {aug_type}, total variations: {len(variations)}") |
|
|
|
except Exception as e: |
|
print(f"Error in {aug_type} augmentation: {str(e)}") |
|
continue |
|
|
|
variations_list = list(variations) |
|
|
|
if self.config.debug: |
|
print(f"Final number of variations generated: {len(variations_list)}") |
|
if not variations_list: |
|
print("WARNING: No variations were generated!") |
|
|
|
return variations_list |
|
|
|
def augment_dialogue(self, dialogue: Dict) -> List[Dict]: |
|
""" |
|
Create augmented versions of the dialogue with optimized processing |
|
""" |
|
|
|
original_length = len(dialogue['turns']) |
|
if original_length > self.config.max_turns_per_dialogue: |
|
if self.config.debug: |
|
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns") |
|
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue] |
|
|
|
turn_variations = [] |
|
context = [] |
|
|
|
|
|
for turn in dialogue['turns']: |
|
original_text = turn['text'] |
|
variations = self._generate_variations_progressive( |
|
original_text, |
|
self.config.max_variations_per_turn |
|
) |
|
|
|
|
|
filtered_variations = self._filter_variations_batch( |
|
variations, |
|
context, |
|
original_text |
|
) |
|
|
|
|
|
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations] |
|
|
|
if self.config.debug: |
|
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations") |
|
|
|
turn_variations.append(turn_vars) |
|
context.append(original_text) |
|
|
|
|
|
augmented_dialogues = self._generate_dialogue_combinations( |
|
dialogue['dialogue_id'], |
|
turn_variations |
|
) |
|
|
|
|
|
result = [{ |
|
'dialogue_id': f"{dialogue['dialogue_id']}_original", |
|
'turns': dialogue['turns'] |
|
}] |
|
|
|
|
|
result.extend(augmented_dialogues[:self.config.augmentation_factor]) |
|
|
|
if self.config.debug: |
|
print(f"Generated {len(result)-1} unique augmented dialogues") |
|
|
|
return result |
|
|
|
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]: |
|
""" |
|
Generate dialogue combinations using sampling |
|
""" |
|
augmented_dialogues = [] |
|
used_combinations = set() |
|
|
|
def generate_dialogues(current_turns=None, turn_index=0): |
|
if current_turns is None: |
|
current_turns = [] |
|
|
|
if len(augmented_dialogues) >= self.config.augmentation_factor: |
|
return |
|
|
|
if turn_index == len(turn_variations): |
|
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) |
|
if dialogue_fingerprint not in used_combinations: |
|
used_combinations.add(dialogue_fingerprint) |
|
augmented_dialogues.append({ |
|
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", |
|
'turns': current_turns.copy() |
|
}) |
|
return |
|
|
|
variations = list(turn_variations[turn_index]) |
|
np.random.shuffle(variations) |
|
|
|
for variation in variations[:self.config.max_sampled_variations]: |
|
if len(augmented_dialogues) >= self.config.augmentation_factor: |
|
return |
|
current_turns.append(variation) |
|
generate_dialogues(current_turns, turn_index + 1) |
|
current_turns.pop() |
|
|
|
try: |
|
generate_dialogues() |
|
except Exception as e: |
|
print(f"Error in dialogue generation: {str(e)}") |
|
return [] |
|
|
|
return augmented_dialogues |
|
|
|
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool: |
|
""" |
|
Check if two dialogues are duplicates. |
|
""" |
|
text1 = " ".join(turn['text'] for turn in dialogue1['turns']) |
|
text2 = " ".join(turn['text'] for turn in dialogue2['turns']) |
|
return text1 == text2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _augment_short_text(self, turn: Dict) -> List[Dict]: |
|
""" |
|
Special handling for very short texts with predefined variations. |
|
Args: |
|
turn (Dict): Original dialogue turn |
|
|
|
Returns: |
|
List[Dict]: List of variations for the short text |
|
""" |
|
text = turn['text'] |
|
common_variations = { |
|
'goodbye': [ |
|
'Bye!', 'Farewell!', 'See you!', 'Take care!', |
|
'Goodbye!', 'Bye for now!', 'Until next time!' |
|
], |
|
'hello': [ |
|
'Hi!', 'Hey!', 'Hello!', 'Greetings!', |
|
'Good day!', 'Hi there!', 'Hello there!' |
|
], |
|
'yes': [ |
|
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', |
|
'That\'s right!', 'Definitely!', 'Sure!' |
|
], |
|
'no': [ |
|
'No!', 'Nope!', 'Not at all!', 'Negative!', |
|
'Unfortunately not!', 'I\'m afraid not!' |
|
], |
|
'thanks': [ |
|
'Thank you!', 'Thanks a lot!', 'Many thanks!', |
|
'I appreciate it!', 'Thank you so much!' |
|
], |
|
'ok': [ |
|
'Okay!', 'Alright!', 'Sure!', 'Got it!', |
|
'Understood!', 'Fine!', 'Great!', 'Perfect!', |
|
'That works!', 'Sounds good!' |
|
], |
|
'good': [ |
|
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!', |
|
'Fantastic!', 'Amazing!', 'Terrific!' |
|
] |
|
} |
|
|
|
|
|
text_lower = text.lower().rstrip('!.,?') |
|
variations = [] |
|
|
|
|
|
for key, predefined_vars in common_variations.items(): |
|
if key in text_lower or text_lower in key: |
|
variations.extend(predefined_vars) |
|
|
|
|
|
if not variations: |
|
|
|
variations = [ |
|
text.rstrip('!.,?') + '!', |
|
text.rstrip('!.,?') + '.', |
|
text.rstrip('!.,?') |
|
] |
|
|
|
|
|
variations.extend([ |
|
v.capitalize() for v in variations |
|
if v.capitalize() not in variations |
|
]) |
|
|
|
|
|
unique_variations = list(set(variations)) |
|
quality_variations = [] |
|
|
|
for var in unique_variations: |
|
metrics = self.quality_metrics.compute_metrics(text, var) |
|
quality_score = ( |
|
0.35 * metrics['semantic_similarity'] + |
|
0.30 * (1.0 - metrics['perplexity'] / 100) + |
|
0.15 * (1.0 - metrics['grammar_errors'] / 10) + |
|
0.15 * metrics['content_preservation'] + |
|
0.10 * metrics['type_token_ratio'] |
|
) |
|
|
|
|
|
if quality_score >= 0.5: |
|
quality_variations.append(var) |
|
|
|
|
|
if not quality_variations: |
|
quality_variations = [text] |
|
|
|
|
|
return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]] |
|
|
|
def _is_technical_or_formal_text(self, text: str) -> bool: |
|
""" |
|
Check if text is formal/technical and shouldn't have spelling variations. |
|
""" |
|
formal_indicators = { |
|
'technical_terms': {'api', 'config', 'database', 'server', 'system'}, |
|
'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'}, |
|
'professional_context': {'meeting', 'conference', 'project', 'deadline'} |
|
} |
|
|
|
text_lower = text.lower() |
|
words = set(text_lower.split()) |
|
|
|
for category in formal_indicators.values(): |
|
if words.intersection(category): |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|