File size: 6,570 Bytes
9622166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
import torch.nn as nn
import math
# Import neccessary layers
from built_transformer.embeddings import Embeddings
from built_transformer.encoder import Encoder, EncoderLayer
from built_transformer.decoders import Decoder, DecoderLayer
from built_transformer.positional_encodings import PositionalEncoding
from built_transformer.slot_classifier import SlotClassifier
class TransformerChatbot(nn.Module):
"""
Unified Transformer-based chatbot model that combines:
- Joint token/role/turn embeddings
- Encoder-decoder architecture with attention
- Slot-filling classification
- Generation capabilities
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
num_roles: int = 2,
max_turns: int = 16,
num_slots: int = 4,
dropout: float = 0.1,
max_len: int = 5000
):
super().__init__()
# Embeddings for tokens, roles, and turns
self.embed = Embeddings(
char=vocab_size, # Fixed type and name mismatch
dimension_for_model=d_model,
num_of_roles=num_roles,
max_turns=max_turns
)
# Positional encoding
self.pos_enc = PositionalEncoding(d_model, dropout, max_len)
# Encoder stack
self.encoder = Encoder(
vocab_size=vocab_size,
dimension_of_model=d_model,
num_of_heads=num_heads,
num_layers=num_encoder_layers,
dim_feedforward=d_ff,
dropout=dropout,
max_len=max_len,
num_of_roles=num_roles,
max_turns=max_turns
)
# Decoder stack
self.decoder = Decoder(
vocab_size=vocab_size,
dimension_for_model=d_model,
num_layers=num_decoder_layers,
num_of_heads=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
max_len=max_len
)
# Output projections
self.out_proj = nn.Linear(d_model, vocab_size)
self.slot_classifier = SlotClassifier(d_model, num_slots)
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
#Initialize parameters with Xavier uniform initialization
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def load_state_dict(self, state_dict, strict=True):
# Check if this is an old model format (has encoder.embed.weight), since previous versions uses different weights
if 'encoder.embed.weight' in state_dict:
# This is an old model, we need to adapt the weights
old_embed_weight = state_dict['encoder.embed.weight']
# Copy the old embedding weights to the new structure
state_dict['encoder.embed.lut.weight'] = old_embed_weight
# Initialize role and turn embeddings with correct sizes
state_dict['encoder.embed.lut_roles.weight'] = torch.zeros(2, old_embed_weight.size(1)) # 2 roles
state_dict['encoder.embed.lut_turns.weight'] = torch.zeros(16, old_embed_weight.size(1)) # 16 turns
state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1))
state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1))
# Remove the old key
del state_dict['encoder.embed.weight']
return super().load_state_dict(state_dict, strict=strict)
def encode(self, src_tokens, src_roles, src_turns, src_mask=None):
"""
Encode source sequences with role and turn information.
Args:
src_tokens: [B, S] token IDs
src_roles: [B, S] role IDs
src_turns: [B, S] turn IDs
src_mask: padding mask [B, 1, 1, S]
Returns:
enc_out: [B, S, d_model]
"""
# Pass through encoder (embedding and positional encoding handled inside)
return self.encoder(src_tokens, src_roles, src_turns, src_mask)
def decode(
self,
tgt_tokens,
enc_out,
tgt_roles,
tgt_turns,
src_mask=None,
tgt_mask=None
):
"""
Decode target sequences with encoder context.
Args:
tgt_tokens: [B, T] target token IDs
enc_out: [B, S, d_model] encoder output
tgt_roles: [B, T] target role IDs
tgt_turns: [B, T] target turn IDs
src_mask: [B, 1, 1, S] source mask
tgt_mask: [B, 1, T, T] target mask
Returns:
logits: [B, T, vocab_size]
"""
# Combine embeddings
y = self.embed(tgt_tokens, tgt_roles, tgt_turns)
y = self.pos_enc(y)
# Pass through decoder
dec_out = self.decoder(tgt_tokens, enc_out, tgt_mask, src_mask)
return self.out_proj(dec_out)
def forward(
self,
src_tokens,
tgt_tokens,
src_roles,
tgt_roles,
src_turns,
tgt_turns,
src_mask=None,
tgt_mask=None
):
"""
Full forward pass combining encoding, decoding, and slot classification.
Args:
src_tokens: [B, S] source token IDs
tgt_tokens: [B, T] target token IDs
src_roles: [B, S] source role IDs
tgt_roles: [B, T] target role IDs
src_turns: [B, S] source turn IDs
tgt_turns: [B, T] target turn IDs
src_mask: [B, 1, 1, S] source mask
tgt_mask: [B, 1, T, T] target mask
Returns:
gen_logits: [B, T, vocab_size] generation logits
slot_logits: [B, num_slots] slot classification logits
"""
# Encode source sequence
enc_out = self.encode(src_tokens, src_roles, src_turns, src_mask)
# Decode target sequence
gen_logits = self.decode(
tgt_tokens,
enc_out,
tgt_roles,
tgt_turns,
src_mask,
tgt_mask
)
# Use first position of encoder output for slot classification
cls_rep = enc_out[:, 0, :]
slot_logits = self.slot_classifier(cls_rep)
return gen_logits, slot_logits
|