|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List, Dict, Optional |
|
import json |
|
import re |
|
import hashlib |
|
import pickle |
|
import spacy |
|
from tqdm import tqdm |
|
from pipeline_config import PipelineConfig |
|
from dialogue_augmenter import DialogueAugmenter |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class ProcessingPipeline: |
|
""" |
|
Complete pipeline combining validation, optimization, and augmentation. |
|
""" |
|
|
|
def __init__(self, config: Optional[PipelineConfig] = None): |
|
self.config = config or PipelineConfig() |
|
self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner']) |
|
self.augmenter = DialogueAugmenter(self.nlp, self.config) |
|
self.num_threads = self.config.batch_size |
|
self.cache_dir = Path("./cache") |
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
def process_dataset(self, dialogues: List[Dict]) -> List[Dict]: |
|
""" |
|
Process entire dataset through the pipeline. |
|
""" |
|
print(f"Processing {len(dialogues)} dialogues") |
|
start_time = datetime.now() |
|
|
|
|
|
if self.config.use_cache: |
|
cache_path = self._get_cache_path(dialogues) |
|
if cache_path.exists(): |
|
print("Loading from cache...") |
|
with open(cache_path, 'rb') as f: |
|
return pickle.load(f) |
|
|
|
|
|
valid_dialogues = self._process_validation( |
|
dialogues, |
|
self._validate_and_clean_dialogue, |
|
"validating and cleaning" |
|
) |
|
|
|
if not valid_dialogues: |
|
raise ValueError("Dialogue validation resulted in an empty dataset.") |
|
|
|
deduplicated_dialogues = self._deduplicate_dialogues(valid_dialogues) |
|
|
|
|
|
all_processed_dialogues = [] |
|
for dialogue in deduplicated_dialogues: |
|
augmented = self.augmenter.augment_dialogue(dialogue) |
|
all_processed_dialogues.extend(augmented) |
|
|
|
|
|
if self.config.use_cache: |
|
with open(cache_path, 'wb') as f: |
|
pickle.dump(all_processed_dialogues, f) |
|
|
|
processing_time = datetime.now() - start_time |
|
print(f"Processing completed in {processing_time}") |
|
print(f"Generated {len(all_processed_dialogues)} total dialogues") |
|
|
|
return all_processed_dialogues |
|
|
|
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]: |
|
""" |
|
Deduplicate dialogues based on text similarity. |
|
""" |
|
print("Deduplicating dialogues...") |
|
if not dialogues: |
|
print("No dialogues provided for deduplication.") |
|
return [] |
|
|
|
|
|
texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues] |
|
tfidf = TfidfVectorizer().fit_transform(texts) |
|
sim_matrix = cosine_similarity(tfidf) |
|
|
|
unique_indices = set() |
|
for i, row in enumerate(sim_matrix): |
|
if i not in unique_indices: |
|
similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i] |
|
unique_indices.add(i) |
|
unique_indices.difference_update(similar_indices) |
|
|
|
deduplicated_dialogues = [dialogues[i] for i in unique_indices] |
|
|
|
print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.") |
|
return deduplicated_dialogues |
|
|
|
def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]: |
|
""" |
|
Validate and clean a single dialogue. |
|
""" |
|
try: |
|
|
|
if not all(field in dialogue for field in self.config.required_fields): |
|
return None |
|
|
|
|
|
cleaned_turns = [] |
|
for turn in dialogue['turns']: |
|
if self._validate_turn(turn): |
|
cleaned_turn = { |
|
'speaker': turn['speaker'], |
|
'text': self._clean_text(turn['text']) |
|
} |
|
cleaned_turns.append(cleaned_turn) |
|
|
|
if cleaned_turns: |
|
return { |
|
'dialogue_id': dialogue['dialogue_id'], |
|
'turns': cleaned_turns |
|
} |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") |
|
return None |
|
|
|
def _validate_turn(self, turn: Dict) -> bool: |
|
""" |
|
Validate a single speaking turn. |
|
""" |
|
return ( |
|
turn['speaker'] in self.config.allowed_speakers and |
|
self.config.min_length <= len(turn['text']) <= self.config.max_length |
|
) |
|
|
|
def _clean_text(self, text: str) -> str: |
|
""" |
|
Clean and normalize text. |
|
""" |
|
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
|
|
text = re.sub(r'[βΒ΄`]', "'", text) |
|
text = re.sub(r'[ββ]', '"', text) |
|
|
|
|
|
text = "".join(char for char in text if ord(char) >= 32 or char == '\n') |
|
|
|
return text |
|
|
|
def _process_validation(self, items: List, func, description: str) -> List: |
|
""" |
|
Process items sequentially with a progress bar. |
|
""" |
|
results = [] |
|
print(f"Starting {description}") |
|
for item in tqdm(items, desc=description): |
|
try: |
|
result = func(item) |
|
if result is not None: |
|
results.append(result) |
|
except Exception as e: |
|
print(f"Error processing item: {str(e)}") |
|
print(f"Completed {description}. Processed {len(results)} items successfully") |
|
return results |
|
|
|
def _get_cache_path(self, data: List[Dict]) -> Path: |
|
""" |
|
Generate cache file path based on data hash. |
|
""" |
|
data_str = json.dumps(data, sort_keys=True) |
|
hash_value = hashlib.md5(data_str.encode()).hexdigest() |
|
return self.cache_dir / f"cache_{hash_value}.pkl" |