File size: 4,002 Bytes
a2dad08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_utils import *

from attention import SelfAttentionHead, MultiHeadAttention, FeedForwardNet, DecoderBlock


class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embed, block_size, num_heads, n_layers) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.decoder_blocks = nn.Sequential(*[DecoderBlock(n_embed, num_heads, block_size=block_size) for _ in range(n_layers)] )
        self.ln_final = nn.LayerNorm(n_embed)

        ## self.sa_head = SelfAttentionHead(vocab_size, n_embed, block_size)
        # self.sa_heads = MultiHeadAttention(num_heads=4, head_size=n_embed//4, n_embed=n_embed, block_size=block_size)
        # self.ffn = FeedForwardNet(n_embed, dropout=0.2)

        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):

        # idx and targets both are tensors of shape (B, T) -> B = batch_sz, T = seq_len ("time steps", here 8)
        B, T = idx.shape
        tok_embed = self.token_embedding_table(idx) # (B, T, C)  C = "channels", here vocab_size or embedding dim for each token
        pos_embed = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T, C)  C = "channels", here vocab_size or embedding dim for each token
        x_in = tok_embed + pos_embed
        # x_in = self.sa_heads(x_in)
        # x_in = self.ffn(x_in)
        x_in = self.ln_final(self.decoder_blocks(x_in))
        logits = self.lm_head(x_in) # (B, T, C)  C = "channels", here vocab_size or embedding dim for each token

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # Cross entropy requires the 2nd param to be C "channels"
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T), ignore_index=0)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) shaped array of indices in current context
        for _ in range(max_new_tokens):
            #limit input idx to last "block size" tokens
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            #focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax for probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            #sample from distribudion
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            #append sampled index to running sequence idx
            idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)

        return idx
    
    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



if __name__ == "__main__":
    from data_utils import *
    xb, yb = get_random_batch('train')
    xb = xb.to(device)
    yb = yb.to(device)

    m = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)
    logits, loss = m(xb, yb)
    print(logits.shape)
    print(loss)