piyushgrover commited on
Commit
189668a
·
verified ·
1 Parent(s): 8d780f0

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +221 -178
model.py CHANGED
@@ -1,202 +1,245 @@
1
- # Solving for residual std scaling issue
2
- import os
3
- import math
4
- import time
5
- import inspect
6
- from dataclasses import dataclass
7
  import torch
8
  import torch.nn as nn
9
- from torch.nn import functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
11
 
12
- class CausalSelfAttention(nn.Module):
 
 
 
13
 
14
- def __init__(self, config):
15
- super().__init__()
16
- assert config.n_embd % config.n_head == 0
17
- # key, query, value projections for all heads, but in a batch
18
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
19
- # output projection
20
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
21
- self.c_proj.NANGPT_SCALE_INIT = 1
22
- # regularization
23
- self.n_head = config.n_head
24
- self.n_embd = config.n_embd
25
- self.register_buffer("bias",
26
- torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size,
27
- config.block_size))
28
 
29
- def forward(self, x):
30
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
31
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
32
- # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
33
- # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
34
- qkv = self.c_attn(x)
35
- q, k, v = qkv.split(self.n_embd, dim=2)
36
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
-
40
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
41
- att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
42
- att = F.softmax(att, dim=-1)
43
- y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
44
-
45
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
 
46
  # output projection
47
- y = self.c_proj(y)
48
  return y
49
 
50
 
51
- class MLP(nn.Module):
52
 
53
- def __init__(self, config):
 
54
  super().__init__()
55
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
56
- self.gelu = nn.GELU(approximate='tanh')
57
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
58
- self.c_proj.NANOGPT_SCALE_INIT = 1
 
 
 
59
 
60
  def forward(self, x):
61
- x = self.c_fc(x)
62
- x = self.gelu(x)
63
- x = self.c_proj(x)
64
- return x
65
 
66
 
67
- class Block(nn.Module):
 
 
 
 
 
 
 
 
68
 
69
- def __init__(self, config):
70
- super().__init__()
71
- self.ln_1 = nn.LayerNorm(config.n_embd)
72
- self.attn = CausalSelfAttention(config)
73
- self.ln_2 = nn.LayerNorm(config.n_embd)
74
- self.mlp = MLP(config)
75
 
76
- def forward(self, x):
77
- x = x + self.attn(self.ln_1(x))
78
- x = x + self.mlp(self.ln_2(x))
79
- return x
80
 
 
 
 
81
 
82
- @dataclass
83
- class GPTConfig:
84
- block_size: int = 1024 # max sequence length
85
- vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
86
- n_layer: int = 12 # number of layers
87
- n_head: int = 8 # number of heads
88
- n_embd: int = 768 # embedding dimension
89
 
 
 
 
 
 
90
 
91
- class GPT(nn.Module):
 
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def __init__(self, config):
94
- super().__init__()
95
- self.config = config
96
-
97
- self.transformer = nn.ModuleDict(dict(
98
- wte=nn.Embedding(config.vocab_size, config.n_embd),
99
- wpe=nn.Embedding(config.block_size, config.n_embd),
100
- h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
101
- ln_f=nn.LayerNorm(config.n_embd),
102
- ))
103
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
104
-
105
- # weight sharing
106
- self.transformer.wte.weight = self.lm_head.weight
107
-
108
- # weight initialization
109
- self.apply(self._init_weights)
110
-
111
- def _init_weights(self, module):
112
- if isinstance(module, nn.Linear):
113
- std = 0.02
114
- if hasattr(module, 'NANGPT_SCALE_INIT'):
115
- std *= (2 * self.config.n_layer) ** -0.5
116
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
117
- if module.bias is not None:
118
- torch.nn.init.zeros_(module.bias)
119
- elif isinstance(module, nn.Embedding):
120
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
121
-
122
- def forward(self, idx, targets=None):
123
- # idx is of shape (B, T)
124
- B, T = idx.size()
125
- assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
126
- # forward the token and posisition embeddings
127
- pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
128
- pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
129
- tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
130
- x = tok_emb + pos_emb
131
- # forward the blocks of the transformer
132
- for block in self.transformer.h:
133
- x = block(x)
134
- # forward the final layernorm and the classifier
135
- x = self.transformer.ln_f(x)
136
- logits = self.lm_head(x) # (B, T, vocab_size)
137
- loss = None
138
- if targets is not None:
139
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
140
- return logits, loss
141
-
142
- @classmethod
143
- def from_pretrained(cls, model_type):
144
- """Loads pretrained GPT-2 model weights from huggingface"""
145
- assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
146
- from transformers import GPT2LMHeadModel
147
- print("loading weights from pretrained gpt: %s" % model_type)
148
-
149
- # n_layer, n_head and n_embd are determined from model_type
150
- config_args = {
151
- 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
152
- 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
153
- 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
154
- 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
155
- }[model_type]
156
- config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
157
- config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
158
- # create a from-scratch initialized minGPT model
159
- config = GPTConfig(**config_args)
160
- model = GPT(config)
161
- sd = model.state_dict()
162
- sd_keys = sd.keys()
163
- sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
164
-
165
- # init a huggingface/transformers model
166
- model_hf = GPT2LMHeadModel.from_pretrained(model_type)
167
- sd_hf = model_hf.state_dict()
168
-
169
- # copy while ensuring all of the parameters are aligned and match in names and shapes
170
- sd_keys_hf = sd_hf.keys()
171
- sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
172
- sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
173
- transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
174
- # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
175
- # this means that we have to transpose these weights when we import them
176
- assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
177
- for k in sd_keys_hf:
178
- if any(k.endswith(w) for w in transposed):
179
- # special treatment for the Conv1D weights we need to transpose
180
- assert sd_hf[k].shape[::-1] == sd[k].shape
181
- with torch.no_grad():
182
- sd[k].copy_(sd_hf[k].t())
183
- else:
184
- # vanilla copy over the other parameters
185
- assert sd_hf[k].shape == sd[k].shape
186
- with torch.no_grad():
187
- sd[k].copy_(sd_hf[k])
188
-
189
- return model
190
-
191
- def generate(self, input_tensor, max_length, EOS_TOKEN_ID=50256):
192
- output_ids = input_tensor # Start with input
193
- self.eval()
194
- for _ in range(max_length - input_tensor.size(1)):
195
- logits = self(input_tensor) # Forward pass
196
- if isinstance(logits, tuple):
197
- logits = logits[0]
198
- next_token = torch.argmax(logits[:, -1, :], dim=-1) # Get the next token
199
- input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
200
- if next_token.item() == EOS_TOKEN_ID: # Stop if end-of-sequence token is generated
201
- break
202
- return input_tensor
 
