JoeArmani
commited on
Commit
·
bc503de
1
Parent(s):
febdb1e
update gpu processing
Browse files- .gitignore +3 -0
- dialogue_augmenter.py +69 -232
- main.py +7 -19
- processing_pipeline.py +143 -40
.gitignore
CHANGED
@@ -156,3 +156,6 @@ cython_debug/
|
|
156 |
|
157 |
datasets/*
|
158 |
!datasets/.gitkeep
|
|
|
|
|
|
|
|
156 |
|
157 |
datasets/*
|
158 |
!datasets/.gitkeep
|
159 |
+
|
160 |
+
processed_outputs/*
|
161 |
+
!processed_outputs/.gitkeep
|
dialogue_augmenter.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Dict, List
|
2 |
import numpy as np
|
|
|
3 |
import tensorflow as tf
|
4 |
import tensorflow_hub as hub
|
5 |
import re
|
@@ -19,13 +20,53 @@ class DialogueAugmenter:
|
|
19 |
def __init__(self, nlp, config: PipelineConfig):
|
20 |
self.nlp = nlp
|
21 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
self.quality_metrics = QualityMetrics(config)
|
23 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# Advanced augmentation techniques
|
26 |
self.paraphraser = Paraphraser()
|
27 |
self.back_translator = BackTranslator()
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Basic augmentation techniques
|
30 |
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
31 |
self.spelling_augmenter = naw.SpellingAug()
|
@@ -37,63 +78,62 @@ class DialogueAugmenter:
|
|
37 |
('spelling', self.spelling_augmenter)
|
38 |
]
|
39 |
}
|
40 |
-
|
41 |
-
# Initialize cache
|
42 |
-
self.embedding_cache = {}
|
43 |
-
self.perplexity_cache = {}
|
44 |
-
|
45 |
-
# Compile regex patterns
|
46 |
-
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
|
47 |
-
|
48 |
-
# GPU memory management
|
49 |
-
gpus = tf.config.list_physical_devices('GPU')
|
50 |
-
if gpus:
|
51 |
-
try:
|
52 |
-
for gpu in gpus:
|
53 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
54 |
-
except RuntimeError as e:
|
55 |
-
print(e)
|
56 |
|
57 |
@lru_cache(maxsize=1024)
|
58 |
def _compute_embedding(self, text: str) -> np.ndarray:
|
59 |
"""Cached computation of text embedding"""
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
|
63 |
-
"""Compute embeddings for multiple texts at once"""
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
67 |
"""
|
68 |
-
|
69 |
"""
|
70 |
if self.config.debug:
|
71 |
print(f"\nQuick check for variation: {variation}")
|
72 |
-
|
73 |
-
#
|
74 |
orig_len = len(original.split())
|
75 |
var_len = len(variation.split())
|
76 |
|
77 |
-
# For very short texts (1-3 words), allow more variation
|
78 |
if orig_len <= 3:
|
79 |
-
if var_len > orig_len *
|
80 |
if self.config.debug:
|
81 |
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
82 |
return False
|
83 |
else:
|
84 |
-
if var_len > orig_len *
|
85 |
if self.config.debug:
|
86 |
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
87 |
return False
|
88 |
|
89 |
-
#
|
90 |
-
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'}
|
91 |
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
92 |
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
93 |
|
94 |
-
|
|
|
|
|
95 |
if self.config.debug:
|
96 |
-
print("Failed content check:
|
97 |
return False
|
98 |
|
99 |
if self.config.debug:
|
@@ -401,69 +441,6 @@ class DialogueAugmenter:
|
|
401 |
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
|
402 |
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
|
403 |
return text1 == text2
|
404 |
-
|
405 |
-
# def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]:
|
406 |
-
# """
|
407 |
-
# Generate augmented versions of the turn using multiple strategies.
|
408 |
-
# """
|
409 |
-
# text = turn['text']
|
410 |
-
# words = text.split()
|
411 |
-
|
412 |
-
# # Special handling for very short texts
|
413 |
-
# if len(words) < 3:
|
414 |
-
# return self._augment_short_text(turn)
|
415 |
-
|
416 |
-
# all_variations = set()
|
417 |
-
|
418 |
-
# # Advanced augmentations (paraphrase and back-translation)
|
419 |
-
# for augmenter in self.augmenters['advanced']:
|
420 |
-
# try:
|
421 |
-
# if isinstance(augmenter, Paraphraser):
|
422 |
-
# variations = augmenter.paraphrase(text)
|
423 |
-
# all_variations.update(variations)
|
424 |
-
# elif isinstance(augmenter, BackTranslator):
|
425 |
-
# aug_text = augmenter.back_translate(text)
|
426 |
-
# if aug_text:
|
427 |
-
# all_variations.add(aug_text)
|
428 |
-
# except Exception as e:
|
429 |
-
# print(f"Error in advanced augmentation: {str(e)}")
|
430 |
-
# continue
|
431 |
-
|
432 |
-
# # Basic nlpaug augmentations
|
433 |
-
# for aug_type, augmenter in self.augmenters['basic']:
|
434 |
-
# try:
|
435 |
-
# if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
|
436 |
-
# continue
|
437 |
-
|
438 |
-
# aug_texts = augmenter.augment(text, n=2)
|
439 |
-
# if isinstance(aug_texts, list):
|
440 |
-
# all_variations.update(aug_texts)
|
441 |
-
# else:
|
442 |
-
# all_variations.add(aug_texts)
|
443 |
-
# except Exception as e:
|
444 |
-
# print(f"Error in {aug_type} augmentation: {str(e)}")
|
445 |
-
# continue
|
446 |
-
|
447 |
-
# # Remove exact duplicates and empty strings
|
448 |
-
# augmented_texts = [t for t in list(all_variations) if t.strip()]
|
449 |
-
|
450 |
-
# # Apply context filtering
|
451 |
-
# if context:
|
452 |
-
# augmented_texts = self._filter_by_context(augmented_texts, context)
|
453 |
-
# print(f"After context filtering: {len(augmented_texts)} variations")
|
454 |
-
|
455 |
-
# # Select best variations
|
456 |
-
# best_variations = self._select_best_augmentations(
|
457 |
-
# text,
|
458 |
-
# augmented_texts,
|
459 |
-
# num_to_select=self.config.augmentation_factor,
|
460 |
-
# min_quality_score=0.7
|
461 |
-
# )
|
462 |
-
|
463 |
-
# # Create variations with speaker info
|
464 |
-
# variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations]
|
465 |
-
|
466 |
-
# return variations
|
467 |
|
468 |
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
469 |
"""
|
@@ -574,143 +551,3 @@ class DialogueAugmenter:
|
|
574 |
return True
|
575 |
|
576 |
return False
|
577 |
-
|
578 |
-
# def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]:
|
579 |
-
# """
|
580 |
-
# Filter variations based on conversation context using config parameters.
|
581 |
-
# """
|
582 |
-
# # Manage context window using config
|
583 |
-
# recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context
|
584 |
-
|
585 |
-
# filtered_variations = []
|
586 |
-
# context_embedding = self.use_model([' '.join(recent_context)])[0].numpy()
|
587 |
-
|
588 |
-
# prev_turn = recent_context[-1] if recent_context else ''
|
589 |
-
|
590 |
-
# for variation in variations:
|
591 |
-
# var_embedding = self.use_model([variation])[0].numpy()
|
592 |
-
|
593 |
-
# # Overall context similarity
|
594 |
-
# context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0]
|
595 |
-
|
596 |
-
# # Direct response coherence
|
597 |
-
# response_coherence = 1.0
|
598 |
-
# if prev_turn:
|
599 |
-
# prev_embedding = self.use_model([prev_turn])[0].numpy()
|
600 |
-
# response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0]
|
601 |
-
|
602 |
-
# # Use weights from config
|
603 |
-
# combined_similarity = (
|
604 |
-
# self.config.context_similarity_weight * context_similarity +
|
605 |
-
# self.config.response_coherence_weight * response_coherence
|
606 |
-
# )
|
607 |
-
|
608 |
-
# if (combined_similarity >= self.config.semantic_similarity_threshold and
|
609 |
-
# response_coherence >= self.config.min_response_coherence):
|
610 |
-
# filtered_variations.append(variation)
|
611 |
-
# if self.config.debug:
|
612 |
-
# print(f"Accepted variation: {variation}")
|
613 |
-
# print(f"Context similarity: {context_similarity:.3f}")
|
614 |
-
# print(f"Response coherence: {response_coherence:.3f}")
|
615 |
-
# print(f"Combined score: {combined_similarity:.3f}\n")
|
616 |
-
# else:
|
617 |
-
# if self.config.debug:
|
618 |
-
# print(f"Rejected variation: {variation}")
|
619 |
-
# print(f"Combined score {combined_similarity:.3f} below threshold "
|
620 |
-
# f"{self.config.semantic_similarity_threshold}")
|
621 |
-
# print(f"Response coherence {response_coherence:.3f} below threshold "
|
622 |
-
# f"{self.config.min_response_coherence}\n")
|
623 |
-
|
624 |
-
# return filtered_variations or variations # Fallback to original
|
625 |
-
|
626 |
-
# def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None,
|
627 |
-
# num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]:
|
628 |
-
# """
|
629 |
-
# Select the best augmentations using a quality score.
|
630 |
-
# Args:
|
631 |
-
# original (str): The original text
|
632 |
-
# candidates (List[str]): List of candidate augmented texts
|
633 |
-
# used_variations (set): Set of already used variations
|
634 |
-
# num_to_select (int): Number of variations to select
|
635 |
-
# min_quality_score (float): Minimum quality score threshold
|
636 |
-
# """
|
637 |
-
# if used_variations is None:
|
638 |
-
# used_variations = set()
|
639 |
-
|
640 |
-
# candidates = [c for c in candidates if c.strip()]
|
641 |
-
|
642 |
-
# # Skip short text
|
643 |
-
# if len(original.split()) < 3:
|
644 |
-
# print(f"Text too short for augmentation: {original}")
|
645 |
-
# return [original]
|
646 |
-
|
647 |
-
# scored_candidates = []
|
648 |
-
# for candidate in candidates:
|
649 |
-
# if candidate in used_variations:
|
650 |
-
# continue
|
651 |
-
|
652 |
-
# metrics = self.quality_metrics.compute_metrics(original, candidate)
|
653 |
-
|
654 |
-
# # Add contextual penalty for inappropriate audience terms
|
655 |
-
# audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'}
|
656 |
-
# has_audience_term = any(term in candidate.lower() for term in audience_terms)
|
657 |
-
# audience_penalty = 0.2 if has_audience_term else 0.0
|
658 |
-
|
659 |
-
# # Weighted quality score
|
660 |
-
# quality_score = (
|
661 |
-
# 0.40 * metrics['semantic_similarity'] + # Semantic preservation
|
662 |
-
# 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency
|
663 |
-
# 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar
|
664 |
-
# 0.15 * metrics['content_preservation'] + # Content preservation
|
665 |
-
# 0.05 * metrics['type_token_ratio'] # Lexical diversity
|
666 |
-
# )
|
667 |
-
|
668 |
-
# quality_score -= audience_penalty
|
669 |
-
|
670 |
-
# if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss
|
671 |
-
# metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap
|
672 |
-
# continue
|
673 |
-
|
674 |
-
# # Bonus points for:
|
675 |
-
# # Length similarity to original
|
676 |
-
# if 0.75 <= metrics['length_ratio'] <= 1.25:
|
677 |
-
# quality_score += 0.05
|
678 |
-
|
679 |
-
# # Correct grammar
|
680 |
-
# if metrics['grammar_errors'] == 0:
|
681 |
-
# quality_score += 0.025
|
682 |
-
|
683 |
-
# print(f"Candidate: {candidate}")
|
684 |
-
# print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}")
|
685 |
-
|
686 |
-
# # Consider the augmentationif meets basic quality threshold
|
687 |
-
# if quality_score >= min_quality_score:
|
688 |
-
# print('Candidate accepted\n')
|
689 |
-
# scored_candidates.append((candidate, quality_score, metrics))
|
690 |
-
# else:
|
691 |
-
# print('Candidate rejected\n')
|
692 |
-
|
693 |
-
# # Sort by quality score with small random factor for diversity
|
694 |
-
# scored_candidates.sort(key=lambda x: x[1], reverse=True)
|
695 |
-
|
696 |
-
# selected = []
|
697 |
-
# for candidate, score, metrics in scored_candidates:
|
698 |
-
# # Check diversity against already selected
|
699 |
-
# if len(selected) == 0:
|
700 |
-
# selected.append(candidate)
|
701 |
-
# continue
|
702 |
-
|
703 |
-
# # Compute average similarity to already selected
|
704 |
-
# avg_similarity = np.mean([
|
705 |
-
# self.quality_metrics.compute_semantic_similarity(candidate, prev)
|
706 |
-
# for prev in selected
|
707 |
-
# ])
|
708 |
-
|
709 |
-
# # Add if sufficiently different (similarity < 0.98)
|
710 |
-
# if avg_similarity < 0.98:
|
711 |
-
# selected.append(candidate)
|
712 |
-
|
713 |
-
# if len(selected) >= num_to_select:
|
714 |
-
# break
|
715 |
-
|
716 |
-
# return selected
|
|
|
1 |
from typing import Dict, List
|
2 |
import numpy as np
|
3 |
+
import torch
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_hub as hub
|
6 |
import re
|
|
|
20 |
def __init__(self, nlp, config: PipelineConfig):
|
21 |
self.nlp = nlp
|
22 |
self.config = config
|
23 |
+
|
24 |
+
# Detect hardware and set appropriate batch sizes and optimization strategy
|
25 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26 |
+
self.use_gpu = torch.cuda.is_available()
|
27 |
+
|
28 |
+
if self.config.debug:
|
29 |
+
print(f"Using device: {self.device}")
|
30 |
+
if self.use_gpu:
|
31 |
+
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
32 |
+
|
33 |
+
# Load base models
|
34 |
self.quality_metrics = QualityMetrics(config)
|
35 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
36 |
|
37 |
+
# Initialize augmentation models based on hardware
|
38 |
+
self._initialize_augmentation_models()
|
39 |
+
|
40 |
+
# Initialize caches
|
41 |
+
self.embedding_cache = {}
|
42 |
+
self.perplexity_cache = {}
|
43 |
+
|
44 |
+
# Compile regex patterns
|
45 |
+
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
|
46 |
+
|
47 |
+
# GPU memory management if available
|
48 |
+
if self.use_gpu:
|
49 |
+
gpus = tf.config.list_physical_devices('GPU')
|
50 |
+
if gpus:
|
51 |
+
try:
|
52 |
+
for gpu in gpus:
|
53 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
54 |
+
except RuntimeError as e:
|
55 |
+
print(e)
|
56 |
+
|
57 |
+
def _initialize_augmentation_models(self):
|
58 |
+
"""Initialize augmentation models with appropriate device settings"""
|
59 |
# Advanced augmentation techniques
|
60 |
self.paraphraser = Paraphraser()
|
61 |
self.back_translator = BackTranslator()
|
62 |
|
63 |
+
if self.use_gpu:
|
64 |
+
# Move models to GPU if available
|
65 |
+
self.paraphraser.model = self.paraphraser.model.to(self.device)
|
66 |
+
self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device)
|
67 |
+
self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device)
|
68 |
+
self.back_translator.model_backward = self.back_translator.model_backward.to(self.device)
|
69 |
+
|
70 |
# Basic augmentation techniques
|
71 |
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
72 |
self.spelling_augmenter = naw.SpellingAug()
|
|
|
78 |
('spelling', self.spelling_augmenter)
|
79 |
]
|
80 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
@lru_cache(maxsize=1024)
|
83 |
def _compute_embedding(self, text: str) -> np.ndarray:
|
84 |
"""Cached computation of text embedding"""
|
85 |
+
if text in self.embedding_cache:
|
86 |
+
return self.embedding_cache[text]
|
87 |
+
embedding = self.use_model([text])[0].numpy()
|
88 |
+
self.embedding_cache[text] = embedding
|
89 |
+
return embedding
|
90 |
|
91 |
def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
|
92 |
+
"""Compute embeddings for multiple texts at once with hardware optimization"""
|
93 |
+
# Check cache first
|
94 |
+
uncached_texts = [t for t in texts if t not in self.embedding_cache]
|
95 |
+
if uncached_texts:
|
96 |
+
embeddings = self.use_model(uncached_texts).numpy()
|
97 |
+
# Update cache
|
98 |
+
for text, embedding in zip(uncached_texts, embeddings):
|
99 |
+
self.embedding_cache[text] = embedding
|
100 |
+
|
101 |
+
# Return all embeddings (from cache or newly computed)
|
102 |
+
return np.array([self.embedding_cache[t] for t in texts])
|
103 |
|
104 |
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
105 |
"""
|
106 |
+
Stricter preliminary quality check while maintaining reasonable pass rates
|
107 |
"""
|
108 |
if self.config.debug:
|
109 |
print(f"\nQuick check for variation: {variation}")
|
110 |
+
|
111 |
+
# Stricter length check
|
112 |
orig_len = len(original.split())
|
113 |
var_len = len(variation.split())
|
114 |
|
115 |
+
# For very short texts (1-3 words), still allow more variation
|
116 |
if orig_len <= 3:
|
117 |
+
if var_len > orig_len * 3: # Reduced from 4x to 3x
|
118 |
if self.config.debug:
|
119 |
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
120 |
return False
|
121 |
else:
|
122 |
+
if var_len > orig_len * 2: # Reduced from 3x to 2x
|
123 |
if self.config.debug:
|
124 |
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
125 |
return False
|
126 |
|
127 |
+
# Enhanced content check - more words in common
|
128 |
+
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
|
129 |
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
130 |
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
131 |
|
132 |
+
# Require more content word overlap
|
133 |
+
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
|
134 |
+
if content_overlap < 0.3: # Increased from no minimum to 30% overlap
|
135 |
if self.config.debug:
|
136 |
+
print(f"Failed content check: overlap {content_overlap:.2f}")
|
137 |
return False
|
138 |
|
139 |
if self.config.debug:
|
|
|
441 |
text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
|
442 |
text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
|
443 |
return text1 == text2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
446 |
"""
|
|
|
551 |
return True
|
552 |
|
553 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -65,28 +65,26 @@ def main():
|
|
65 |
context_window_size=4,
|
66 |
max_complexity_threshold=100,
|
67 |
use_cache=False,
|
68 |
-
debug=
|
69 |
allowed_speakers=['user', 'assistant'],
|
70 |
required_fields=['dialogue_id', 'turns']
|
71 |
)
|
72 |
|
73 |
try:
|
74 |
# Set max_examples (Optional[int]) for testing
|
75 |
-
max_examples =
|
76 |
|
77 |
# Initialize and load Taskmaster dataset
|
78 |
print("Loading Taskmaster dataset")
|
79 |
taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
|
80 |
-
|
81 |
-
taskmaster_dialogues = taskmaster_processor.load_dataset(taskmaster_dir, max_examples=max_examples)
|
82 |
taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
|
83 |
print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
|
84 |
|
85 |
# Initialize and load Schema-Guided dataset
|
86 |
print("Loading Schema-Guided dataset")
|
87 |
schema_dialogue_processor = SchemaGuidedProcessor(config)
|
88 |
-
|
89 |
-
schema_dialogues = schema_dialogue_processor.load_dataset(schema_guided_dir, max_examples=max_examples)
|
90 |
schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
|
91 |
print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
|
92 |
|
@@ -102,19 +100,9 @@ def main():
|
|
102 |
# Process through augmentation pipeline
|
103 |
print("Processing combined dataset")
|
104 |
pipeline = ProcessingPipeline(config)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
output_path = 'augmented_combined_dataset.json'
|
109 |
-
with open(output_path, 'w', encoding='utf-8') as f:
|
110 |
-
json.dump(processed_dialogues, f, indent=2, ensure_ascii=False)
|
111 |
-
|
112 |
-
# Print statistics
|
113 |
-
print(f"\nProcessed Statistics:")
|
114 |
-
print(f"Total dialogues: {len(processed_dialogues)}")
|
115 |
-
print(f"Taskmaster domains: {len(taskmaster_processor.domains)}")
|
116 |
-
print(f"Schema-Guided services: {len(schema_dialogue_processor.services)}")
|
117 |
-
print(f"Schema-Guided domains: {len(schema_dialogue_processor.domains)}")
|
118 |
|
119 |
except Exception as e:
|
120 |
print(f"Processing failed: {str(e)}")
|
|
|
65 |
context_window_size=4,
|
66 |
max_complexity_threshold=100,
|
67 |
use_cache=False,
|
68 |
+
debug=False,
|
69 |
allowed_speakers=['user', 'assistant'],
|
70 |
required_fields=['dialogue_id', 'turns']
|
71 |
)
|
72 |
|
73 |
try:
|
74 |
# Set max_examples (Optional[int]) for testing
|
75 |
+
max_examples = 5
|
76 |
|
77 |
# Initialize and load Taskmaster dataset
|
78 |
print("Loading Taskmaster dataset")
|
79 |
taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
|
80 |
+
taskmaster_dialogues = taskmaster_processor.load_dataset('./datasets/taskmaster', max_examples=max_examples)
|
|
|
81 |
taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
|
82 |
print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
|
83 |
|
84 |
# Initialize and load Schema-Guided dataset
|
85 |
print("Loading Schema-Guided dataset")
|
86 |
schema_dialogue_processor = SchemaGuidedProcessor(config)
|
87 |
+
schema_dialogues = schema_dialogue_processor.load_dataset('./datasets/schema_guided', max_examples=max_examples)
|
|
|
88 |
schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
|
89 |
print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
|
90 |
|
|
|
100 |
# Process through augmentation pipeline
|
101 |
print("Processing combined dataset")
|
102 |
pipeline = ProcessingPipeline(config)
|
103 |
+
output_path = pipeline.process_dataset(combined_dialogues)
|
104 |
+
print(f"Processing complete. Results saved to {output_path}")
|
105 |
+
pipeline.cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
except Exception as e:
|
108 |
print(f"Processing failed: {str(e)}")
|
processing_pipeline.py
CHANGED
@@ -4,13 +4,15 @@ from typing import List, Dict, Optional
|
|
4 |
import json
|
5 |
import re
|
6 |
import hashlib
|
7 |
-
import pickle
|
8 |
import spacy
|
|
|
9 |
from tqdm import tqdm
|
10 |
from pipeline_config import PipelineConfig
|
11 |
from dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
14 |
|
15 |
class ProcessingPipeline:
|
16 |
"""
|
@@ -24,50 +26,151 @@ class ProcessingPipeline:
|
|
24 |
self.num_threads = self.config.batch_size
|
25 |
self.cache_dir = Path("./cache")
|
26 |
self.cache_dir.mkdir(exist_ok=True)
|
27 |
-
|
28 |
-
|
29 |
-
""
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
# Check cache
|
36 |
-
if self.config.use_cache:
|
37 |
-
cache_path = self._get_cache_path(dialogues)
|
38 |
-
if cache_path.exists():
|
39 |
-
print("Loading from cache...")
|
40 |
-
with open(cache_path, 'rb') as f:
|
41 |
-
return pickle.load(f)
|
42 |
-
|
43 |
-
# Validate and clean
|
44 |
-
valid_dialogues = self._process_validation(
|
45 |
-
dialogues,
|
46 |
-
self._validate_and_clean_dialogue,
|
47 |
-
"validating and cleaning"
|
48 |
-
)
|
49 |
|
50 |
-
if
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
all_processed_dialogues.extend(augmented)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
|
73 |
"""
|
|
|
4 |
import json
|
5 |
import re
|
6 |
import hashlib
|
|
|
7 |
import spacy
|
8 |
+
import torch
|
9 |
from tqdm import tqdm
|
10 |
from pipeline_config import PipelineConfig
|
11 |
from dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
from concurrent.futures import ProcessPoolExecutor
|
15 |
+
from typing import Set
|
16 |
|
17 |
class ProcessingPipeline:
|
18 |
"""
|
|
|
26 |
self.num_threads = self.config.batch_size
|
27 |
self.cache_dir = Path("./cache")
|
28 |
self.cache_dir.mkdir(exist_ok=True)
|
29 |
+
self.output_dir = Path("processed_outputs")
|
30 |
+
self.output_dir.mkdir(exist_ok=True)
|
31 |
+
self.checkpoint_file = self.output_dir / "processing_checkpoint.json"
|
32 |
+
self.batch_size = self.config.batch_size
|
33 |
+
self.use_gpu = torch.cuda.is_available()
|
34 |
+
self.batch_size = 32 if self.use_gpu else 8
|
35 |
+
self.use_multiprocessing = not self.use_gpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
if self.config.debug:
|
38 |
+
print(f"ProcessingPipeline initialized with:")
|
39 |
+
print(f"- GPU available: {self.use_gpu}")
|
40 |
+
print(f"- Batch size: {self.batch_size}")
|
41 |
+
print(f"- Using multiprocessing: {self.use_multiprocessing}")
|
42 |
+
|
43 |
+
def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path:
|
44 |
+
"""Save a batch of results to a separate JSON file"""
|
45 |
+
batch_file = self.output_dir / f"batch_{batch_num:04d}.json"
|
46 |
+
with open(batch_file, 'w') as f:
|
47 |
+
json.dump(batch_results, f)
|
48 |
+
return batch_file
|
49 |
+
|
50 |
+
def _load_checkpoint(self) -> set:
|
51 |
+
"""Load set of processed dialogue IDs from checkpoint"""
|
52 |
+
if self.checkpoint_file.exists():
|
53 |
+
with open(self.checkpoint_file, 'r') as f:
|
54 |
+
return set(json.load(f))
|
55 |
+
return set()
|
56 |
|
57 |
+
def _update_checkpoint(self, processed_ids: set):
|
58 |
+
"""Update checkpoint with newly processed IDs"""
|
59 |
+
with open(self.checkpoint_file, 'w') as f:
|
60 |
+
json.dump(list(processed_ids), f)
|
|
|
61 |
|
62 |
+
def _process_batch(self, batch: List[Dict]) -> List[Dict]:
|
63 |
+
"""Process batch with optimized model calls"""
|
64 |
+
results = []
|
65 |
+
try:
|
66 |
+
if self.use_gpu:
|
67 |
+
results = self.augmenter.process_batch(batch)
|
68 |
+
else:
|
69 |
+
# Collect all texts that need processing
|
70 |
+
all_texts = []
|
71 |
+
text_to_dialogue_map = {}
|
72 |
+
for dialogue in batch:
|
73 |
+
for turn in dialogue['turns']:
|
74 |
+
all_texts.append(turn['text'])
|
75 |
+
text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
|
76 |
+
|
77 |
+
# Batch process embeddings
|
78 |
+
embeddings = self.augmenter._compute_batch_embeddings(all_texts)
|
79 |
+
|
80 |
+
# Process dialogues with cached embeddings
|
81 |
+
for dialogue in batch:
|
82 |
+
try:
|
83 |
+
augmented = self.augmenter.augment_dialogue(dialogue)
|
84 |
+
results.extend(augmented)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
|
87 |
+
continue
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error processing batch: {str(e)}")
|
90 |
+
return results
|
91 |
+
|
92 |
+
def combine_results(self) -> Path:
|
93 |
+
"""Combine all batch files into final output"""
|
94 |
+
all_results = []
|
95 |
+
batch_files = sorted(self.output_dir.glob("batch_*.json"))
|
96 |
+
|
97 |
+
print(f"Combining {len(batch_files)} batch files...")
|
98 |
+
for batch_file in tqdm(batch_files):
|
99 |
+
with open(batch_file, 'r') as f:
|
100 |
+
batch_data = json.load(f)
|
101 |
+
all_results.extend(batch_data)
|
102 |
+
|
103 |
+
# Save combined results
|
104 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
105 |
+
final_output = self.output_dir / f"augmented_dataset_{timestamp}.json"
|
106 |
+
with open(final_output, 'w') as f:
|
107 |
+
json.dump(all_results, f)
|
108 |
+
|
109 |
+
if self.config.debug:
|
110 |
+
print(f"Combined {len(all_results)} dialogues into {final_output}")
|
111 |
+
|
112 |
+
return final_output
|
113 |
|
114 |
+
def process_dataset(self, dialogues: List[Dict]) -> Path:
|
115 |
+
"""Process dataset with hardware-appropriate optimizations and progress tracking"""
|
116 |
+
processed_ids = self._load_checkpoint()
|
117 |
+
|
118 |
+
# Filter out already processed dialogues
|
119 |
+
remaining_dialogues = [d for d in dialogues
|
120 |
+
if d['dialogue_id'] not in processed_ids]
|
121 |
+
|
122 |
+
total_dialogues = len(dialogues)
|
123 |
+
remaining_count = len(remaining_dialogues)
|
124 |
+
processed_count = total_dialogues - remaining_count
|
125 |
+
|
126 |
+
print("\nDataset Processing Status:")
|
127 |
+
print(f"Total dialogues in dataset: {total_dialogues}")
|
128 |
+
print(f"Previously processed: {processed_count}")
|
129 |
+
print(f"Remaining to process: {remaining_count}")
|
130 |
+
print("-" * 50)
|
131 |
+
|
132 |
+
# Process in batches with progress bar
|
133 |
+
for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size),
|
134 |
+
desc="Processing batches",
|
135 |
+
total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size):
|
136 |
+
batch = remaining_dialogues[batch_num:batch_num + self.batch_size]
|
137 |
+
current_position = processed_count + batch_num + len(batch)
|
138 |
+
|
139 |
+
total_progress = (current_position / total_dialogues) * 100
|
140 |
+
batch_progress = (batch_num + 1) / ((len(remaining_dialogues) + self.batch_size - 1) // self.batch_size) * 100
|
141 |
+
|
142 |
+
print(f"\rProgress: {current_position}/{total_dialogues} dialogues "
|
143 |
+
f"({total_progress:.1f}% complete) - "
|
144 |
+
f"Batch {batch_num//self.batch_size + 1} of "
|
145 |
+
f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}", end="")
|
146 |
+
|
147 |
+
# Process batch
|
148 |
+
batch_results = self._process_batch(batch)
|
149 |
+
|
150 |
+
if batch_results:
|
151 |
+
self._save_batch(batch_results, batch_num)
|
152 |
+
batch_ids = {d['dialogue_id'] for d in batch}
|
153 |
+
processed_ids.update(batch_ids)
|
154 |
+
self._update_checkpoint(processed_ids)
|
155 |
+
|
156 |
+
print("\n" + "-" * 50)
|
157 |
+
print("Processing complete. Combining results...")
|
158 |
+
return self.combine_results()
|
159 |
|
160 |
+
def cleanup(self):
|
161 |
+
"""Clean up intermediate batch files after successful processing"""
|
162 |
+
batch_files = list(self.output_dir.glob("batch_*.json"))
|
163 |
+
for file in batch_files:
|
164 |
+
try:
|
165 |
+
file.unlink()
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Error deleting {file}: {e}")
|
168 |
+
|
169 |
+
if self.checkpoint_file.exists():
|
170 |
+
try:
|
171 |
+
self.checkpoint_file.unlink()
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Error deleting checkpoint file: {e}")
|
174 |
|
175 |
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
|
176 |
"""
|