File size: 30,050 Bytes
3190e1e
 
bc503de
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc503de
 
 
 
 
 
 
 
 
300fe5d
bc503de
3190e1e
300fe5d
 
 
3190e1e
 
bc503de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
300fe5d
bc503de
300fe5d
bc503de
 
3190e1e
 
 
 
300fe5d
 
 
3190e1e
 
 
 
 
 
 
 
bc503de
 
 
 
 
3190e1e
 
bc503de
 
 
 
 
 
 
 
 
 
 
3190e1e
 
 
300fe5d
3190e1e
 
 
bc503de
3190e1e
 
300fe5d
 
3190e1e
300fe5d
3190e1e
 
 
 
300fe5d
3190e1e
 
 
300fe5d
 
bc503de
3190e1e
 
300fe5d
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
 
 
 
 
3190e1e
 
 
 
 
300fe5d
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
 
 
300fe5d
3190e1e
 
 
300fe5d
 
3190e1e
300fe5d
3190e1e
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
 
 
 
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
 
 
300fe5d
3190e1e
 
 
 
 
300fe5d
3190e1e
 
 
 
 
 
 
 
300fe5d
3190e1e
 
 
 
300fe5d
 
3190e1e
 
300fe5d
3190e1e
 
 
 
300fe5d
 
3190e1e
 
300fe5d
3190e1e
 
300fe5d
 
3190e1e
 
 
300fe5d
3190e1e
 
 
 
300fe5d
 
 
 
 
 
 
 
3190e1e
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
 
300fe5d
3190e1e
 
 
 
300fe5d
3190e1e
 
 
300fe5d
3190e1e
 
 
300fe5d
3190e1e
 
 
 
 
 
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
 
300fe5d
3190e1e
300fe5d
3190e1e
 
 
 
300fe5d
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
 
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
 
300fe5d
 
3190e1e
 
 
300fe5d
3190e1e
 
 
300fe5d
3190e1e
 
 
300fe5d
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
 
300fe5d
3190e1e
300fe5d
3190e1e
300fe5d
3190e1e
 
 
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
 
 
 
 
 
 
 
 
 
 
 
300fe5d
 
 
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
3190e1e
300fe5d
3190e1e
 
 
 
 
 
300fe5d
 
3190e1e
300fe5d
 
 
3190e1e
 
 
300fe5d
 
3190e1e
300fe5d
3190e1e
 
300fe5d
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
3190e1e
300fe5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
from typing import Dict, List
import numpy as np
import torch
import tensorflow as tf
import tensorflow_hub as hub
from pipeline_config import PipelineConfig
from quality_metrics import QualityMetrics
from paraphraser import Paraphraser
import nlpaug.augmenter.word as naw
from functools import lru_cache
from sklearn.metrics.pairwise import cosine_similarity