1
+ # model.py
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers.models.llama.modeling_llama import (
6
+ LlamaRotaryEmbedding,
7
+ LlamaRMSNorm,
8
+ )
9
+
10
+
11
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
12
+ """Applies Rotary Position Embedding to the query and key tensors.
13
+
14
+ Args:
15
+ q (`torch.Tensor`): The query tensor.
16
+ k (`torch.Tensor`): The key tensor.
17
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
18
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
19
+ position_ids (`torch.Tensor`, *optional*):
20
+ Deprecated and unused.
21
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
22
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
23
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
24
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
25
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
26
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
27
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
28
+ Returns:
29
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
30
+ """
31
+ cos = cos.unsqueeze(unsqueeze_dim)
32
+ sin = sin.unsqueeze(unsqueeze_dim)
33
+ q_embed = (q * cos) + (rotate_half(q) * sin)
34
+ k_embed = (k * cos) + (rotate_half(k) * sin)
35
+ return q_embed, k_embed
36
+
37
+
38
+ def rotate_half(x):
39
+ """Rotates half the hidden dims of the input."""
40
+ x1 = x[..., : x.shape[-1] // 2]
41
+ x2 = x[..., x.shape[-1] // 2 :]
42
+ return torch.cat((-x2, x1), dim=-1)
43
+
44
+
45
+ class CausalAttention(nn.Module):
46
+ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
47
+ super().__init__()
48
 
49
+ self.hidden_size = hidden_size
50
+ self.num_attention_heads = num_attention_heads
51
+ self.num_key_value_heads = num_key_value_heads
52
+ self.head_dim = hidden_size // num_attention_heads
53
 
54
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
55
+ self.scaling = self.head_dim ** -0.5
56
+ #self.attention_dropout = attention_dropout
57
+ self.is_causal = True
58
 
59
+ # Query, Key, Value projections
60
+ self.q_proj = nn.Linear(hidden_size, self.head_dim * num_attention_heads, bias=False)
61
+ self.k_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
62
+ self.v_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
63
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
 
 
 
 
 
 
 
 
 
64
 
65
+ def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
66
+ batch, seq_len = hidden_states.shape[:-1]
67
+ hidden_shape = (batch, seq_len, -1, self.head_dim)
68
+
69
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
70
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
71
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
72
+
73
+ cos, sin = position_embeddings
74
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
75
+
76
+ y = F.scaled_dot_product_attention(query_states,
77
+ key_states,
78
+ value_states,
79
+ is_causal=True,
80
+ enable_gqa=True) # Flash attention
81
+
82
+ y = y.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) # re-assemble all head outputs side by side
83
  # output projection
84
+ y = self.o_proj(y)
85
  return y
86
 
87
 
 
88
 
89
+ class MLP(nn.Module): ###Inspired from LLamaMLP
90
+ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
91
  super().__init__()
92
+
93
+ self.hidden_size = hidden_size
94
+ self.intermediate_size = intermediate_size
95
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
96
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
97
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
98
+ self.act_fn = activation_fn
99
 
100
  def forward(self, x):
101
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
102
+ return down_proj
 
 
103
 
104
 
105
+ class TransformerBlock(nn.Module):
106
+ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
107
+ super(TransformerBlock, self).__init__()
108
+ self.hidden_size = hidden_size
109
+ self.num_attention_heads = num_attention_heads
110
+ self.num_key_value_heads = num_key_value_heads
111
+ self.head_dim = hidden_size // num_attention_heads
112
+ assert self.head_dim * num_attention_heads == hidden_size, "Hidden size must be divisible by the number of attention heads."
113
+ assert self.hidden_size % self.num_key_value_heads == 0, "hidden_size must be divisible by num_key_value_heads"
114
 
115
+ self.layer_norm_1 = LlamaRMSNorm(self.hidden_size, eps=eps)
 
 
 
 
 
116
 
117
+ self.attn = CausalAttention(hidden_size, num_attention_heads, num_key_value_heads)
 
 
 
118
 
119
+ # Feedforward layer
120
+ self.feed_forward = MLP(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn)
121
+ self.layer_norm_2 = LlamaRMSNorm(self.hidden_size, eps=eps)
122
 
123
+ def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
124
+ # Layer normalization
125
+ residual = hidden_states
126
+ hidden_states = self.layer_norm_1(hidden_states)
 
 
 
127
 
128
+ '''
129
+ # Query projection
130
+ query = self.query_proj(hidden_states)
131
+ query = query.view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads,
132
+ self.head_dim).transpose(1, 2)
133
 
134
+ # Key and Value projections with shared num_key_value_heads
135
+ key = self.key_proj(hidden_states)
136
+ value = self.value_proj(hidden_states)
137
 
138
+ key = key.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
139
+ self.head_dim).transpose(1, 2)
140
+ value = value.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
141
+ self.head_dim).transpose(1, 2)
142
+
143
+ # Expand keys and values to match num_attention_heads
144
+ key = key.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
145
+ value = value.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
146
+
147
+ # Apply rotary embeddings to query and key
148
+ cos, sin = position_embeddings
149
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
150
+
151
+ # Scaled dot-product attention
152
+ attention_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, is_causal=True)
153
+
154
+ # Reshape back to [batch_size, seq_length, hidden_size]
155
+ attention_output = attention_output.transpose(1, 2).contiguous().view(hidden_states.size(0), -1,
156
+ self.hidden_size)
157
+
158
+ # Output projection
159
+ attention_output = self.out_proj(attention_output)
160
+ '''
161
+ attention_output = self.attn(hidden_states, position_embeddings=position_embeddings)
162
+
163
+ # Residual connection
164
+ hidden_states = residual + attention_output
165
+
166
+ # Feedforward layer
167
+ residual = hidden_states
168
+
169
+ # Feed-forward
170
+ hidden_states = self.layer_norm_2(hidden_states)
171
+ feed_forward_output = self.feed_forward(hidden_states)
172
+
173
+ hidden_states = residual + feed_forward_output
174
+
175
+ return hidden_states
176
+
177
+
178
+ class SmollM(nn.Module):
179
  def __init__(self, config):
180
+ super(SmollM, self).__init__()
181
+ self.vocab_size = config['vocab_size']
182
+ self.hidden_size = config['hidden_size']
183
+ self.num_hidden_layers = config['num_hidden_layers']
184
+ self.num_attention_heads = config['num_attention_heads']
185
+ self.num_key_value_heads = config['num_key_value_heads']
186
+ self.max_position_embeddings = config['max_position_embeddings']
187
+ self.intermediate_size = config['intermediate_size']
188
+ self.initializer_range = config['initializer_range']
189
+ self.eps = config['rms_norm_eps']
190
+
191
+ self.head_dim = self.hidden_size // self.num_attention_heads
192
+
193
+ self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
194
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
195
+
196
+ self.layers = nn.ModuleList([
197
+ TransformerBlock(
198
+ hidden_size=self.hidden_size,
199
+ num_attention_heads=self.num_attention_heads,
200
+ num_key_value_heads=self.num_key_value_heads,
201
+ intermediate_size=self.intermediate_size,
202
+ eps=self.eps,
203
+ activation_fn=F.silu # Activation function specified in config
204
+ ) for _ in range(self.num_hidden_layers)
205
+ ])
206
+
207
+ self.layer_norm = LlamaRMSNorm(self.hidden_size, eps=self.eps)
208
+
209
+ # Language modeling head
210
+ self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
211
+
212
+ # Share weights between embedding and lm_head
213
+ self.lm_head.weight = self.embedding.weight
214
+
215
+ self._init_weights()
216
+
217
+ def forward(self, input_ids, attention_mask=None):
218
+ batch_size, seq_length = input_ids.size()
219
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
220
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
221
+
222
+ embeddings = self.embedding(input_ids)
223
+
224
+ hidden_states = embeddings
225
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
226
+
227
+ for layer in self.layers:
228
+ hidden_states = layer(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
229
+
230
+ hidden_states = self.layer_norm(hidden_states)
231
+ logits = self.lm_head(hidden_states)
232
+ return logits
233
+
234
+ def _init_weights(self):
235
+ for module in self.modules():
236
+ if isinstance(module, nn.Linear):
237
+ nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
238
+ if module.bias is not None:
239
+ nn.init.zeros_(module.bias)
240
+ elif isinstance(module, nn.Embedding):
241
+ nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
242
+ elif isinstance(module, nn.LayerNorm):
243
+ nn.init.constant_(module.bias, 0)
244
+ nn.init.constant_(module.weight, 1.0)
245
+