csc525_retrieval_based_chatbot / dialogue_augmenter.py
JoeArmani
update gpu processing
bc503de
raw
history blame
23.1 kB
from typing import Dict, List
import numpy as np
import torch
import tensorflow as tf
import tensorflow_hub as hub
import re
from pipeline_config import PipelineConfig
from quality_metrics import QualityMetrics
from paraphraser import Paraphraser
from back_translator import BackTranslator
import nlpaug.augmenter.word as naw
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from sklearn.metrics.pairwise import cosine_similarity
class DialogueAugmenter:
"""
Optimized dialogue augmentation with quality control and complexity management.
"""
def __init__(self, nlp, config: PipelineConfig):
self.nlp = nlp
self.config = config
# Detect hardware and set appropriate batch sizes and optimization strategy
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.use_gpu = torch.cuda.is_available()
if self.config.debug:
print(f"Using device: {self.device}")
if self.use_gpu:
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
# Load base models
self.quality_metrics = QualityMetrics(config)
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
# Initialize augmentation models based on hardware
self._initialize_augmentation_models()
# Initialize caches
self.embedding_cache = {}
self.perplexity_cache = {}
# Compile regex patterns
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
# GPU memory management if available
if self.use_gpu:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
def _initialize_augmentation_models(self):
"""Initialize augmentation models with appropriate device settings"""
# Advanced augmentation techniques
self.paraphraser = Paraphraser()
self.back_translator = BackTranslator()
if self.use_gpu:
# Move models to GPU if available
self.paraphraser.model = self.paraphraser.model.to(self.device)
self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device)
self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device)
self.back_translator.model_backward = self.back_translator.model_backward.to(self.device)
# Basic augmentation techniques
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
self.spelling_augmenter = naw.SpellingAug()
self.augmenters = {
'advanced': [self.paraphraser, self.back_translator],
'basic': [
('synonym', self.word_augmenter),
('spelling', self.spelling_augmenter)
]
}
@lru_cache(maxsize=1024)
def _compute_embedding(self, text: str) -> np.ndarray:
"""Cached computation of text embedding"""
if text in self.embedding_cache:
return self.embedding_cache[text]
embedding = self.use_model([text])[0].numpy()
self.embedding_cache[text] = embedding
return embedding
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
"""Compute embeddings for multiple texts at once with hardware optimization"""
# Check cache first
uncached_texts = [t for t in texts if t not in self.embedding_cache]
if uncached_texts:
embeddings = self.use_model(uncached_texts).numpy()
# Update cache
for text, embedding in zip(uncached_texts, embeddings):
self.embedding_cache[text] = embedding
# Return all embeddings (from cache or newly computed)
return np.array([self.embedding_cache[t] for t in texts])
def _quick_quality_check(self, variation: str, original: str) -> bool:
"""
Stricter preliminary quality check while maintaining reasonable pass rates
"""
if self.config.debug:
print(f"\nQuick check for variation: {variation}")
# Stricter length check
orig_len = len(original.split())
var_len = len(variation.split())
# For very short texts (1-3 words), still allow more variation
if orig_len <= 3:
if var_len > orig_len * 3: # Reduced from 4x to 3x
if self.config.debug:
print(f"Failed length check (short text): {var_len} vs {orig_len}")
return False
else:
if var_len > orig_len * 2: # Reduced from 3x to 2x
if self.config.debug:
print(f"Failed length check (long text): {var_len} vs {orig_len}")
return False
# Enhanced content check - more words in common
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
# Require more content word overlap
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
if content_overlap < 0.3: # Increased from no minimum to 30% overlap
if self.config.debug:
print(f"Failed content check: overlap {content_overlap:.2f}")
return False
if self.config.debug:
print("Passed all quick checks")
return True
def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
"""Compute quality metrics for multiple candidates in parallel"""
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(self.quality_metrics.compute_metrics, original, candidate)
for candidate in candidates
]
return [future.result() for future in futures]
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
"""
Filter variations using batched computations with detailed logging
"""
if not variations:
return []
if self.config.debug:
print(f"\nStarting filtration of {len(variations)} variations")
print(f"Context length: {len(context)}")
print(f"Original turn: {original_turn}")
words = original_turn.split()
if len(words) < 3:
if self.config.debug:
print("Short text detected, using predefined variations")
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
return [var['text'] for var in short_text_variations]
# If this is the first turn (no context), be more lenient
if not context:
preliminary_filtered = variations
if self.config.debug:
print("First turn - skipping preliminary filtering")
else:
# Quick preliminary filtering against original turn
preliminary_filtered = []
for var in variations:
passed = self._quick_quality_check(var, original_turn)
if self.config.debug:
print(f"\nVariation: {var}")
print(f"Passed quick check: {passed}")
if passed:
preliminary_filtered.append(var)
if self.config.debug:
print(f"Variations after quick check: {len(preliminary_filtered)}")
if not preliminary_filtered:
return []
# Only use last turn for coherence
recent_context = [context[-1]] if context else []
context_text = ' '.join(recent_context) if recent_context else ''
# Even more lenient thresholds
min_similarity = 0.1 # Further reduced
min_coherence = 0.05 # Further reduced
if context_text:
if self.config.debug:
print(f"\nContext text: {context_text}")
all_texts = [context_text] + preliminary_filtered
all_embeddings = self._compute_batch_embeddings(all_texts)
context_embedding = all_embeddings[0]
variation_embeddings = all_embeddings[1:]
# Vectorized similarity computation
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
# Response coherence check
if recent_context:
prev_embedding = self._compute_embedding(recent_context[-1])
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
else:
response_coherence = np.ones_like(context_similarities)
# Combined scoring with detailed logging
filtered_variations = []
for i, (variation, sim, coh) in enumerate(zip(
preliminary_filtered, context_similarities, response_coherence)):
# Use absolute values for scoring
combined_score = (
self.config.context_similarity_weight * abs(sim) +
self.config.response_coherence_weight * abs(coh)
)
if self.config.debug:
print(f"\nVariation: {variation}")
print(f"Context similarity: {sim:.3f}")
print(f"Response coherence: {coh:.3f}")
print(f"Combined score: {combined_score:.3f}")
# Accept if EITHER score is good enough
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
filtered_variations.append(variation)
if self.config.debug:
print("ACCEPTED")
else:
if self.config.debug:
print("REJECTED")
# If we have enough variations, stop
if len(filtered_variations) >= self.config.max_variations_per_turn:
break
else:
filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn]
if self.config.debug:
print(f"\nFinal filtered variations: {len(filtered_variations)}")
return filtered_variations
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
"""
Generate variations progressively until we have enough good ones
"""
variations = set()
if self.config.debug:
print(f"\nAttempting to generate {needed} variations for text: {text}")
# Try advanced augmenters first
for augmenter in self.augmenters['advanced']:
if len(variations) >= needed:
break
try:
if isinstance(augmenter, Paraphraser):
if self.config.debug:
print("Trying paraphrase augmentation...")
new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations))
if self.config.debug:
print(f"Paraphraser generated {len(new_vars)} variations")
else:
if self.config.debug:
print("Trying back translation...")
new_vars = [augmenter.back_translate(text)]
if self.config.debug:
print(f"Back translator generated {len(new_vars)} variations")
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
if self.config.debug:
print(f"Current unique variations: {len(variations)}")
except Exception as e:
print(f"Error in advanced augmentation: {str(e)}")
continue
# Try basic augmenters if needed
if len(variations) < needed:
if self.config.debug:
print("Not enough variations, trying basic augmenters...")
for aug_type, augmenter in self.augmenters['basic']:
if len(variations) >= needed:
break
try:
if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
if self.config.debug:
print("Skipping spelling augmentation for technical text")
continue
if self.config.debug:
print(f"Trying {aug_type} augmentation...")
new_vars = augmenter.augment(text, n=2)
if isinstance(new_vars, list):
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
else:
if new_vars.strip() and new_vars != text:
variations.add(new_vars)
if self.config.debug:
print(f"After {aug_type}, total variations: {len(variations)}")
except Exception as e:
print(f"Error in {aug_type} augmentation: {str(e)}")
continue
variations_list = list(variations)
if self.config.debug:
print(f"Final number of variations generated: {len(variations_list)}")
if not variations_list:
print("WARNING: No variations were generated!")
return variations_list
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
"""
Create augmented versions of the dialogue with optimized processing
"""
# Early dialogue length check
original_length = len(dialogue['turns'])
if original_length > self.config.max_turns_per_dialogue:
if self.config.debug:
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
turn_variations = []
context = []
# Process each turn with progressive generation
for turn in dialogue['turns']:
original_text = turn['text'] # Store original turn text
variations = self._generate_variations_progressive(
original_text,
self.config.max_variations_per_turn
)
# Batch filter variations with original text
filtered_variations = self._filter_variations_batch(
variations,
context,
original_text # Pass the original turn text
)
# Create turn variations with speaker info
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
if self.config.debug:
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
turn_variations.append(turn_vars)
context.append(original_text)
# Generate combinations with sampling
augmented_dialogues = self._generate_dialogue_combinations(
dialogue['dialogue_id'],
turn_variations
)
# Add original dialogue
result = [{
'dialogue_id': f"{dialogue['dialogue_id']}_original",
'turns': dialogue['turns']
}]
# Add unique augmentations
result.extend(augmented_dialogues[:self.config.augmentation_factor])
if self.config.debug:
print(f"Generated {len(result)-1} unique augmented dialogues")
return result
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
"""
Generate dialogue combinations using sampling
"""
augmented_dialogues = []
used_combinations = set()
def generate_dialogues(current_turns=None, turn_index=0):
if current_turns is None:
current_turns = []
if len(augmented_dialogues) >= self.config.augmentation_factor:
return
if turn_index == len(turn_variations):
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
if dialogue_fingerprint not in used_combinations:
used_combinations.add(dialogue_fingerprint)
augmented_dialogues.append({
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
'turns': current_turns.copy()
})
return
variations = list(turn_variations[turn_index])
np.random.shuffle(variations)
for variation in variations[:self.config.max_sampled_variations]:
if len(augmented_dialogues) >= self.config.augmentation_factor:
return
current_turns.append(variation)
generate_dialogues(current_turns, turn_index + 1)
current_turns.pop()
try:
generate_dialogues()
except Exception as e:
print(f"Error in dialogue generation: {str(e)}")
return []
return augmented_dialogues
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
"""
Check if two dialogues are duplicates.
"""
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
return text1 == text2
def _augment_short_text(self, turn: Dict) -> List[Dict]:
"""
Special handling for very short texts with predefined variations.
Args:
turn (Dict): Original dialogue turn
Returns:
List[Dict]: List of variations for the short text
"""
text = turn['text']
common_variations = {
'goodbye': [
'Bye!', 'Farewell!', 'See you!', 'Take care!',
'Goodbye!', 'Bye for now!', 'Until next time!'
],
'hello': [
'Hi!', 'Hey!', 'Hello!', 'Greetings!',
'Good day!', 'Hi there!', 'Hello there!'
],
'yes': [
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
'That\'s right!', 'Definitely!', 'Sure!'
],
'no': [
'No!', 'Nope!', 'Not at all!', 'Negative!',
'Unfortunately not!', 'I\'m afraid not!'
],
'thanks': [
'Thank you!', 'Thanks a lot!', 'Many thanks!',
'I appreciate it!', 'Thank you so much!'
],
'ok': [
'Okay!', 'Alright!', 'Sure!', 'Got it!',
'Understood!', 'Fine!', 'Great!', 'Perfect!',
'That works!', 'Sounds good!'
],
'good': [
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
'Fantastic!', 'Amazing!', 'Terrific!'
]
}
# Try to find matching variations
text_lower = text.lower().rstrip('!.,?')
variations = []
# Check if text matches any of our predefined categories
for key, predefined_vars in common_variations.items():
if key in text_lower or text_lower in key:
variations.extend(predefined_vars)
# If no predefined variations found, generate simple variants
if not variations:
# Add punctuation variations
variations = [
text.rstrip('!.,?') + '!',
text.rstrip('!.,?') + '.',
text.rstrip('!.,?')
]
# Add capitalization variations
variations.extend([
v.capitalize() for v in variations
if v.capitalize() not in variations
])
# Filter variations for uniqueness and quality
unique_variations = list(set(variations))
quality_variations = []
for var in unique_variations:
metrics = self.quality_metrics.compute_metrics(text, var)
quality_score = (
0.35 * metrics['semantic_similarity'] +
0.30 * (1.0 - metrics['perplexity'] / 100) +
0.15 * (1.0 - metrics['grammar_errors'] / 10) +
0.15 * metrics['content_preservation'] +
0.10 * metrics['type_token_ratio']
)
# More lenient quality threshold for short texts
if quality_score >= 0.5: # Lower threshold for short texts
quality_variations.append(var)
# Ensure we have at least some variations
if not quality_variations:
quality_variations = [text]
# Return the variations with original speaker
return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]]
def _is_technical_or_formal_text(self, text: str) -> bool:
"""
Check if text is formal/technical and shouldn't have spelling variations.
"""
formal_indicators = {
'technical_terms': {'api', 'config', 'database', 'server', 'system'},
'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'},
'professional_context': {'meeting', 'conference', 'project', 'deadline'}
}
text_lower = text.lower()
words = set(text_lower.split())
for category in formal_indicators.values():
if words.intersection(category):
return True
return False