|
from dataclasses import dataclass, field |
|
from typing import List, Dict, Optional, Any |
|
import json |
|
import re |
|
from pathlib import Path |
|
from pipeline_config import PipelineConfig |
|
|
|
@dataclass |
|
class TaskmasterDialogue: |
|
""" |
|
Structured representation of a Taskmaster dialogue |
|
""" |
|
conversation_id: str |
|
instruction_id: Optional[str] |
|
scenario: Optional[str] |
|
domain: Optional[str] |
|
turns: List[Dict[str, Any]] |
|
original_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
def __str__(self): |
|
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)" |
|
|
|
def validate(self) -> bool: |
|
return bool(self.conversation_id and isinstance(self.turns, list)) |
|
|
|
class TaskmasterProcessor: |
|
""" |
|
Handles processing and preparation of Taskmaster dataset dialogues |
|
""" |
|
config: PipelineConfig |
|
use_ontology: bool = False |
|
ontology: Optional[Dict[str, Any]] = None |
|
domains: set = field(default_factory=set) |
|
scenarios: set = field(default_factory=set) |
|
|
|
def __init__(self, config: PipelineConfig, use_ontology: bool = False): |
|
self.config = config |
|
self.use_ontology = use_ontology |
|
self.ontology = None |
|
self.domains = set() |
|
self.scenarios = set() |
|
|
|
def load_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]: |
|
""" |
|
Load and parse Taskmaster JSON dataset. |
|
Handles self-dialogs, woz-dialogs, and ontology files. |
|
""" |
|
required_files = { |
|
"self-dialogs": "self-dialogs.json", |
|
"woz-dialogs": "woz-dialogs.json", |
|
"ontology": "ontology.json", |
|
} |
|
|
|
|
|
missing_files = [name for name, path in required_files.items() if not Path(base_dir, path).exists()] |
|
if missing_files: |
|
raise FileNotFoundError(f"Missing required taskmaster files: {missing_files}") |
|
|
|
|
|
ontology_path = Path(base_dir, required_files['ontology']) |
|
with open(ontology_path, 'r', encoding='utf-8') as f: |
|
self.ontology = json.load(f) |
|
|
|
processed_dialogues = [] |
|
for file_key in ["self-dialogs", "woz-dialogs"]: |
|
file_path = Path(base_dir, required_files[file_key]) |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
raw_data = json.load(f) |
|
|
|
for dialogue in raw_data: |
|
|
|
conversation_id = dialogue.get('conversation_id', '') |
|
instruction_id = dialogue.get('instruction_id', None) |
|
|
|
if 'utterances' in dialogue: |
|
turns = self._process_utterances(dialogue['utterances']) |
|
scenario = dialogue.get('scenario', '') |
|
domain = self._extract_domain(scenario) |
|
else: |
|
turns = [] |
|
scenario = '' |
|
domain = '' |
|
|
|
|
|
metadata = {k: v for k, v in dialogue.items() |
|
if k not in {'conversation_id', 'instruction_id', 'utterances'}} |
|
|
|
|
|
processed_dialogue = TaskmasterDialogue( |
|
conversation_id=conversation_id, |
|
instruction_id=instruction_id, |
|
scenario=scenario, |
|
domain=domain, |
|
turns=turns, |
|
original_metadata=metadata |
|
) |
|
|
|
processed_dialogues.append(processed_dialogue) |
|
|
|
|
|
if domain: |
|
self.domains.add(domain) |
|
if scenario: |
|
self.scenarios.add(scenario) |
|
|
|
if max_examples and len(processed_dialogues) >= max_examples: |
|
break |
|
|
|
return processed_dialogues |
|
|
|
def _process_utterances(self, utterances: List[Dict]) -> List[Dict]: |
|
""" |
|
Process utterances into a standardized format |
|
""" |
|
processed_turns = [] |
|
|
|
for utterance in utterances: |
|
|
|
speaker = 'assistant' if utterance.get('speaker') == 'ASSISTANT' else 'user' |
|
|
|
|
|
text = utterance.get('text', '').strip() |
|
|
|
|
|
segments = utterance.get('segments', []) |
|
|
|
|
|
turn = { |
|
'speaker': speaker, |
|
'text': text, |
|
'original_speaker': utterance.get('speaker', ''), |
|
'segments': segments, |
|
'metadata': {k: v for k, v in utterance.items() |
|
if k not in {'speaker', 'text', 'segments'}} |
|
} |
|
|
|
processed_turns.append(turn) |
|
|
|
return processed_turns |
|
|
|
def _extract_domain(self, scenario: str) -> str: |
|
""" |
|
Extract domain from scenario description |
|
""" |
|
domain_patterns = { |
|
'restaurant': r'\b(restaurant|dining|food|reservation)\b', |
|
'movie': r'\b(movie|cinema|film|ticket)\b', |
|
'ride_share': r'\b(ride|taxi|uber|lyft)\b', |
|
'coffee': r'\b(coffee|café|cafe|starbucks)\b', |
|
'pizza': r'\b(pizza|delivery|order food)\b', |
|
'auto': r'\b(car|vehicle|repair|maintenance)\b', |
|
} |
|
|
|
scenario_lower = scenario.lower() |
|
|
|
for domain, pattern in domain_patterns.items(): |
|
if re.search(pattern, scenario_lower): |
|
return domain |
|
|
|
return 'other' |
|
|
|
def convert_to_pipeline_format(self, taskmaster_dialogues: List[TaskmasterDialogue]) -> List[Dict]: |
|
""" |
|
Convert TaskmasterDialogues to the format expected by the ProcessingPipeline |
|
""" |
|
pipeline_dialogues = [] |
|
|
|
for dialogue in taskmaster_dialogues: |
|
|
|
processed_turns = [] |
|
for turn in dialogue.turns: |
|
if turn['text'].strip(): |
|
processed_turns.append({ |
|
'speaker': turn['speaker'], |
|
'text': turn['text'] |
|
}) |
|
|
|
|
|
pipeline_dialogue = { |
|
'dialogue_id': dialogue.conversation_id, |
|
'turns': processed_turns, |
|
'metadata': { |
|
'instruction_id': dialogue.instruction_id, |
|
'scenario': dialogue.scenario, |
|
'domain': dialogue.domain, |
|
**dialogue.original_metadata |
|
} |
|
} |
|
|
|
pipeline_dialogues.append(pipeline_dialogue) |
|
|
|
return pipeline_dialogues |