|
from dataclasses import dataclass, field |
|
from typing import List, Dict, Optional, Any |
|
import json |
|
import glob |
|
from pathlib import Path |
|
from pipeline_config import PipelineConfig |
|
|
|
@dataclass |
|
class SchemaGuidedDialogue: |
|
""" |
|
Structured representation of a Schema-Guided dialogue |
|
""" |
|
dialogue_id: str |
|
service_name: str |
|
service_description: Optional[str] |
|
schema: Dict[str, Any] |
|
turns: List[Dict[str, Any]] |
|
original_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
class SchemaGuidedProcessor: |
|
""" |
|
Handles processing and preparation of Schema-Guided dataset dialogues |
|
""" |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
self.services = set() |
|
self.domains = set() |
|
self.schemas = {} |
|
|
|
def load_dataset(self, base_dir, max_examples: Optional[int] = None) -> List[SchemaGuidedDialogue]: |
|
""" |
|
Load and parse Schema-Guided Dialogue dataset |
|
|
|
Args: |
|
dialogue_path: Path to the dialogue JSON file |
|
schema_path: Path to the schema JSON file |
|
""" |
|
|
|
schema_file = Path(base_dir, "schema.json") |
|
dialogue_files_pattern = str(Path(base_dir, "dialogues_*.json")) |
|
|
|
|
|
if not schema_file.exists(): |
|
raise FileNotFoundError(f"Schema file not found at {schema_file}") |
|
|
|
|
|
self.schemas = self._load_schemas(schema_file) |
|
|
|
|
|
dialogue_files = glob.glob(dialogue_files_pattern) |
|
if not dialogue_files: |
|
raise FileNotFoundError(f"No dialogue files found matching pattern {dialogue_files_pattern}") |
|
|
|
print(f"Found {len(dialogue_files)} dialogue files to process.") |
|
|
|
|
|
processed_dialogues = [] |
|
for file_path in dialogue_files: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
raw_dialogues = json.load(f) |
|
|
|
for dialogue in raw_dialogues: |
|
processed_dialogues.append(self._process_single_dialogue(dialogue)) |
|
|
|
if max_examples and len(processed_dialogues) >= max_examples: |
|
break |
|
|
|
return processed_dialogues |
|
|
|
def _process_single_dialogue(self, dialogue: Dict[str, Any]) -> SchemaGuidedDialogue: |
|
""" |
|
Process a single dialogue JSON object into a SchemaGuidedDialogue object. |
|
""" |
|
dialogue_id = str(dialogue.get("dialogue_id", "")) |
|
services = dialogue.get("services", []) |
|
service_name = services[0] if services else None |
|
schema = self.schemas.get(service_name, {}) |
|
service_description = schema.get("description", "") |
|
|
|
|
|
turns = self._process_turns(dialogue.get("turns", [])) |
|
|
|
|
|
metadata = { |
|
"services": services, |
|
"original_id": dialogue_id, |
|
} |
|
|
|
return SchemaGuidedDialogue( |
|
dialogue_id=f"schema_guided_{dialogue_id}", |
|
service_name=service_name, |
|
service_description=service_description, |
|
schema=schema, |
|
turns=turns, |
|
original_metadata=metadata, |
|
) |
|
|
|
def _validate_schema(self, schema: Dict[str, Any]) -> bool: |
|
""" |
|
Validate a schema |
|
""" |
|
required_keys = {"service_name", "description", "slots", "intents"} |
|
missing_keys = required_keys - schema.keys() |
|
if missing_keys: |
|
print(f"Warning: Missing keys in schema {schema.get('service_name', 'unknown')}: {missing_keys}") |
|
return False |
|
return True |
|
|
|
def _load_schemas(self, schema_path: str) -> Dict[str, Any]: |
|
""" |
|
Load and process service schemas |
|
""" |
|
with open(schema_path, 'r', encoding='utf-8') as f: |
|
schemas = json.load(f) |
|
|
|
|
|
return { |
|
schema["service_name"]: schema for schema in schemas if self._validate_schema(schema) |
|
} |
|
|
|
def _process_turns(self, turns: List[Dict]) -> List[Dict]: |
|
""" |
|
Process dialogue turns into standardized format |
|
""" |
|
processed_turns = [] |
|
|
|
for turn in turns: |
|
try: |
|
|
|
speaker = 'assistant' if turn.get('speaker') == 'SYSTEM' else 'user' |
|
|
|
|
|
text = turn.get('utterance', '').strip() |
|
|
|
|
|
frames = turn.get('frames', []) |
|
acts = [] |
|
slots = [] |
|
|
|
for frame in frames: |
|
if 'actions' in frame: |
|
acts.extend(frame['actions']) |
|
if 'slots' in frame: |
|
slots.extend(frame['slots']) |
|
|
|
|
|
processed_turn = { |
|
'speaker': speaker, |
|
'text': text, |
|
'original_speaker': turn.get('speaker', ''), |
|
'dialogue_acts': acts, |
|
'slots': slots, |
|
'metadata': {k: v for k, v in turn.items() |
|
if k not in {'speaker', 'utterance', 'frames'}} |
|
} |
|
|
|
processed_turns.append(processed_turn) |
|
except Exception as e: |
|
print(f"Error processing turn: {str(e)}") |
|
continue |
|
|
|
return processed_turns |
|
|
|
def convert_to_pipeline_format(self, schema_dialogues: List[SchemaGuidedDialogue]) -> List[Dict]: |
|
""" |
|
Convert SchemaGuidedDialogues to the format expected by the ProcessingPipeline |
|
""" |
|
pipeline_dialogues = [] |
|
|
|
for dialogue in schema_dialogues: |
|
|
|
processed_turns = [ |
|
{"speaker": turn["speaker"], "text": turn["text"]} |
|
for turn in dialogue.turns if turn["text"].strip() |
|
] |
|
|
|
|
|
pipeline_dialogue = { |
|
'dialogue_id': dialogue.dialogue_id, |
|
'turns': processed_turns, |
|
'metadata': { |
|
'service_name': dialogue.service_name, |
|
'service_description': dialogue.service_description, |
|
'schema': dialogue.schema, |
|
**dialogue.original_metadata |
|
} |
|
} |
|
|
|
pipeline_dialogues.append(pipeline_dialogue) |
|
|
|
return pipeline_dialogues |
|
|