from datetime import datetime from pathlib import Path from typing import List, Dict, Optional import json import re import hashlib import spacy import torch from tqdm import tqdm from pipeline_config import PipelineConfig from dialogue_augmenter import DialogueAugmenter from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from typing import Set class ProcessingPipeline: """ Complete pipeline combining validation, optimization, and augmentation. """ def __init__(self, config: Optional[PipelineConfig] = None): self.config = config or PipelineConfig() self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner']) self.augmenter = DialogueAugmenter(self.nlp, self.config) self.num_threads = self.config.batch_size self.cache_dir = Path("./cache") self.cache_dir.mkdir(exist_ok=True) self.output_dir = Path("processed_outputs") self.output_dir.mkdir(exist_ok=True) self.checkpoint_file = self.output_dir / "processing_checkpoint.json" self.batch_size = self.config.batch_size self.use_gpu = torch.cuda.is_available() self.batch_size = 32 if self.use_gpu else 8 self.use_multiprocessing = not self.use_gpu # Counters for grouping batches self.batch_counter = 0 # Count batches since last group combine self.batch_group_number = 0 # How many groups have been created if self.config.debug: print(f"ProcessingPipeline initialized with:") print(f"- GPU available: {self.use_gpu}") print(f"- Batch size: {self.batch_size}") print(f"- Using multiprocessing: {self.use_multiprocessing}") def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path: """Save a batch of results to a separate JSON file""" batch_file = self.output_dir / f"batch_{batch_num:04d}.json" with open(batch_file, 'w') as f: json.dump(batch_results, f) return batch_file def _load_checkpoint(self) -> set: """Load set of processed dialogue IDs from checkpoint""" if self.checkpoint_file.exists(): with open(self.checkpoint_file, 'r') as f: return set(json.load(f)) return set() def _update_checkpoint(self, processed_ids: set): """Update checkpoint with newly processed IDs""" with open(self.checkpoint_file, 'w') as f: json.dump(list(processed_ids), f) def _process_batch(self, batch: List[Dict]) -> List[Dict]: """Process batch with optimized model calls""" results = [] try: if self.use_gpu: results = self.augmenter.process_batch(batch) else: # Collect all texts that need processing all_texts = [] text_to_dialogue_map = {} for dialogue in batch: for turn in dialogue['turns']: all_texts.append(turn['text']) text_to_dialogue_map[turn['text']] = dialogue['dialogue_id'] # Batch process embeddings self.augmenter._compute_batch_embeddings(all_texts) # Process dialogues with cached embeddings for dialogue in batch: try: augmented = self.augmenter.augment_dialogue(dialogue) results.extend(augmented) except Exception as e: print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") continue except Exception as e: print(f"Error processing batch: {str(e)}") return results def _combine_intermediate_batches(self): """ Combine all current batch_*.json files into a single batch_group_XXXX.json file, then remove the batch_*.json files. """ batch_files = sorted(self.output_dir.glob("batch_*.json")) if not batch_files: return None # No files to combine combined_data = [] for bf in batch_files: with open(bf, 'r') as f: combined_data.extend(json.load(f)) bf.unlink() # Remove the individual batch file after reading self.batch_group_number += 1 group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json" with open(group_file, 'w') as f: json.dump(combined_data, f) return group_file def combine_results(self) -> Path: """Combine all batch_group_*.json files into final output""" all_results = [] group_files = sorted(self.output_dir.glob("batch_group_*.json")) print(f"Combining {len(group_files)} group files...") for group_file in tqdm(group_files): with open(group_file, 'r') as f: group_data = json.load(f) all_results.extend(group_data) # Save combined results timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") final_output = self.output_dir / f"augmented_dataset_{timestamp}.json" with open(final_output, 'w') as f: json.dump(all_results, f) if self.config.debug: print(f"Combined {len(all_results)} dialogues into {final_output}") return final_output def process_dataset(self, dialogues: List[Dict]) -> Path: """Process dataset with hardware-appropriate optimizations and progress tracking""" processed_ids = self._load_checkpoint() # Filter out already processed dialogues remaining_dialogues = [d for d in dialogues if d['dialogue_id'] not in processed_ids] total_dialogues = len(dialogues) remaining_count = len(remaining_dialogues) processed_count = total_dialogues - remaining_count print("\nDataset Processing Status:") print(f"Total dialogues in dataset: {total_dialogues}") print(f"Previously processed: {processed_count}") print(f"Remaining to process: {remaining_count}") print("-" * 50) # Process in batches with progress bar for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size), desc="Processing batches", total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size): batch = remaining_dialogues[batch_num:batch_num + self.batch_size] current_position = processed_count + batch_num + len(batch) total_progress = (current_position / total_dialogues) * 100 print('\033[K', end='') print(f"Processing: {current_position}/{total_dialogues} dialogues " f"({total_progress:.1f}% complete)") print(f"Current batch: {batch_num//self.batch_size + 1} of " f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}") print("-" * 50) # Process batch batch_results = self._process_batch(batch) if batch_results: self._save_batch(batch_results, batch_num) batch_ids = {d['dialogue_id'] for d in batch} processed_ids.update(batch_ids) self._update_checkpoint(processed_ids) # Increment batch counter and combine if needed self.batch_counter += 1 if self.batch_counter == 25: # Combine these 25 batches into a group file self._combine_intermediate_batches() self.batch_counter = 0 # Reset counter after grouping # If there are leftover batches less than 25 # combine them into one final group file if self.batch_counter > 0: self._combine_intermediate_batches() self.batch_counter = 0 print("\n" + "-" * 50) print("Processing complete. Combining results...") return self.combine_results() def cleanup(self): """Clean up intermediate files after successful processing""" # Clean up any leftover batch files (should not exist if logic is correct) batch_files = list(self.output_dir.glob("batch_*.json")) for file in batch_files: try: file.unlink() except Exception as e: print(f"Error deleting {file}: {e}") # We can also remove batch_group_*.json if desired after final combine # but that might not be necessary if we want to keep them. if self.checkpoint_file.exists(): try: self.checkpoint_file.unlink() except Exception as e: print(f"Error deleting checkpoint file: {e}") def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]: """ Deduplicate dialogues based on text similarity. """ print("Deduplicating dialogues...") if not dialogues: print("No dialogues provided for deduplication.") return [] # Combine turns into single text for similarity comparison texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues] tfidf = TfidfVectorizer().fit_transform(texts) sim_matrix = cosine_similarity(tfidf) unique_indices = set() for i, row in enumerate(sim_matrix): if i not in unique_indices: similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i] unique_indices.add(i) unique_indices.difference_update(similar_indices) deduplicated_dialogues = [dialogues[i] for i in unique_indices] print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.") return deduplicated_dialogues def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]: """ Validate and clean a single dialogue. """ try: # Check required fields if not all(field in dialogue for field in self.config.required_fields): return None # Process turns cleaned_turns = [] for turn in dialogue['turns']: if self._validate_turn(turn): cleaned_turn = { 'speaker': turn['speaker'], 'text': self._clean_text(turn['text']) } cleaned_turns.append(cleaned_turn) if cleaned_turns: return { 'dialogue_id': dialogue['dialogue_id'], 'turns': cleaned_turns } return None except Exception as e: print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") return None def _validate_turn(self, turn: Dict) -> bool: """ Validate a single speaking turn. """ return ( turn['speaker'] in self.config.allowed_speakers and self.config.min_length <= len(turn['text']) <= self.config.max_length ) def _clean_text(self, text: str) -> str: """ Clean and normalize text. """ # Remove excessive whitespace text = re.sub(r'\s+', ' ', text.strip()) # Normalize quotes and apostrophes text = re.sub(r'[โ€™ยด`]', "'", text) text = re.sub(r'[โ€œโ€]', '"', text) # Remove control characters text = "".join(char for char in text if ord(char) >= 32 or char == '\n') return text def _process_validation(self, items: List, func, description: str) -> List: """ Process items sequentially with a progress bar. """ results = [] print(f"Starting {description}") for item in tqdm(items, desc=description): try: result = func(item) if result is not None: results.append(result) except Exception as e: print(f"Error processing item: {str(e)}") print(f"Completed {description}. Processed {len(results)} items successfully") return results def _get_cache_path(self, data: List[Dict]) -> Path: """ Generate cache file path based on data hash. """ data_str = json.dumps(data, sort_keys=True) hash_value = hashlib.md5(data_str.encode()).hexdigest() return self.cache_dir / f"cache_{hash_value}.pkl"