from dataclasses import dataclass, field from typing import List, Dict, Optional, Any import json import glob from pathlib import Path from pipeline_config import PipelineConfig @dataclass class SchemaGuidedDialogue: """ Structured representation of a Schema-Guided dialogue """ dialogue_id: str service_name: str service_description: Optional[str] schema: Dict[str, Any] turns: List[Dict[str, Any]] original_metadata: Dict[str, Any] = field(default_factory=dict) class SchemaGuidedProcessor: """ Handles processing and preparation of Schema-Guided dataset dialogues """ def __init__(self, config: PipelineConfig): self.config = config self.services = set() self.domains = set() self.schemas = {} def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]: """ Load and parse Schema-Guided Dialogue dataset Args: dialogue_path: Path to the dialogue JSON file schema_path: Path to the schema JSON file """ # Define schema and dialogue file patterns schema_file = Path(base_dir, "schema.json") dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json")) # Check for schema file if not schema_file.exists(): raise FileNotFoundError(f"Schema file not found at {schema_file}") # Load schema self.schemas = self._load_schemas(schema_file) # Find and validate dialogue files dialogue_files = glob.glob(dialogue_files_pattern) if not dialogue_files: raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}") print(f"Found {len(dialogue_files)} dialogue files to process.") # Process all dialogues processed_dialogues = [] for file_path in dialogue_files: with open(file_path, 'r', encoding='utf-8') as f: raw_dialogues = json.load(f) for dialogue in raw_dialogues: processed_dialogues.append(self._process_single_dialogue(dialogue)) if max_examples and len(processed_dialogues) >= max_examples: break return processed_dialogues def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue: """ Process a single dialogue JSON object into a SchemaGuidedDialogue object. """ dialogue_id = str(dialogue.get("dialogue_id", "")) services = dialogue.get("services", []) service_name = services[0] if services else None schema = self.schemas.get(service_name, {}) service_description = schema.get("description", "") # Process turns turns = self._process_turns(dialogue.get("turns", [])) # Store metadata metadata = { "services": services, "original_id": dialogue_id, } return SchemaGuidedDialogue( dialogue_id=f"schema_guided_{dialogue_id}", service_name=service_name, service_description=service_description, schema=schema, turns=turns, original_metadata=metadata, ) def _validate_schema(self, schema: Dict[str, Any]) -> bool: """ Validate a schema """ required_keys = {"service_name", "description", "slots", "intents"} missing_keys = required_keys - schema.keys() if missing_keys: print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}") return False return True def _load_schemas(self, schema_path: str) -> Dict[str, Any]: """ Load and process service schemas """ with open(schema_path, 'r', encoding='utf-8') as f: schemas = json.load(f) # Validate and index schemas return { schema["service_name"]: schema for schema in schemas if self._validate_schema(schema) } def _process_turns(self, turns: List[Dict]) -> List[Dict]: """ Process dialogue turns into standardized format """ processed_turns = [] for turn in turns: try: # Map speakers to standard format speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user' # Extract utterance and clean it text = turn.get('utterance', '').strip() # Extract frames and dialogue acts frames = turn.get('frames', []) acts = [] slots = [] for frame in frames: if 'actions' in frame: acts.extend(frame['actions']) if 'slots' in frame: slots.extend(frame['slots']) # Create the processed turn processed_turn = { 'speaker': speaker, 'text': text, 'original_speaker': turn.get('speaker', ''), 'dialogue_acts': acts, 'slots': slots, 'metadata': {k: v for k, v in turn.items() if k not in {'speaker', 'utterance', 'frames'}} } processed_turns.append(processed_turn) except Exception as e: print(f"Error processing turn: {str(e)}") continue return processed_turns def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]: """ Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline """ pipeline_dialogues = [] for dialogue in schema_dialogues: # Convert turns to the expected format processed_turns = [ {"speaker": turn["speaker"], "text": turn["text"]} for turn in dialogue.turns if turn["text"].strip() ] # Create dialogue in pipeline format pipeline_dialogue = { 'dialogue_id': dialogue.dialogue_id, 'turns': processed_turns, 'metadata': { 'service_name': dialogue.service_name, 'service_description': dialogue.service_description, 'schema': dialogue.schema, **dialogue.original_metadata } } pipeline_dialogues.append(pipeline_dialogue) return pipeline_dialogues