csc525_retrieval_based_chatbot / dialogue_augmenter.py
JoeArmani
Initial commit
3190e1e
raw
history blame
30.2 kB
from typing import Dict, List
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import re
from pipeline_config import PipelineConfig
from quality_metrics import QualityMetrics
from paraphraser import Paraphraser
from back_translator import BackTranslator
import nlpaug.augmenter.word as naw
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from sklearn.metrics.pairwise import cosine_similarity
class DialogueAugmenter:
"""
Optimized dialogue augmentation with quality control and complexity management.
"""
def __init__(self, nlp, config: PipelineConfig):
self.nlp = nlp
self.config = config
self.quality_metrics = QualityMetrics(config)
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
# Advanced augmentation techniques
self.paraphraser = Paraphraser()
self.back_translator = BackTranslator()
# Basic augmentation techniques
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
self.spelling_augmenter = naw.SpellingAug()
self.augmenters = {
'advanced': [self.paraphraser, self.back_translator],
'basic': [
('synonym', self.word_augmenter),
('spelling', self.spelling_augmenter)
]
}
# Initialize cache
self.embedding_cache = {}
self.perplexity_cache = {}
# Compile regex patterns
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
# GPU memory management
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
@lru_cache(maxsize=1024)
def _compute_embedding(self, text: str) -> np.ndarray:
"""Cached computation of text embedding"""
return self.use_model([text])[0].numpy()
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
"""Compute embeddings for multiple texts at once"""
return self.use_model(texts).numpy()
def _quick_quality_check(self, variation: str, original: str) -> bool:
"""
Simplified preliminary quality check with minimal standards
"""
if self.config.debug:
print(f"\nQuick check for variation: {variation}")
# Only reject if length is extremely different
orig_len = len(original.split())
var_len = len(variation.split())
# For very short texts (1-3 words), allow more variation
if orig_len <= 3:
if var_len > orig_len * 4: # Allow up to 4x length for short texts
if self.config.debug:
print(f"Failed length check (short text): {var_len} vs {orig_len}")
return False
else:
if var_len > orig_len * 3: # Allow up to 3x length for longer texts
if self.config.debug:
print(f"Failed length check (long text): {var_len} vs {orig_len}")
return False
# Basic content check - at least one word in common (excluding stop words)
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'}
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
if not orig_words.intersection(var_words):
if self.config.debug:
print("Failed content check: no content words in common")
return False
if self.config.debug:
print("Passed all quick checks")
return True
def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
"""Compute quality metrics for multiple candidates in parallel"""
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(self.quality_metrics.compute_metrics, original, candidate)
for candidate in candidates
]
return [future.result() for future in futures]
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
"""
Filter variations using batched computations with detailed logging
"""
if not variations:
return []
if self.config.debug:
print(f"\nStarting filtration of {len(variations)} variations")
print(f"Context length: {len(context)}")
print(f"Original turn: {original_turn}")
words = original_turn.split()
if len(words) < 3:
if self.config.debug:
print("Short text detected, using predefined variations")
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
return [var['text'] for var in short_text_variations]
# If this is the first turn (no context), be more lenient
if not context:
preliminary_filtered = variations
if self.config.debug:
print("First turn - skipping preliminary filtering")
else:
# Quick preliminary filtering against original turn
preliminary_filtered = []
for var in variations:
passed = self._quick_quality_check(var, original_turn)
if self.config.debug:
print(f"\nVariation: {var}")
print(f"Passed quick check: {passed}")
if passed:
preliminary_filtered.append(var)
if self.config.debug:
print(f"Variations after quick check: {len(preliminary_filtered)}")
if not preliminary_filtered:
return []
# Only use last turn for coherence
recent_context = [context[-1]] if context else []
context_text = ' '.join(recent_context) if recent_context else ''
# Even more lenient thresholds
min_similarity = 0.1 # Further reduced
min_coherence = 0.05 # Further reduced
if context_text:
if self.config.debug:
print(f"\nContext text: {context_text}")
all_texts = [context_text] + preliminary_filtered
all_embeddings = self._compute_batch_embeddings(all_texts)
context_embedding = all_embeddings[0]
variation_embeddings = all_embeddings[1:]
# Vectorized similarity computation
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
# Response coherence check
if recent_context:
prev_embedding = self._compute_embedding(recent_context[-1])
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
else:
response_coherence = np.ones_like(context_similarities)
# Combined scoring with detailed logging
filtered_variations = []
for i, (variation, sim, coh) in enumerate(zip(
preliminary_filtered, context_similarities, response_coherence)):
# Use absolute values for scoring
combined_score = (
self.config.context_similarity_weight * abs(sim) +
self.config.response_coherence_weight * abs(coh)
)
if self.config.debug:
print(f"\nVariation: {variation}")
print(f"Context similarity: {sim:.3f}")
print(f"Response coherence: {coh:.3f}")
print(f"Combined score: {combined_score:.3f}")
# Accept if EITHER score is good enough
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
filtered_variations.append(variation)
if self.config.debug:
print("ACCEPTED")
else:
if self.config.debug:
print("REJECTED")
# If we have enough variations, stop
if len(filtered_variations) >= self.config.max_variations_per_turn:
break
else:
filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn]
if self.config.debug:
print(f"\nFinal filtered variations: {len(filtered_variations)}")
return filtered_variations
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
"""
Generate variations progressively until we have enough good ones
"""
variations = set()
if self.config.debug:
print(f"\nAttempting to generate {needed} variations for text: {text}")
# Try advanced augmenters first
for augmenter in self.augmenters['advanced']:
if len(variations) >= needed:
break
try:
if isinstance(augmenter, Paraphraser):
if self.config.debug:
print("Trying paraphrase augmentation...")
new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations))
if self.config.debug:
print(f"Paraphraser generated {len(new_vars)} variations")
else:
if self.config.debug:
print("Trying back translation...")
new_vars = [augmenter.back_translate(text)]
if self.config.debug:
print(f"Back translator generated {len(new_vars)} variations")
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
if self.config.debug:
print(f"Current unique variations: {len(variations)}")
except Exception as e:
print(f"Error in advanced augmentation: {str(e)}")
continue
# Try basic augmenters if needed
if len(variations) < needed:
if self.config.debug:
print("Not enough variations, trying basic augmenters...")
for aug_type, augmenter in self.augmenters['basic']:
if len(variations) >= needed:
break
try:
if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
if self.config.debug:
print("Skipping spelling augmentation for technical text")
continue
if self.config.debug:
print(f"Trying {aug_type} augmentation...")
new_vars = augmenter.augment(text, n=2)
if isinstance(new_vars, list):
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
else:
if new_vars.strip() and new_vars != text:
variations.add(new_vars)
if self.config.debug:
print(f"After {aug_type}, total variations: {len(variations)}")
except Exception as e:
print(f"Error in {aug_type} augmentation: {str(e)}")
continue
variations_list = list(variations)
if self.config.debug:
print(f"Final number of variations generated: {len(variations_list)}")
if not variations_list:
print("WARNING: No variations were generated!")
return variations_list
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
"""
Create augmented versions of the dialogue with optimized processing
"""
# Early dialogue length check
original_length = len(dialogue['turns'])
if original_length > self.config.max_turns_per_dialogue:
if self.config.debug:
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
turn_variations = []
context = []
# Process each turn with progressive generation
for turn in dialogue['turns']:
original_text = turn['text'] # Store original turn text
variations = self._generate_variations_progressive(
original_text,
self.config.max_variations_per_turn
)
# Batch filter variations with original text
filtered_variations = self._filter_variations_batch(
variations,
context,
original_text # Pass the original turn text
)
# Create turn variations with speaker info
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
if self.config.debug:
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
turn_variations.append(turn_vars)
context.append(original_text)
# Generate combinations with sampling
augmented_dialogues = self._generate_dialogue_combinations(
dialogue['dialogue_id'],
turn_variations
)
# Add original dialogue
result = [{
'dialogue_id': f"{dialogue['dialogue_id']}_original",
'turns': dialogue['turns']
}]
# Add unique augmentations
result.extend(augmented_dialogues[:self.config.augmentation_factor])
if self.config.debug:
print(f"Generated {len(result)-1} unique augmented dialogues")
return result
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
"""
Generate dialogue combinations using sampling
"""
augmented_dialogues = []
used_combinations = set()
def generate_dialogues(current_turns=None, turn_index=0):
if current_turns is None:
current_turns = []
if len(augmented_dialogues) >= self.config.augmentation_factor:
return
if turn_index == len(turn_variations):
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
if dialogue_fingerprint not in used_combinations:
used_combinations.add(dialogue_fingerprint)
augmented_dialogues.append({
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
'turns': current_turns.copy()
})
return
variations = list(turn_variations[turn_index])
np.random.shuffle(variations)
for variation in variations[:self.config.max_sampled_variations]:
if len(augmented_dialogues) >= self.config.augmentation_factor:
return
current_turns.append(variation)
generate_dialogues(current_turns, turn_index + 1)
current_turns.pop()
try:
generate_dialogues()
except Exception as e:
print(f"Error in dialogue generation: {str(e)}")
return []
return augmented_dialogues
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
"""
Check if two dialogues are duplicates.
"""
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
return text1 == text2
# def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]:
# """
# Generate augmented versions of the turn using multiple strategies.
# """
# text = turn['text']
# words = text.split()
# # Special handling for very short texts
# if len(words) < 3:
# return self._augment_short_text(turn)
# all_variations = set()
# # Advanced augmentations (paraphrase and back-translation)
# for augmenter in self.augmenters['advanced']:
# try:
# if isinstance(augmenter, Paraphraser):
# variations = augmenter.paraphrase(text)
# all_variations.update(variations)
# elif isinstance(augmenter, BackTranslator):
# aug_text = augmenter.back_translate(text)
# if aug_text:
# all_variations.add(aug_text)
# except Exception as e:
# print(f"Error in advanced augmentation: {str(e)}")
# continue
# # Basic nlpaug augmentations
# for aug_type, augmenter in self.augmenters['basic']:
# try:
# if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
# continue
# aug_texts = augmenter.augment(text, n=2)
# if isinstance(aug_texts, list):
# all_variations.update(aug_texts)
# else:
# all_variations.add(aug_texts)
# except Exception as e:
# print(f"Error in {aug_type} augmentation: {str(e)}")
# continue
# # Remove exact duplicates and empty strings
# augmented_texts = [t for t in list(all_variations) if t.strip()]
# # Apply context filtering
# if context:
# augmented_texts = self._filter_by_context(augmented_texts, context)
# print(f"After context filtering: {len(augmented_texts)} variations")
# # Select best variations
# best_variations = self._select_best_augmentations(
# text,
# augmented_texts,
# num_to_select=self.config.augmentation_factor,
# min_quality_score=0.7
# )
# # Create variations with speaker info
# variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations]
# return variations
def _augment_short_text(self, turn: Dict) -> List[Dict]:
"""
Special handling for very short texts with predefined variations.
Args:
turn (Dict): Original dialogue turn
Returns:
List[Dict]: List of variations for the short text
"""
text = turn['text']
common_variations = {
'goodbye': [
'Bye!', 'Farewell!', 'See you!', 'Take care!',
'Goodbye!', 'Bye for now!', 'Until next time!'
],
'hello': [
'Hi!', 'Hey!', 'Hello!', 'Greetings!',
'Good day!', 'Hi there!', 'Hello there!'
],
'yes': [
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
'That\'s right!', 'Definitely!', 'Sure!'
],
'no': [
'No!', 'Nope!', 'Not at all!', 'Negative!',
'Unfortunately not!', 'I\'m afraid not!'
],
'thanks': [
'Thank you!', 'Thanks a lot!', 'Many thanks!',
'I appreciate it!', 'Thank you so much!'
],
'ok': [
'Okay!', 'Alright!', 'Sure!', 'Got it!',
'Understood!', 'Fine!', 'Great!', 'Perfect!',
'That works!', 'Sounds good!'
],
'good': [
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
'Fantastic!', 'Amazing!', 'Terrific!'
]
}
# Try to find matching variations
text_lower = text.lower().rstrip('!.,?')
variations = []
# Check if text matches any of our predefined categories
for key, predefined_vars in common_variations.items():
if key in text_lower or text_lower in key:
variations.extend(predefined_vars)
# If no predefined variations found, generate simple variants
if not variations:
# Add punctuation variations
variations = [
text.rstrip('!.,?') + '!',
text.rstrip('!.,?') + '.',
text.rstrip('!.,?')
]
# Add capitalization variations
variations.extend([
v.capitalize() for v in variations
if v.capitalize() not in variations
])
# Filter variations for uniqueness and quality
unique_variations = list(set(variations))
quality_variations = []
for var in unique_variations:
metrics = self.quality_metrics.compute_metrics(text, var)
quality_score = (
0.35 * metrics['semantic_similarity'] +
0.30 * (1.0 - metrics['perplexity'] / 100) +
0.15 * (1.0 - metrics['grammar_errors'] / 10) +
0.15 * metrics['content_preservation'] +
0.10 * metrics['type_token_ratio']
)
# More lenient quality threshold for short texts
if quality_score >= 0.5: # Lower threshold for short texts
quality_variations.append(var)
# Ensure we have at least some variations
if not quality_variations:
quality_variations = [text]
# Return the variations with original speaker
return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]]
def _is_technical_or_formal_text(self, text: str) -> bool:
"""
Check if text is formal/technical and shouldn't have spelling variations.
"""
formal_indicators = {
'technical_terms': {'api', 'config', 'database', 'server', 'system'},
'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'},
'professional_context': {'meeting', 'conference', 'project', 'deadline'}
}
text_lower = text.lower()
words = set(text_lower.split())
for category in formal_indicators.values():
if words.intersection(category):
return True
return False
# def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]:
# """
# Filter variations based on conversation context using config parameters.
# """
# # Manage context window using config
# recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context
# filtered_variations = []
# context_embedding = self.use_model([' '.join(recent_context)])[0].numpy()
# prev_turn = recent_context[-1] if recent_context else ''
# for variation in variations:
# var_embedding = self.use_model([variation])[0].numpy()
# # Overall context similarity
# context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0]
# # Direct response coherence
# response_coherence = 1.0
# if prev_turn:
# prev_embedding = self.use_model([prev_turn])[0].numpy()
# response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0]
# # Use weights from config
# combined_similarity = (
# self.config.context_similarity_weight * context_similarity +
# self.config.response_coherence_weight * response_coherence
# )
# if (combined_similarity >= self.config.semantic_similarity_threshold and
# response_coherence >= self.config.min_response_coherence):
# filtered_variations.append(variation)
# if self.config.debug:
# print(f"Accepted variation: {variation}")
# print(f"Context similarity: {context_similarity:.3f}")
# print(f"Response coherence: {response_coherence:.3f}")
# print(f"Combined score: {combined_similarity:.3f}\n")
# else:
# if self.config.debug:
# print(f"Rejected variation: {variation}")
# print(f"Combined score {combined_similarity:.3f} below threshold "
# f"{self.config.semantic_similarity_threshold}")
# print(f"Response coherence {response_coherence:.3f} below threshold "
# f"{self.config.min_response_coherence}\n")
# return filtered_variations or variations # Fallback to original
# def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None,
# num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]:
# """
# Select the best augmentations using a quality score.
# Args:
# original (str): The original text
# candidates (List[str]): List of candidate augmented texts
# used_variations (set): Set of already used variations
# num_to_select (int): Number of variations to select
# min_quality_score (float): Minimum quality score threshold
# """
# if used_variations is None:
# used_variations = set()
# candidates = [c for c in candidates if c.strip()]
# # Skip short text
# if len(original.split()) < 3:
# print(f"Text too short for augmentation: {original}")
# return [original]
# scored_candidates = []
# for candidate in candidates:
# if candidate in used_variations:
# continue
# metrics = self.quality_metrics.compute_metrics(original, candidate)
# # Add contextual penalty for inappropriate audience terms
# audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'}
# has_audience_term = any(term in candidate.lower() for term in audience_terms)
# audience_penalty = 0.2 if has_audience_term else 0.0
# # Weighted quality score
# quality_score = (
# 0.40 * metrics['semantic_similarity'] + # Semantic preservation
# 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency
# 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar
# 0.15 * metrics['content_preservation'] + # Content preservation
# 0.05 * metrics['type_token_ratio'] # Lexical diversity
# )
# quality_score -= audience_penalty
# if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss
# metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap
# continue
# # Bonus points for:
# # Length similarity to original
# if 0.75 <= metrics['length_ratio'] <= 1.25:
# quality_score += 0.05
# # Correct grammar
# if metrics['grammar_errors'] == 0:
# quality_score += 0.025
# print(f"Candidate: {candidate}")
# print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}")
# # Consider the augmentationif meets basic quality threshold
# if quality_score >= min_quality_score:
# print('Candidate accepted\n')
# scored_candidates.append((candidate, quality_score, metrics))
# else:
# print('Candidate rejected\n')
# # Sort by quality score with small random factor for diversity
# scored_candidates.sort(key=lambda x: x[1], reverse=True)
# selected = []
# for candidate, score, metrics in scored_candidates:
# # Check diversity against already selected
# if len(selected) == 0:
# selected.append(candidate)
# continue
# # Compute average similarity to already selected
# avg_similarity = np.mean([
# self.quality_metrics.compute_semantic_similarity(candidate, prev)
# for prev in selected
# ])
# # Add if sufficiently different (similarity < 0.98)
# if avg_similarity < 0.98:
# selected.append(candidate)
# if len(selected) >= num_to_select:
# break
# return selected