File size: 9,777 Bytes
9622166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from datasets import load_dataset
from tokenizers import Tokenizer
from torch.utils.data import DataLoader, Dataset
import torch
from transformer_chat import TransformerChatbot
import pandas as pd
import random

# Loading atis-datasets
raw_dataset = load_dataset("tuetschek/atis", split="train")

# Loading tokenizer from file
tokenizer = Tokenizer.from_file('tokenizer.json')

# Create synthetic responses for ATIS queries for training purposes
def create_response_for_intent(intent, text):
    """Create synthetic responses for ATIS intents"""
    responses = {
        'atis_flight': [
            "I can help you with flight information. What specific details do you need?",
            "I'll search for flights matching your criteria. Please provide departure and arrival cities.",
            "Let me find available flights for you. When would you like to travel?"
        ],
        'atis_flight_no': [
            "I can help you with flight number information. Please provide the flight number.",
            "Let me search for details about that flight number.",
            "I'll look up information for that specific flight."
        ],
        'atis_airfare': [
            "I can help you find airfare information. What's your travel route?",
            "Let me search for the best airfare options for your trip.",
            "I'll check current airfare prices for your destination."
        ],
        'atis_airline': [
            "I can help you with airline information. Which airline are you looking for?",
            "Let me provide information about that airline.",
            "I'll search for details about the airline you mentioned."
        ],
        'atis_abbreviation': [
            "I can help you with airport abbreviations. Which abbreviation do you need?",
            "Let me explain that airport abbreviation for you.",
            "I'll provide the full name for that airport code."
        ],
        'atis_airport': [
            "I can help you with airport information. Which airport are you looking for?",
            "Let me provide details about that airport.",
            "I'll search for information about the airport you mentioned."
        ],
        'atis_distance': [
            "I can help you calculate distances between airports. Which airports are you interested in?",
            "Let me calculate the distance for you.",
            "I'll provide distance information between those locations."
        ],
        'atis_ground_service': [
            "I can help you with ground transportation services. What type of service do you need?",
            "Let me find ground transportation options for you.",
            "I'll search for available ground services at your destination."
        ],
        'atis_aircraft': [
            "I can help you with aircraft information. What type of aircraft are you looking for?",
            "Let me provide details about that aircraft type.",
            "I'll search for information about the aircraft you mentioned."
        ],
        'atis_capacity': [
            "I can help you with capacity information. What specific capacity details do you need?",
            "Let me check the capacity for that flight or aircraft.",
            "I'll provide capacity information for your query."
        ],
        'atis_quantity': [
            "I can help you with quantity information. What specific quantity are you looking for?",
            "Let me check the quantity for that item or service.",
            "I'll provide quantity information for your request."
        ],
        'atis_meal': [
            "I can help you with meal information. What type of meal service are you looking for?",
            "Let me check meal options for your flight.",
            "I'll provide information about meal services available."
        ],
        'atis_cheapest': [
            "I can help you find the cheapest options. What's your travel route?",
            "Let me search for the most affordable options for your trip.",
            "I'll find the cheapest flights or services for you."
        ],
        'atis_restriction': [
            "I can help you with travel restrictions. What type of restrictions are you asking about?",
            "Let me check the restrictions for your travel plans.",
            "I'll provide information about travel restrictions."
        ],
        'atis_day_name': [
            "I can help you with day information. What specific day are you looking for?",
            "Let me check the schedule for that day.",
            "I'll provide information about flights or services on that day."
        ]
    }
    
    # Get base responses for the intent calssification datasets
    base_responses = responses.get(intent, [
        "I can help you with that. Please provide more details.",
        "Let me assist you with your request.",
        "I'll help you find the information you need."
    ])
    
    # For variety
    if "flight" in text.lower():
        base_responses.extend([
            "I can help you book a flight. What are your travel dates?",
            "Let me search for available flights for you.",
            "I'll help you find the best flight options."
        ])
    
    return random.choice(base_responses)

# Create training data with question-answer pairs
def create_training_pairs():
    training_data = []
    
    for item in raw_dataset:
        question = item['text']
        intent = item['intent']
        response = create_response_for_intent(intent, question)
        
        # Tokenize question and response
        question_encoding = tokenizer.encode(question)
        response_encoding = tokenizer.encode(response)
        
        # Add the specially defined tokens
        question_ids = [tokenizer.token_to_id("[CLS]")] + question_encoding.ids + [tokenizer.token_to_id("[SEP]")]
        response_ids = [tokenizer.token_to_id("[CLS]")] + response_encoding.ids + [tokenizer.token_to_id("[SEP]")]
        
        training_data.append({
            'question_ids': question_ids,
            'response_ids': response_ids,
            'question_len': len(question_ids),
            'response_len': len(response_ids)
        })
    
    return training_data

# Create custom dataset for training
class AtisGenerationDataset(Dataset):
    def __init__(self, training_data, tokenizer, max_length=128):
        self.training_data = training_data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.training_data)

    def __getitem__(self, idx):
        item = self.training_data[idx]
        
        # Pad sequences
        question_ids = item['question_ids'][:self.max_length//2]
        response_ids = item['response_ids'][:self.max_length//2]
        
        # Pad with PAD token
        question_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(question_ids))
        response_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(response_ids))
        
        return (
            torch.tensor(question_ids),
            torch.tensor(response_ids),
            torch.tensor(item['question_len']),
            torch.tensor(item['response_len'])
        )

# Create training data
print("Creating training data...")
training_data = create_training_pairs()
print(f"Created {len(training_data)} training pairs")

# Prepare DataLoader
atis_dataset = AtisGenerationDataset(training_data, tokenizer)
dataloader = DataLoader(atis_dataset, batch_size=16, shuffle=True)

# Prepare model with all the neccessary parameters
vocab_size = tokenizer.get_vocab_size()
model = TransformerChatbot(
    vocab_size=vocab_size,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_encoder_layers=6,
    num_decoder_layers=6,
    num_roles=2,
    max_turns=16,
    num_slots=len(set(item['intent'] for item in raw_dataset)),
    dropout=0.1
)

# Using gpu - cuda for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop for generation
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[PAD]"))

print("Starting training...")
for epoch in range(10):  # 10 epochs for fast training
    model.train()
    total_loss = 0
    for batch_idx, (question_ids, response_ids, question_lens, response_lens) in enumerate(dataloader):
        question_ids = question_ids.to(device)
        response_ids = response_ids.to(device)
        
        batch_size, seq_len = question_ids.shape
        
        # Dummy roles and turns
        roles = torch.zeros_like(question_ids)
        turns = torch.zeros_like(question_ids)
        # Forward pass
        gen_logits, slot_logits = model(
            question_ids, response_ids,
            roles, roles,  
            turns, turns   
        )
        
        # Calculate loss for generation (teacher forcing)
        target_ids = response_ids[:, 1:]  # Remove [CLS] token
        gen_logits = gen_logits[:, :-1, :]  # Remove last position        
        # Flatten for loss calculation
        gen_logits_flat = gen_logits.reshape(-1, vocab_size)
        target_ids_flat = target_ids.reshape(-1)
        loss = loss_fn(gen_logits_flat, target_ids_flat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
    # Averaging the losses
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

# Save model
print("Saving model...")
torch.save(model.state_dict(), 'atis_transformer.pt')
print("Training completed!")