Merge branch 'dev'
Browse files- .gitignore +2 -6
- augmented_combined_dataset.json +0 -0
- back_translator.py +56 -0
- dialogue_augmenter.py +716 -0
- main.py +124 -0
- paraphraser.py +31 -0
- pipeline_config.py +58 -0
- processing_pipeline.py +176 -0
- quality_metrics.py +129 -0
- readme.md +43 -0
- requirements.txt +12 -0
- schema_guided_dialogue_processor.py +192 -0
- setup.py +100 -0
- taskmaster_processor.py +192 -0
.gitignore
CHANGED
@@ -154,9 +154,5 @@ dmypy.json
|
|
154 |
# Cython debug symbols
|
155 |
cython_debug/
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
-
#.idea/
|
|
|
154 |
# Cython debug symbols
|
155 |
cython_debug/
|
156 |
|
157 |
+
datasets/*
|
158 |
+
!datasets/.gitkeep
|
|
|
|
|
|
|
|
augmented_combined_dataset.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
back_translator.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
MarianMTModel,
|
3 |
+
MarianTokenizer,
|
4 |
+
)
|
5 |
+
|
6 |
+
class BackTranslator:
|
7 |
+
"""
|
8 |
+
Perform Back-translation with pivot language. English -> German -> Spanish -> English
|
9 |
+
Args:
|
10 |
+
source_lang: Source language (default: 'en')
|
11 |
+
pivot_lang: Pivot language (default: 'de')
|
12 |
+
target_lang: Target language (default: 'es')
|
13 |
+
Examples:
|
14 |
+
back_translator = BackTranslator()
|
15 |
+
back_translator.back_translate("Hello, how are you?")
|
16 |
+
"""
|
17 |
+
def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
|
18 |
+
# Forward (English to German)
|
19 |
+
pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
|
20 |
+
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
|
21 |
+
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
|
22 |
+
|
23 |
+
# Pivot translation model (German to Spanish)
|
24 |
+
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
|
25 |
+
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
|
26 |
+
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
|
27 |
+
|
28 |
+
# Backward (Spanish to English)
|
29 |
+
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
|
30 |
+
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
|
31 |
+
self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
|
32 |
+
|
33 |
+
def back_translate(self, text):
|
34 |
+
"""
|
35 |
+
Perform back-translation through German and Spanish to generate text variations.
|
36 |
+
Args:
|
37 |
+
text (str): The input text to be back-translated
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
str: The back-translated text
|
41 |
+
"""
|
42 |
+
# 1. English to German
|
43 |
+
encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, truncation=True, return_tensors='pt')
|
44 |
+
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
|
45 |
+
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, skip_special_tokens=True)[0]
|
46 |
+
|
47 |
+
# 2. German to Spanish
|
48 |
+
encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, truncation=True, return_tensors='pt')
|
49 |
+
retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
|
50 |
+
tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, skip_special_tokens=True)[0]
|
51 |
+
|
52 |
+
# 3. Spanish to English
|
53 |
+
encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, truncation=True, return_tensors='pt')
|
54 |
+
retranslated = self.model_backward.generate(**encoded_back)
|
55 |
+
src_text = self.tokenizer_backward.batch_decode(retranslated, skip_special_tokens=True)[0]
|
56 |
+
return src_text
|
dialogue_augmenter.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow_hub as hub
|
5 |
+
import re
|
6 |
+
from pipeline_config import PipelineConfig
|
7 |
+
from quality_metrics import QualityMetrics
|
8 |
+
from paraphraser import Paraphraser
|
9 |
+
from back_translator import BackTranslator
|
10 |
+
import nlpaug.augmenter.word as naw
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
from functools import lru_cache
|
13 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
|
15 |
+
class DialogueAugmenter:
|
16 |
+
"""
|
17 |
+
Optimized dialogue augmentation with quality control and complexity management.
|
18 |
+
"""
|
19 |
+
def __init__(self, nlp, config: PipelineConfig):
|
20 |
+
self.nlp = nlp
|
21 |
+
self.config = config
|
22 |
+
self.quality_metrics = QualityMetrics(config)
|
23 |
+
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
24 |
+
|
25 |
+
# Advanced augmentation techniques
|
26 |
+
self.paraphraser = Paraphraser()
|
27 |
+
self.back_translator = BackTranslator()
|
28 |
+
|
29 |
+
# Basic augmentation techniques
|
30 |
+
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
31 |
+
self.spelling_augmenter = naw.SpellingAug()
|
32 |
+
|
33 |
+
self.augmenters = {
|
34 |
+
'advanced': [self.paraphraser, self.back_translator],
|
35 |
+
'basic': [
|
36 |
+
('synonym', self.word_augmenter),
|
37 |
+
('spelling', self.spelling_augmenter)
|
38 |
+
]
|
39 |
+
}
|
40 |
+
|
41 |
+
# Initialize cache
|
42 |
+
self.embedding_cache = {}
|
43 |
+
self.perplexity_cache = {}
|
44 |
+
|
45 |
+
# Compile regex patterns
|
46 |
+
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
|
47 |
+
|
48 |
+
# GPU memory management
|
49 |
+
gpus = tf.config.list_physical_devices('GPU')
|
50 |
+
if gpus:
|
51 |
+
try:
|
52 |
+
for gpu in gpus:
|
53 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
54 |
+
except RuntimeError as e:
|
55 |
+
print(e)
|
56 |
+
|
57 |
+
@lru_cache(maxsize=1024)
|
58 |
+
def _compute_embedding(self, text: str) -> np.ndarray:
|
59 |
+
"""Cached computation of text embedding"""
|
60 |
+
return self.use_model([text])[0].numpy()
|
61 |
+
|
62 |
+
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
|
63 |
+
"""Compute embeddings for multiple texts at once"""
|
64 |
+
return self.use_model(texts).numpy()
|
65 |
+
|
66 |
+
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
67 |
+
"""
|
68 |
+
Simplified preliminary quality check with minimal standards
|
69 |
+
"""
|
70 |
+
if self.config.debug:
|
71 |
+
print(f"\nQuick check for variation: {variation}")
|
72 |
+
|
73 |
+
# Only reject if length is extremely different
|
74 |
+
orig_len = len(original.split())
|
75 |
+
var_len = len(variation.split())
|
76 |
+
|
77 |
+
# For very short texts (1-3 words), allow more variation
|
78 |
+
if orig_len <= 3:
|
79 |
+
if var_len > orig_len * 4: # Allow up to 4x length for short texts
|
80 |
+
if self.config.debug:
|
81 |
+
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
82 |
+
return False
|
83 |
+
else:
|
84 |
+
if var_len > orig_len * 3: # Allow up to 3x length for longer texts
|
85 |
+
if self.config.debug:
|
86 |
+
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
87 |
+
return False
|
88 |
+
|
89 |
+
# Basic content check - at least one word in common (excluding stop words)
|
90 |
+
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'}
|
91 |
+
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
92 |
+
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
93 |
+
|
94 |
+
if not orig_words.intersection(var_words):
|
95 |
+
if self.config.debug:
|
96 |
+
print("Failed content check: no content words in common")
|
97 |
+
return False
|
98 |
+
|
99 |
+
if self.config.debug:
|
100 |
+
print("Passed all quick checks")
|
101 |
+
return True
|
102 |
+
|
103 |
+
def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
|
104 |
+
"""Compute quality metrics for multiple candidates in parallel"""
|
105 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
106 |
+
futures = [
|
107 |
+
executor.submit(self.quality_metrics.compute_metrics, original, candidate)
|
108 |
+
for candidate in candidates
|
109 |
+
]
|
110 |
+
return [future.result() for future in futures]
|
111 |
+
|
112 |
+
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
|
113 |
+
"""
|
114 |
+
Filter variations using batched computations with detailed logging
|
115 |
+
"""
|
116 |
+
if not variations:
|
117 |
+
return []
|
118 |
+
|
119 |
+
if self.config.debug:
|
120 |
+
print(f"\nStarting filtration of {len(variations)} variations")
|
121 |
+
print(f"Context length: {len(context)}")
|
122 |
+
print(f"Original turn: {original_turn}")
|
123 |
+
|
124 |
+
words = original_turn.split()
|
125 |
+
if len(words) < 3:
|
126 |
+
if self.config.debug:
|
127 |
+
print("Short text detected, using predefined variations")
|
128 |
+
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
|
129 |
+
return [var['text'] for var in short_text_variations]
|
130 |
+
|
131 |
+
# If this is the first turn (no context), be more lenient
|
132 |
+
if not context:
|
133 |
+
preliminary_filtered = variations
|
134 |
+
if self.config.debug:
|
135 |
+
print("First turn - skipping preliminary filtering")
|
136 |
+
else:
|
137 |
+
# Quick preliminary filtering against original turn
|
138 |
+
preliminary_filtered = []
|
139 |
+
for var in variations:
|
140 |
+
passed = self._quick_quality_check(var, original_turn)
|
141 |
+
if self.config.debug:
|
142 |
+
print(f"\nVariation: {var}")
|
143 |
+
print(f"Passed quick check: {passed}")
|
144 |
+
if passed:
|
145 |
+
preliminary_filtered.append(var)
|
146 |
+
|
147 |
+
if self.config.debug:
|
148 |
+
print(f"Variations after quick check: {len(preliminary_filtered)}")
|
149 |
+
|
150 |
+
if not preliminary_filtered:
|
151 |
+
return []
|
152 |
+
|
153 |
+
# Only use last turn for coherence
|
154 |
+
recent_context = [context[-1]] if context else []
|
155 |
+
context_text = ' '.join(recent_context) if recent_context else ''
|
156 |
+
|
157 |
+
# Even more lenient thresholds
|
158 |
+
min_similarity = 0.1 # Further reduced
|
159 |
+
min_coherence = 0.05 # Further reduced
|
160 |
+
|
161 |
+
if context_text:
|
162 |
+
if self.config.debug:
|
163 |
+
print(f"\nContext text: {context_text}")
|
164 |
+
|
165 |
+
all_texts = [context_text] + preliminary_filtered
|
166 |
+
all_embeddings = self._compute_batch_embeddings(all_texts)
|
167 |
+
|
168 |
+
context_embedding = all_embeddings[0]
|
169 |
+
variation_embeddings = all_embeddings[1:]
|
170 |
+
|
171 |
+
# Vectorized similarity computation
|
172 |
+
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
|
173 |
+
|
174 |
+
# Response coherence check
|
175 |
+
if recent_context:
|
176 |
+
prev_embedding = self._compute_embedding(recent_context[-1])
|
177 |
+
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
|
178 |
+
else:
|
179 |
+
response_coherence = np.ones_like(context_similarities)
|
180 |
+
|
181 |
+
# Combined scoring with detailed logging
|
182 |
+
filtered_variations = []
|
183 |
+
for i, (variation, sim, coh) in enumerate(zip(
|
184 |
+
preliminary_filtered, context_similarities, response_coherence)):
|
185 |
+
# Use absolute values for scoring
|
186 |
+
combined_score = (
|
187 |
+
self.config.context_similarity_weight * abs(sim) +
|
188 |
+
self.config.response_coherence_weight * abs(coh)
|
189 |
+
)
|
190 |
+
|
191 |
+
if self.config.debug:
|
192 |
+
print(f"\nVariation: {variation}")
|
193 |
+
print(f"Context similarity: {sim:.3f}")
|
194 |
+
print(f"Response coherence: {coh:.3f}")
|
195 |
+
print(f"Combined score: {combined_score:.3f}")
|
196 |
+
|
197 |
+
# Accept if EITHER score is good enough
|
198 |
+
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
|
199 |
+
filtered_variations.append(variation)
|
200 |
+
if self.config.debug:
|
201 |
+
print("ACCEPTED")
|
202 |
+
else:
|
203 |
+
if self.config.debug:
|
204 |
+
print("REJECTED")
|
205 |
+
|
206 |
+
# If we have enough variations, stop
|
207 |
+
if len(filtered_variations) >= self.config.max_variations_per_turn:
|
208 |
+
break
|
209 |
+
else:
|
210 |
+
filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn]
|
211 |
+
|
212 |
+
if self.config.debug:
|
213 |
+
print(f"\nFinal filtered variations: {len(filtered_variations)}")
|
214 |
+
|
215 |
+
return filtered_variations
|
216 |
+
|
217 |
+
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
|
218 |
+
"""
|
219 |
+
Generate variations progressively until we have enough good ones
|
220 |
+
"""
|
221 |
+
variations = set()
|
222 |
+
|
223 |
+
if self.config.debug:
|
224 |
+
print(f"\nAttempting to generate {needed} variations for text: {text}")
|
225 |
+
|
226 |
+
# Try advanced augmenters first
|
227 |
+
for augmenter in self.augmenters['advanced']:
|
228 |
+
if len(variations) >= needed:
|
229 |
+
break
|
230 |
+
|
231 |
+
try:
|
232 |
+
if isinstance(augmenter, Paraphraser):
|
233 |
+
if self.config.debug:
|
234 |
+
print("Trying paraphrase augmentation...")
|
235 |
+
new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations))
|
236 |
+
if self.config.debug:
|
237 |
+
print(f"Paraphraser generated {len(new_vars)} variations")
|
238 |
+
else:
|
239 |
+
if self.config.debug:
|
240 |
+
print("Trying back translation...")
|
241 |
+
new_vars = [augmenter.back_translate(text)]
|
242 |
+
if self.config.debug:
|
243 |
+
print(f"Back translator generated {len(new_vars)} variations")
|
244 |
+
|
245 |
+
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
246 |
+
variations.update(valid_vars)
|
247 |
+
|
248 |
+
if self.config.debug:
|
249 |
+
print(f"Current unique variations: {len(variations)}")
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
print(f"Error in advanced augmentation: {str(e)}")
|
253 |
+
continue
|
254 |
+
|
255 |
+
# Try basic augmenters if needed
|
256 |
+
if len(variations) < needed:
|
257 |
+
if self.config.debug:
|
258 |
+
print("Not enough variations, trying basic augmenters...")
|
259 |
+
|
260 |
+
for aug_type, augmenter in self.augmenters['basic']:
|
261 |
+
if len(variations) >= needed:
|
262 |
+
break
|
263 |
+
|
264 |
+
try:
|
265 |
+
if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
|
266 |
+
if self.config.debug:
|
267 |
+
print("Skipping spelling augmentation for technical text")
|
268 |
+
continue
|
269 |
+
|
270 |
+
if self.config.debug:
|
271 |
+
print(f"Trying {aug_type} augmentation...")
|
272 |
+
|
273 |
+
new_vars = augmenter.augment(text, n=2)
|
274 |
+
if isinstance(new_vars, list):
|
275 |
+
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
276 |
+
variations.update(valid_vars)
|
277 |
+
else:
|
278 |
+
if new_vars.strip() and new_vars != text:
|
279 |
+
variations.add(new_vars)
|
280 |
+
|
281 |
+
if self.config.debug:
|
282 |
+
print(f"After {aug_type}, total variations: {len(variations)}")
|
283 |
+
|
284 |
+
except Exception as e:
|
285 |
+
print(f"Error in {aug_type} augmentation: {str(e)}")
|
286 |
+
continue
|
287 |
+
|
288 |
+
variations_list = list(variations)
|
289 |
+
|
290 |
+
if self.config.debug:
|
291 |
+
print(f"Final number of variations generated: {len(variations_list)}")
|
292 |
+
if not variations_list:
|
293 |
+
print("WARNING: No variations were generated!")
|
294 |
+
|
295 |
+
return variations_list
|
296 |
+
|
297 |
+
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
|
298 |
+
"""
|
299 |
+
Create augmented versions of the dialogue with optimized processing
|
300 |
+
"""
|
301 |
+
# Early dialogue length check
|
302 |
+
original_length = len(dialogue['turns'])
|
303 |
+
if original_length > self.config.max_turns_per_dialogue:
|
304 |
+
if self.config.debug:
|
305 |
+
print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
|
306 |
+
dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
|
307 |
+
|
308 |
+
turn_variations = []
|
309 |
+
context = []
|
310 |
+
|
311 |
+
# Process each turn with progressive generation
|
312 |
+
for turn in dialogue['turns']:
|
313 |
+
original_text = turn['text'] # Store original turn text
|
314 |
+
variations = self._generate_variations_progressive(
|
315 |
+
original_text,
|
316 |
+
self.config.max_variations_per_turn
|
317 |
+
)
|
318 |
+
|
319 |
+
# Batch filter variations with original text
|
320 |
+
filtered_variations = self._filter_variations_batch(
|
321 |
+
variations,
|
322 |
+
context,
|
323 |
+
original_text # Pass the original turn text
|
324 |
+
)
|
325 |
+
|
326 |
+
# Create turn variations with speaker info
|
327 |
+
turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
|
328 |
+
|
329 |
+
if self.config.debug:
|
330 |
+
print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
|
331 |
+
|
332 |
+
turn_variations.append(turn_vars)
|
333 |
+
context.append(original_text)
|
334 |
+
|
335 |
+
# Generate combinations with sampling
|
336 |
+
augmented_dialogues = self._generate_dialogue_combinations(
|
337 |
+
dialogue['dialogue_id'],
|
338 |
+
turn_variations
|
339 |
+
)
|
340 |
+
|
341 |
+
# Add original dialogue
|
342 |
+
result = [{
|
343 |
+
'dialogue_id': f"{dialogue['dialogue_id']}_original",
|
344 |
+
'turns': dialogue['turns']
|
345 |
+
}]
|
346 |
+
|
347 |
+
# Add unique augmentations
|
348 |
+
result.extend(augmented_dialogues[:self.config.augmentation_factor])
|
349 |
+
|
350 |
+
if self.config.debug:
|
351 |
+
print(f"Generated {len(result)-1} unique augmented dialogues")
|
352 |
+
|
353 |
+
return result
|
354 |
+
|
355 |
+
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
|
356 |
+
"""
|
357 |
+
Generate dialogue combinations using sampling
|
358 |
+
"""
|
359 |
+
augmented_dialogues = []
|
360 |
+
used_combinations = set()
|
361 |
+
|
362 |
+
def generate_dialogues(current_turns=None, turn_index=0):
|
363 |
+
if current_turns is None:
|
364 |
+
current_turns = []
|
365 |
+
|
366 |
+
if len(augmented_dialogues) >= self.config.augmentation_factor:
|
367 |
+
return
|
368 |
+
|
369 |
+
if turn_index == len(turn_variations):
|
370 |
+
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
371 |
+
if dialogue_fingerprint not in used_combinations:
|
372 |
+
used_combinations.add(dialogue_fingerprint)
|
373 |
+
augmented_dialogues.append({
|
374 |
+
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
|
375 |
+
'turns': current_turns.copy()
|
376 |
+
})
|
377 |
+
return
|
378 |
+
|
379 |
+
variations = list(turn_variations[turn_index])
|
380 |
+
np.random.shuffle(variations)
|
381 |
+
|
382 |
+
for variation in variations[:self.config.max_sampled_variations]:
|
383 |
+
if len(augmented_dialogues) >= self.config.augmentation_factor:
|
384 |
+
return
|
385 |
+
current_turns.append(variation)
|
386 |
+
generate_dialogues(current_turns, turn_index + 1)
|
387 |
+
current_turns.pop()
|
388 |
+
|
389 |
+
try:
|
390 |
+
generate_dialogues()
|
391 |
+
except Exception as e:
|
392 |
+
print(f"Error in dialogue generation: {str(e)}")
|
393 |
+
return []
|
394 |
+
|
395 |
+
return augmented_dialogues
|
396 |
+
|
397 |
+
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
|
398 |
+
"""
|
399 |
+
Check if two dialogues are duplicates.
|
400 |
+
"""
|
401 |
+
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
|
402 |
+
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
|
403 |
+
return text1 == text2
|
404 |
+
|
405 |
+
# def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]:
|
406 |
+
# """
|
407 |
+
# Generate augmented versions of the turn using multiple strategies.
|
408 |
+
# """
|
409 |
+
# text = turn['text']
|
410 |
+
# words = text.split()
|
411 |
+
|
412 |
+
# # Special handling for very short texts
|
413 |
+
# if len(words) < 3:
|
414 |
+
# return self._augment_short_text(turn)
|
415 |
+
|
416 |
+
# all_variations = set()
|
417 |
+
|
418 |
+
# # Advanced augmentations (paraphrase and back-translation)
|
419 |
+
# for augmenter in self.augmenters['advanced']:
|
420 |
+
# try:
|
421 |
+
# if isinstance(augmenter, Paraphraser):
|
422 |
+
# variations = augmenter.paraphrase(text)
|
423 |
+
# all_variations.update(variations)
|
424 |
+
# elif isinstance(augmenter, BackTranslator):
|
425 |
+
# aug_text = augmenter.back_translate(text)
|
426 |
+
# if aug_text:
|
427 |
+
# all_variations.add(aug_text)
|
428 |
+
# except Exception as e:
|
429 |
+
# print(f"Error in advanced augmentation: {str(e)}")
|
430 |
+
# continue
|
431 |
+
|
432 |
+
# # Basic nlpaug augmentations
|
433 |
+
# for aug_type, augmenter in self.augmenters['basic']:
|
434 |
+
# try:
|
435 |
+
# if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
|
436 |
+
# continue
|
437 |
+
|
438 |
+
# aug_texts = augmenter.augment(text, n=2)
|
439 |
+
# if isinstance(aug_texts, list):
|
440 |
+
# all_variations.update(aug_texts)
|
441 |
+
# else:
|
442 |
+
# all_variations.add(aug_texts)
|
443 |
+
# except Exception as e:
|
444 |
+
# print(f"Error in {aug_type} augmentation: {str(e)}")
|
445 |
+
# continue
|
446 |
+
|
447 |
+
# # Remove exact duplicates and empty strings
|
448 |
+
# augmented_texts = [t for t in list(all_variations) if t.strip()]
|
449 |
+
|
450 |
+
# # Apply context filtering
|
451 |
+
# if context:
|
452 |
+
# augmented_texts = self._filter_by_context(augmented_texts, context)
|
453 |
+
# print(f"After context filtering: {len(augmented_texts)} variations")
|
454 |
+
|
455 |
+
# # Select best variations
|
456 |
+
# best_variations = self._select_best_augmentations(
|
457 |
+
# text,
|
458 |
+
# augmented_texts,
|
459 |
+
# num_to_select=self.config.augmentation_factor,
|
460 |
+
# min_quality_score=0.7
|
461 |
+
# )
|
462 |
+
|
463 |
+
# # Create variations with speaker info
|
464 |
+
# variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations]
|
465 |
+
|
466 |
+
# return variations
|
467 |
+
|
468 |
+
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
469 |
+
"""
|
470 |
+
Special handling for very short texts with predefined variations.
|
471 |
+
Args:
|
472 |
+
turn (Dict): Original dialogue turn
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
List[Dict]: List of variations for the short text
|
476 |
+
"""
|
477 |
+
text = turn['text']
|
478 |
+
common_variations = {
|
479 |
+
'goodbye': [
|
480 |
+
'Bye!', 'Farewell!', 'See you!', 'Take care!',
|
481 |
+
'Goodbye!', 'Bye for now!', 'Until next time!'
|
482 |
+
],
|
483 |
+
'hello': [
|
484 |
+
'Hi!', 'Hey!', 'Hello!', 'Greetings!',
|
485 |
+
'Good day!', 'Hi there!', 'Hello there!'
|
486 |
+
],
|
487 |
+
'yes': [
|
488 |
+
'Yes!', 'Correct!', 'Indeed!', 'Absolutely!',
|
489 |
+
'That\'s right!', 'Definitely!', 'Sure!'
|
490 |
+
],
|
491 |
+
'no': [
|
492 |
+
'No!', 'Nope!', 'Not at all!', 'Negative!',
|
493 |
+
'Unfortunately not!', 'I\'m afraid not!'
|
494 |
+
],
|
495 |
+
'thanks': [
|
496 |
+
'Thank you!', 'Thanks a lot!', 'Many thanks!',
|
497 |
+
'I appreciate it!', 'Thank you so much!'
|
498 |
+
],
|
499 |
+
'ok': [
|
500 |
+
'Okay!', 'Alright!', 'Sure!', 'Got it!',
|
501 |
+
'Understood!', 'Fine!', 'Great!', 'Perfect!',
|
502 |
+
'That works!', 'Sounds good!'
|
503 |
+
],
|
504 |
+
'good': [
|
505 |
+
'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
|
506 |
+
'Fantastic!', 'Amazing!', 'Terrific!'
|
507 |
+
]
|
508 |
+
}
|
509 |
+
|
510 |
+
# Try to find matching variations
|
511 |
+
text_lower = text.lower().rstrip('!.,?')
|
512 |
+
variations = []
|
513 |
+
|
514 |
+
# Check if text matches any of our predefined categories
|
515 |
+
for key, predefined_vars in common_variations.items():
|
516 |
+
if key in text_lower or text_lower in key:
|
517 |
+
variations.extend(predefined_vars)
|
518 |
+
|
519 |
+
# If no predefined variations found, generate simple variants
|
520 |
+
if not variations:
|
521 |
+
# Add punctuation variations
|
522 |
+
variations = [
|
523 |
+
text.rstrip('!.,?') + '!',
|
524 |
+
text.rstrip('!.,?') + '.',
|
525 |
+
text.rstrip('!.,?')
|
526 |
+
]
|
527 |
+
|
528 |
+
# Add capitalization variations
|
529 |
+
variations.extend([
|
530 |
+
v.capitalize() for v in variations
|
531 |
+
if v.capitalize() not in variations
|
532 |
+
])
|
533 |
+
|
534 |
+
# Filter variations for uniqueness and quality
|
535 |
+
unique_variations = list(set(variations))
|
536 |
+
quality_variations = []
|
537 |
+
|
538 |
+
for var in unique_variations:
|
539 |
+
metrics = self.quality_metrics.compute_metrics(text, var)
|
540 |
+
quality_score = (
|
541 |
+
0.35 * metrics['semantic_similarity'] +
|
542 |
+
0.30 * (1.0 - metrics['perplexity'] / 100) +
|
543 |
+
0.15 * (1.0 - metrics['grammar_errors'] / 10) +
|
544 |
+
0.15 * metrics['content_preservation'] +
|
545 |
+
0.10 * metrics['type_token_ratio']
|
546 |
+
)
|
547 |
+
|
548 |
+
# More lenient quality threshold for short texts
|
549 |
+
if quality_score >= 0.5: # Lower threshold for short texts
|
550 |
+
quality_variations.append(var)
|
551 |
+
|
552 |
+
# Ensure we have at least some variations
|
553 |
+
if not quality_variations:
|
554 |
+
quality_variations = [text]
|
555 |
+
|
556 |
+
# Return the variations with original speaker
|
557 |
+
return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]]
|
558 |
+
|
559 |
+
def _is_technical_or_formal_text(self, text: str) -> bool:
|
560 |
+
"""
|
561 |
+
Check if text is formal/technical and shouldn't have spelling variations.
|
562 |
+
"""
|
563 |
+
formal_indicators = {
|
564 |
+
'technical_terms': {'api', 'config', 'database', 'server', 'system'},
|
565 |
+
'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'},
|
566 |
+
'professional_context': {'meeting', 'conference', 'project', 'deadline'}
|
567 |
+
}
|
568 |
+
|
569 |
+
text_lower = text.lower()
|
570 |
+
words = set(text_lower.split())
|
571 |
+
|
572 |
+
for category in formal_indicators.values():
|
573 |
+
if words.intersection(category):
|
574 |
+
return True
|
575 |
+
|
576 |
+
return False
|
577 |
+
|
578 |
+
# def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]:
|
579 |
+
# """
|
580 |
+
# Filter variations based on conversation context using config parameters.
|
581 |
+
# """
|
582 |
+
# # Manage context window using config
|
583 |
+
# recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context
|
584 |
+
|
585 |
+
# filtered_variations = []
|
586 |
+
# context_embedding = self.use_model([' '.join(recent_context)])[0].numpy()
|
587 |
+
|
588 |
+
# prev_turn = recent_context[-1] if recent_context else ''
|
589 |
+
|
590 |
+
# for variation in variations:
|
591 |
+
# var_embedding = self.use_model([variation])[0].numpy()
|
592 |
+
|
593 |
+
# # Overall context similarity
|
594 |
+
# context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0]
|
595 |
+
|
596 |
+
# # Direct response coherence
|
597 |
+
# response_coherence = 1.0
|
598 |
+
# if prev_turn:
|
599 |
+
# prev_embedding = self.use_model([prev_turn])[0].numpy()
|
600 |
+
# response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0]
|
601 |
+
|
602 |
+
# # Use weights from config
|
603 |
+
# combined_similarity = (
|
604 |
+
# self.config.context_similarity_weight * context_similarity +
|
605 |
+
# self.config.response_coherence_weight * response_coherence
|
606 |
+
# )
|
607 |
+
|
608 |
+
# if (combined_similarity >= self.config.semantic_similarity_threshold and
|
609 |
+
# response_coherence >= self.config.min_response_coherence):
|
610 |
+
# filtered_variations.append(variation)
|
611 |
+
# if self.config.debug:
|
612 |
+
# print(f"Accepted variation: {variation}")
|
613 |
+
# print(f"Context similarity: {context_similarity:.3f}")
|
614 |
+
# print(f"Response coherence: {response_coherence:.3f}")
|
615 |
+
# print(f"Combined score: {combined_similarity:.3f}\n")
|
616 |
+
# else:
|
617 |
+
# if self.config.debug:
|
618 |
+
# print(f"Rejected variation: {variation}")
|
619 |
+
# print(f"Combined score {combined_similarity:.3f} below threshold "
|
620 |
+
# f"{self.config.semantic_similarity_threshold}")
|
621 |
+
# print(f"Response coherence {response_coherence:.3f} below threshold "
|
622 |
+
# f"{self.config.min_response_coherence}\n")
|
623 |
+
|
624 |
+
# return filtered_variations or variations # Fallback to original
|
625 |
+
|
626 |
+
# def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None,
|
627 |
+
# num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]:
|
628 |
+
# """
|
629 |
+
# Select the best augmentations using a quality score.
|
630 |
+
# Args:
|
631 |
+
# original (str): The original text
|
632 |
+
# candidates (List[str]): List of candidate augmented texts
|
633 |
+
# used_variations (set): Set of already used variations
|
634 |
+
# num_to_select (int): Number of variations to select
|
635 |
+
# min_quality_score (float): Minimum quality score threshold
|
636 |
+
# """
|
637 |
+
# if used_variations is None:
|
638 |
+
# used_variations = set()
|
639 |
+
|
640 |
+
# candidates = [c for c in candidates if c.strip()]
|
641 |
+
|
642 |
+
# # Skip short text
|
643 |
+
# if len(original.split()) < 3:
|
644 |
+
# print(f"Text too short for augmentation: {original}")
|
645 |
+
# return [original]
|
646 |
+
|
647 |
+
# scored_candidates = []
|
648 |
+
# for candidate in candidates:
|
649 |
+
# if candidate in used_variations:
|
650 |
+
# continue
|
651 |
+
|
652 |
+
# metrics = self.quality_metrics.compute_metrics(original, candidate)
|
653 |
+
|
654 |
+
# # Add contextual penalty for inappropriate audience terms
|
655 |
+
# audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'}
|
656 |
+
# has_audience_term = any(term in candidate.lower() for term in audience_terms)
|
657 |
+
# audience_penalty = 0.2 if has_audience_term else 0.0
|
658 |
+
|
659 |
+
# # Weighted quality score
|
660 |
+
# quality_score = (
|
661 |
+
# 0.40 * metrics['semantic_similarity'] + # Semantic preservation
|
662 |
+
# 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency
|
663 |
+
# 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar
|
664 |
+
# 0.15 * metrics['content_preservation'] + # Content preservation
|
665 |
+
# 0.05 * metrics['type_token_ratio'] # Lexical diversity
|
666 |
+
# )
|
667 |
+
|
668 |
+
# quality_score -= audience_penalty
|
669 |
+
|
670 |
+
# if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss
|
671 |
+
# metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap
|
672 |
+
# continue
|
673 |
+
|
674 |
+
# # Bonus points for:
|
675 |
+
# # Length similarity to original
|
676 |
+
# if 0.75 <= metrics['length_ratio'] <= 1.25:
|
677 |
+
# quality_score += 0.05
|
678 |
+
|
679 |
+
# # Correct grammar
|
680 |
+
# if metrics['grammar_errors'] == 0:
|
681 |
+
# quality_score += 0.025
|
682 |
+
|
683 |
+
# print(f"Candidate: {candidate}")
|
684 |
+
# print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}")
|
685 |
+
|
686 |
+
# # Consider the augmentationif meets basic quality threshold
|
687 |
+
# if quality_score >= min_quality_score:
|
688 |
+
# print('Candidate accepted\n')
|
689 |
+
# scored_candidates.append((candidate, quality_score, metrics))
|
690 |
+
# else:
|
691 |
+
# print('Candidate rejected\n')
|
692 |
+
|
693 |
+
# # Sort by quality score with small random factor for diversity
|
694 |
+
# scored_candidates.sort(key=lambda x: x[1], reverse=True)
|
695 |
+
|
696 |
+
# selected = []
|
697 |
+
# for candidate, score, metrics in scored_candidates:
|
698 |
+
# # Check diversity against already selected
|
699 |
+
# if len(selected) == 0:
|
700 |
+
# selected.append(candidate)
|
701 |
+
# continue
|
702 |
+
|
703 |
+
# # Compute average similarity to already selected
|
704 |
+
# avg_similarity = np.mean([
|
705 |
+
# self.quality_metrics.compute_semantic_similarity(candidate, prev)
|
706 |
+
# for prev in selected
|
707 |
+
# ])
|
708 |
+
|
709 |
+
# # Add if sufficiently different (similarity < 0.98)
|
710 |
+
# if avg_similarity < 0.98:
|
711 |
+
# selected.append(candidate)
|
712 |
+
|
713 |
+
# if len(selected) >= num_to_select:
|
714 |
+
# break
|
715 |
+
|
716 |
+
# return selected
|
main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
CSC525 - Module 8 Option 2 - Joseph Armani
|
3 |
+
Description and References in the README.md file.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import tensorflow as tf
|
7 |
+
from typing import List, Dict
|
8 |
+
from pipeline_config import PipelineConfig
|
9 |
+
from processing_pipeline import ProcessingPipeline
|
10 |
+
from taskmaster_processor import TaskmasterProcessor
|
11 |
+
from schema_guided_dialogue_processor import SchemaGuidedProcessor
|
12 |
+
|
13 |
+
def combine_datasets(taskmaster_dialogues: List[Dict],
|
14 |
+
schema_guided_dialogues: List[Dict]) -> List[Dict]:
|
15 |
+
"""
|
16 |
+
Combine dialogues from both datasets into a single list
|
17 |
+
|
18 |
+
Args:
|
19 |
+
taskmaster_dialogues: List of dialogues in pipeline format from Taskmaster
|
20 |
+
schema_guided_dialogues: List of dialogues in pipeline format from Schema-Guided
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
List[Dict]: Combined list of dialogues
|
24 |
+
"""
|
25 |
+
# Ensure unique dialogue IDs
|
26 |
+
combined_dialogues = []
|
27 |
+
seen_ids = set()
|
28 |
+
duplicate_count = 0 # Track duplicates for reporting
|
29 |
+
|
30 |
+
for dialogue in taskmaster_dialogues:
|
31 |
+
dialogue_copy = dialogue.copy()
|
32 |
+
dialogue_id = dialogue_copy['dialogue_id']
|
33 |
+
if dialogue_id in seen_ids:
|
34 |
+
duplicate_count += 1
|
35 |
+
dialogue_id = f"taskmaster_{dialogue_id}"
|
36 |
+
seen_ids.add(dialogue_id)
|
37 |
+
dialogue_copy['dialogue_id'] = dialogue_id
|
38 |
+
combined_dialogues.append(dialogue_copy)
|
39 |
+
|
40 |
+
for dialogue in schema_guided_dialogues:
|
41 |
+
dialogue_copy = dialogue.copy()
|
42 |
+
dialogue_id = dialogue_copy['dialogue_id']
|
43 |
+
if dialogue_id in seen_ids:
|
44 |
+
duplicate_count += 1
|
45 |
+
dialogue_id = f"schema_guided_{dialogue_id}"
|
46 |
+
seen_ids.add(dialogue_id)
|
47 |
+
dialogue_copy['dialogue_id'] = dialogue_id
|
48 |
+
combined_dialogues.append(dialogue_copy)
|
49 |
+
|
50 |
+
# Log the results
|
51 |
+
print(f"Combine Datasets: Found and resolved {duplicate_count} duplicate dialogue IDs.")
|
52 |
+
print(f"Combine Datasets: Total dialogues combined: {len(combined_dialogues)}")
|
53 |
+
|
54 |
+
return combined_dialogues
|
55 |
+
|
56 |
+
def main():
|
57 |
+
# Configuration
|
58 |
+
config = PipelineConfig(
|
59 |
+
min_length=1,
|
60 |
+
max_length=512,
|
61 |
+
batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
|
62 |
+
max_turns_per_dialogue=6,
|
63 |
+
max_variations_per_turn=3,
|
64 |
+
max_sampled_variations=2,
|
65 |
+
context_window_size=4,
|
66 |
+
max_complexity_threshold=100,
|
67 |
+
use_cache=False,
|
68 |
+
debug=True,
|
69 |
+
allowed_speakers=['user', 'assistant'],
|
70 |
+
required_fields=['dialogue_id', 'turns']
|
71 |
+
)
|
72 |
+
|
73 |
+
try:
|
74 |
+
# Set max_examples (Optional[int]) for testing
|
75 |
+
max_examples = None
|
76 |
+
|
77 |
+
# Initialize and load Taskmaster dataset
|
78 |
+
print("Loading Taskmaster dataset")
|
79 |
+
taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
|
80 |
+
taskmaster_dir = './datasets/taskmaster'
|
81 |
+
taskmaster_dialogues = taskmaster_processor.load_dataset(taskmaster_dir, max_examples=max_examples)
|
82 |
+
taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
|
83 |
+
print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
|
84 |
+
|
85 |
+
# Initialize and load Schema-Guided dataset
|
86 |
+
print("Loading Schema-Guided dataset")
|
87 |
+
schema_dialogue_processor = SchemaGuidedProcessor(config)
|
88 |
+
schema_guided_dir = './datasets/schema_guided'
|
89 |
+
schema_dialogues = schema_dialogue_processor.load_dataset(schema_guided_dir, max_examples=max_examples)
|
90 |
+
schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
|
91 |
+
print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
|
92 |
+
|
93 |
+
# Combine datasets
|
94 |
+
print("Combining datasets")
|
95 |
+
combined_dialogues = combine_datasets(taskmaster_pipeline_dialogues, schema_pipeline_dialogues)
|
96 |
+
print(f"Combined Dialogues: {len(combined_dialogues)}")
|
97 |
+
|
98 |
+
if not combined_dialogues:
|
99 |
+
print("Combined dialogues are empty. Exiting.")
|
100 |
+
return
|
101 |
+
|
102 |
+
# Process through augmentation pipeline
|
103 |
+
print("Processing combined dataset")
|
104 |
+
pipeline = ProcessingPipeline(config)
|
105 |
+
processed_dialogues = pipeline.process_dataset(combined_dialogues)
|
106 |
+
|
107 |
+
# Save results
|
108 |
+
output_path = 'augmented_combined_dataset.json'
|
109 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
110 |
+
json.dump(processed_dialogues, f, indent=2, ensure_ascii=False)
|
111 |
+
|
112 |
+
# Print statistics
|
113 |
+
print(f"\nProcessed Statistics:")
|
114 |
+
print(f"Total dialogues: {len(processed_dialogues)}")
|
115 |
+
print(f"Taskmaster domains: {len(taskmaster_processor.domains)}")
|
116 |
+
print(f"Schema-Guided services: {len(schema_dialogue_processor.services)}")
|
117 |
+
print(f"Schema-Guided domains: {len(schema_dialogue_processor.domains)}")
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Processing failed: {str(e)}")
|
121 |
+
raise
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
paraphraser.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
AutoTokenizer,
|
3 |
+
AutoModelForSeq2SeqLM,
|
4 |
+
)
|
5 |
+
|
6 |
+
class Paraphraser:
|
7 |
+
def __init__(self, model_name='humarin/chatgpt_paraphraser_on_T5_base'):
|
8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
+
self.model.eval()
|
11 |
+
|
12 |
+
def paraphrase(self, text, num_return_sequences=5, num_beams=10, num_beam_groups=5, diversity_penalty=0.8):
|
13 |
+
try:
|
14 |
+
input_text = "paraphrase: " + text + " </s>"
|
15 |
+
encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
|
16 |
+
input_ids = encoding["input_ids"]
|
17 |
+
|
18 |
+
outputs = self.model.generate(
|
19 |
+
input_ids=input_ids,
|
20 |
+
max_length=256,
|
21 |
+
num_beams=num_beams,
|
22 |
+
num_beam_groups=num_beam_groups,
|
23 |
+
num_return_sequences=num_return_sequences,
|
24 |
+
diversity_penalty=diversity_penalty,
|
25 |
+
early_stopping=True
|
26 |
+
)
|
27 |
+
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
28 |
+
return paraphrases
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Error in paraphrasing: {e}")
|
31 |
+
return []
|
pipeline_config.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class PipelineConfig:
|
6 |
+
"""
|
7 |
+
Config for the pipeline
|
8 |
+
"""
|
9 |
+
# Validation settings
|
10 |
+
min_length: int = 1
|
11 |
+
max_length: int = 512
|
12 |
+
min_tokens: int = 1
|
13 |
+
max_tokens: int = 128
|
14 |
+
|
15 |
+
allowed_speakers: List[str] = None
|
16 |
+
required_fields: List[str] = None
|
17 |
+
|
18 |
+
# Text augmentation settings
|
19 |
+
augmentation_factor: int = 4
|
20 |
+
augmentation_techniques: List[str] = None
|
21 |
+
|
22 |
+
max_turns_per_dialogue: int = 6
|
23 |
+
max_variations_per_turn: int = 3
|
24 |
+
max_sampled_variations: int = 2
|
25 |
+
max_complexity_threshold: int = 100
|
26 |
+
complexity_reduction_turns: int = 4
|
27 |
+
|
28 |
+
# Quality thresholds
|
29 |
+
semantic_similarity_threshold: float = 0.45
|
30 |
+
grammar_error_threshold: int = 2
|
31 |
+
rouge1_f1_threshold: float = 0.30
|
32 |
+
rouge2_f1_threshold: float = 0.15
|
33 |
+
perplexity_threshold: float = 50.0
|
34 |
+
|
35 |
+
# Response coherence thresholds
|
36 |
+
min_response_coherence: float = 0.3
|
37 |
+
context_similarity_weight: float = 0.35
|
38 |
+
response_coherence_weight: float = 0.65
|
39 |
+
|
40 |
+
# Performance settings
|
41 |
+
batch_size: int = 32
|
42 |
+
use_cache: bool = True
|
43 |
+
debug: bool = False
|
44 |
+
|
45 |
+
context_window_size: int = 4
|
46 |
+
|
47 |
+
def __post_init__(self):
|
48 |
+
if self.allowed_speakers is None:
|
49 |
+
self.allowed_speakers = ['user', 'assistant']
|
50 |
+
if self.required_fields is None:
|
51 |
+
self.required_fields = ['dialogue_id', 'turns']
|
52 |
+
if self.augmentation_techniques is None:
|
53 |
+
self.augmentation_techniques = ['paraphrase', 'back_translation']
|
54 |
+
|
55 |
+
# Validate weights sum to 1.0
|
56 |
+
if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
|
57 |
+
raise ValueError("Context similarity and response coherence weights must sum to 1.0")
|
58 |
+
|
processing_pipeline.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Dict, Optional
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import hashlib
|
7 |
+
import pickle
|
8 |
+
import spacy
|
9 |
+
from tqdm import tqdm
|
10 |
+
from pipeline_config import PipelineConfig
|
11 |
+
from dialogue_augmenter import DialogueAugmenter
|
12 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
|
15 |
+
class ProcessingPipeline:
|
16 |
+
"""
|
17 |
+
Complete pipeline combining validation, optimization, and augmentation.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, config: Optional[PipelineConfig] = None):
|
21 |
+
self.config = config or PipelineConfig()
|
22 |
+
self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])
|
23 |
+
self.augmenter = DialogueAugmenter(self.nlp, self.config)
|
24 |
+
self.num_threads = self.config.batch_size
|
25 |
+
self.cache_dir = Path("./cache")
|
26 |
+
self.cache_dir.mkdir(exist_ok=True)
|
27 |
+
|
28 |
+
def process_dataset(self, dialogues: List[Dict]) -> List[Dict]:
|
29 |
+
"""
|
30 |
+
Process entire dataset through the pipeline.
|
31 |
+
"""
|
32 |
+
print(f"Processing {len(dialogues)} dialogues")
|
33 |
+
start_time = datetime.now()
|
34 |
+
|
35 |
+
# Check cache
|
36 |
+
if self.config.use_cache:
|
37 |
+
cache_path = self._get_cache_path(dialogues)
|
38 |
+
if cache_path.exists():
|
39 |
+
print("Loading from cache...")
|
40 |
+
with open(cache_path, 'rb') as f:
|
41 |
+
return pickle.load(f)
|
42 |
+
|
43 |
+
# Validate and clean
|
44 |
+
valid_dialogues = self._process_validation(
|
45 |
+
dialogues,
|
46 |
+
self._validate_and_clean_dialogue,
|
47 |
+
"validating and cleaning"
|
48 |
+
)
|
49 |
+
|
50 |
+
if not valid_dialogues:
|
51 |
+
raise ValueError("Dialogue validation resulted in an empty dataset.")
|
52 |
+
|
53 |
+
deduplicated_dialogues = self._deduplicate_dialogues(valid_dialogues)
|
54 |
+
|
55 |
+
# Augment dialogues
|
56 |
+
all_processed_dialogues = []
|
57 |
+
for dialogue in deduplicated_dialogues:
|
58 |
+
augmented = self.augmenter.augment_dialogue(dialogue)
|
59 |
+
all_processed_dialogues.extend(augmented)
|
60 |
+
|
61 |
+
# Save to cache
|
62 |
+
if self.config.use_cache:
|
63 |
+
with open(cache_path, 'wb') as f:
|
64 |
+
pickle.dump(all_processed_dialogues, f)
|
65 |
+
|
66 |
+
processing_time = datetime.now() - start_time
|
67 |
+
print(f"Processing completed in {processing_time}")
|
68 |
+
print(f"Generated {len(all_processed_dialogues)} total dialogues")
|
69 |
+
|
70 |
+
return all_processed_dialogues
|
71 |
+
|
72 |
+
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
|
73 |
+
"""
|
74 |
+
Deduplicate dialogues based on text similarity.
|
75 |
+
"""
|
76 |
+
print("Deduplicating dialogues...")
|
77 |
+
if not dialogues:
|
78 |
+
print("No dialogues provided for deduplication.")
|
79 |
+
return []
|
80 |
+
|
81 |
+
# Combine turns into single text for similarity comparison
|
82 |
+
texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues]
|
83 |
+
tfidf = TfidfVectorizer().fit_transform(texts)
|
84 |
+
sim_matrix = cosine_similarity(tfidf)
|
85 |
+
|
86 |
+
unique_indices = set()
|
87 |
+
for i, row in enumerate(sim_matrix):
|
88 |
+
if i not in unique_indices:
|
89 |
+
similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i]
|
90 |
+
unique_indices.add(i)
|
91 |
+
unique_indices.difference_update(similar_indices)
|
92 |
+
|
93 |
+
deduplicated_dialogues = [dialogues[i] for i in unique_indices]
|
94 |
+
|
95 |
+
print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.")
|
96 |
+
return deduplicated_dialogues
|
97 |
+
|
98 |
+
def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]:
|
99 |
+
"""
|
100 |
+
Validate and clean a single dialogue.
|
101 |
+
"""
|
102 |
+
try:
|
103 |
+
# Check required fields
|
104 |
+
if not all(field in dialogue for field in self.config.required_fields):
|
105 |
+
return None
|
106 |
+
|
107 |
+
# Process turns
|
108 |
+
cleaned_turns = []
|
109 |
+
for turn in dialogue['turns']:
|
110 |
+
if self._validate_turn(turn):
|
111 |
+
cleaned_turn = {
|
112 |
+
'speaker': turn['speaker'],
|
113 |
+
'text': self._clean_text(turn['text'])
|
114 |
+
}
|
115 |
+
cleaned_turns.append(cleaned_turn)
|
116 |
+
|
117 |
+
if cleaned_turns:
|
118 |
+
return {
|
119 |
+
'dialogue_id': dialogue['dialogue_id'],
|
120 |
+
'turns': cleaned_turns
|
121 |
+
}
|
122 |
+
|
123 |
+
return None
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
|
127 |
+
return None
|
128 |
+
|
129 |
+
def _validate_turn(self, turn: Dict) -> bool:
|
130 |
+
"""
|
131 |
+
Validate a single speaking turn.
|
132 |
+
"""
|
133 |
+
return (
|
134 |
+
turn['speaker'] in self.config.allowed_speakers and
|
135 |
+
self.config.min_length <= len(turn['text']) <= self.config.max_length
|
136 |
+
)
|
137 |
+
|
138 |
+
def _clean_text(self, text: str) -> str:
|
139 |
+
"""
|
140 |
+
Clean and normalize text.
|
141 |
+
"""
|
142 |
+
# Remove excessive whitespace
|
143 |
+
text = re.sub(r'\s+', ' ', text.strip())
|
144 |
+
|
145 |
+
# Normalize quotes and apostrophes
|
146 |
+
text = re.sub(r'[’´`]', "'", text)
|
147 |
+
text = re.sub(r'[“”]', '"', text)
|
148 |
+
|
149 |
+
# Remove control characters
|
150 |
+
text = "".join(char for char in text if ord(char) >= 32 or char == '\n')
|
151 |
+
|
152 |
+
return text
|
153 |
+
|
154 |
+
def _process_validation(self, items: List, func, description: str) -> List:
|
155 |
+
"""
|
156 |
+
Process items sequentially with a progress bar.
|
157 |
+
"""
|
158 |
+
results = []
|
159 |
+
print(f"Starting {description}")
|
160 |
+
for item in tqdm(items, desc=description):
|
161 |
+
try:
|
162 |
+
result = func(item)
|
163 |
+
if result is not None:
|
164 |
+
results.append(result)
|
165 |
+
except Exception as e:
|
166 |
+
print(f"Error processing item: {str(e)}")
|
167 |
+
print(f"Completed {description}. Processed {len(results)} items successfully")
|
168 |
+
return results
|
169 |
+
|
170 |
+
def _get_cache_path(self, data: List[Dict]) -> Path:
|
171 |
+
"""
|
172 |
+
Generate cache file path based on data hash.
|
173 |
+
"""
|
174 |
+
data_str = json.dumps(data, sort_keys=True)
|
175 |
+
hash_value = hashlib.md5(data_str.encode()).hexdigest()
|
176 |
+
return self.cache_dir / f"cache_{hash_value}.pkl"
|
quality_metrics.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tensorflow as tf
|
3 |
+
import tensorflow_hub as hub
|
4 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
5 |
+
import language_tool_python
|
6 |
+
from rouge_score import rouge_scorer
|
7 |
+
import spacy
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
import numpy as np
|
10 |
+
from typing import Dict
|
11 |
+
from pipeline_config import PipelineConfig
|
12 |
+
|
13 |
+
class QualityMetrics:
|
14 |
+
"""
|
15 |
+
Measure augmented text quality
|
16 |
+
"""
|
17 |
+
def __init__(self, config: PipelineConfig):
|
18 |
+
self.config = config
|
19 |
+
|
20 |
+
# Semantic similarity
|
21 |
+
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
22 |
+
|
23 |
+
# Fluency metrics
|
24 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
25 |
+
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
# Grammar
|
29 |
+
self.language_tool = language_tool_python.LanguageTool('en-US')
|
30 |
+
|
31 |
+
# Lexical similarity
|
32 |
+
self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
33 |
+
|
34 |
+
# Diversity
|
35 |
+
self.nlp = spacy.load('en_core_web_sm')
|
36 |
+
|
37 |
+
def compute_perplexity(self, text):
|
38 |
+
try:
|
39 |
+
encodings = self.tokenizer(text, return_tensors='pt')
|
40 |
+
input_ids = encodings['input_ids']
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = self.model(input_ids, labels=input_ids)
|
43 |
+
loss = outputs.loss
|
44 |
+
perplexity = torch.exp(loss)
|
45 |
+
return perplexity.item()
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error computing perplexity for text '{text}': {e}")
|
48 |
+
return float('inf') # High perplexity value == poor quality
|
49 |
+
|
50 |
+
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
|
51 |
+
"""
|
52 |
+
Compute semantic similarity between two texts using the Universal Sentence Encoder.
|
53 |
+
Args:
|
54 |
+
text1 (str): First text
|
55 |
+
text2 (str): Second text
|
56 |
+
Returns:
|
57 |
+
float: Cosine similarity score between the two texts (0-1)
|
58 |
+
"""
|
59 |
+
embeddings = self.use_model([text1, text2])
|
60 |
+
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
|
61 |
+
return cosine_similarity([emb1], [emb2])[0][0]
|
62 |
+
|
63 |
+
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
|
64 |
+
"""
|
65 |
+
Compute quality metrics
|
66 |
+
"""
|
67 |
+
metrics = {}
|
68 |
+
|
69 |
+
# 1. Semantic Preservation
|
70 |
+
embeddings = self.use_model([original, augmented])
|
71 |
+
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
|
72 |
+
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
|
73 |
+
|
74 |
+
# 2. Fluency & Naturalness
|
75 |
+
metrics['perplexity'] = self.compute_perplexity(augmented)
|
76 |
+
metrics['grammar_errors'] = len(self.language_tool.check(augmented))
|
77 |
+
|
78 |
+
# 3. Lexical Diversity
|
79 |
+
doc_orig = self.nlp(original)
|
80 |
+
doc_aug = self.nlp(augmented)
|
81 |
+
|
82 |
+
# Type-token ratio with safety check
|
83 |
+
aug_tokens = [token.text.lower() for token in doc_aug]
|
84 |
+
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
|
85 |
+
|
86 |
+
# Content word overlap with safety checks
|
87 |
+
orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop])
|
88 |
+
aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
|
89 |
+
|
90 |
+
# Safety check for empty content sets
|
91 |
+
if len(orig_content) == 0:
|
92 |
+
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
|
93 |
+
else:
|
94 |
+
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
|
95 |
+
|
96 |
+
# 4. Structural Preservation
|
97 |
+
rouge_scores = self.rouge.score(original, augmented)
|
98 |
+
metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
|
99 |
+
metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
|
100 |
+
metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
|
101 |
+
|
102 |
+
# 5. Length Preservation with safety check
|
103 |
+
orig_words = len(original.split())
|
104 |
+
aug_words = len(augmented.split())
|
105 |
+
metrics['length_ratio'] = aug_words / max(orig_words, 1)
|
106 |
+
|
107 |
+
return metrics
|
108 |
+
|
109 |
+
def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool:
|
110 |
+
"""
|
111 |
+
Enhanced quality threshold checking
|
112 |
+
"""
|
113 |
+
# Core quality checks
|
114 |
+
basic_quality = (
|
115 |
+
metrics['perplexity'] <= self.config.perplexity_threshold and
|
116 |
+
metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
|
117 |
+
metrics['grammar_errors'] <= self.config.grammar_error_threshold
|
118 |
+
)
|
119 |
+
|
120 |
+
# Length preservation check
|
121 |
+
length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
|
122 |
+
|
123 |
+
# Diversity check
|
124 |
+
diversity_ok = metrics['type_token_ratio'] >= 0.4
|
125 |
+
|
126 |
+
# Content preservation check
|
127 |
+
content_ok = metrics['content_preservation'] >= 0.6
|
128 |
+
|
129 |
+
return all([basic_quality, length_ok, diversity_ok, content_ok])
|
readme.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Retrieval-based learning chatbot
|
2 |
+
|
3 |
+
CSC525 - Module 8 Option 2 - Retrieval-based Learning Chatbot - Joseph Armani
|
4 |
+
|
5 |
+
## TODO
|
6 |
+
|
7 |
+
A Python tool to generate high-quality dialog variations.
|
8 |
+
|
9 |
+
This package automatically downloads the following models during installation:
|
10 |
+
|
11 |
+
- Universal Sentence Encoder v4 (TensorFlow Hub)
|
12 |
+
- ChatGPT Paraphraser T5-base
|
13 |
+
- Helsinki-NLP translation models (en-de, de-es, es-en)
|
14 |
+
- GPT-2 (for perplexity scoring)
|
15 |
+
- spaCy en_core_web_sm
|
16 |
+
- nltk wordnet and averaged_perceptron_tagger_eng models
|
17 |
+
|
18 |
+
## Install package
|
19 |
+
|
20 |
+
pip install -e .
|
21 |
+
|
22 |
+
## Description
|
23 |
+
|
24 |
+
This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
|
25 |
+
It creates high-quality augmented versions of dialogues by applying various text augmentation techniques and quality control checks.
|
26 |
+
Two approaches are used for text augmentation: paraphrasing and back-translation. The pipeline also includes quality metrics for evaluating the augmented text.
|
27 |
+
Special handling is implemented for very short text such as greetings and farewells, which are predefined and filtered for quality.
|
28 |
+
The pipeline is designed to process a dataset of dialogues and generate multiple high-quality augmented versions of each dialogue.
|
29 |
+
The pipeline ensures duplicate dialogues are not generated and that the output meets quality thresholds for semantic similarity, grammar, fluency, diversity, and content preservation.
|
30 |
+
|
31 |
+
## References
|
32 |
+
|
33 |
+
Accsany, P. (2024). Working with JSON data in Python. Real Python. <https://realpython.com/python-json/>
|
34 |
+
Explosion AI Team. (n.d.). Spacy · industrial-strength natural language processing in python. <https://spacy.io/>
|
35 |
+
GeeksforGeeks. (2024). Text augmentation techniques in NLP. GeeksforGeeks. <https://www.geeksforgeeks.org/text-augmentation-techniques-in-nlp/>
|
36 |
+
Helsinki-NLP. (2024). Opus-MT [Computer software]. GitHub. <https://github.com/Helsinki-NLP/Opus-MT>
|
37 |
+
Hugging Face. (n.d.). Transformers. Hugging Face. <https://huggingface.co/docs/transformers/en/index>
|
38 |
+
Humarin. (2023). ChatGPT paraphraser on T5-base [Computer software]. Hugging Face. <https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base>
|
39 |
+
Keita, Z. (2022). Data augmentation in NLP using back-translation with MarianMT. Towards Data Science. <https://towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a>
|
40 |
+
Memgraph. (2023). Cosine similarity in Python with scikit-learn. Memgraph. <https://memgraph.com/blog/cosine-similarity-python-scikit-learn>
|
41 |
+
Morris, J. (n.d.). language-tool-python (Version 2.8.1) [Computer software]. PyPI. <https://pypi.org/project/language-tool-python/>
|
42 |
+
TensorFlow. (n.d.). Universal sentence encoder. TensorFlow Hub. <https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder>
|
43 |
+
Waheed, A. (2023). How to calculate ROUGE score in Python. Python Code. <https://thepythoncode.com/article/calculate-rouge-score-in-python>
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spacy>=3.0.0 # Text processing and tokenization
|
2 |
+
numpy>=1.19.0 # General numerical computation
|
3 |
+
tqdm>=4.64.0 # Progress bar
|
4 |
+
torch>=1.10.0 # PyTorch, for deep learning
|
5 |
+
tensorflow>=2.6.0 # TensorFlow, for deep learning
|
6 |
+
tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
|
7 |
+
transformers>=4.21.0 # Hugging Face Transformers library
|
8 |
+
rouge-score>=0.1.2 # ROUGE metric for evaluation
|
9 |
+
language-tool-python>=2.7.1 # Grammar checking and text correction
|
10 |
+
scikit-learn>=1.0.0 # Machine learning tools
|
11 |
+
nlpaug>=1.1.0 # Data augmentation for NLP
|
12 |
+
nltk>=3.6.0 # Natural language toolkit
|
schema_guided_dialogue_processor.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Dict, Optional, Any
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
from pipeline_config import PipelineConfig
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class SchemaGuidedDialogue:
|
10 |
+
"""
|
11 |
+
Structured representation of a Schema-Guided dialogue
|
12 |
+
"""
|
13 |
+
dialogue_id: str
|
14 |
+
service_name: str
|
15 |
+
service_description: Optional[str]
|
16 |
+
schema: Dict[str, Any]
|
17 |
+
turns: List[Dict[str, Any]]
|
18 |
+
original_metadata: Dict[str, Any] = field(default_factory=dict)
|
19 |
+
|
20 |
+
class SchemaGuidedProcessor:
|
21 |
+
"""
|
22 |
+
Handles processing and preparation of Schema-Guided dataset dialogues
|
23 |
+
"""
|
24 |
+
def __init__(self, config: PipelineConfig):
|
25 |
+
self.config = config
|
26 |
+
self.services = set()
|
27 |
+
self.domains = set()
|
28 |
+
self.schemas = {}
|
29 |
+
|
30 |
+
def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]:
|
31 |
+
"""
|
32 |
+
Load and parse Schema-Guided Dialogue dataset
|
33 |
+
|
34 |
+
Args:
|
35 |
+
dialogue_path: Path to the dialogue JSON file
|
36 |
+
schema_path: Path to the schema JSON file
|
37 |
+
"""
|
38 |
+
# Define schema and dialogue file patterns
|
39 |
+
schema_file = Path(base_dir, "schema.json")
|
40 |
+
dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json"))
|
41 |
+
|
42 |
+
# Check for schema file
|
43 |
+
if not schema_file.exists():
|
44 |
+
raise FileNotFoundError(f"Schema file not found at {schema_file}")
|
45 |
+
|
46 |
+
# Load schema
|
47 |
+
self.schemas = self._load_schemas(schema_file)
|
48 |
+
|
49 |
+
# Find and validate dialogue files
|
50 |
+
dialogue_files = glob.glob(dialogue_files_pattern)
|
51 |
+
if not dialogue_files:
|
52 |
+
raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}")
|
53 |
+
|
54 |
+
print(f"Found {len(dialogue_files)} dialogue files to process.")
|
55 |
+
|
56 |
+
# Process all dialogues
|
57 |
+
processed_dialogues = []
|
58 |
+
for file_path in dialogue_files:
|
59 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
60 |
+
raw_dialogues = json.load(f)
|
61 |
+
|
62 |
+
for dialogue in raw_dialogues:
|
63 |
+
processed_dialogues.append(self._process_single_dialogue(dialogue))
|
64 |
+
|
65 |
+
if max_examples and len(processed_dialogues) >= max_examples:
|
66 |
+
break
|
67 |
+
|
68 |
+
return processed_dialogues
|
69 |
+
|
70 |
+
def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue:
|
71 |
+
"""
|
72 |
+
Process a single dialogue JSON object into a SchemaGuidedDialogue object.
|
73 |
+
"""
|
74 |
+
dialogue_id = str(dialogue.get("dialogue_id", ""))
|
75 |
+
services = dialogue.get("services", [])
|
76 |
+
service_name = services[0] if services else None
|
77 |
+
schema = self.schemas.get(service_name, {})
|
78 |
+
service_description = schema.get("description", "")
|
79 |
+
|
80 |
+
# Process turns
|
81 |
+
turns = self._process_turns(dialogue.get("turns", []))
|
82 |
+
|
83 |
+
# Store metadata
|
84 |
+
metadata = {
|
85 |
+
"services": services,
|
86 |
+
"original_id": dialogue_id,
|
87 |
+
}
|
88 |
+
|
89 |
+
return SchemaGuidedDialogue(
|
90 |
+
dialogue_id=f"schema_guided_{dialogue_id}",
|
91 |
+
service_name=service_name,
|
92 |
+
service_description=service_description,
|
93 |
+
schema=schema,
|
94 |
+
turns=turns,
|
95 |
+
original_metadata=metadata,
|
96 |
+
)
|
97 |
+
|
98 |
+
def _validate_schema(self, schema: Dict[str, Any]) -> bool:
|
99 |
+
"""
|
100 |
+
Validate a schema
|
101 |
+
"""
|
102 |
+
required_keys = {"service_name", "description", "slots", "intents"}
|
103 |
+
missing_keys = required_keys - schema.keys()
|
104 |
+
if missing_keys:
|
105 |
+
print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}")
|
106 |
+
return False
|
107 |
+
return True
|
108 |
+
|
109 |
+
def _load_schemas(self, schema_path: str) -> Dict[str, Any]:
|
110 |
+
"""
|
111 |
+
Load and process service schemas
|
112 |
+
"""
|
113 |
+
with open(schema_path, 'r', encoding='utf-8') as f:
|
114 |
+
schemas = json.load(f)
|
115 |
+
|
116 |
+
# Validate and index schemas
|
117 |
+
return {
|
118 |
+
schema["service_name"]: schema for schema in schemas if self._validate_schema(schema)
|
119 |
+
}
|
120 |
+
|
121 |
+
def _process_turns(self, turns: List[Dict]) -> List[Dict]:
|
122 |
+
"""
|
123 |
+
Process dialogue turns into standardized format
|
124 |
+
"""
|
125 |
+
processed_turns = []
|
126 |
+
|
127 |
+
for turn in turns:
|
128 |
+
try:
|
129 |
+
# Map speakers to standard format
|
130 |
+
speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user'
|
131 |
+
|
132 |
+
# Extract utterance and clean it
|
133 |
+
text = turn.get('utterance', '').strip()
|
134 |
+
|
135 |
+
# Extract frames and dialogue acts
|
136 |
+
frames = turn.get('frames', [])
|
137 |
+
acts = []
|
138 |
+
slots = []
|
139 |
+
|
140 |
+
for frame in frames:
|
141 |
+
if 'actions' in frame:
|
142 |
+
acts.extend(frame['actions'])
|
143 |
+
if 'slots' in frame:
|
144 |
+
slots.extend(frame['slots'])
|
145 |
+
|
146 |
+
# Create the processed turn
|
147 |
+
processed_turn = {
|
148 |
+
'speaker': speaker,
|
149 |
+
'text': text,
|
150 |
+
'original_speaker': turn.get('speaker', ''),
|
151 |
+
'dialogue_acts': acts,
|
152 |
+
'slots': slots,
|
153 |
+
'metadata': {k: v for k, v in turn.items()
|
154 |
+
if k not in {'speaker', 'utterance', 'frames'}}
|
155 |
+
}
|
156 |
+
|
157 |
+
processed_turns.append(processed_turn)
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Error processing turn: {str(e)}")
|
160 |
+
continue
|
161 |
+
|
162 |
+
return processed_turns
|
163 |
+
|
164 |
+
def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]:
|
165 |
+
"""
|
166 |
+
Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline
|
167 |
+
"""
|
168 |
+
pipeline_dialogues = []
|
169 |
+
|
170 |
+
for dialogue in schema_dialogues:
|
171 |
+
# Convert turns to the expected format
|
172 |
+
processed_turns = [
|
173 |
+
{"speaker": turn["speaker"], "text": turn["text"]}
|
174 |
+
for turn in dialogue.turns if turn["text"].strip()
|
175 |
+
]
|
176 |
+
|
177 |
+
# Create dialogue in pipeline format
|
178 |
+
pipeline_dialogue = {
|
179 |
+
'dialogue_id': dialogue.dialogue_id,
|
180 |
+
'turns': processed_turns,
|
181 |
+
'metadata': {
|
182 |
+
'service_name': dialogue.service_name,
|
183 |
+
'service_description': dialogue.service_description,
|
184 |
+
'schema': dialogue.schema,
|
185 |
+
**dialogue.original_metadata
|
186 |
+
}
|
187 |
+
}
|
188 |
+
|
189 |
+
pipeline_dialogues.append(pipeline_dialogue)
|
190 |
+
|
191 |
+
return pipeline_dialogues
|
192 |
+
|
setup.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
|
5 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
6 |
+
long_description = fh.read()
|
7 |
+
|
8 |
+
with open("requirements.txt", "r", encoding="utf-8") as fh:
|
9 |
+
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
|
10 |
+
|
11 |
+
def setup_spacy_model():
|
12 |
+
"""
|
13 |
+
Download spaCy model.
|
14 |
+
"""
|
15 |
+
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
|
16 |
+
|
17 |
+
def setup_models():
|
18 |
+
"""
|
19 |
+
Download other required models.
|
20 |
+
"""
|
21 |
+
import tensorflow_hub as hub
|
22 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
23 |
+
from transformers import (
|
24 |
+
AutoTokenizer,
|
25 |
+
GPT2TokenizerFast,
|
26 |
+
MarianTokenizer
|
27 |
+
)
|
28 |
+
|
29 |
+
# Download Universal Sentence Encoder
|
30 |
+
_ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
31 |
+
|
32 |
+
# Download paraphraser model
|
33 |
+
_ = AutoTokenizer.from_pretrained('humarin/chatgpt_paraphraser_on_T5_base')
|
34 |
+
|
35 |
+
# Download translation models
|
36 |
+
source_lang, pivot_lang, target_lang = 'en', 'de', 'es'
|
37 |
+
model_names = [
|
38 |
+
f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}',
|
39 |
+
f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}',
|
40 |
+
f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
|
41 |
+
]
|
42 |
+
for model_name in model_names:
|
43 |
+
_ = MarianTokenizer.from_pretrained(model_name)
|
44 |
+
|
45 |
+
# Download GPT-2
|
46 |
+
_ = GPT2TokenizerFast.from_pretrained('gpt2')
|
47 |
+
|
48 |
+
def setup_nltk():
|
49 |
+
"""
|
50 |
+
Download required NLTK data.
|
51 |
+
"""
|
52 |
+
import nltk
|
53 |
+
required_packages = [
|
54 |
+
'wordnet',
|
55 |
+
'averaged_perceptron_tagger_eng'
|
56 |
+
]
|
57 |
+
|
58 |
+
for package in required_packages:
|
59 |
+
try:
|
60 |
+
print(f"Downloading {package}...")
|
61 |
+
nltk.download(package)
|
62 |
+
print(f"Successfully downloaded {package}")
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Warning: Could not download {package}: {str(e)}")
|
65 |
+
|
66 |
+
setup(
|
67 |
+
name="text-data-augmenter",
|
68 |
+
version="0.1.0",
|
69 |
+
author="Joe Armani",
|
70 |
+
author_email="[email protected]",
|
71 |
+
description="A tool for generating high-quality dialogue variations",
|
72 |
+
packages=find_packages(),
|
73 |
+
classifiers=[
|
74 |
+
"Development Status :: 3 - Alpha",
|
75 |
+
"Intended Audience :: Science/Research",
|
76 |
+
"License :: OSI Approved :: MIT License",
|
77 |
+
"Operating System :: OS Independent",
|
78 |
+
"Programming Language :: Python :: 3",
|
79 |
+
"Programming Language :: Python :: 3.8",
|
80 |
+
"Programming Language :: Python :: 3.9",
|
81 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
82 |
+
"Topic :: Text Processing :: Linguistic",
|
83 |
+
],
|
84 |
+
python_requires=">=3.8",
|
85 |
+
install_requires=requirements,
|
86 |
+
entry_points={
|
87 |
+
"console_scripts": [
|
88 |
+
"dialogue-augment=dialogue_augmenter.main:main",
|
89 |
+
],
|
90 |
+
},
|
91 |
+
include_package_data=True,
|
92 |
+
package_data={
|
93 |
+
"dialogue_augmenter": ["data/*.json", "config/*.yaml"],
|
94 |
+
},
|
95 |
+
)
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
setup_spacy_model()
|
99 |
+
setup_models()
|
100 |
+
setup_nltk()
|
taskmaster_processor.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Dict, Optional, Any
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
from pathlib import Path
|
6 |
+
from pipeline_config import PipelineConfig
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class TaskmasterDialogue:
|
10 |
+
"""
|
11 |
+
Structured representation of a Taskmaster dialogue
|
12 |
+
"""
|
13 |
+
conversation_id: str
|
14 |
+
instruction_id: Optional[str]
|
15 |
+
scenario: Optional[str]
|
16 |
+
domain: Optional[str]
|
17 |
+
turns: List[Dict[str, Any]]
|
18 |
+
original_metadata: Dict[str, Any] = field(default_factory=dict)
|
19 |
+
|
20 |
+
def __str__(self):
|
21 |
+
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
|
22 |
+
|
23 |
+
def validate(self) -> bool:
|
24 |
+
return bool(self.conversation_id and isinstance(self.turns, list))
|
25 |
+
|
26 |
+
class TaskmasterProcessor:
|
27 |
+
"""
|
28 |
+
Handles processing and preparation of Taskmaster dataset dialogues
|
29 |
+
"""
|
30 |
+
config: PipelineConfig
|
31 |
+
use_ontology: bool = False # Whether to load and use ontology
|
32 |
+
ontology: Optional[Dict[str, Any]] = None # Holds ontology data if loaded
|
33 |
+
domains: set = field(default_factory=set) # Tracks unique domains
|
34 |
+
scenarios: set = field(default_factory=set) # Tracks unique scenarios
|
35 |
+
|
36 |
+
def __init__(self, config: PipelineConfig, use_ontology: bool = False):
|
37 |
+
self.config = config
|
38 |
+
self.use_ontology = use_ontology
|
39 |
+
self.ontology = None
|
40 |
+
self.domains = set()
|
41 |
+
self.scenarios = set()
|
42 |
+
|
43 |
+
def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
|
44 |
+
"""
|
45 |
+
Load and parse Taskmaster JSON dataset.
|
46 |
+
Handles self-dialogs, woz-dialogs, and ontology files.
|
47 |
+
"""
|
48 |
+
required_files = {
|
49 |
+
"self-dialogs": "self-dialogs.json",
|
50 |
+
"woz-dialogs": "woz-dialogs.json",
|
51 |
+
"ontology": "ontology.json",
|
52 |
+
}
|
53 |
+
|
54 |
+
# Check for required files
|
55 |
+
missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()]
|
56 |
+
if missing_files:
|
57 |
+
raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}")
|
58 |
+
|
59 |
+
# load ontology
|
60 |
+
ontology_path = Path(base_dir, required_files['ontology'])
|
61 |
+
with open(ontology_path, 'r', encoding='utf-8') as f:
|
62 |
+
self.ontology = json.load(f)
|
63 |
+
|
64 |
+
processed_dialogues = []
|
65 |
+
for file_key in ["self-dialogs", "woz-dialogs"]:
|
66 |
+
file_path = Path(base_dir, required_files[file_key])
|
67 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
68 |
+
raw_data = json.load(f)
|
69 |
+
|
70 |
+
for dialogue in raw_data:
|
71 |
+
# Extract core dialogue components
|
72 |
+
conversation_id = dialogue.get('conversation_id', '')
|
73 |
+
instruction_id = dialogue.get('instruction_id', None)
|
74 |
+
|
75 |
+
if 'utterances' in dialogue:
|
76 |
+
turns = self._process_utterances(dialogue['utterances'])
|
77 |
+
scenario = dialogue.get('scenario', '')
|
78 |
+
domain = self._extract_domain(scenario)
|
79 |
+
else:
|
80 |
+
turns = []
|
81 |
+
scenario = ''
|
82 |
+
domain = ''
|
83 |
+
|
84 |
+
# Store metadata
|
85 |
+
metadata = {k: v for k, v in dialogue.items()
|
86 |
+
if k not in {'conversation_id', 'instruction_id', 'utterances'}}
|
87 |
+
|
88 |
+
# Create structured dialogue object
|
89 |
+
processed_dialogue = TaskmasterDialogue(
|
90 |
+
conversation_id=conversation_id,
|
91 |
+
instruction_id=instruction_id,
|
92 |
+
scenario=scenario,
|
93 |
+
domain=domain,
|
94 |
+
turns=turns,
|
95 |
+
original_metadata=metadata
|
96 |
+
)
|
97 |
+
|
98 |
+
processed_dialogues.append(processed_dialogue)
|
99 |
+
|
100 |
+
# Update domain and scenario tracking
|
101 |
+
if domain:
|
102 |
+
self.domains.add(domain)
|
103 |
+
if scenario:
|
104 |
+
self.scenarios.add(scenario)
|
105 |
+
|
106 |
+
if max_examples and len(processed_dialogues) >= max_examples:
|
107 |
+
break
|
108 |
+
|
109 |
+
return processed_dialogues
|
110 |
+
|
111 |
+
def _process_utterances(self, utterances: List[Dict]) -> List[Dict]:
|
112 |
+
"""
|
113 |
+
Process utterances into a standardized format
|
114 |
+
"""
|
115 |
+
processed_turns = []
|
116 |
+
|
117 |
+
for utterance in utterances:
|
118 |
+
# Map Taskmaster speaker roles to your expected format
|
119 |
+
speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user'
|
120 |
+
|
121 |
+
# Extract and clean the text
|
122 |
+
text = utterance.get('text', '').strip()
|
123 |
+
|
124 |
+
# Extract any segments or annotations if present
|
125 |
+
segments = utterance.get('segments', [])
|
126 |
+
|
127 |
+
# Create the processed turn
|
128 |
+
turn = {
|
129 |
+
'speaker': speaker,
|
130 |
+
'text': text,
|
131 |
+
'original_speaker': utterance.get('speaker', ''),
|
132 |
+
'segments': segments,
|
133 |
+
'metadata': {k: v for k, v in utterance.items()
|
134 |
+
if k not in {'speaker', 'text', 'segments'}}
|
135 |
+
}
|
136 |
+
|
137 |
+
processed_turns.append(turn)
|
138 |
+
|
139 |
+
return processed_turns
|
140 |
+
|
141 |
+
def _extract_domain(self, scenario: str) -> str:
|
142 |
+
"""
|
143 |
+
Extract domain from scenario description
|
144 |
+
"""
|
145 |
+
domain_patterns = {
|
146 |
+
'restaurant': r'\b(restaurant|dining|food|reservation)\b',
|
147 |
+
'movie': r'\b(movie|cinema|film|ticket)\b',
|
148 |
+
'ride_share': r'\b(ride|taxi|uber|lyft)\b',
|
149 |
+
'coffee': r'\b(coffee|café|cafe|starbucks)\b',
|
150 |
+
'pizza': r'\b(pizza|delivery|order food)\b',
|
151 |
+
'auto': r'\b(car|vehicle|repair|maintenance)\b',
|
152 |
+
}
|
153 |
+
|
154 |
+
scenario_lower = scenario.lower()
|
155 |
+
|
156 |
+
for domain, pattern in domain_patterns.items():
|
157 |
+
if re.search(pattern, scenario_lower):
|
158 |
+
return domain
|
159 |
+
|
160 |
+
return 'other'
|
161 |
+
|
162 |
+
def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]:
|
163 |
+
"""
|
164 |
+
Convert TaskmasterDialogues to the format expected by the ProcessingPipeline
|
165 |
+
"""
|
166 |
+
pipeline_dialogues = []
|
167 |
+
|
168 |
+
for dialogue in taskmaster_dialogues:
|
169 |
+
# Convert turns to the expected format
|
170 |
+
processed_turns = []
|
171 |
+
for turn in dialogue.turns:
|
172 |
+
if turn['text'].strip(): # Skip empty turns
|
173 |
+
processed_turns.append({
|
174 |
+
'speaker': turn['speaker'],
|
175 |
+
'text': turn['text']
|
176 |
+
})
|
177 |
+
|
178 |
+
# Create dialogue in pipeline format
|
179 |
+
pipeline_dialogue = {
|
180 |
+
'dialogue_id': dialogue.conversation_id,
|
181 |
+
'turns': processed_turns,
|
182 |
+
'metadata': {
|
183 |
+
'instruction_id': dialogue.instruction_id,
|
184 |
+
'scenario': dialogue.scenario,
|
185 |
+
'domain': dialogue.domain,
|
186 |
+
**dialogue.original_metadata
|
187 |
+
}
|
188 |
+
}
|
189 |
+
|
190 |
+
pipeline_dialogues.append(pipeline_dialogue)
|
191 |
+
|
192 |
+
return pipeline_dialogues
|