|
from datasets import load_dataset |
|
from tokenizers import Tokenizer |
|
from torch.utils.data import DataLoader, Dataset |
|
import torch |
|
from transformer_chat import TransformerChatbot |
|
import pandas as pd |
|
import random |
|
|
|
|
|
raw_dataset = load_dataset("tuetschek/atis", split="train") |
|
|
|
|
|
tokenizer = Tokenizer.from_file('tokenizer.json') |
|
|
|
|
|
def create_response_for_intent(intent, text): |
|
"""Create synthetic responses for ATIS intents""" |
|
responses = { |
|
'atis_flight': [ |
|
"I can help you with flight information. What specific details do you need?", |
|
"I'll search for flights matching your criteria. Please provide departure and arrival cities.", |
|
"Let me find available flights for you. When would you like to travel?" |
|
], |
|
'atis_flight_no': [ |
|
"I can help you with flight number information. Please provide the flight number.", |
|
"Let me search for details about that flight number.", |
|
"I'll look up information for that specific flight." |
|
], |
|
'atis_airfare': [ |
|
"I can help you find airfare information. What's your travel route?", |
|
"Let me search for the best airfare options for your trip.", |
|
"I'll check current airfare prices for your destination." |
|
], |
|
'atis_airline': [ |
|
"I can help you with airline information. Which airline are you looking for?", |
|
"Let me provide information about that airline.", |
|
"I'll search for details about the airline you mentioned." |
|
], |
|
'atis_abbreviation': [ |
|
"I can help you with airport abbreviations. Which abbreviation do you need?", |
|
"Let me explain that airport abbreviation for you.", |
|
"I'll provide the full name for that airport code." |
|
], |
|
'atis_airport': [ |
|
"I can help you with airport information. Which airport are you looking for?", |
|
"Let me provide details about that airport.", |
|
"I'll search for information about the airport you mentioned." |
|
], |
|
'atis_distance': [ |
|
"I can help you calculate distances between airports. Which airports are you interested in?", |
|
"Let me calculate the distance for you.", |
|
"I'll provide distance information between those locations." |
|
], |
|
'atis_ground_service': [ |
|
"I can help you with ground transportation services. What type of service do you need?", |
|
"Let me find ground transportation options for you.", |
|
"I'll search for available ground services at your destination." |
|
], |
|
'atis_aircraft': [ |
|
"I can help you with aircraft information. What type of aircraft are you looking for?", |
|
"Let me provide details about that aircraft type.", |
|
"I'll search for information about the aircraft you mentioned." |
|
], |
|
'atis_capacity': [ |
|
"I can help you with capacity information. What specific capacity details do you need?", |
|
"Let me check the capacity for that flight or aircraft.", |
|
"I'll provide capacity information for your query." |
|
], |
|
'atis_quantity': [ |
|
"I can help you with quantity information. What specific quantity are you looking for?", |
|
"Let me check the quantity for that item or service.", |
|
"I'll provide quantity information for your request." |
|
], |
|
'atis_meal': [ |
|
"I can help you with meal information. What type of meal service are you looking for?", |
|
"Let me check meal options for your flight.", |
|
"I'll provide information about meal services available." |
|
], |
|
'atis_cheapest': [ |
|
"I can help you find the cheapest options. What's your travel route?", |
|
"Let me search for the most affordable options for your trip.", |
|
"I'll find the cheapest flights or services for you." |
|
], |
|
'atis_restriction': [ |
|
"I can help you with travel restrictions. What type of restrictions are you asking about?", |
|
"Let me check the restrictions for your travel plans.", |
|
"I'll provide information about travel restrictions." |
|
], |
|
'atis_day_name': [ |
|
"I can help you with day information. What specific day are you looking for?", |
|
"Let me check the schedule for that day.", |
|
"I'll provide information about flights or services on that day." |
|
] |
|
} |
|
|
|
|
|
base_responses = responses.get(intent, [ |
|
"I can help you with that. Please provide more details.", |
|
"Let me assist you with your request.", |
|
"I'll help you find the information you need." |
|
]) |
|
|
|
|
|
if "flight" in text.lower(): |
|
base_responses.extend([ |
|
"I can help you book a flight. What are your travel dates?", |
|
"Let me search for available flights for you.", |
|
"I'll help you find the best flight options." |
|
]) |
|
|
|
return random.choice(base_responses) |
|
|
|
|
|
def create_training_pairs(): |
|
training_data = [] |
|
|
|
for item in raw_dataset: |
|
question = item['text'] |
|
intent = item['intent'] |
|
response = create_response_for_intent(intent, question) |
|
|
|
|
|
question_encoding = tokenizer.encode(question) |
|
response_encoding = tokenizer.encode(response) |
|
|
|
|
|
question_ids = [tokenizer.token_to_id("[CLS]")] + question_encoding.ids + [tokenizer.token_to_id("[SEP]")] |
|
response_ids = [tokenizer.token_to_id("[CLS]")] + response_encoding.ids + [tokenizer.token_to_id("[SEP]")] |
|
|
|
training_data.append({ |
|
'question_ids': question_ids, |
|
'response_ids': response_ids, |
|
'question_len': len(question_ids), |
|
'response_len': len(response_ids) |
|
}) |
|
|
|
return training_data |
|
|
|
|
|
class AtisGenerationDataset(Dataset): |
|
def __init__(self, training_data, tokenizer, max_length=128): |
|
self.training_data = training_data |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.training_data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.training_data[idx] |
|
|
|
|
|
question_ids = item['question_ids'][:self.max_length//2] |
|
response_ids = item['response_ids'][:self.max_length//2] |
|
|
|
|
|
question_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(question_ids)) |
|
response_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(response_ids)) |
|
|
|
return ( |
|
torch.tensor(question_ids), |
|
torch.tensor(response_ids), |
|
torch.tensor(item['question_len']), |
|
torch.tensor(item['response_len']) |
|
) |
|
|
|
|
|
print("Creating training data...") |
|
training_data = create_training_pairs() |
|
print(f"Created {len(training_data)} training pairs") |
|
|
|
|
|
atis_dataset = AtisGenerationDataset(training_data, tokenizer) |
|
dataloader = DataLoader(atis_dataset, batch_size=16, shuffle=True) |
|
|
|
|
|
vocab_size = tokenizer.get_vocab_size() |
|
model = TransformerChatbot( |
|
vocab_size=vocab_size, |
|
d_model=512, |
|
num_heads=8, |
|
d_ff=2048, |
|
num_encoder_layers=6, |
|
num_decoder_layers=6, |
|
num_roles=2, |
|
max_turns=16, |
|
num_slots=len(set(item['intent'] for item in raw_dataset)), |
|
dropout=0.1 |
|
) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
|
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[PAD]")) |
|
|
|
print("Starting training...") |
|
for epoch in range(10): |
|
model.train() |
|
total_loss = 0 |
|
for batch_idx, (question_ids, response_ids, question_lens, response_lens) in enumerate(dataloader): |
|
question_ids = question_ids.to(device) |
|
response_ids = response_ids.to(device) |
|
|
|
batch_size, seq_len = question_ids.shape |
|
|
|
|
|
roles = torch.zeros_like(question_ids) |
|
turns = torch.zeros_like(question_ids) |
|
|
|
gen_logits, slot_logits = model( |
|
question_ids, response_ids, |
|
roles, roles, |
|
turns, turns |
|
) |
|
|
|
|
|
target_ids = response_ids[:, 1:] |
|
gen_logits = gen_logits[:, :-1, :] |
|
|
|
gen_logits_flat = gen_logits.reshape(-1, vocab_size) |
|
target_ids_flat = target_ids.reshape(-1) |
|
loss = loss_fn(gen_logits_flat, target_ids_flat) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
if batch_idx % 100 == 0: |
|
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}") |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}") |
|
|
|
|
|
print("Saving model...") |
|
torch.save(model.state_dict(), 'atis_transformer.pt') |
|
print("Training completed!") |