Spaces:
Running
Running
Upload model.py
Browse files
model.py
CHANGED
@@ -1,202 +1,245 @@
|
|
1 |
-
#
|
2 |
-
import os
|
3 |
-
import math
|
4 |
-
import time
|
5 |
-
import inspect
|
6 |
-
from dataclasses import dataclass
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
self.
|
19 |
-
# output projection
|
20 |
-
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
21 |
-
self.c_proj.NANGPT_SCALE_INIT = 1
|
22 |
-
# regularization
|
23 |
-
self.n_head = config.n_head
|
24 |
-
self.n_embd = config.n_embd
|
25 |
-
self.register_buffer("bias",
|
26 |
-
torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size,
|
27 |
-
config.block_size))
|
28 |
|
29 |
-
def forward(self,
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
# output projection
|
47 |
-
y = self.
|
48 |
return y
|
49 |
|
50 |
|
51 |
-
class MLP(nn.Module):
|
52 |
|
53 |
-
|
|
|
54 |
super().__init__()
|
55 |
-
|
56 |
-
self.
|
57 |
-
self.
|
58 |
-
self.
|
|
|
|
|
|
|
59 |
|
60 |
def forward(self, x):
|
61 |
-
|
62 |
-
|
63 |
-
x = self.c_proj(x)
|
64 |
-
return x
|
65 |
|
66 |
|
67 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
super().__init__()
|
71 |
-
self.ln_1 = nn.LayerNorm(config.n_embd)
|
72 |
-
self.attn = CausalSelfAttention(config)
|
73 |
-
self.ln_2 = nn.LayerNorm(config.n_embd)
|
74 |
-
self.mlp = MLP(config)
|
75 |
|
76 |
-
|
77 |
-
x = x + self.attn(self.ln_1(x))
|
78 |
-
x = x + self.mlp(self.ln_2(x))
|
79 |
-
return x
|
80 |
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
n_layer: int = 12 # number of layers
|
87 |
-
n_head: int = 8 # number of heads
|
88 |
-
n_embd: int = 768 # embedding dimension
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
|
|
|
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def __init__(self, config):
|
94 |
-
super().__init__()
|
95 |
-
self.
|
96 |
-
|
97 |
-
self.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
self.
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
#
|
124 |
-
|
125 |
-
|
126 |
-
#
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
model = GPT(config)
|
161 |
-
sd = model.state_dict()
|
162 |
-
sd_keys = sd.keys()
|
163 |
-
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
164 |
-
|
165 |
-
# init a huggingface/transformers model
|
166 |
-
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
167 |
-
sd_hf = model_hf.state_dict()
|
168 |
-
|
169 |
-
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
170 |
-
sd_keys_hf = sd_hf.keys()
|
171 |
-
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
172 |
-
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
173 |
-
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
174 |
-
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
175 |
-
# this means that we have to transpose these weights when we import them
|
176 |
-
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
177 |
-
for k in sd_keys_hf:
|
178 |
-
if any(k.endswith(w) for w in transposed):
|
179 |
-
# special treatment for the Conv1D weights we need to transpose
|
180 |
-
assert sd_hf[k].shape[::-1] == sd[k].shape
|
181 |
-
with torch.no_grad():
|
182 |
-
sd[k].copy_(sd_hf[k].t())
|
183 |
-
else:
|
184 |
-
# vanilla copy over the other parameters
|
185 |
-
assert sd_hf[k].shape == sd[k].shape
|
186 |
-
with torch.no_grad():
|
187 |
-
sd[k].copy_(sd_hf[k])
|
188 |
-
|
189 |
-
return model
|
190 |
-
|
191 |
-
def generate(self, input_tensor, max_length, EOS_TOKEN_ID=50256):
|
192 |
-
output_ids = input_tensor # Start with input
|
193 |
-
self.eval()
|
194 |
-
for _ in range(max_length - input_tensor.size(1)):
|
195 |
-
logits = self(input_tensor) # Forward pass
|
196 |
-
if isinstance(logits, tuple):
|
197 |
-
logits = logits[0]
|
198 |
-
next_token = torch.argmax(logits[:, -1, :], dim=-1) # Get the next token
|
199 |
-
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
|
200 |
-
if next_token.item() == EOS_TOKEN_ID: # Stop if end-of-sequence token is generated
|
201 |
-
break
|
202 |
-
return input_tensor
|
|
|
1 |
+
# model.py
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers.models.llama.modeling_llama import (
|
6 |
+
LlamaRotaryEmbedding,
|
7 |
+
LlamaRMSNorm,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
12 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
q (`torch.Tensor`): The query tensor.
|
16 |
+
k (`torch.Tensor`): The key tensor.
|
17 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
18 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
19 |
+
position_ids (`torch.Tensor`, *optional*):
|
20 |
+
Deprecated and unused.
|
21 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
22 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
23 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
24 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
25 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
26 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
27 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
28 |
+
Returns:
|
29 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
30 |
+
"""
|
31 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
32 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
33 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
34 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
35 |
+
return q_embed, k_embed
|
36 |
+
|
37 |
+
|
38 |
+
def rotate_half(x):
|
39 |
+
"""Rotates half the hidden dims of the input."""
|
40 |
+
x1 = x[..., : x.shape[-1] // 2]
|
41 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
42 |
+
return torch.cat((-x2, x1), dim=-1)
|
43 |
+
|
44 |
+
|
45 |
+
class CausalAttention(nn.Module):
|
46 |
+
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
|
47 |
+
super().__init__()
|
48 |
|
49 |
+
self.hidden_size = hidden_size
|
50 |
+
self.num_attention_heads = num_attention_heads
|
51 |
+
self.num_key_value_heads = num_key_value_heads
|
52 |
+
self.head_dim = hidden_size // num_attention_heads
|
53 |
|
54 |
+
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
55 |
+
self.scaling = self.head_dim ** -0.5
|
56 |
+
#self.attention_dropout = attention_dropout
|
57 |
+
self.is_causal = True
|
58 |
|
59 |
+
# Query, Key, Value projections
|
60 |
+
self.q_proj = nn.Linear(hidden_size, self.head_dim * num_attention_heads, bias=False)
|
61 |
+
self.k_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
|
62 |
+
self.v_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
|
63 |
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
66 |
+
batch, seq_len = hidden_states.shape[:-1]
|
67 |
+
hidden_shape = (batch, seq_len, -1, self.head_dim)
|
68 |
+
|
69 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
70 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
71 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
72 |
+
|
73 |
+
cos, sin = position_embeddings
|
74 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
75 |
+
|
76 |
+
y = F.scaled_dot_product_attention(query_states,
|
77 |
+
key_states,
|
78 |
+
value_states,
|
79 |
+
is_causal=True,
|
80 |
+
enable_gqa=True) # Flash attention
|
81 |
+
|
82 |
+
y = y.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) # re-assemble all head outputs side by side
|
83 |
# output projection
|
84 |
+
y = self.o_proj(y)
|
85 |
return y
|
86 |
|
87 |
|
|
|
88 |
|
89 |
+
class MLP(nn.Module): ###Inspired from LLamaMLP
|
90 |
+
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
|
91 |
super().__init__()
|
92 |
+
|
93 |
+
self.hidden_size = hidden_size
|
94 |
+
self.intermediate_size = intermediate_size
|
95 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
96 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
97 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
98 |
+
self.act_fn = activation_fn
|
99 |
|
100 |
def forward(self, x):
|
101 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
102 |
+
return down_proj
|
|
|
|
|
103 |
|
104 |
|
105 |
+
class TransformerBlock(nn.Module):
|
106 |
+
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
|
107 |
+
super(TransformerBlock, self).__init__()
|
108 |
+
self.hidden_size = hidden_size
|
109 |
+
self.num_attention_heads = num_attention_heads
|
110 |
+
self.num_key_value_heads = num_key_value_heads
|
111 |
+
self.head_dim = hidden_size // num_attention_heads
|
112 |
+
assert self.head_dim * num_attention_heads == hidden_size, "Hidden size must be divisible by the number of attention heads."
|
113 |
+
assert self.hidden_size % self.num_key_value_heads == 0, "hidden_size must be divisible by num_key_value_heads"
|
114 |
|
115 |
+
self.layer_norm_1 = LlamaRMSNorm(self.hidden_size, eps=eps)
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
self.attn = CausalAttention(hidden_size, num_attention_heads, num_key_value_heads)
|
|
|
|
|
|
|
118 |
|
119 |
+
# Feedforward layer
|
120 |
+
self.feed_forward = MLP(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn)
|
121 |
+
self.layer_norm_2 = LlamaRMSNorm(self.hidden_size, eps=eps)
|
122 |
|
123 |
+
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
124 |
+
# Layer normalization
|
125 |
+
residual = hidden_states
|
126 |
+
hidden_states = self.layer_norm_1(hidden_states)
|
|
|
|
|
|
|
127 |
|
128 |
+
'''
|
129 |
+
# Query projection
|
130 |
+
query = self.query_proj(hidden_states)
|
131 |
+
query = query.view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads,
|
132 |
+
self.head_dim).transpose(1, 2)
|
133 |
|
134 |
+
# Key and Value projections with shared num_key_value_heads
|
135 |
+
key = self.key_proj(hidden_states)
|
136 |
+
value = self.value_proj(hidden_states)
|
137 |
|
138 |
+
key = key.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
|
139 |
+
self.head_dim).transpose(1, 2)
|
140 |
+
value = value.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
|
141 |
+
self.head_dim).transpose(1, 2)
|
142 |
+
|
143 |
+
# Expand keys and values to match num_attention_heads
|
144 |
+
key = key.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
|
145 |
+
value = value.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
|
146 |
+
|
147 |
+
# Apply rotary embeddings to query and key
|
148 |
+
cos, sin = position_embeddings
|
149 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
150 |
+
|
151 |
+
# Scaled dot-product attention
|
152 |
+
attention_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, is_causal=True)
|
153 |
+
|
154 |
+
# Reshape back to [batch_size, seq_length, hidden_size]
|
155 |
+
attention_output = attention_output.transpose(1, 2).contiguous().view(hidden_states.size(0), -1,
|
156 |
+
self.hidden_size)
|
157 |
+
|
158 |
+
# Output projection
|
159 |
+
attention_output = self.out_proj(attention_output)
|
160 |
+
'''
|
161 |
+
attention_output = self.attn(hidden_states, position_embeddings=position_embeddings)
|
162 |
+
|
163 |
+
# Residual connection
|
164 |
+
hidden_states = residual + attention_output
|
165 |
+
|
166 |
+
# Feedforward layer
|
167 |
+
residual = hidden_states
|
168 |
+
|
169 |
+
# Feed-forward
|
170 |
+
hidden_states = self.layer_norm_2(hidden_states)
|
171 |
+
feed_forward_output = self.feed_forward(hidden_states)
|
172 |
+
|
173 |
+
hidden_states = residual + feed_forward_output
|
174 |
+
|
175 |
+
return hidden_states
|
176 |
+
|
177 |
+
|
178 |
+
class SmollM(nn.Module):
|
179 |
def __init__(self, config):
|
180 |
+
super(SmollM, self).__init__()
|
181 |
+
self.vocab_size = config['vocab_size']
|
182 |
+
self.hidden_size = config['hidden_size']
|
183 |
+
self.num_hidden_layers = config['num_hidden_layers']
|
184 |
+
self.num_attention_heads = config['num_attention_heads']
|
185 |
+
self.num_key_value_heads = config['num_key_value_heads']
|
186 |
+
self.max_position_embeddings = config['max_position_embeddings']
|
187 |
+
self.intermediate_size = config['intermediate_size']
|
188 |
+
self.initializer_range = config['initializer_range']
|
189 |
+
self.eps = config['rms_norm_eps']
|
190 |
+
|
191 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
192 |
+
|
193 |
+
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
|
194 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
|
195 |
+
|
196 |
+
self.layers = nn.ModuleList([
|
197 |
+
TransformerBlock(
|
198 |
+
hidden_size=self.hidden_size,
|
199 |
+
num_attention_heads=self.num_attention_heads,
|
200 |
+
num_key_value_heads=self.num_key_value_heads,
|
201 |
+
intermediate_size=self.intermediate_size,
|
202 |
+
eps=self.eps,
|
203 |
+
activation_fn=F.silu # Activation function specified in config
|
204 |
+
) for _ in range(self.num_hidden_layers)
|
205 |
+
])
|
206 |
+
|
207 |
+
self.layer_norm = LlamaRMSNorm(self.hidden_size, eps=self.eps)
|
208 |
+
|
209 |
+
# Language modeling head
|
210 |
+
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
211 |
+
|
212 |
+
# Share weights between embedding and lm_head
|
213 |
+
self.lm_head.weight = self.embedding.weight
|
214 |
+
|
215 |
+
self._init_weights()
|
216 |
+
|
217 |
+
def forward(self, input_ids, attention_mask=None):
|
218 |
+
batch_size, seq_length = input_ids.size()
|
219 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
220 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
221 |
+
|
222 |
+
embeddings = self.embedding(input_ids)
|
223 |
+
|
224 |
+
hidden_states = embeddings
|
225 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
226 |
+
|
227 |
+
for layer in self.layers:
|
228 |
+
hidden_states = layer(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
|
229 |
+
|
230 |
+
hidden_states = self.layer_norm(hidden_states)
|
231 |
+
logits = self.lm_head(hidden_states)
|
232 |
+
return logits
|
233 |
+
|
234 |
+
def _init_weights(self):
|
235 |
+
for module in self.modules():
|
236 |
+
if isinstance(module, nn.Linear):
|
237 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
|
238 |
+
if module.bias is not None:
|
239 |
+
nn.init.zeros_(module.bias)
|
240 |
+
elif isinstance(module, nn.Embedding):
|
241 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
|
242 |
+
elif isinstance(module, nn.LayerNorm):
|
243 |
+
nn.init.constant_(module.bias, 0)
|
244 |
+
nn.init.constant_(module.weight, 1.0)
|
245 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|