class DialogueAugmenter:
    """
    Optimized dialogue augmentation with quality control and complexity management.
    """
    def __init__(self, nlp, config: PipelineConfig):
        self.nlp = nlp
        self.config = config
        
        # Detect hardware and set appropriate batch sizes and optimization strategy
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.use_gpu = torch.cuda.is_available()
        
        if self.config.debug:
            print(f"Using device: {self.device}")
            if self.use_gpu:
                print(f"GPU Device: {torch.cuda.get_device_name(0)}")
                
        
        self.quality_metrics = QualityMetrics(config)
        self.semantic_similarity_threshold = 0.75
        
        # Load model
        self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
        
        # Initialize augmentation models based on hardware
        self._initialize_augmentation_models()
        
        # Initialize caches
        self.embedding_cache = {}
        
        # GPU memory management if available
        if self.use_gpu:
            gpus = tf.config.list_physical_devices('GPU')
            if gpus:
                try:
                    for gpu in gpus:
                        tf.config.experimental.set_memory_growth(gpu, True)
                except RuntimeError as e:
                    print(e)

    def _initialize_augmentation_models(self):
        """Initialize augmentation models with appropriate device settings"""
        # Advanced augmentation techniques
        self.paraphraser = Paraphraser()            
        if self.use_gpu:
            # Move model to GPU if available
            self.paraphraser.model = self.paraphraser.model.to(self.device)
        
        # Basic augmentation techniques
        self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
        
        self.augmenters = {
            'advanced': [
                self.paraphraser,
            ],
            'basic': [
                ('synonym', self.word_augmenter),
            ]
        }

    @lru_cache(maxsize=1024)
    def _compute_embedding(self, text: str) -> np.ndarray:
        """Cached computation of text embedding"""
        if text in self.embedding_cache:
            return self.embedding_cache[text]
        embedding = self.use_model([text])[0].numpy()
        self.embedding_cache[text] = embedding
        return embedding

    def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
        """Compute embeddings for multiple texts at once with hardware optimization"""
        # Check cache first
        uncached_texts = [t for t in texts if t not in self.embedding_cache]
        if uncached_texts:
            embeddings = self.use_model(uncached_texts).numpy()
            # Update cache
            for text, embedding in zip(uncached_texts, embeddings):
                self.embedding_cache[text] = embedding
    
        # Return all embeddings (from cache or newly computed)
        return np.array([self.embedding_cache[t] for t in texts])

    def _quick_quality_check(self, variation: str, original: str) -> bool:
        """
        Preliminary quality check while maintaining reasonable pass rates
        """
        if self.config.debug:
            print(f"\nQuick check for variation: {variation}")
                
        orig_len = len(original.split())
        var_len = len(variation.split())

        # For very short texts (<= 3 words), still allow more variation
        if orig_len <= 3:
            if var_len > orig_len * 3:
                if self.config.debug:
                    print(f"Failed length check (short text): {var_len} vs {orig_len}")
                return False
        else:
            if var_len > orig_len * 2:
                if self.config.debug:
                    print(f"Failed length check (long text): {var_len} vs {orig_len}")
                return False

        # Adjust content overlap check based on length
        stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
        orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
        var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)

        # If very short turn (less than 5 words), skip the content overlap check
        if orig_len >= 5:
            content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
            if content_overlap < 0.2:
                if self.config.debug:
                    print(f"Failed content check: overlap {content_overlap:.2f}")
                return False
        else:
            if self.config.debug:
                print("Short turn detected (<5 words), skipping content overlap check")

        if self.config.debug:
            print("Passed all quick checks")
        return True

    def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
        """
        Filter variations using batched computations with detailed logging
        """
        if not variations:
            return []
        
        if self.config.debug:
            print(f"\nStarting filtration of {len(variations)} variations")
            print(f"Context length: {len(context)}")
            print(f"Original turn: {original_turn}")
        
        words = original_turn.split()
        orig_len = len(words)

        # If very short text, consider adjusting thresholds
        is_very_short = orig_len < 5

        if len(words) < 3:
            if self.config.debug:
                print("Short text detected, using predefined variations")
            short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
            return [var['text'] for var in short_text_variations]

        # If this is the first turn (no context), be more lenient
        if not context:
            preliminary_filtered = variations
            if self.config.debug:
                print("First turn - skipping preliminary filtering")
        else:
            # Quick preliminary filtering against original turn
            preliminary_filtered = []
            for var in variations:
                passed = self._quick_quality_check(var, original_turn)
                if self.config.debug:
                    print(f"\nVariation: {var}")
                    print(f"Passed quick check: {passed}")
                if passed:
                    preliminary_filtered.append(var)

        if self.config.debug:
            print(f"Variations after quick check: {len(preliminary_filtered)}")

        if not preliminary_filtered:
            return []

        # Compute embeddings for original and variations
        original_embedding = self._compute_embedding(original_turn)
        variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)

        # Compute similarities
        sims = cosine_similarity([original_embedding], variation_embeddings)[0]

        # If very short turn, slightly lower the semantic similarity threshold
        dynamic_sem_threshold = self.semantic_similarity_threshold
        if is_very_short:
            dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)

        # Filter by semantic similarity threshold
        refined_filtered = []
        for var, sim in zip(preliminary_filtered, sims):
            if sim >= dynamic_sem_threshold:
                refined_filtered.append(var)
            else:
                if self.config.debug:
                    print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")

        if not refined_filtered:
            return []

        # Relax context coherence thresholds further if desired
        # We already have min_similarity = 0.1, min_coherence = 0.05
        # Let's lower them slightly more if the turn is very short:
        if is_very_short:
            min_similarity = 0.05
            min_coherence = 0.02
        else:
            min_similarity = 0.1
            min_coherence = 0.05

        # Only use last turn for coherence
        recent_context = [context[-1]] if context else []
        context_text = ' '.join(recent_context) if recent_context else ''

        if context_text:
            if self.config.debug:
                print(f"\nContext text: {context_text}")

            all_texts = [context_text] + refined_filtered
            all_embeddings = self._compute_batch_embeddings(all_texts)

            context_embedding = all_embeddings[0]
            variation_embeddings = all_embeddings[1:]

            # Vectorized similarity computation
            context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]

            # Response coherence check
            if recent_context:
                prev_embedding = self._compute_embedding(recent_context[-1])
                response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
            else:
                response_coherence = np.ones_like(context_similarities)

            filtered_variations = []
            for i, (variation, sim, coh) in enumerate(zip(
                refined_filtered, context_similarities, response_coherence)):
                combined_score = (
                    self.config.context_similarity_weight * abs(sim) +
                    self.config.response_coherence_weight * abs(coh)
                )

                if self.config.debug:
                    print(f"\nVariation: {variation}")
                    print(f"Context similarity: {sim:.3f}")
                    print(f"Response coherence: {coh:.3f}")
                    print(f"Combined score: {combined_score:.3f}")

                # Accept if EITHER score is good enough
                if (combined_score >= min_similarity or abs(coh) >= min_coherence):
                    filtered_variations.append(variation)
                    if self.config.debug:
                        print("ACCEPTED")
                else:
                    if self.config.debug:
                        print("REJECTED")

                # If we have enough variations, stop
                if len(filtered_variations) >= self.config.max_variations_per_turn:
                    break
        else:
            filtered_variations = refined_filtered[:self.config.max_variations_per_turn]

        if self.config.debug:
            print(f"\nFinal filtered variations: {len(filtered_variations)}")

        return filtered_variations

    def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
        """
        Generate variations progressively until we have enough good ones.
        Adjust paraphraser parameters for closer paraphrases as needed.
        """
        variations = set()

        if self.config.debug:
            print(f"\nAttempting to generate {needed} variations for text: {text}")

        # Fine-tune paraphraser here if needed: fewer beams, less diversity already done
        for augmenter in self.augmenters['advanced']:
            if len(variations) >= needed:
                break

            try:
                if isinstance(augmenter, Paraphraser):
                    if self.config.debug:
                        print("Trying paraphrase augmentation...")
                    new_vars = augmenter.paraphrase(
                        text,
                        num_return_sequences=needed-len(variations),
                        device=self.device if self.use_gpu else None,
                        num_beams=4,          # even fewer beams for more faithful paraphrases
                        num_beam_groups=1,
                        diversity_penalty=0.0
                    )
                    if self.config.debug:
                        print(f"Paraphraser generated {len(new_vars)} variations")

                valid_vars = [v for v in new_vars if v.strip() and v != text]
                variations.update(valid_vars)

                if self.config.debug:
                    print(f"Current unique variations: {len(variations)}")

            except Exception as e:
                print(f"Error in advanced augmentation: {str(e)}")
                continue

        # Try basic augmenters if needed
        if len(variations) < needed:
            if self.config.debug:
                print("Not enough variations, trying basic augmenters...")

            for aug_type, augmenter in self.augmenters['basic']:
                if len(variations) >= needed:
                    break

                try:
                    if self.config.debug:
                        print(f"Trying {aug_type} augmentation...")

                    new_vars = augmenter.augment(text, n=2)
                    if isinstance(new_vars, list):
                        valid_vars = [v for v in new_vars if v.strip() and v != text]
                        variations.update(valid_vars)
                    else:
                        if new_vars.strip() and new_vars != text:
                            variations.add(new_vars)

                    if self.config.debug:
                        print(f"After {aug_type}, total variations: {len(variations)}")

                except Exception as e:
                    print(f"Error in {aug_type} augmentation: {str(e)}")
                    continue

        variations_list = list(variations)

        if self.config.debug:
            print(f"Final number of variations generated: {len(variations_list)}")
            if not variations_list:
                print("WARNING: No variations were generated!")

        return variations_list

    def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
        """
        Create augmented versions of the dialogue with optimized processing
        """
        # Early dialogue length check
        original_length = len(dialogue['turns'])
        if original_length > self.config.max_turns_per_dialogue:
            if self.config.debug:
                print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns")
            dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue]
        
        turn_variations = []
        context = []
        
        # Process each turn with progressive generation
        for turn in dialogue['turns']:
            original_text = turn['text']  # Store original turn text
            variations = self._generate_variations_progressive(
                original_text, 
                self.config.max_variations_per_turn
            )
            
            # Batch filter variations with original text
            filtered_variations = self._filter_variations_batch(
                variations, 
                context,
                original_text  # Pass the original turn text
            )
            
            # Create turn variations with speaker info
            turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations]
            
            if self.config.debug:
                print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations")
            
            turn_variations.append(turn_vars)
            context.append(original_text)
        
        # Generate combinations with sampling
        augmented_dialogues = self._generate_dialogue_combinations(
            dialogue['dialogue_id'],
            turn_variations,
            dialogue
        )
        
        # Add original dialogue
        result = [{
            'dialogue_id': f"{dialogue['dialogue_id']}_original",
            'turns': dialogue['turns']
        }]
        
        # Add unique augmentations
        result.extend(augmented_dialogues[:self.config.augmentation_factor])
        
        if self.config.debug:
            print(f"Generated {len(result)-1} unique augmented dialogues")
        
        return result

    def _variation_score(self, original: str, variation: str) -> float:
        """
        Compute a single numeric score for a variation to guide selection.
        You could use semantic similarity, content preservation, etc.
        Higher is better.
        """
        metrics = self.quality_metrics.compute_metrics(original, variation)
        # Example: Primarily semantic similarity, with a slight boost for content preservation
        # Adjust as needed.
        score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
        return score

    def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
        """
        Compute a quality score for the entire augmented dialogue.
        For example, average semantic similarity of turns to the original turns.
        This is done after the dialogue is formed.
        """
        original_texts = [t['text'] for t in original_dialogue['turns']]
        aug_texts = [t['text'] for t in dialogue['turns']]
        
        # Compute semantic similarity turn-by-turn and average it
        scores = []
        for orig, aug in zip(original_texts, aug_texts):
            # Simple semantic similarity for scoring
            emb_orig = self._compute_embedding(orig)
            emb_aug = self._compute_embedding(aug)
            sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
            scores.append(sim)
        
        # Could also incorporate diversity checks, content overlap, etc.
        return float(np.mean(scores)) if scores else 0.0

    def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
        """
        Generate dialogue combinations using a more controlled approach:
        - Include the original turn as a fallback variation for each turn.
        - Sort variations by a quality score.
        - Ensure a balanced augmentation by requiring at least some turns to be augmented.
        - Over-generate and then select top dialogues by quality.
        """
        # Over-generate factor: create more candidates than needed
        over_generate_factor = self.config.augmentation_factor * 2

        # Add the original turn as a fallback variation for each turn if not present
        for i, turn_variants in enumerate(turn_variations):
            original_turn_text = None
            # Check if we previously stored original turn text with a marker or just use the original dialogue
            # If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
            original_turn_text = original_dialogue['turns'][i]['text']

            # Add the original turn as a variation if not already included
            if not any(v['text'] == original_turn_text for v in turn_variants):
                turn_variants.append({
                    'speaker': original_dialogue['turns'][i]['speaker'],
                    'text': original_turn_text
                })

            # Sort variations by score
            original_text = original_dialogue['turns'][i]['text']
            turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)

        augmented_dialogues = []
        used_combinations = set()

        def generate_candidates(current_turns=None, turn_index=0):
            if current_turns is None:
                current_turns = []
            
            if len(augmented_dialogues) >= over_generate_factor:
                return

            if turn_index == len(turn_variations):
                # Completed a candidate dialogue
                dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
                if dialogue_fingerprint not in used_combinations:
                    used_combinations.add(dialogue_fingerprint)
                    # Check if we have enough augmented turns
                    aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns) 
                                    if orig['text'] != curr['text'])
                    # Require at least half the turns to be augmented, for example
                    if aug_count >= max(1, len(turn_variations)//2):
                        augmented_dialogues.append({
                            'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
                            'turns': current_turns.copy()
                        })
                return

            turn_candidates = turn_variations[turn_index]

            # If no variations are available for this turn, let's just return without error.
            # Normally, this shouldn't happen since we always add the original turn above.
            if not turn_candidates:
                # If you want to at least have the original turn, add it now:
                original_text = original_dialogue['turns'][turn_index]['text']
                turn_candidates.append({
                    'speaker': original_dialogue['turns'][turn_index]['speaker'],
                    'text': original_text
                })

            # After the fallback, if still empty for some reason, just return.
            if not turn_candidates:
                return

            # Example strategy:
            # 1. Always try the top variation (most semantically similar).
            # 2. If available and allowed, pick a mid-ranked variation for diversity.
            # 3. Include the original turn if not selected yet.

            num_vars = min(self.config.max_sampled_variations, len(turn_candidates))

            # Always include top variation
            candidates_to_pick = [turn_candidates[0]]

            # If we have more than 2 variations and can pick more, add a middle variation for diversity
            if len(turn_candidates) > 2 and num_vars > 1:
                mid_index = len(turn_candidates)//2
                candidates_to_pick.append(turn_candidates[mid_index])

            # If we still have room for another variation, try adding the original turn if not included
            if num_vars > len(candidates_to_pick):
                original_turn_text = original_dialogue['turns'][turn_index]['text']
                orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
                if orig_candidate and orig_candidate not in candidates_to_pick:
                    candidates_to_pick.append(orig_candidate)

            # Shuffle candidates to produce different dialogues
            np.random.shuffle(candidates_to_pick)

            for variation in candidates_to_pick:
                if len(augmented_dialogues) >= over_generate_factor:
                    return
                current_turns.append(variation)
                generate_candidates(current_turns, turn_index + 1)
                current_turns.pop()

        try:
            generate_candidates()
        except Exception as e:
            print(f"Error in dialogue generation: {str(e)}")
            return []

        # Over-generated set of augmented dialogues is now available
        # Let's score them and pick the top ones
        scored_dialogues = []
        for d in augmented_dialogues:
            score = self._dialogue_quality_score(d, original_dialogue)
            scored_dialogues.append((score, d))
        
        scored_dialogues.sort(key=lambda x: x[0], reverse=True)
        # Pick top `augmentation_factor` dialogues
        final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
        
        return final_dialogues
    # def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
    #     """
    #     Generate dialogue combinations using sampling
    #     """
    #     augmented_dialogues = []
    #     used_combinations = set()
        
    #     def generate_dialogues(current_turns=None, turn_index=0):
    #         if current_turns is None:
    #             current_turns = []
            
    #         if len(augmented_dialogues) >= self.config.augmentation_factor:
    #             return

    #         if turn_index == len(turn_variations):
    #             dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
    #             if dialogue_fingerprint not in used_combinations:
    #                 used_combinations.add(dialogue_fingerprint)
    #                 augmented_dialogues.append({
    #                     'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
    #                     'turns': current_turns.copy()
    #                 })
    #             return
            
    #         variations = list(turn_variations[turn_index])
    #         np.random.shuffle(variations)
            
    #         for variation in variations[:self.config.max_sampled_variations]:
    #             if len(augmented_dialogues) >= self.config.augmentation_factor:
    #                 return
    #             current_turns.append(variation)
    #             generate_dialogues(current_turns, turn_index + 1)
    #             current_turns.pop()
        
    #     try:
    #         generate_dialogues()
    #     except Exception as e:
    #         print(f"Error in dialogue generation: {str(e)}")
    #         return []
        
    #     return augmented_dialogues

    def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
        """
        Check if two dialogues are duplicates.
        """
        text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
        text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
        return text1 == text2
    
    def _augment_short_text(self, turn: Dict) -> List[Dict]:
        """
        Special handling for very short texts with predefined variations.
        If predefined variations are found, return them directly.
        Otherwise, produce simple punctuation and capitalization variants.
        Skip heavy quality checks for efficiency. These variations are safe and minimal.
        """
        text = turn['text']
        common_variations = {
            'goodbye': [
                'Bye!', 'Farewell!', 'See you!', 'Take care!',
                'Goodbye!', 'Bye for now!', 'Until next time!'
            ],
            'hello': [
                'Hi!', 'Hey!', 'Hello!', 'Greetings!', 
                'Good day!', 'Hi there!', 'Hello there!'
            ],
            'yes': [
                'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', 
                'That\'s right!', 'Definitely!', 'Sure!'
            ],
            'no': [
                'No!', 'Nope!', 'Not at all!', 'Negative!',
                'Unfortunately not!', 'I\'m afraid not!'
            ],
            'thanks': [
                'Thank you!', 'Thanks a lot!', 'Many thanks!',
                'I appreciate it!', 'Thank you so much!'
            ],
            'ok': [
                'Okay!', 'Alright!', 'Sure!', 'Got it!',
                'Understood!', 'Fine!', 'Great!', 'Perfect!',
                'That works!', 'Sounds good!'
            ],
            'good': [
                'Great!', 'Excellent!', 'Perfect!', 'Wonderful!',
                'Fantastic!', 'Amazing!', 'Terrific!'
            ]
        }

        text_lower = text.lower().rstrip('!.,?')
        # Check if text matches any predefined category
        variations = []
        for key, predefined_vars in common_variations.items():
            if key in text_lower or text_lower in key:
                variations.extend(predefined_vars)
        
        if not variations:
            # Generate simple punctuation and capitalization variations if no predefined match
            base = text.rstrip('!.,?')
            variations = [
                base + '!',
                base + '.',
                base
            ]
            
            # Add capitalization variations
            capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
            variations.extend(capitalized)
        
        # Ensure uniqueness
        unique_variations = list(set(variations))
        
        # Directly return these variations, as they are minimal and trusted
        # No further quality checks are needed
        result_variations = unique_variations[:self.config.augmentation_factor]
        return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
    
    def process_batch(self, batch: List[Dict]) -> List[Dict]:
        """Process multiple dialogues at once to maximize GPU utilization"""
        results = []
        
        # Pre-compute embeddings for all texts in batch
        all_texts = []
        text_to_embedding = {}
        
        for dialogue in batch:
            for turn in dialogue['turns']:
                all_texts.append(turn['text'])
        
        # Batch compute embeddings
        if all_texts:
            embeddings = self._compute_batch_embeddings(all_texts)
            for text, embedding in zip(all_texts, embeddings):
                self.embedding_cache[text] = embedding
        
        # Process each dialogue using cached embeddings
        for dialogue in batch:
            try:
                augmented = self.augment_dialogue(dialogue)
                results.extend(augmented)
            except Exception as e:
                print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
                continue
        
        return results