csc525_retrieval_based_chatbot / dialogue_augmenter.py
JoeArmani
summarization, reranker, environment setup, and response quality checker
f7b283c
raw
history blame
30.1 kB
from typing import Dict, List
import numpy as np
import torch
import tensorflow as tf
import tensorflow_hub as hub
from pipeline_config import PipelineConfig
from quality_metrics import QualityMetrics
from paraphraser import Paraphraser
import nlpaug.augmenter.word as naw
from functools import lru_cache
from sklearn.metrics.pairwise import cosine_similarity
class DialogueAugmenter:
"""
Optimized dialogue augmentation with quality control and complexity management.
"""
def __init__(self, nlp, config: PipelineConfig):
self.nlp = nlp
self.config = config
# Detect hardware and set appropriate batch sizes and optimization strategy
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.use_gpu = torch.cuda.is_available()
if self.config.debug:
print(f"Using device: {self.device}")
if self.use_gpu:
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
self.quality_metrics = QualityMetrics(config)
self.semantic_similarity_threshold = 0.75
# Load model
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
# Initialize augmentation models based on hardware
self._initialize_augmentation_models()
# Initialize caches
self.embedding_cache = {}
# GPU memory management if available
if self.use_gpu:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
def _initialize_augmentation_models(self):
"""Initialize augmentation models with appropriate device settings"""
# Advanced augmentation techniques
self.paraphraser = Paraphraser()
if self.use_gpu:
# Move model to GPU if available
self.paraphraser.model = self.paraphraser.model.to(self.device)
# Basic augmentation techniques
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
self.augmenters = {
'advanced': [
self.paraphraser,
],
'basic': [
('synonym', self.word_augmenter),
]
}
@lru_cache(maxsize=1024)
def _compute_embedding(self, text: str) -> np.ndarray:
"""Cached computation of text embedding"""
if text in self.embedding_cache:
return self.embedding_cache[text]
embedding = self.use_model([text])[0].numpy()
self.embedding_cache[text] = embedding
return embedding
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
"""Compute embeddings for multiple texts at once with hardware optimization"""
# Check cache first
uncached_texts = [t for t in texts if t not in self.embedding_cache]
if uncached_texts:
embeddings = self.use_model(uncached_texts).numpy()
# Update cache
for text, embedding in zip(uncached_texts, embeddings):
self.embedding_cache[text] = embedding
# Return all embeddings (from cache or newly computed)
return np.array([self.embedding_cache[t] for t in texts])
def _quick_quality_check(self, variation: str, original: str) -> bool:
"""
Preliminary quality check while maintaining reasonable pass rates
"""
if self.config.debug:
print(f"\nQuick check for variation: {variation}")
orig_len = len(original.split())
var_len = len(variation.split())
# For very short texts (<= 3 words), still allow more variation
if orig_len <= 3:
if var_len > orig_len * 3:
if self.config.debug:
print(f"Failed length check (short text): {var_len} vs {orig_len}")
return False
else:
if var_len > orig_len * 2:
if self.config.debug:
print(f"Failed length check (long text): {var_len} vs {orig_len}")
return False
# Adjust content overlap check based on length
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
# If very short turn (less than 5 words), skip the content overlap check
if orig_len >= 5:
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
if content_overlap < 0.2:
if self.config.debug:
print(f"Failed content check: overlap {content_overlap:.2f}")
return False
else:
if self.config.debug:
print("Short turn detected (<5 words), skipping content overlap check")
if self.config.debug:
print("Passed all quick checks")
return True
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
"""
Filter variations using batched computations with detailed logging
"""
if not variations:
return []
if self.config.debug:
print(f"\nStarting filtration of {len(variations)} variations")
print(f"Context length: {len(context)}")
print(f"Original turn: {original_turn}")
words = original_turn.split()
orig_len = len(words)
# If very short text, consider adjusting thresholds
is_very_short = orig_len < 5
if len(words) < 3:
if self.config.debug:
print("Short text detected, using predefined variations")
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
return [var['text'] for var in short_text_variations]
# If this is the first turn (no context), be more lenient
if not context:
preliminary_filtered = variations
if self.config.debug:
print("First turn - skipping preliminary filtering")
else:
# Quick preliminary filtering against original turn
preliminary_filtered = []
for var in variations:
passed = self._quick_quality_check(var, original_turn)
if self.config.debug:
print(f"\nVariation: {var}")
print(f"Passed quick check: {passed}")
if passed:
preliminary_filtered.append(var)
if self.config.debug:
print(f"Variations after quick check: {len(preliminary_filtered)}")
if not preliminary_filtered:
return []
# Compute embeddings for original and variations
original_embedding = self._compute_embedding(original_turn)
variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)
# Compute similarities
sims = cosine_similarity([original_embedding], variation_embeddings)[0]
# If very short turn, slightly lower the semantic similarity threshold
dynamic_sem_threshold = self.semantic_similarity_threshold
if is_very_short:
dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)
# Filter by semantic similarity threshold
refined_filtered = []
for var, sim in zip(preliminary_filtered, sims):
if sim >= dynamic_sem_threshold:
refined_filtered.append(var)
else:
if self.config.debug:
print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")
if not refined_filtered:
return []
# Relax context coherence thresholds further if desired
# We already have min_similarity = 0.1, min_coherence = 0.05
# Let's lower them slightly more if the turn is very short:
if is_very_short:
min_similarity = 0.05
min_coherence = 0.02
else:
min_similarity = 0.1
min_coherence = 0.05
# Only use last turn for coherence
recent_context = [context[-1]] if context else []
context_text = ' '.join(recent_context) if recent_context else ''
if context_text:
if self.config.debug:
print(f"\nContext text: {context_text}")
all_texts = [context_text] + refined_filtered
all_embeddings = self._compute_batch_embeddings(all_texts)
context_embedding = all_embeddings[0]
variation_embeddings = all_embeddings[1:]
# Vectorized similarity computation
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
# Response coherence check
if recent_context:
prev_embedding = self._compute_embedding(recent_context[-1])
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
else:
response_coherence = np.ones_like(context_similarities)
filtered_variations = []
for i, (variation, sim, coh) in enumerate(zip(
refined_filtered, context_similarities, response_coherence)):
combined_score = (
self.config.context_similarity_weight * abs(sim) +
self.config.response_coherence_weight * abs(coh)
)
if self.config.debug:
print(f"\nVariation: {variation}")
print(f"Context similarity: {sim:.3f}")
print(f"Response coherence: {coh:.3f}")
print(f"Combined score: {combined_score:.3f}")
# Accept if EITHER score is good enough
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
filtered_variations.append(variation)
if self.config.debug:
print("ACCEPTED")
else:
if self.config.debug:
print("REJECTED")
# If we have enough variations, stop
if len(filtered_variations) >= self.config.max_variations_per_turn:
break
else:
filtered_variations = refined_filtered[:self.config.max_variations_per_turn]
if self.config.debug:
print(f"\nFinal filtered variations: {len(filtered_variations)}")
return filtered_variations
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
"""
Generate variations progressively until we have enough good ones.
Adjust paraphraser parameters for closer paraphrases as needed.
"""
variations = set()
if self.config.debug:
print(f"\nAttempting to generate {needed} variations for text: {text}")
# Fine-tune paraphraser here if needed: fewer beams, less diversity already done
for augmenter in self.augmenters['advanced']:
if len(variations) >= needed:
break
try:
if isinstance(augmenter, Paraphraser):
if self.config.debug:
print("Trying paraphrase augmentation...")
new_vars = augmenter.paraphrase(
text,
num_return_sequences=needed-len(variations),
device=self.device if self.use_gpu else None,
num_beams=4, # even fewer beams for more faithful paraphrases
num_beam_groups=1,
diversity_penalty=0.0
)
if self.config.debug:
print(f"Paraphraser generated {len(new_vars)} variations")
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
if self.config.debug:
print(f"Current unique variations: {len(variations)}")
except Exception as e:
print(f"Error in advanced augmentation: {str(e)}")
continue
# Try basic augmenters if needed
if len(variations) < needed:
if self.config.debug:
print("Not enough variations, trying basic augmenters...")
for aug_type, augmenter in self.augmenters['basic']:
if len(variations) >= needed:
break
try:
if self.config.debug:
print(f"Trying {aug_type} augmentation...")
new_vars = augmenter.augment(text, n=2)
if isinstance(new_vars, list):
valid_vars = [v for v in new_vars if v.strip() and v != text]
variations.update(valid_vars)
else:
if new_vars.strip() and new_vars != text:
variations.add(new_vars)
if self.config.debug:
print(f"After {aug_type}, total variations: {len(variations)}")
except Exception as e:
print(f"Error in {aug_type} augmentation: {str(e)}")
continue
variations_list = list(variations)
if self.config.debug:
print(f"Final number of variations generated: {len(variations_list)}")
if not variations_list:
print("WARNING: No variations were generated!")
return variations_list
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
"""
Create augmented versions of the dialogue with optimized processing
"""
# Early dialogue length check
original_length = len(dialogue['turns'])
if original_length > self.config.max_turns_per_dialogue:
if self.config.debug:
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
turn_variations = []
context = []
# Process each turn with progressive generation
for turn in dialogue['turns']:
original_text = turn['text'] # Store original turn text
variations = self._generate_variations_progressive(
original_text,
self.config.max_variations_per_turn
)
# Batch filter variations with original text
filtered_variations = self._filter_variations_batch(
variations,
context,
original_text # Pass the original turn text
)
# Create turn variations with speaker info
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
if self.config.debug:
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
turn_variations.append(turn_vars)
context.append(original_text)
# Generate combinations with sampling
augmented_dialogues = self._generate_dialogue_combinations(
dialogue['dialogue_id'],
turn_variations,
dialogue
)
# Add original dialogue
result = [{
'dialogue_id': f"{dialogue['dialogue_id']}_original",
'turns': dialogue['turns']
}]
# Add unique augmentations
result.extend(augmented_dialogues[:self.config.augmentation_factor])
if self.config.debug:
print(f"Generated {len(result)-1} unique augmented dialogues")
return result
def _variation_score(self, original: str, variation: str) -> float:
"""
Compute a single numeric score for a variation to guide selection.
You could use semantic similarity, content preservation, etc.
Higher is better.
"""
metrics = self.quality_metrics.compute_metrics(original, variation)
# Example: Primarily semantic similarity, with a slight boost for content preservation
# Adjust as needed.
score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
return score
def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
"""
Compute a quality score for the entire augmented dialogue.
For example, average semantic similarity of turns to the original turns.
This is done after the dialogue is formed.
"""
original_texts = [t['text'] for t in original_dialogue['turns']]
aug_texts = [t['text'] for t in dialogue['turns']]
# Compute semantic similarity turn-by-turn and average it
scores = []
for orig, aug in zip(original_texts, aug_texts):
# Simple semantic similarity for scoring
emb_orig = self._compute_embedding(orig)
emb_aug = self._compute_embedding(aug)
sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
scores.append(sim)
# Could also incorporate diversity checks, content overlap, etc.
return float(np.mean(scores)) if scores else 0.0
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
"""
Generate dialogue combinations using a more controlled approach:
- Include the original turn as a fallback variation for each turn.
- Sort variations by a quality score.
- Ensure a balanced augmentation by requiring at least some turns to be augmented.
- Over-generate and then select top dialogues by quality.
"""
# Over-generate factor: create more candidates than needed
over_generate_factor = self.config.augmentation_factor * 2
# Add the original turn as a fallback variation for each turn if not present
for i, turn_variants in enumerate(turn_variations):
original_turn_text = None
# Check if we previously stored original turn text with a marker or just use the original dialogue
# If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
original_turn_text = original_dialogue['turns'][i]['text']
# Add the original turn as a variation if not already included
if not any(v['text'] == original_turn_text for v in turn_variants):
turn_variants.append({
'speaker': original_dialogue['turns'][i]['speaker'],
'text': original_turn_text
})
# Sort variations by score
original_text = original_dialogue['turns'][i]['text']
turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)
augmented_dialogues = []
used_combinations = set()
def generate_candidates(current_turns=None, turn_index=0):
if current_turns is None:
current_turns = []
if len(augmented_dialogues) >= over_generate_factor:
return
if turn_index == len(turn_variations):
# Completed a candidate dialogue
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
if dialogue_fingerprint not in used_combinations:
used_combinations.add(dialogue_fingerprint)
# Check if we have enough augmented turns
aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns)
if orig['text'] != curr['text'])
# Require at least half the turns to be augmented, for example
if aug_count >= max(1, len(turn_variations)//2):
augmented_dialogues.append({
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
'turns': current_turns.copy()
})
return
turn_candidates = turn_variations[turn_index]
# If no variations are available for this turn, let's just return without error.
# Normally, this shouldn't happen since we always add the original turn above.
if not turn_candidates:
# If you want to at least have the original turn, add it now:
original_text = original_dialogue['turns'][turn_index]['text']
turn_candidates.append({
'speaker': original_dialogue['turns'][turn_index]['speaker'],
'text': original_text
})
# After the fallback, if still empty for some reason, just return.
if not turn_candidates:
return
# Example strategy:
# 1. Always try the top variation (most semantically similar).
# 2. If available and allowed, pick a mid-ranked variation for diversity.
# 3. Include the original turn if not selected yet.
num_vars = min(self.config.max_sampled_variations, len(turn_candidates))
# Always include top variation
candidates_to_pick = [turn_candidates[0]]
# If we have more than 2 variations and can pick more, add a middle variation for diversity
if len(turn_candidates) > 2 and num_vars > 1:
mid_index = len(turn_candidates)//2
candidates_to_pick.append(turn_candidates[mid_index])
# If we still have room for another variation, try adding the original turn if not included
if num_vars > len(candidates_to_pick):
original_turn_text = original_dialogue['turns'][turn_index]['text']
orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
if orig_candidate and orig_candidate not in candidates_to_pick:
candidates_to_pick.append(orig_candidate)
# Shuffle candidates to produce different dialogues
np.random.shuffle(candidates_to_pick)
for variation in candidates_to_pick:
if len(augmented_dialogues) >= over_generate_factor:
return
current_turns.append(variation)
generate_candidates(current_turns, turn_index + 1)
current_turns.pop()
try:
generate_candidates()
except Exception as e:
print(f"Error in dialogue generation: {str(e)}")
return []
# Over-generated set of augmented dialogues is now available
# Let's score them and pick the top ones
scored_dialogues = []
for d in augmented_dialogues:
score = self._dialogue_quality_score(d, original_dialogue)
scored_dialogues.append((score, d))
scored_dialogues.sort(key=lambda x: x[0], reverse=True)
# Pick top `augmentation_factor` dialogues
final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
return final_dialogues
# def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
# """
# Generate dialogue combinations using sampling
# """
# augmented_dialogues = []
# used_combinations = set()
# def generate_dialogues(current_turns=None, turn_index=0):
# if current_turns is None:
# current_turns = []
# if len(augmented_dialogues) >= self.config.augmentation_factor:
# return
# if turn_index == len(turn_variations):
# dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
# if dialogue_fingerprint not in used_combinations:
# used_combinations.add(dialogue_fingerprint)
# augmented_dialogues.append({
# 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
# 'turns': current_turns.copy()
# })
# return
# variations = list(turn_variations[turn_index])
# np.random.shuffle(variations)
# for variation in variations[:self.config.max_sampled_variations]:
# if len(augmented_dialogues) >= self.config.augmentation_factor:
# return
# current_turns.append(variation)
# generate_dialogues(current_turns, turn_index + 1)
# current_turns.pop()
# try:
# generate_dialogues()
# except Exception as e:
# print(f"Error in dialogue generation: {str(e)}")
# return []
# return augmented_dialogues
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
"""
Check if two dialogues are duplicates.
"""
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
return text1 == text2
def _augment_short_text(self, turn: Dict) -> List[Dict]:
"""
Special handling for very short texts with predefined variations.
If predefined variations are found, return them directly.
Otherwise, produce simple punctuation and capitalization variants.
Skip heavy quality checks for efficiency. These variations are safe and minimal.
"""
text = turn['text']
common_variations = {
'goodbye': [
'Bye!', 'Farewell!', 'See you!', 'Take care!',
'Goodbye!', 'Bye for now!', 'Until next time!'
],
'hello': [
'Hi!', 'Hey!', 'Hello!', 'Greetings!',
'Good day!', 'Hi there!', 'Hello there!'
],
'yes': [
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
'That\'s right!', 'Definitely!', 'Sure!'
],
'no': [
'No!', 'Nope!', 'Not at all!', 'Negative!',
'Unfortunately not!', 'I\'m afraid not!'
],
'thanks': [
'Thank you!', 'Thanks a lot!', 'Many thanks!',
'I appreciate it!', 'Thank you so much!'
],
'ok': [
'Okay!', 'Alright!', 'Sure!', 'Got it!',
'Understood!', 'Fine!', 'Great!', 'Perfect!',
'That works!', 'Sounds good!'
],
'good': [
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
'Fantastic!', 'Amazing!', 'Terrific!'
]
}
text_lower = text.lower().rstrip('!.,?')
# Check if text matches any predefined category
variations = []
for key, predefined_vars in common_variations.items():
if key in text_lower or text_lower in key:
variations.extend(predefined_vars)
if not variations:
# Generate simple punctuation and capitalization variations if no predefined match
base = text.rstrip('!.,?')
variations = [
base + '!',
base + '.',
base
]
# Add capitalization variations
capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
variations.extend(capitalized)
# Ensure uniqueness
unique_variations = list(set(variations))
# Directly return these variations, as they are minimal and trusted
# No further quality checks are needed
result_variations = unique_variations[:self.config.augmentation_factor]
return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
def process_batch(self, batch: List[Dict]) -> List[Dict]:
"""Process multiple dialogues at once to maximize GPU utilization"""
results = []
# Pre-compute embeddings for all texts in batch
all_texts = []
text_to_embedding = {}
for dialogue in batch:
for turn in dialogue['turns']:
all_texts.append(turn['text'])
# Batch compute embeddings
if all_texts:
embeddings = self._compute_batch_embeddings(all_texts)
for text, embedding in zip(all_texts, embeddings):
self.embedding_cache[text] = embedding
# Process each dialogue using cached embeddings
for dialogue in batch:
try:
augmented = self.augment_dialogue(dialogue)
results.extend(augmented)
except Exception as e:
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
continue
return results