JoeArmani commited on
Commit
bc503de
·
1 Parent(s): febdb1e

update gpu processing

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. dialogue_augmenter.py +69 -232
  3. main.py +7 -19
  4. processing_pipeline.py +143 -40
.gitignore CHANGED
@@ -156,3 +156,6 @@ cython_debug/
156
 
157
  datasets/*
158
  !datasets/.gitkeep
 
 
 
 
156
 
157
  datasets/*
158
  !datasets/.gitkeep
159
+
160
+ processed_outputs/*
161
+ !processed_outputs/.gitkeep
dialogue_augmenter.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Dict, List
2
  import numpy as np
 
3
  import tensorflow as tf
4
  import tensorflow_hub as hub
5
  import re
@@ -19,13 +20,53 @@ class DialogueAugmenter:
19
  def __init__(self, nlp, config: PipelineConfig):
20
  self.nlp = nlp
21
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
22
  self.quality_metrics = QualityMetrics(config)
23
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Advanced augmentation techniques
26
  self.paraphraser = Paraphraser()
27
  self.back_translator = BackTranslator()
28
 
 
 
 
 
 
 
 
29
  # Basic augmentation techniques
30
  self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
31
  self.spelling_augmenter = naw.SpellingAug()
@@ -37,63 +78,62 @@ class DialogueAugmenter:
37
  ('spelling', self.spelling_augmenter)
38
  ]
39
  }
40
-
41
- # Initialize cache
42
- self.embedding_cache = {}
43
- self.perplexity_cache = {}
44
-
45
- # Compile regex patterns
46
- self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
47
-
48
- # GPU memory management
49
- gpus = tf.config.list_physical_devices('GPU')
50
- if gpus:
51
- try:
52
- for gpu in gpus:
53
- tf.config.experimental.set_memory_growth(gpu, True)
54
- except RuntimeError as e:
55
- print(e)
56
 
57
  @lru_cache(maxsize=1024)
58
  def _compute_embedding(self, text: str) -> np.ndarray:
59
  """Cached computation of text embedding"""
60
- return self.use_model([text])[0].numpy()
 
 
 
 
61
 
62
  def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
63
- """Compute embeddings for multiple texts at once"""
64
- return self.use_model(texts).numpy()
 
 
 
 
 
 
 
 
 
65
 
66
  def _quick_quality_check(self, variation: str, original: str) -> bool:
67
  """
68
- Simplified preliminary quality check with minimal standards
69
  """
70
  if self.config.debug:
71
  print(f"\nQuick check for variation: {variation}")
72
-
73
- # Only reject if length is extremely different
74
  orig_len = len(original.split())
75
  var_len = len(variation.split())
76
 
77
- # For very short texts (1-3 words), allow more variation
78
  if orig_len <= 3:
79
- if var_len > orig_len * 4: # Allow up to 4x length for short texts
80
  if self.config.debug:
81
  print(f"Failed length check (short text): {var_len} vs {orig_len}")
82
  return False
83
  else:
84
- if var_len > orig_len * 3: # Allow up to 3x length for longer texts
85
  if self.config.debug:
86
  print(f"Failed length check (long text): {var_len} vs {orig_len}")
87
  return False
88
 
89
- # Basic content check - at least one word in common (excluding stop words)
90
- stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are'}
91
  orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
92
  var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
93
 
94
- if not orig_words.intersection(var_words):
 
 
95
  if self.config.debug:
96
- print("Failed content check: no content words in common")
97
  return False
98
 
99
  if self.config.debug:
@@ -401,69 +441,6 @@ class DialogueAugmenter:
401
  text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
402
  text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
403
  return text1 == text2
404
-
405
- # def _augment_turn(self, turn: Dict, context: List[str]) -> List[Dict]:
406
- # """
407
- # Generate augmented versions of the turn using multiple strategies.
408
- # """
409
- # text = turn['text']
410
- # words = text.split()
411
-
412
- # # Special handling for very short texts
413
- # if len(words) < 3:
414
- # return self._augment_short_text(turn)
415
-
416
- # all_variations = set()
417
-
418
- # # Advanced augmentations (paraphrase and back-translation)
419
- # for augmenter in self.augmenters['advanced']:
420
- # try:
421
- # if isinstance(augmenter, Paraphraser):
422
- # variations = augmenter.paraphrase(text)
423
- # all_variations.update(variations)
424
- # elif isinstance(augmenter, BackTranslator):
425
- # aug_text = augmenter.back_translate(text)
426
- # if aug_text:
427
- # all_variations.add(aug_text)
428
- # except Exception as e:
429
- # print(f"Error in advanced augmentation: {str(e)}")
430
- # continue
431
-
432
- # # Basic nlpaug augmentations
433
- # for aug_type, augmenter in self.augmenters['basic']:
434
- # try:
435
- # if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
436
- # continue
437
-
438
- # aug_texts = augmenter.augment(text, n=2)
439
- # if isinstance(aug_texts, list):
440
- # all_variations.update(aug_texts)
441
- # else:
442
- # all_variations.add(aug_texts)
443
- # except Exception as e:
444
- # print(f"Error in {aug_type} augmentation: {str(e)}")
445
- # continue
446
-
447
- # # Remove exact duplicates and empty strings
448
- # augmented_texts = [t for t in list(all_variations) if t.strip()]
449
-
450
- # # Apply context filtering
451
- # if context:
452
- # augmented_texts = self._filter_by_context(augmented_texts, context)
453
- # print(f"After context filtering: {len(augmented_texts)} variations")
454
-
455
- # # Select best variations
456
- # best_variations = self._select_best_augmentations(
457
- # text,
458
- # augmented_texts,
459
- # num_to_select=self.config.augmentation_factor,
460
- # min_quality_score=0.7
461
- # )
462
-
463
- # # Create variations with speaker info
464
- # variations = [{'speaker': turn['speaker'], 'text': text} for text in best_variations]
465
-
466
- # return variations
467
 
468
  def _augment_short_text(self, turn: Dict) -> List[Dict]:
469
  """
@@ -574,143 +551,3 @@ class DialogueAugmenter:
574
  return True
575
 
576
  return False
577
-
578
- # def _filter_by_context(self, variations: List[str], context: List[str]) -> List[str]:
579
- # """
580
- # Filter variations based on conversation context using config parameters.
581
- # """
582
- # # Manage context window using config
583
- # recent_context = context[-self.config.context_window_size:] if len(context) > self.config.context_window_size else context
584
-
585
- # filtered_variations = []
586
- # context_embedding = self.use_model([' '.join(recent_context)])[0].numpy()
587
-
588
- # prev_turn = recent_context[-1] if recent_context else ''
589
-
590
- # for variation in variations:
591
- # var_embedding = self.use_model([variation])[0].numpy()
592
-
593
- # # Overall context similarity
594
- # context_similarity = cosine_similarity([context_embedding], [var_embedding])[0][0]
595
-
596
- # # Direct response coherence
597
- # response_coherence = 1.0
598
- # if prev_turn:
599
- # prev_embedding = self.use_model([prev_turn])[0].numpy()
600
- # response_coherence = cosine_similarity([prev_embedding], [var_embedding])[0][0]
601
-
602
- # # Use weights from config
603
- # combined_similarity = (
604
- # self.config.context_similarity_weight * context_similarity +
605
- # self.config.response_coherence_weight * response_coherence
606
- # )
607
-
608
- # if (combined_similarity >= self.config.semantic_similarity_threshold and
609
- # response_coherence >= self.config.min_response_coherence):
610
- # filtered_variations.append(variation)
611
- # if self.config.debug:
612
- # print(f"Accepted variation: {variation}")
613
- # print(f"Context similarity: {context_similarity:.3f}")
614
- # print(f"Response coherence: {response_coherence:.3f}")
615
- # print(f"Combined score: {combined_similarity:.3f}\n")
616
- # else:
617
- # if self.config.debug:
618
- # print(f"Rejected variation: {variation}")
619
- # print(f"Combined score {combined_similarity:.3f} below threshold "
620
- # f"{self.config.semantic_similarity_threshold}")
621
- # print(f"Response coherence {response_coherence:.3f} below threshold "
622
- # f"{self.config.min_response_coherence}\n")
623
-
624
- # return filtered_variations or variations # Fallback to original
625
-
626
- # def _select_best_augmentations(self, original: str, candidates: List[str], used_variations: set = None,
627
- # num_to_select: int = 3, min_quality_score: float = 0.7) -> List[str]:
628
- # """
629
- # Select the best augmentations using a quality score.
630
- # Args:
631
- # original (str): The original text
632
- # candidates (List[str]): List of candidate augmented texts
633
- # used_variations (set): Set of already used variations
634
- # num_to_select (int): Number of variations to select
635
- # min_quality_score (float): Minimum quality score threshold
636
- # """
637
- # if used_variations is None:
638
- # used_variations = set()
639
-
640
- # candidates = [c for c in candidates if c.strip()]
641
-
642
- # # Skip short text
643
- # if len(original.split()) < 3:
644
- # print(f"Text too short for augmentation: {original}")
645
- # return [original]
646
-
647
- # scored_candidates = []
648
- # for candidate in candidates:
649
- # if candidate in used_variations:
650
- # continue
651
-
652
- # metrics = self.quality_metrics.compute_metrics(original, candidate)
653
-
654
- # # Add contextual penalty for inappropriate audience terms
655
- # audience_terms = {'everyone', 'everybody', 'folks', 'all', 'guys', 'people'}
656
- # has_audience_term = any(term in candidate.lower() for term in audience_terms)
657
- # audience_penalty = 0.2 if has_audience_term else 0.0
658
-
659
- # # Weighted quality score
660
- # quality_score = (
661
- # 0.40 * metrics['semantic_similarity'] + # Semantic preservation
662
- # 0.25 * (1.0 - metrics['perplexity'] / 100) + # Fluency
663
- # 0.15 * (1.0 - metrics['grammar_errors'] / 10) + # Grammar
664
- # 0.15 * metrics['content_preservation'] + # Content preservation
665
- # 0.05 * metrics['type_token_ratio'] # Lexical diversity
666
- # )
667
-
668
- # quality_score -= audience_penalty
669
-
670
- # if (metrics['semantic_similarity'] < 0.5 or # Reject on semantic threshold miss
671
- # metrics['rouge1_f1'] < 0.2): # Enforce minimum lexical overlap
672
- # continue
673
-
674
- # # Bonus points for:
675
- # # Length similarity to original
676
- # if 0.75 <= metrics['length_ratio'] <= 1.25:
677
- # quality_score += 0.05
678
-
679
- # # Correct grammar
680
- # if metrics['grammar_errors'] == 0:
681
- # quality_score += 0.025
682
-
683
- # print(f"Candidate: {candidate}")
684
- # print(f"Quality score: {quality_score:.2f}, Metrics: {metrics}")
685
-
686
- # # Consider the augmentationif meets basic quality threshold
687
- # if quality_score >= min_quality_score:
688
- # print('Candidate accepted\n')
689
- # scored_candidates.append((candidate, quality_score, metrics))
690
- # else:
691
- # print('Candidate rejected\n')
692
-
693
- # # Sort by quality score with small random factor for diversity
694
- # scored_candidates.sort(key=lambda x: x[1], reverse=True)
695
-
696
- # selected = []
697
- # for candidate, score, metrics in scored_candidates:
698
- # # Check diversity against already selected
699
- # if len(selected) == 0:
700
- # selected.append(candidate)
701
- # continue
702
-
703
- # # Compute average similarity to already selected
704
- # avg_similarity = np.mean([
705
- # self.quality_metrics.compute_semantic_similarity(candidate, prev)
706
- # for prev in selected
707
- # ])
708
-
709
- # # Add if sufficiently different (similarity < 0.98)
710
- # if avg_similarity < 0.98:
711
- # selected.append(candidate)
712
-
713
- # if len(selected) >= num_to_select:
714
- # break
715
-
716
- # return selected
 
1
  from typing import Dict, List
2
  import numpy as np
3
+ import torch
4
  import tensorflow as tf
5
  import tensorflow_hub as hub
6
  import re
 
20
  def __init__(self, nlp, config: PipelineConfig):
21
  self.nlp = nlp
22
  self.config = config
23
+
24
+ # Detect hardware and set appropriate batch sizes and optimization strategy
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ self.use_gpu = torch.cuda.is_available()
27
+
28
+ if self.config.debug:
29
+ print(f"Using device: {self.device}")
30
+ if self.use_gpu:
31
+ print(f"GPU Device: {torch.cuda.get_device_name(0)}")
32
+
33
+ # Load base models
34
  self.quality_metrics = QualityMetrics(config)
35
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
36
 
37
+ # Initialize augmentation models based on hardware
38
+ self._initialize_augmentation_models()
39
+
40
+ # Initialize caches
41
+ self.embedding_cache = {}
42
+ self.perplexity_cache = {}
43
+
44
+ # Compile regex patterns
45
+ self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
46
+
47
+ # GPU memory management if available
48
+ if self.use_gpu:
49
+ gpus = tf.config.list_physical_devices('GPU')
50
+ if gpus:
51
+ try:
52
+ for gpu in gpus:
53
+ tf.config.experimental.set_memory_growth(gpu, True)
54
+ except RuntimeError as e:
55
+ print(e)
56
+
57
+ def _initialize_augmentation_models(self):
58
+ """Initialize augmentation models with appropriate device settings"""
59
  # Advanced augmentation techniques
60
  self.paraphraser = Paraphraser()
61
  self.back_translator = BackTranslator()
62
 
63
+ if self.use_gpu:
64
+ # Move models to GPU if available
65
+ self.paraphraser.model = self.paraphraser.model.to(self.device)
66
+ self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device)
67
+ self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device)
68
+ self.back_translator.model_backward = self.back_translator.model_backward.to(self.device)
69
+
70
  # Basic augmentation techniques
71
  self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
72
  self.spelling_augmenter = naw.SpellingAug()
 
78
  ('spelling', self.spelling_augmenter)
79
  ]
80
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  @lru_cache(maxsize=1024)
83
  def _compute_embedding(self, text: str) -> np.ndarray:
84
  """Cached computation of text embedding"""
85
+ if text in self.embedding_cache:
86
+ return self.embedding_cache[text]
87
+ embedding = self.use_model([text])[0].numpy()
88
+ self.embedding_cache[text] = embedding
89
+ return embedding
90
 
91
  def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray:
92
+ """Compute embeddings for multiple texts at once with hardware optimization"""
93
+ # Check cache first
94
+ uncached_texts = [t for t in texts if t not in self.embedding_cache]
95
+ if uncached_texts:
96
+ embeddings = self.use_model(uncached_texts).numpy()
97
+ # Update cache
98
+ for text, embedding in zip(uncached_texts, embeddings):
99
+ self.embedding_cache[text] = embedding
100
+
101
+ # Return all embeddings (from cache or newly computed)
102
+ return np.array([self.embedding_cache[t] for t in texts])
103
 
104
  def _quick_quality_check(self, variation: str, original: str) -> bool:
105
  """
106
+ Stricter preliminary quality check while maintaining reasonable pass rates
107
  """
108
  if self.config.debug:
109
  print(f"\nQuick check for variation: {variation}")
110
+
111
+ # Stricter length check
112
  orig_len = len(original.split())
113
  var_len = len(variation.split())
114
 
115
+ # For very short texts (1-3 words), still allow more variation
116
  if orig_len <= 3:
117
+ if var_len > orig_len * 3: # Reduced from 4x to 3x
118
  if self.config.debug:
119
  print(f"Failed length check (short text): {var_len} vs {orig_len}")
120
  return False
121
  else:
122
+ if var_len > orig_len * 2: # Reduced from 3x to 2x
123
  if self.config.debug:
124
  print(f"Failed length check (long text): {var_len} vs {orig_len}")
125
  return False
126
 
127
+ # Enhanced content check - more words in common
128
+ stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
129
  orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
130
  var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
131
 
132
+ # Require more content word overlap
133
+ content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
134
+ if content_overlap < 0.3: # Increased from no minimum to 30% overlap
135
  if self.config.debug:
136
+ print(f"Failed content check: overlap {content_overlap:.2f}")
137
  return False
138
 
139
  if self.config.debug:
 
441
  text1 = " ".join(turn['text'] for turn in dialogue1['turns'])
442
  text2 = " ".join(turn['text'] for turn in dialogue2['turns'])
443
  return text1 == text2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  def _augment_short_text(self, turn: Dict) -> List[Dict]:
446
  """
 
551
  return True
552
 
553
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -65,28 +65,26 @@ def main():
65
  context_window_size=4,
66
  max_complexity_threshold=100,
67
  use_cache=False,
68
- debug=True,
69
  allowed_speakers=['user', 'assistant'],
70
  required_fields=['dialogue_id', 'turns']
71
  )
72
 
73
  try:
74
  # Set max_examples (Optional[int]) for testing
75
- max_examples = None
76
 
77
  # Initialize and load Taskmaster dataset
78
  print("Loading Taskmaster dataset")
79
  taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
80
- taskmaster_dir = './datasets/taskmaster'
81
- taskmaster_dialogues = taskmaster_processor.load_dataset(taskmaster_dir, max_examples=max_examples)
82
  taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
83
  print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
84
 
85
  # Initialize and load Schema-Guided dataset
86
  print("Loading Schema-Guided dataset")
87
  schema_dialogue_processor = SchemaGuidedProcessor(config)
88
- schema_guided_dir = './datasets/schema_guided'
89
- schema_dialogues = schema_dialogue_processor.load_dataset(schema_guided_dir, max_examples=max_examples)
90
  schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
91
  print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
92
 
@@ -102,19 +100,9 @@ def main():
102
  # Process through augmentation pipeline
103
  print("Processing combined dataset")
104
  pipeline = ProcessingPipeline(config)
105
- processed_dialogues = pipeline.process_dataset(combined_dialogues)
106
-
107
- # Save results
108
- output_path = 'augmented_combined_dataset.json'
109
- with open(output_path, 'w', encoding='utf-8') as f:
110
- json.dump(processed_dialogues, f, indent=2, ensure_ascii=False)
111
-
112
- # Print statistics
113
- print(f"\nProcessed Statistics:")
114
- print(f"Total dialogues: {len(processed_dialogues)}")
115
- print(f"Taskmaster domains: {len(taskmaster_processor.domains)}")
116
- print(f"Schema-Guided services: {len(schema_dialogue_processor.services)}")
117
- print(f"Schema-Guided domains: {len(schema_dialogue_processor.domains)}")
118
 
119
  except Exception as e:
120
  print(f"Processing failed: {str(e)}")
 
65
  context_window_size=4,
66
  max_complexity_threshold=100,
67
  use_cache=False,
68
+ debug=False,
69
  allowed_speakers=['user', 'assistant'],
70
  required_fields=['dialogue_id', 'turns']
71
  )
72
 
73
  try:
74
  # Set max_examples (Optional[int]) for testing
75
+ max_examples = 5
76
 
77
  # Initialize and load Taskmaster dataset
78
  print("Loading Taskmaster dataset")
79
  taskmaster_processor = TaskmasterProcessor(config, use_ontology=False)
80
+ taskmaster_dialogues = taskmaster_processor.load_dataset('./datasets/taskmaster', max_examples=max_examples)
 
81
  taskmaster_pipeline_dialogues = taskmaster_processor.convert_to_pipeline_format(taskmaster_dialogues)
82
  print(f"Processed Taskmaster dialogues: {len(taskmaster_pipeline_dialogues)}")
83
 
84
  # Initialize and load Schema-Guided dataset
85
  print("Loading Schema-Guided dataset")
86
  schema_dialogue_processor = SchemaGuidedProcessor(config)
87
+ schema_dialogues = schema_dialogue_processor.load_dataset('./datasets/schema_guided', max_examples=max_examples)
 
88
  schema_pipeline_dialogues = schema_dialogue_processor.convert_to_pipeline_format(schema_dialogues)
89
  print(f"Processed Schema-Guided dialogues: {len(schema_pipeline_dialogues)}")
90
 
 
100
  # Process through augmentation pipeline
101
  print("Processing combined dataset")
102
  pipeline = ProcessingPipeline(config)
103
+ output_path = pipeline.process_dataset(combined_dialogues)
104
+ print(f"Processing complete. Results saved to {output_path}")
105
+ pipeline.cleanup()
 
 
 
 
 
 
 
 
 
 
106
 
107
  except Exception as e:
108
  print(f"Processing failed: {str(e)}")
processing_pipeline.py CHANGED
@@ -4,13 +4,15 @@ from typing import List, Dict, Optional
4
  import json
5
  import re
6
  import hashlib
7
- import pickle
8
  import spacy
 
9
  from tqdm import tqdm
10
  from pipeline_config import PipelineConfig
11
  from dialogue_augmenter import DialogueAugmenter
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
 
 
14
 
15
  class ProcessingPipeline:
16
  """
@@ -24,50 +26,151 @@ class ProcessingPipeline:
24
  self.num_threads = self.config.batch_size
25
  self.cache_dir = Path("./cache")
26
  self.cache_dir.mkdir(exist_ok=True)
27
-
28
- def process_dataset(self, dialogues: List[Dict]) -> List[Dict]:
29
- """
30
- Process entire dataset through the pipeline.
31
- """
32
- print(f"Processing {len(dialogues)} dialogues")
33
- start_time = datetime.now()
34
-
35
- # Check cache
36
- if self.config.use_cache:
37
- cache_path = self._get_cache_path(dialogues)
38
- if cache_path.exists():
39
- print("Loading from cache...")
40
- with open(cache_path, 'rb') as f:
41
- return pickle.load(f)
42
-
43
- # Validate and clean
44
- valid_dialogues = self._process_validation(
45
- dialogues,
46
- self._validate_and_clean_dialogue,
47
- "validating and cleaning"
48
- )
49
 
50
- if not valid_dialogues:
51
- raise ValueError("Dialogue validation resulted in an empty dataset.")
52
-
53
- deduplicated_dialogues = self._deduplicate_dialogues(valid_dialogues)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Augment dialogues
56
- all_processed_dialogues = []
57
- for dialogue in deduplicated_dialogues:
58
- augmented = self.augmenter.augment_dialogue(dialogue)
59
- all_processed_dialogues.extend(augmented)
60
 
61
- # Save to cache
62
- if self.config.use_cache:
63
- with open(cache_path, 'wb') as f:
64
- pickle.dump(all_processed_dialogues, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- processing_time = datetime.now() - start_time
67
- print(f"Processing completed in {processing_time}")
68
- print(f"Generated {len(all_processed_dialogues)} total dialogues")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return all_processed_dialogues
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
73
  """
 
4
  import json
5
  import re
6
  import hashlib
 
7
  import spacy
8
+ import torch
9
  from tqdm import tqdm
10
  from pipeline_config import PipelineConfig
11
  from dialogue_augmenter import DialogueAugmenter
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ from typing import Set
16
 
17
  class ProcessingPipeline:
18
  """
 
26
  self.num_threads = self.config.batch_size
27
  self.cache_dir = Path("./cache")
28
  self.cache_dir.mkdir(exist_ok=True)
29
+ self.output_dir = Path("processed_outputs")
30
+ self.output_dir.mkdir(exist_ok=True)
31
+ self.checkpoint_file = self.output_dir / "processing_checkpoint.json"
32
+ self.batch_size = self.config.batch_size
33
+ self.use_gpu = torch.cuda.is_available()
34
+ self.batch_size = 32 if self.use_gpu else 8
35
+ self.use_multiprocessing = not self.use_gpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ if self.config.debug:
38
+ print(f"ProcessingPipeline initialized with:")
39
+ print(f"- GPU available: {self.use_gpu}")
40
+ print(f"- Batch size: {self.batch_size}")
41
+ print(f"- Using multiprocessing: {self.use_multiprocessing}")
42
+
43
+ def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path:
44
+ """Save a batch of results to a separate JSON file"""
45
+ batch_file = self.output_dir / f"batch_{batch_num:04d}.json"
46
+ with open(batch_file, 'w') as f:
47
+ json.dump(batch_results, f)
48
+ return batch_file
49
+
50
+ def _load_checkpoint(self) -> set:
51
+ """Load set of processed dialogue IDs from checkpoint"""
52
+ if self.checkpoint_file.exists():
53
+ with open(self.checkpoint_file, 'r') as f:
54
+ return set(json.load(f))
55
+ return set()
56
 
57
+ def _update_checkpoint(self, processed_ids: set):
58
+ """Update checkpoint with newly processed IDs"""
59
+ with open(self.checkpoint_file, 'w') as f:
60
+ json.dump(list(processed_ids), f)
 
61
 
62
+ def _process_batch(self, batch: List[Dict]) -> List[Dict]:
63
+ """Process batch with optimized model calls"""
64
+ results = []
65
+ try:
66
+ if self.use_gpu:
67
+ results = self.augmenter.process_batch(batch)
68
+ else:
69
+ # Collect all texts that need processing
70
+ all_texts = []
71
+ text_to_dialogue_map = {}
72
+ for dialogue in batch:
73
+ for turn in dialogue['turns']:
74
+ all_texts.append(turn['text'])
75
+ text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
76
+
77
+ # Batch process embeddings
78
+ embeddings = self.augmenter._compute_batch_embeddings(all_texts)
79
+
80
+ # Process dialogues with cached embeddings
81
+ for dialogue in batch:
82
+ try:
83
+ augmented = self.augmenter.augment_dialogue(dialogue)
84
+ results.extend(augmented)
85
+ except Exception as e:
86
+ print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}")
87
+ continue
88
+ except Exception as e:
89
+ print(f"Error processing batch: {str(e)}")
90
+ return results
91
+
92
+ def combine_results(self) -> Path:
93
+ """Combine all batch files into final output"""
94
+ all_results = []
95
+ batch_files = sorted(self.output_dir.glob("batch_*.json"))
96
+
97
+ print(f"Combining {len(batch_files)} batch files...")
98
+ for batch_file in tqdm(batch_files):
99
+ with open(batch_file, 'r') as f:
100
+ batch_data = json.load(f)
101
+ all_results.extend(batch_data)
102
+
103
+ # Save combined results
104
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
105
+ final_output = self.output_dir / f"augmented_dataset_{timestamp}.json"
106
+ with open(final_output, 'w') as f:
107
+ json.dump(all_results, f)
108
+
109
+ if self.config.debug:
110
+ print(f"Combined {len(all_results)} dialogues into {final_output}")
111
+
112
+ return final_output
113
 
114
+ def process_dataset(self, dialogues: List[Dict]) -> Path:
115
+ """Process dataset with hardware-appropriate optimizations and progress tracking"""
116
+ processed_ids = self._load_checkpoint()
117
+
118
+ # Filter out already processed dialogues
119
+ remaining_dialogues = [d for d in dialogues
120
+ if d['dialogue_id'] not in processed_ids]
121
+
122
+ total_dialogues = len(dialogues)
123
+ remaining_count = len(remaining_dialogues)
124
+ processed_count = total_dialogues - remaining_count
125
+
126
+ print("\nDataset Processing Status:")
127
+ print(f"Total dialogues in dataset: {total_dialogues}")
128
+ print(f"Previously processed: {processed_count}")
129
+ print(f"Remaining to process: {remaining_count}")
130
+ print("-" * 50)
131
+
132
+ # Process in batches with progress bar
133
+ for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size),
134
+ desc="Processing batches",
135
+ total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size):
136
+ batch = remaining_dialogues[batch_num:batch_num + self.batch_size]
137
+ current_position = processed_count + batch_num + len(batch)
138
+
139
+ total_progress = (current_position / total_dialogues) * 100
140
+ batch_progress = (batch_num + 1) / ((len(remaining_dialogues) + self.batch_size - 1) // self.batch_size) * 100
141
+
142
+ print(f"\rProgress: {current_position}/{total_dialogues} dialogues "
143
+ f"({total_progress:.1f}% complete) - "
144
+ f"Batch {batch_num//self.batch_size + 1} of "
145
+ f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}", end="")
146
+
147
+ # Process batch
148
+ batch_results = self._process_batch(batch)
149
+
150
+ if batch_results:
151
+ self._save_batch(batch_results, batch_num)
152
+ batch_ids = {d['dialogue_id'] for d in batch}
153
+ processed_ids.update(batch_ids)
154
+ self._update_checkpoint(processed_ids)
155
+
156
+ print("\n" + "-" * 50)
157
+ print("Processing complete. Combining results...")
158
+ return self.combine_results()
159
 
160
+ def cleanup(self):
161
+ """Clean up intermediate batch files after successful processing"""
162
+ batch_files = list(self.output_dir.glob("batch_*.json"))
163
+ for file in batch_files:
164
+ try:
165
+ file.unlink()
166
+ except Exception as e:
167
+ print(f"Error deleting {file}: {e}")
168
+
169
+ if self.checkpoint_file.exists():
170
+ try:
171
+ self.checkpoint_file.unlink()
172
+ except Exception as e:
173
+ print(f"Error deleting checkpoint file: {e}")
174
 
175
  def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]:
176
  """