from datetime import datetime from pathlib import Path from typing import List, Dict, Optional import json import re import hashlib import pickle import spacy from tqdm import tqdm from pipeline_config import PipelineConfig from dialogue_augmenter import DialogueAugmenter from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity class ProcessingPipeline: """ Complete pipeline combining validation, optimization, and augmentation. """ def __init__(self, config: Optional[PipelineConfig] = None): self.config = config or PipelineConfig() self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner']) self.augmenter = DialogueAugmenter(self.nlp, self.config) self.num_threads = self.config.batch_size self.cache_dir = Path("./cache") self.cache_dir.mkdir(exist_ok=True) def process_dataset(self, dialogues: List[Dict]) -> List[Dict]: """ Process entire dataset through the pipeline. """ print(f"Processing {len(dialogues)} dialogues") start_time = datetime.now() # Check cache if self.config.use_cache: cache_path = self._get_cache_path(dialogues) if cache_path.exists(): print("Loading from cache...") with open(cache_path, 'rb') as f: return pickle.load(f) # Validate and clean valid_dialogues = self._process_validation( dialogues, self._validate_and_clean_dialogue, "validating and cleaning" ) if not valid_dialogues: raise ValueError("Dialogue validation resulted in an empty dataset.") deduplicated_dialogues = self._deduplicate_dialogues(valid_dialogues) # Augment dialogues all_processed_dialogues = [] for dialogue in deduplicated_dialogues: augmented = self.augmenter.augment_dialogue(dialogue) all_processed_dialogues.extend(augmented) # Save to cache if self.config.use_cache: with open(cache_path, 'wb') as f: pickle.dump(all_processed_dialogues, f) processing_time = datetime.now() - start_time print(f"Processing completed in {processing_time}") print(f"Generated {len(all_processed_dialogues)} total dialogues") return all_processed_dialogues def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]: """ Deduplicate dialogues based on text similarity. """ print("Deduplicating dialogues...") if not dialogues: print("No dialogues provided for deduplication.") return [] # Combine turns into single text for similarity comparison texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues] tfidf = TfidfVectorizer().fit_transform(texts) sim_matrix = cosine_similarity(tfidf) unique_indices = set() for i, row in enumerate(sim_matrix): if i not in unique_indices: similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i] unique_indices.add(i) unique_indices.difference_update(similar_indices) deduplicated_dialogues = [dialogues[i] for i in unique_indices] print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.") return deduplicated_dialogues def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]: """ Validate and clean a single dialogue. """ try: # Check required fields if not all(field in dialogue for field in self.config.required_fields): return None # Process turns cleaned_turns = [] for turn in dialogue['turns']: if self._validate_turn(turn): cleaned_turn = { 'speaker': turn['speaker'], 'text': self._clean_text(turn['text']) } cleaned_turns.append(cleaned_turn) if cleaned_turns: return { 'dialogue_id': dialogue['dialogue_id'], 'turns': cleaned_turns } return None except Exception as e: print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") return None def _validate_turn(self, turn: Dict) -> bool: """ Validate a single speaking turn. """ return ( turn['speaker'] in self.config.allowed_speakers and self.config.min_length <= len(turn['text']) <= self.config.max_length ) def _clean_text(self, text: str) -> str: """ Clean and normalize text. """ # Remove excessive whitespace text = re.sub(r'\s+', ' ', text.strip()) # Normalize quotes and apostrophes text = re.sub(r'[โ€™ยด`]', "'", text) text = re.sub(r'[โ€œโ€]', '"', text) # Remove control characters text = "".join(char for char in text if ord(char) >= 32 or char == '\n') return text def _process_validation(self, items: List, func, description: str) -> List: """ Process items sequentially with a progress bar. """ results = [] print(f"Starting {description}") for item in tqdm(items, desc=description): try: result = func(item) if result is not None: results.append(result) except Exception as e: print(f"Error processing item: {str(e)}") print(f"Completed {description}. Processed {len(results)} items successfully") return results def _get_cache_path(self, data: List[Dict]) -> Path: """ Generate cache file path based on data hash. """ data_str = json.dumps(data, sort_keys=True) hash_value = hashlib.md5(data_str.encode()).hexdigest() return self.cache_dir / f"cache_{hash_value}.pkl"