File size: 4,692 Bytes
9622166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#integrating the neccessary classes
import torch
import torch.nn as nn
import math
from .positional_encodings import PositionalEncoding  #import other modules neccessary for 
from .multihead_attention import MultiHeadAttention
from .encoding_layers import position_wide_feed_forward, Residual_layer
from .masking_for_attention import mask
from .embeddings import Embeddings

class EncoderLayer(nn.Module):
    def __init__(self, dimension_for_model, num_of_heads, dim_feedforward, dropout = 0.1):
        '''
        dimension_for_model: the dimension desired for the model specified at the embeddings layer
        num_of_heads: the number of heads for the multi-head-attention structure to keep track of
        dim_feedforward: the dimension of the positional feed forward structure
        dropout: structure for removing model dependencies during training, improving robustness
        '''
        super().__init__()
        # Loading previously coded structures for multi-head attention
        self.self_attn = MultiHeadAttention(dimension_for_model, num_of_heads, dropout)
        self.norm1 = nn.LayerNorm(dimension_for_model)
        self.dropout1 = nn.Dropout(dropout)
        # Loading previously coded structures for position_wide_feed_forward
        self.ffn = position_wide_feed_forward(dimension_for_model, dim_feedforward, dropout)
        self.norm2 = nn.LayerNorm(dimension_for_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        # Self-attention block
        _src = src
        attn_output, _ = self.self_attn(src, src, src, mask=src_mask)  
        src = self.norm1(_src + self.dropout1(attn_output))  # changed attention output
        # Feed-forward block
        _src = src
        ff_output = self.ffn(src)
        src = self.norm2(_src + self.dropout2(ff_output))
        return src


class Encoder(nn.Module):
    """
    Stacked Transformer encoder:
      - embedding + positional encoding
      - N encoder layers
      - final layer norm
    """
    def __init__(self, vocab_size, dimension_of_model, num_of_heads, num_layers, dim_feedforward = 2048, dropout = 0.1, max_len = 5000, num_of_roles=2, max_turns=16):
        super().__init__()
        # Token/role/turn embeddings
        self.embed = Embeddings(vocab_size, dimension_for_model=dimension_of_model, num_of_roles=num_of_roles, max_turns=max_turns)
        # Positional encodings (sinusoidal or learned)
        self.pe = PositionalEncoding(dimension_of_model, dropout=dropout, max_len=max_len)
        # Stacked encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(dimension_of_model, num_of_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])
        # Final normalization
        self.norm = nn.LayerNorm(dimension_of_model)

    def forward(self, src_ids, roles, turns, src_mask = None) -> torch.Tensor:
        """
        Args:
          src_ids: [batch_size x seq_len] input token indices
          roles:   [batch_size x seq_len] role ids
          turns:   [batch_size x seq_len] turn ids
          src_mask: [batch_size, 1, 1, seq_len] mask to prevent attending to padding tokens
        """
        # Embed tokens, roles, and turns
        x = self.embed(src_ids, roles, turns)
        # Add positional information
        x = self.pe(x)
        # Pass through each encoder layer
        for layer in self.layers:
            x = layer(x, src_mask)
        # Final layer normalization
        return self.norm(x)
    
    def load_state_dict(self, state_dict, strict=True):
        """
        Custom state dict loading to handle backward compatibility with old model format
        """
        # Check if this is an old model format (has encoder.embed.weight)
        if 'encoder.embed.weight' in state_dict:
            # This is an old model, we need to adapt the weights
            old_embed_weight = state_dict['encoder.embed.weight']
            
            # Copy the old embedding weights to the new structure
            state_dict['encoder.embed.lut.weight'] = old_embed_weight
            state_dict['encoder.embed.lut_roles.weight'] = torch.zeros_like(old_embed_weight)
            state_dict['encoder.embed.lut_turns.weight'] = torch.zeros_like(old_embed_weight)
            state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1))
            state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1))
            
            # Remove the old key
            del state_dict['encoder.embed.weight']
        
        return super().load_state_dict(state_dict, strict=strict)