|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List, Dict, Optional |
|
import json |
|
import re |
|
import hashlib |
|
import spacy |
|
import torch |
|
from tqdm import tqdm |
|
from pipeline_config import PipelineConfig |
|
from dialogue_augmenter import DialogueAugmenter |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from typing import Set |
|
|
|
class ProcessingPipeline: |
|
""" |
|
Complete pipeline combining validation, optimization, and augmentation. |
|
""" |
|
|
|
def __init__(self, config: Optional[PipelineConfig] = None): |
|
self.config = config or PipelineConfig() |
|
self.nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner']) |
|
self.augmenter = DialogueAugmenter(self.nlp, self.config) |
|
self.num_threads = self.config.batch_size |
|
self.cache_dir = Path("./cache") |
|
self.cache_dir.mkdir(exist_ok=True) |
|
self.output_dir = Path("processed_outputs") |
|
self.output_dir.mkdir(exist_ok=True) |
|
self.checkpoint_file = self.output_dir / "processing_checkpoint.json" |
|
self.batch_size = self.config.batch_size |
|
self.use_gpu = torch.cuda.is_available() |
|
self.batch_size = 32 if self.use_gpu else 8 |
|
self.use_multiprocessing = not self.use_gpu |
|
|
|
|
|
self.batch_counter = 0 |
|
self.batch_group_number = 0 |
|
|
|
if self.config.debug: |
|
print(f"ProcessingPipeline initialized with:") |
|
print(f"- GPU available: {self.use_gpu}") |
|
print(f"- Batch size: {self.batch_size}") |
|
print(f"- Using multiprocessing: {self.use_multiprocessing}") |
|
|
|
def _save_batch(self, batch_results: List[Dict], batch_num: int) -> Path: |
|
"""Save a batch of results to a separate JSON file""" |
|
batch_file = self.output_dir / f"batch_{batch_num:04d}.json" |
|
with open(batch_file, 'w') as f: |
|
json.dump(batch_results, f) |
|
return batch_file |
|
|
|
def _load_checkpoint(self) -> set: |
|
"""Load set of processed dialogue IDs from checkpoint""" |
|
if self.checkpoint_file.exists(): |
|
with open(self.checkpoint_file, 'r') as f: |
|
return set(json.load(f)) |
|
return set() |
|
|
|
def _update_checkpoint(self, processed_ids: set): |
|
"""Update checkpoint with newly processed IDs""" |
|
with open(self.checkpoint_file, 'w') as f: |
|
json.dump(list(processed_ids), f) |
|
|
|
def _process_batch(self, batch: List[Dict]) -> List[Dict]: |
|
"""Process batch with optimized model calls""" |
|
results = [] |
|
try: |
|
if self.use_gpu: |
|
results = self.augmenter.process_batch(batch) |
|
else: |
|
|
|
all_texts = [] |
|
text_to_dialogue_map = {} |
|
for dialogue in batch: |
|
for turn in dialogue['turns']: |
|
all_texts.append(turn['text']) |
|
text_to_dialogue_map[turn['text']] = dialogue['dialogue_id'] |
|
|
|
|
|
self.augmenter._compute_batch_embeddings(all_texts) |
|
|
|
|
|
for dialogue in batch: |
|
try: |
|
augmented = self.augmenter.augment_dialogue(dialogue) |
|
results.extend(augmented) |
|
except Exception as e: |
|
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") |
|
continue |
|
except Exception as e: |
|
print(f"Error processing batch: {str(e)}") |
|
return results |
|
|
|
def _combine_intermediate_batches(self): |
|
""" |
|
Combine all current batch_*.json files into a single batch_group_XXXX.json file, |
|
then remove the batch_*.json files. |
|
""" |
|
batch_files = sorted(self.output_dir.glob("batch_*.json")) |
|
if not batch_files: |
|
return None |
|
|
|
combined_data = [] |
|
for bf in batch_files: |
|
with open(bf, 'r') as f: |
|
combined_data.extend(json.load(f)) |
|
bf.unlink() |
|
|
|
self.batch_group_number += 1 |
|
group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json" |
|
with open(group_file, 'w') as f: |
|
json.dump(combined_data, f) |
|
return group_file |
|
|
|
def combine_results(self) -> Path: |
|
"""Combine all batch_group_*.json files into final output""" |
|
all_results = [] |
|
group_files = sorted(self.output_dir.glob("batch_group_*.json")) |
|
|
|
print(f"Combining {len(group_files)} group files...") |
|
for group_file in tqdm(group_files): |
|
with open(group_file, 'r') as f: |
|
group_data = json.load(f) |
|
all_results.extend(group_data) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
final_output = self.output_dir / f"augmented_dataset_{timestamp}.json" |
|
with open(final_output, 'w') as f: |
|
json.dump(all_results, f) |
|
|
|
if self.config.debug: |
|
print(f"Combined {len(all_results)} dialogues into {final_output}") |
|
|
|
return final_output |
|
|
|
def process_dataset(self, dialogues: List[Dict]) -> Path: |
|
"""Process dataset with hardware-appropriate optimizations and progress tracking""" |
|
processed_ids = self._load_checkpoint() |
|
|
|
|
|
remaining_dialogues = [d for d in dialogues |
|
if d['dialogue_id'] not in processed_ids] |
|
|
|
total_dialogues = len(dialogues) |
|
remaining_count = len(remaining_dialogues) |
|
processed_count = total_dialogues - remaining_count |
|
|
|
print("\nDataset Processing Status:") |
|
print(f"Total dialogues in dataset: {total_dialogues}") |
|
print(f"Previously processed: {processed_count}") |
|
print(f"Remaining to process: {remaining_count}") |
|
print("-" * 50) |
|
|
|
|
|
for batch_num in tqdm(range(0, len(remaining_dialogues), self.batch_size), |
|
desc="Processing batches", |
|
total=(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size): |
|
batch = remaining_dialogues[batch_num:batch_num + self.batch_size] |
|
current_position = processed_count + batch_num + len(batch) |
|
|
|
total_progress = (current_position / total_dialogues) * 100 |
|
|
|
print('\033[K', end='') |
|
print(f"Processing: {current_position}/{total_dialogues} dialogues " |
|
f"({total_progress:.1f}% complete)") |
|
print(f"Current batch: {batch_num//self.batch_size + 1} of " |
|
f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}") |
|
print("-" * 50) |
|
|
|
|
|
batch_results = self._process_batch(batch) |
|
|
|
if batch_results: |
|
self._save_batch(batch_results, batch_num) |
|
batch_ids = {d['dialogue_id'] for d in batch} |
|
processed_ids.update(batch_ids) |
|
self._update_checkpoint(processed_ids) |
|
|
|
|
|
self.batch_counter += 1 |
|
if self.batch_counter == 25: |
|
|
|
self._combine_intermediate_batches() |
|
self.batch_counter = 0 |
|
|
|
|
|
|
|
if self.batch_counter > 0: |
|
self._combine_intermediate_batches() |
|
self.batch_counter = 0 |
|
|
|
print("\n" + "-" * 50) |
|
print("Processing complete. Combining results...") |
|
return self.combine_results() |
|
|
|
def cleanup(self): |
|
"""Clean up intermediate files after successful processing""" |
|
|
|
batch_files = list(self.output_dir.glob("batch_*.json")) |
|
for file in batch_files: |
|
try: |
|
file.unlink() |
|
except Exception as e: |
|
print(f"Error deleting {file}: {e}") |
|
|
|
|
|
|
|
|
|
if self.checkpoint_file.exists(): |
|
try: |
|
self.checkpoint_file.unlink() |
|
except Exception as e: |
|
print(f"Error deleting checkpoint file: {e}") |
|
|
|
def _deduplicate_dialogues(self, dialogues: List[Dict], threshold: float = 0.9) -> List[Dict]: |
|
""" |
|
Deduplicate dialogues based on text similarity. |
|
""" |
|
print("Deduplicating dialogues...") |
|
if not dialogues: |
|
print("No dialogues provided for deduplication.") |
|
return [] |
|
|
|
|
|
texts = [" ".join(turn['text'] for turn in dialogue['turns']) for dialogue in dialogues] |
|
tfidf = TfidfVectorizer().fit_transform(texts) |
|
sim_matrix = cosine_similarity(tfidf) |
|
|
|
unique_indices = set() |
|
for i, row in enumerate(sim_matrix): |
|
if i not in unique_indices: |
|
similar_indices = [j for j, sim in enumerate(row) if sim > threshold and j != i] |
|
unique_indices.add(i) |
|
unique_indices.difference_update(similar_indices) |
|
|
|
deduplicated_dialogues = [dialogues[i] for i in unique_indices] |
|
|
|
print(f"Deduplication complete. Reduced from {len(dialogues)} to {len(deduplicated_dialogues)} dialogues.") |
|
return deduplicated_dialogues |
|
|
|
def _validate_and_clean_dialogue(self, dialogue: Dict) -> Optional[Dict]: |
|
""" |
|
Validate and clean a single dialogue. |
|
""" |
|
try: |
|
|
|
if not all(field in dialogue for field in self.config.required_fields): |
|
return None |
|
|
|
|
|
cleaned_turns = [] |
|
for turn in dialogue['turns']: |
|
if self._validate_turn(turn): |
|
cleaned_turn = { |
|
'speaker': turn['speaker'], |
|
'text': self._clean_text(turn['text']) |
|
} |
|
cleaned_turns.append(cleaned_turn) |
|
|
|
if cleaned_turns: |
|
return { |
|
'dialogue_id': dialogue['dialogue_id'], |
|
'turns': cleaned_turns |
|
} |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {str(e)}") |
|
return None |
|
|
|
def _validate_turn(self, turn: Dict) -> bool: |
|
""" |
|
Validate a single speaking turn. |
|
""" |
|
return ( |
|
turn['speaker'] in self.config.allowed_speakers and |
|
self.config.min_length <= len(turn['text']) <= self.config.max_length |
|
) |
|
|
|
def _clean_text(self, text: str) -> str: |
|
""" |
|
Clean and normalize text. |
|
""" |
|
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
|
|
text = re.sub(r'[βΒ΄`]', "'", text) |
|
text = re.sub(r'[ββ]', '"', text) |
|
|
|
|
|
text = "".join(char for char in text if ord(char) >= 32 or char == '\n') |
|
|
|
return text |
|
|
|
def _process_validation(self, items: List, func, description: str) -> List: |
|
""" |
|
Process items sequentially with a progress bar. |
|
""" |
|
results = [] |
|
print(f"Starting {description}") |
|
for item in tqdm(items, desc=description): |
|
try: |
|
result = func(item) |
|
if result is not None: |
|
results.append(result) |
|
except Exception as e: |
|
print(f"Error processing item: {str(e)}") |
|
print(f"Completed {description}. Processed {len(results)} items successfully") |
|
return results |
|
|
|
def _get_cache_path(self, data: List[Dict]) -> Path: |
|
""" |
|
Generate cache file path based on data hash. |
|
""" |
|
data_str = json.dumps(data, sort_keys=True) |
|
hash_value = hashlib.md5(data_str.encode()).hexdigest() |
|
return self.cache_dir / f"cache_{hash_value}.pkl" |
|
|