|
import os |
|
import re |
|
import json |
|
from pathlib import Path |
|
from typing import List, Dict, Optional, Any |
|
from dataclasses import dataclass, field |
|
|
|
@dataclass |
|
class TaskmasterDialogue: |
|
conversation_id: str |
|
instruction_id: Optional[str] |
|
scenario: Optional[str] |
|
domain: Optional[str] |
|
turns: List[Dict[str, Any]] |
|
original_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
def __str__(self): |
|
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)" |
|
|
|
def validate(self) -> bool: |
|
return bool(self.conversation_id and isinstance(self.turns, list)) |
|
|
|
class PipelineConfig: |
|
""" |
|
Example config structure. Adjust to your real config usage. |
|
""" |
|
def __init__( |
|
self, |
|
debug: bool = True, |
|
min_turns: int = 2, |
|
min_user_words: int = 3 |
|
): |
|
self.debug = debug |
|
self.min_turns = min_turns |
|
self.min_user_words = min_user_words |
|
|
|
class TaskmasterProcessor: |
|
""" |
|
Loads Taskmaster-1 dialogues, extracts domain from scenario, |
|
cleans + filters them, and outputs a pipeline-friendly format. |
|
""" |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
|
|
def load_taskmaster_dataset( |
|
self, |
|
base_dir: str, |
|
max_examples: Optional[int] = None |
|
) -> List[TaskmasterDialogue]: |
|
""" |
|
Load and parse Taskmaster JSON for self-dialogs & woz-dialogs (Taskmaster-1). |
|
Combines scenario text + conversation utterances to detect domain more robustly. |
|
""" |
|
required_files = { |
|
"self-dialogs": "self-dialogs.json", |
|
"woz-dialogs": "woz-dialogs.json", |
|
"ontology": "ontology.json", |
|
} |
|
|
|
missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()] |
|
if missing: |
|
raise FileNotFoundError(f"Missing Taskmaster files: {missing}") |
|
|
|
|
|
ontology_path = Path(base_dir, required_files["ontology"]) |
|
with open(ontology_path, 'r', encoding='utf-8') as f: |
|
ontology = json.load(f) |
|
if self.config.debug: |
|
print(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).") |
|
|
|
dialogues: List[TaskmasterDialogue] = [] |
|
|
|
file_keys = ["self-dialogs", "woz-dialogs"] |
|
for file_key in file_keys: |
|
file_path = Path(base_dir, required_files[file_key]) |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
raw_data = json.load(f) |
|
|
|
for d in raw_data: |
|
conversation_id = d.get("conversation_id", "") |
|
instruction_id = d.get("instruction_id", None) |
|
scenario_text = d.get("scenario", "") |
|
|
|
|
|
utterances = d.get("utterances", []) |
|
turns = self._process_utterances(utterances) |
|
|
|
|
|
domain = self._extract_domain(scenario_text, turns) |
|
|
|
|
|
new_dlg = TaskmasterDialogue( |
|
conversation_id=conversation_id, |
|
instruction_id=instruction_id, |
|
scenario=scenario_text, |
|
domain=domain, |
|
turns=turns, |
|
original_metadata={} |
|
) |
|
dialogues.append(new_dlg) |
|
|
|
if max_examples and len(dialogues) >= max_examples: |
|
break |
|
|
|
if self.config.debug: |
|
print(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.") |
|
return dialogues |
|
|
|
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str: |
|
""" |
|
Combine scenario text + all turn texts to detect domain more robustly. |
|
""" |
|
combined_text = scenario.lower() |
|
for turn in turns: |
|
txt = turn.get('text', '').lower() |
|
combined_text += " " + txt |
|
|
|
|
|
domain_patterns = { |
|
'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat|hungry)\b', |
|
'movie': r'\b(movie|cinema|film|ticket|showtime|theater|flick|screening)\b', |
|
'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver)\b', |
|
'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b', |
|
'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice)\b', |
|
'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b' |
|
} |
|
|
|
for dom, pattern in domain_patterns.items(): |
|
if re.search(pattern, combined_text): |
|
|
|
if self.config.debug: |
|
print(f"Matched domain: {dom} in scenario/turns") |
|
return dom |
|
|
|
if self.config.debug: |
|
print("No domain match, returning 'other'") |
|
return 'other' |
|
|
|
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]: |
|
""" |
|
Convert raw utterances to a cleaned list of (speaker, text). |
|
Skip or remove lines that are numeric, too short, or empty. |
|
""" |
|
cleaned_turns = [] |
|
for utt in utterances: |
|
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user' |
|
raw_text = utt.get('text', '').strip() |
|
|
|
|
|
text = self._clean_text(raw_text) |
|
|
|
|
|
if not text: |
|
continue |
|
if self._is_numeric_line(text): |
|
continue |
|
|
|
|
|
|
|
if len(text.split()) < 2: |
|
|
|
continue |
|
|
|
|
|
cleaned_turns.append({ |
|
'speaker': speaker, |
|
'text': text |
|
}) |
|
return cleaned_turns |
|
|
|
def _clean_text(self, text: str) -> str: |
|
""" |
|
Basic text normalization: remove repeated punctuation, handle weird spacing, etc. |
|
Adjust to your needs. |
|
""" |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = re.sub(r'([!?.,])\1+', r'\1', text) |
|
return text.strip() |
|
|
|
def _is_numeric_line(self, text: str) -> bool: |
|
""" |
|
Return True if line is purely digits/punctuation/spaces, |
|
e.g. "4 3 13", "12345", "3.14". Adjust as needed. |
|
""" |
|
pattern = r'^[\s]*[\d]+([\s\d.,]+)*[\s]*$' |
|
return bool(re.match(pattern, text)) |
|
|
|
def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]: |
|
""" |
|
Filter out dialogues that don't meet min turns / min user words, |
|
then convert them to final pipeline format: |
|
|
|
{ |
|
"dialogue_id": "...", |
|
"domain": "...", |
|
"turns": [ {"speaker": "user", "text": "..."}, ... ] |
|
} |
|
""" |
|
results = [] |
|
for dlg in dialogues: |
|
if not dlg.validate(): |
|
continue |
|
|
|
|
|
if len(dlg.turns) < self.config.min_turns: |
|
continue |
|
|
|
|
|
|
|
keep = True |
|
for turn in dlg.turns: |
|
if turn['speaker'] == 'user': |
|
words_count = len(turn['text'].split()) |
|
if words_count < self.config.min_user_words: |
|
keep = False |
|
break |
|
|
|
if not keep: |
|
continue |
|
|
|
pipeline_dlg = { |
|
'dialogue_id': dlg.conversation_id, |
|
'domain': dlg.domain, |
|
'turns': dlg.turns |
|
} |
|
results.append(pipeline_dlg) |
|
|
|
if self.config.debug: |
|
print(f"[TaskmasterProcessor] Filtered down to {len(results)} dialogues after cleaning.") |
|
return results |