JoeArmani
restructuring
71ca212
raw
history blame
7.42 kB
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
import json
import re
from pathlib import Path
from data_augmentation.pipeline_config import PipelineConfig
@dataclass
class TaskmasterDialogue:
"""
Structured representation of a Taskmaster dialogue
"""
conversation_id: str
instruction_id: Optional[str]
scenario: Optional[str]
domain: Optional[str]
turns: List[Dict[str, Any]]
original_metadata: Dict[str, Any] = field(default_factory=dict)
def __str__(self):
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
def validate(self) -> bool:
return bool(self.conversation_id and isinstance(self.turns, list))
class TaskmasterProcessor:
"""
Handles processing and preparation of Taskmaster dataset dialogues
"""
config: PipelineConfig
use_ontology: bool = False # Whether to load and use ontology
ontology: Optional[Dict[str, Any]] = None # Holds ontology data if loaded
domains: set = field(default_factory=set) # Tracks unique domains
scenarios: set = field(default_factory=set) # Tracks unique scenarios
def __init__(self, config: PipelineConfig, use_ontology: bool = False):
self.config = config
self.use_ontology = use_ontology
self.ontology = None
self.domains = set()
self.scenarios = set()
def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
"""
Load and parse Taskmaster JSON dataset.
Handles self-dialogs, woz-dialogs, and ontology files.
"""
required_files = {
"self-dialogs": "self-dialogs.json",
"woz-dialogs": "woz-dialogs.json",
"ontology": "ontology.json",
}
# Check for required files
missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()]
if missing_files:
raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}")
# load ontology
ontology_path = Path(base_dir, required_files['ontology'])
with open(ontology_path, 'r', encoding='utf-8') as f:
self.ontology = json.load(f)
processed_dialogues = []
for file_key in ["self-dialogs", "woz-dialogs"]:
file_path = Path(base_dir, required_files[file_key])
with open(file_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
for dialogue in raw_data:
# Extract core dialogue components
conversation_id = dialogue.get('conversation_id', '')
instruction_id = dialogue.get('instruction_id', None)
if 'utterances' in dialogue:
turns = self._process_utterances(dialogue['utterances'])
scenario = dialogue.get('scenario', '')
domain = self._extract_domain(scenario)
else:
turns = []
scenario = ''
domain = ''
# Store metadata
metadata = {k: v for k, v in dialogue.items()
if k not in {'conversation_id', 'instruction_id', 'utterances'}}
# Create structured dialogue object
processed_dialogue = TaskmasterDialogue(
conversation_id=conversation_id,
instruction_id=instruction_id,
scenario=scenario,
domain=domain,
turns=turns,
original_metadata=metadata
)
processed_dialogues.append(processed_dialogue)
# Update domain and scenario tracking
if domain:
self.domains.add(domain)
if scenario:
self.scenarios.add(scenario)
if max_examples and len(processed_dialogues) >= max_examples:
break
return processed_dialogues
def _process_utterances(self, utterances: List[Dict]) -> List[Dict]:
"""
Process utterances into a standardized format
"""
processed_turns = []
for utterance in utterances:
# Map Taskmaster speaker roles to your expected format
speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user'
# Extract and clean the text
text = utterance.get('text', '').strip()
# Extract any segments or annotations if present
segments = utterance.get('segments', [])
# Create the processed turn
turn = {
'speaker': speaker,
'text': text,
'original_speaker': utterance.get('speaker', ''),
'segments': segments,
'metadata': {k: v for k, v in utterance.items()
if k not in {'speaker', 'text', 'segments'}}
}
processed_turns.append(turn)
return processed_turns
def _extract_domain(self, scenario: str) -> str:
"""
Extract domain from scenario description
"""
domain_patterns = {
'restaurant': r'\b(restaurant|dining|food|reservation)\b',
'movie': r'\b(movie|cinema|film|ticket)\b',
'ride_share': r'\b(ride|taxi|uber|lyft)\b',
'coffee': r'\b(coffee|café|cafe|starbucks)\b',
'pizza': r'\b(pizza|delivery|order food)\b',
'auto': r'\b(car|vehicle|repair|maintenance)\b',
}
scenario_lower = scenario.lower()
for domain, pattern in domain_patterns.items():
if re.search(pattern, scenario_lower):
return domain
return 'other'
def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]:
"""
Convert TaskmasterDialogues to the format expected by the ProcessingPipeline
"""
pipeline_dialogues = []
for dialogue in taskmaster_dialogues:
# Convert turns to the expected format
processed_turns = []
for turn in dialogue.turns:
if turn['text'].strip(): # Skip empty turns
processed_turns.append({
'speaker': turn['speaker'],
'text': turn['text']
})
# Create dialogue in pipeline format
pipeline_dialogue = {
'dialogue_id': dialogue.conversation_id,
'turns': processed_turns,
'metadata': {
'instruction_id': dialogue.instruction_id,
'scenario': dialogue.scenario,
'domain': dialogue.domain,
**dialogue.original_metadata
}
}
pipeline_dialogues.append(pipeline_dialogue)
return pipeline_dialogues