JoeArmani
style refinements
c111c20
raw
history blame
8.19 kB
import os
import re
import json
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
@dataclass
class TaskmasterDialogue:
conversation_id: str
instruction_id: Optional[str]
scenario: Optional[str]
domain: Optional[str]
turns: List[Dict[str, Any]]
original_metadata: Dict[str, Any] = field(default_factory=dict)
def __str__(self):
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
def validate(self) -> bool:
return bool(self.conversation_id and isinstance(self.turns, list))
class RawDataProcessingConfig:
"""
Simple config for raw dataset processing
"""
def __init__(
self,
debug: bool = True,
max_length: int = 512,
min_turns: int = 2,
min_user_words: int = 3
):
self.debug = debug
self.max_length = max_length
self.min_turns = min_turns
self.min_user_words = min_user_words
class TaskmasterProcessor:
"""
Load Taskmaster-1 dialogues, extracts domain.
Clean, filter, save to pipeline format.
"""
def __init__(self, config: RawDataProcessingConfig):
self.config = config
def load_taskmaster_dataset(
self,
base_dir: str,
max_examples: Optional[int] = None
) -> List[TaskmasterDialogue]:
"""
Load & parse Taskmaster-1 JSON for self-dialogs & woz-dialogs.
"""
required_files = {
"self-dialogs": "self-dialogs.json",
"woz-dialogs": "woz-dialogs.json",
"ontology": "ontology.json",
}
# Check for missing files
missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()]
if missing:
raise FileNotFoundError(f"Missing Taskmaster files: {missing}")
# Load ontology
ontology_path = Path(base_dir, required_files["ontology"])
with open(ontology_path, 'r', encoding='utf-8') as f:
ontology = json.load(f)
if self.config.debug:
print(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
dialogues: List[TaskmasterDialogue] = []
# Process each file
file_keys = ["self-dialogs", "woz-dialogs"]
for file_key in file_keys:
file_path = Path(base_dir, required_files[file_key])
with open(file_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
for d in raw_data:
conversation_id = d.get("conversation_id", "")
instruction_id = d.get("instruction_id", None)
scenario_text = d.get("scenario", "")
# Handle utterances
utterances = d.get("utterances", [])
turns = self._process_utterances(utterances)
# Detect Domain
domain = self._extract_domain(scenario_text, turns)
# Build the object
new_dlg = TaskmasterDialogue(
conversation_id=conversation_id,
instruction_id=instruction_id,
scenario=scenario_text,
domain=domain,
turns=turns,
original_metadata={}
)
dialogues.append(new_dlg)
if max_examples and len(dialogues) >= max_examples:
break
if self.config.debug:
print(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
return dialogues
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
"""
Combine scenario text + all turn texts to detect domain more robustly.
"""
combined_text = scenario.lower()
for turn in turns:
txt = turn.get('text', '').lower()
combined_text += " " + txt
# Domain patterns
domain_patterns = {
'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat|hungry)\b',
'movie': r'\b(movie|cinema|film|ticket|showtime|theater|flick|screening)\b',
'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver)\b',
'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b',
'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice)\b',
'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b'
}
for domain, pattern in domain_patterns.items():
if re.search(pattern, combined_text):
# Optional: print if debug
if self.config.debug:
print(f"Matched domain: {domain} in scenario/turns")
return domain
if self.config.debug:
print("No domain match, returning 'other'")
return 'other'
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""
Convert "utterances" to a cleaned List -> (speaker, text).
Skip lines that are numeric, too short, or empty.
"""
cleaned_turns = []
for utt in utterances:
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
raw_text = utt.get('text', '').strip()
# Text cleaning
text = self._clean_text(raw_text)
# Skip blank or numeric lines (e.g. "4 3 13")
if not text or self._is_numeric_line(text):
continue
# Skip too short (no training benefit from 1-word user turns). E.g. "ok","yes", etc.
if len(text.split()) < 3:
continue
# Add to cleaned turns
cleaned_turns.append({
'speaker': speaker,
'text': text
})
return cleaned_turns
def _clean_text(self, text: str) -> str:
"""
Simple text normalization
"""
# Strip multiple spaces, remove unnecessary punctuation
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'([!?.,])\1+', r'\1', text)
return text.strip()
def _is_numeric_line(self, text: str) -> bool:
"""
Return True if line is purely digits/punctuation/spaces,
e.g. "4 3 13" and similar found in Taskmaster-1 dataset.
"""
pattern = r'^[\s]*[\d]+([\s\d.,]+)*[\s]*$'
return bool(re.match(pattern, text))
def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]:
"""
Filter out dialogues that don't meet min length requirements. Convert to pipeline format.
{
"dialogue_id": "...",
"domain": "...",
"turns": [ {"speaker": "user", "text": "..."}, ... ]
}
"""
results = []
for dlg in dialogues:
if not dlg.validate():
continue
# Skip if too few turns
if len(dlg.turns) < self.config.min_turns:
continue
# Skip if any user turn is too short
keep = True
for turn in dlg.turns:
if turn['speaker'] == 'user':
words_count = len(turn['text'].split())
if words_count < self.config.min_user_words:
keep = False
break
if not keep:
continue
pipeline_dlg = {
'dialogue_id': dlg.conversation_id,
'domain': dlg.domain,
'turns': dlg.turns # already cleaned
}
results.append(pipeline_dlg)
if self.config.debug:
print(f"[TaskmasterProcessor] Filtered down to {len(results)} dialogues after cleaning.")
return results