Aegis-ATIS-Demo / transformer_chat.py
literallybannedfromcallingbob's picture
updated
9622166
import torch
import torch.nn as nn
import math
# Import neccessary layers
from built_transformer.embeddings import Embeddings
from built_transformer.encoder import Encoder, EncoderLayer
from built_transformer.decoders import Decoder, DecoderLayer
from built_transformer.positional_encodings import PositionalEncoding
from built_transformer.slot_classifier import SlotClassifier
class TransformerChatbot(nn.Module):
"""
Unified Transformer-based chatbot model that combines:
- Joint token/role/turn embeddings
- Encoder-decoder architecture with attention
- Slot-filling classification
- Generation capabilities
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
num_roles: int = 2,
max_turns: int = 16,
num_slots: int = 4,
dropout: float = 0.1,
max_len: int = 5000
):
super().__init__()
# Embeddings for tokens, roles, and turns
self.embed = Embeddings(
char=vocab_size, # Fixed type and name mismatch
dimension_for_model=d_model,
num_of_roles=num_roles,
max_turns=max_turns
)
# Positional encoding
self.pos_enc = PositionalEncoding(d_model, dropout, max_len)
# Encoder stack
self.encoder = Encoder(
vocab_size=vocab_size,
dimension_of_model=d_model,
num_of_heads=num_heads,
num_layers=num_encoder_layers,
dim_feedforward=d_ff,
dropout=dropout,
max_len=max_len,
num_of_roles=num_roles,
max_turns=max_turns
)
# Decoder stack
self.decoder = Decoder(
vocab_size=vocab_size,
dimension_for_model=d_model,
num_layers=num_decoder_layers,
num_of_heads=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
max_len=max_len
)
# Output projections
self.out_proj = nn.Linear(d_model, vocab_size)
self.slot_classifier = SlotClassifier(d_model, num_slots)
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
#Initialize parameters with Xavier uniform initialization
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def load_state_dict(self, state_dict, strict=True):
# Check if this is an old model format (has encoder.embed.weight), since previous versions uses different weights
if 'encoder.embed.weight' in state_dict:
# This is an old model, we need to adapt the weights
old_embed_weight = state_dict['encoder.embed.weight']
# Copy the old embedding weights to the new structure
state_dict['encoder.embed.lut.weight'] = old_embed_weight
# Initialize role and turn embeddings with correct sizes
state_dict['encoder.embed.lut_roles.weight'] = torch.zeros(2, old_embed_weight.size(1)) # 2 roles
state_dict['encoder.embed.lut_turns.weight'] = torch.zeros(16, old_embed_weight.size(1)) # 16 turns
state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1))
state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1))
# Remove the old key
del state_dict['encoder.embed.weight']
return super().load_state_dict(state_dict, strict=strict)
def encode(self, src_tokens, src_roles, src_turns, src_mask=None):
"""
Encode source sequences with role and turn information.
Args:
src_tokens: [B, S] token IDs
src_roles: [B, S] role IDs
src_turns: [B, S] turn IDs
src_mask: padding mask [B, 1, 1, S]
Returns:
enc_out: [B, S, d_model]
"""
# Pass through encoder (embedding and positional encoding handled inside)
return self.encoder(src_tokens, src_roles, src_turns, src_mask)
def decode(
self,
tgt_tokens,
enc_out,
tgt_roles,
tgt_turns,
src_mask=None,
tgt_mask=None
):
"""
Decode target sequences with encoder context.
Args:
tgt_tokens: [B, T] target token IDs
enc_out: [B, S, d_model] encoder output
tgt_roles: [B, T] target role IDs
tgt_turns: [B, T] target turn IDs
src_mask: [B, 1, 1, S] source mask
tgt_mask: [B, 1, T, T] target mask
Returns:
logits: [B, T, vocab_size]
"""
# Combine embeddings
y = self.embed(tgt_tokens, tgt_roles, tgt_turns)
y = self.pos_enc(y)
# Pass through decoder
dec_out = self.decoder(tgt_tokens, enc_out, tgt_mask, src_mask)
return self.out_proj(dec_out)
def forward(
self,
src_tokens,
tgt_tokens,
src_roles,
tgt_roles,
src_turns,
tgt_turns,
src_mask=None,
tgt_mask=None
):
"""
Full forward pass combining encoding, decoding, and slot classification.
Args:
src_tokens: [B, S] source token IDs
tgt_tokens: [B, T] target token IDs
src_roles: [B, S] source role IDs
tgt_roles: [B, T] target role IDs
src_turns: [B, S] source turn IDs
tgt_turns: [B, T] target turn IDs
src_mask: [B, 1, 1, S] source mask
tgt_mask: [B, 1, T, T] target mask
Returns:
gen_logits: [B, T, vocab_size] generation logits
slot_logits: [B, num_slots] slot classification logits
"""
# Encode source sequence
enc_out = self.encode(src_tokens, src_roles, src_turns, src_mask)
# Decode target sequence
gen_logits = self.decode(
tgt_tokens,
enc_out,
tgt_roles,
tgt_turns,
src_mask,
tgt_mask
)
# Use first position of encoder output for slot classification
cls_rep = enc_out[:, 0, :]
slot_logits = self.slot_classifier(cls_rep)
return gen_logits, slot_logits