csc525_retrieval_based_chatbot / taskmaster_processor.py
JoeArmani
sentence transformer
64e7c31
raw
history blame
9.35 kB
import os
import re
import json
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from logger_config import config_logger
logger = config_logger(__name__)
@dataclass
class TaskmasterDialogue:
conversation_id: str
instruction_id: Optional[str]
scenario: Optional[str]
domain: Optional[str]
turns: List[Dict[str, Any]]
original_metadata: Dict[str, Any] = field(default_factory=dict)
def __str__(self):
return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)"
def validate(self) -> bool:
return bool(self.conversation_id and isinstance(self.turns, list))
class RawDataProcessingConfig:
"""
Simple config for raw dataset processing
"""
def __init__(
self,
debug: bool = True,
max_length: int = 512,
min_turns: int = 4,
min_user_words: int = 3
):
self.debug = debug
self.max_length = max_length
self.min_turns = min_turns
self.min_user_words = min_user_words
class TaskmasterProcessor:
"""
Load Taskmaster-1 dialogues, extracts domain.
Clean, filter, save to pipeline format.
"""
def __init__(self, config: RawDataProcessingConfig):
self.config = config
def load_taskmaster_dataset(
self,
base_dir: str,
max_examples: Optional[int] = None
) -> List[TaskmasterDialogue]:
"""
Load & parse Taskmaster-1 JSON for self-dialogs & woz-dialogs.
"""
required_files = {
"self-dialogs": "self-dialogs.json",
"woz-dialogs": "woz-dialogs.json",
"ontology": "ontology.json",
}
# Check for missing files
missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()]
if missing:
raise FileNotFoundError(f"Missing Taskmaster files: {missing}")
# Load ontology
ontology_path = Path(base_dir, required_files["ontology"])
with open(ontology_path, 'r', encoding='utf-8') as f:
ontology = json.load(f)
if self.config.debug:
logger.info(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
dialogues: List[TaskmasterDialogue] = []
# Process each file
file_keys = ["self-dialogs", "woz-dialogs"]
for file_key in file_keys:
file_path = Path(base_dir, required_files[file_key])
with open(file_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
for d in raw_data:
conversation_id = d.get("conversation_id", "")
instruction_id = d.get("instruction_id", None)
scenario_text = d.get("scenario", "")
# Handle utterances
utterances = d.get("utterances", [])
turns = self._process_utterances(utterances)
# Detect Domain
domain = self._extract_domain(scenario_text, turns)
# Build the object
new_dlg = TaskmasterDialogue(
conversation_id=conversation_id,
instruction_id=instruction_id,
scenario=scenario_text,
domain=domain,
turns=turns,
original_metadata={}
)
dialogues.append(new_dlg)
if max_examples and len(dialogues) >= max_examples:
break
if self.config.debug:
logger.info(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
return dialogues
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
"""
Combine scenario text + all turn texts to detect domain more robustly.
"""
combined_text = scenario.lower()
for turn in turns:
txt = turn.get('text', '').lower()
combined_text += " " + txt
# Domain patterns
domain_patterns = {
'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat|hungry)\b',
'movie': r'\b(movie|cinema|film|ticket|showtime|theater|flick|screening)\b',
'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver)\b',
'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b',
'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice)\b',
'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b'
}
for domain, pattern in domain_patterns.items():
if re.search(pattern, combined_text):
# Optional: logger.info if debug
if self.config.debug:
logger.info(f"Matched domain: {domain} in scenario/turns")
return domain
if self.config.debug:
logger.info("No domain match, returning 'other'")
return 'other'
def _clean_text(self, text: str) -> str:
"""
Simple text normalization
"""
# Strip multiple spaces, remove unnecessary punctuation
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'([!?.,])\1+', r'\1', text)
return text.strip()
def _is_numeric_line(self, text: str) -> bool:
"""
Return True if line is purely digits/punctuation/spaces,
e.g. "4 3 13" and similar found in Taskmaster-1 dataset.
"""
pattern = r'^[\s]*[\d]+([\s\d.,]+)*[\s]*$'
return bool(re.match(pattern, text))
def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]:
"""
Filter out dialogues that don't meet min length requirements. Convert to pipeline format.
{
"dialogue_id": "...",
"domain": "...",
"turns": [ {"speaker": "user", "text": "..."}, ... ]
}
"""
total = len(dialogues)
invalid = 0
too_few_turns = 0
short_user_turns = 0
results = []
for dlg in dialogues:
if not dlg.validate():
invalid += 1
continue
# Skip if too few turns
if len(dlg.turns) < self.config.min_turns:
too_few_turns += 1
continue
# Skip if any user turn is too short
keep = True
for turn in dlg.turns:
if turn['speaker'] == 'user':
words_count = len(turn['text'].split())
if words_count < self.config.min_user_words:
short_user_turns += 1
keep = False
break
if not keep:
continue
pipeline_dlg = {
'dialogue_id': dlg.conversation_id,
'domain': dlg.domain,
'turns': dlg.turns
}
results.append(pipeline_dlg)
if self.config.debug:
logger.info(f"\nFiltering Statistics:")
logger.info(f"Total dialogues: {total}")
logger.info(f"Invalid dialogues: {invalid}")
logger.info(f"Too few turns: {too_few_turns}")
logger.info(f"Short user turns: {short_user_turns}")
logger.info(f"Remaining dialogues: {len(results)}")
logger.info(f"Filtering rate: {((total - len(results)) / total) * 100:.1f}%\n")
return results
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Added logging to track utterance filtering"""
total = len(utterances)
empty = 0
numeric = 0
too_short = 0
cleaned_turns = []
for utt in utterances:
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
raw_text = utt.get('text', '').strip()
text = self._clean_text(raw_text)
if not text:
empty += 1
continue
if self._is_numeric_line(text):
numeric += 1
continue
if len(text.split()) < 3:
too_short += 1
continue
cleaned_turns.append({
'speaker': speaker,
'text': text
})
if self.config.debug and total > 0:
logger.info(f"\nUtterance Cleaning Statistics (Dialogue {utterances[0].get('conversation_id', 'unknown')}):")
logger.info(f"Total utterances: {total}")
logger.info(f"Empty/blank: {empty}")
logger.info(f"Numeric only: {numeric}")
logger.info(f"Too short (<3 words): {too_short}")
logger.info(f"Remaining turns: {len(cleaned_turns)}")
logger.info(f"Filtering rate: {((total - len(cleaned_turns)) / total) * 100:.1f}%\n")
return cleaned_turns