Spaces:
Running
Running
File size: 10,960 Bytes
189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a 8d780f0 189668a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
LlamaRMSNorm,
)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CausalAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
#self.attention_dropout = attention_dropout
self.is_causal = True
# Query, Key, Value projections
self.q_proj = nn.Linear(hidden_size, self.head_dim * num_attention_heads, bias=False)
self.k_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
self.v_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
batch, seq_len = hidden_states.shape[:-1]
hidden_shape = (batch, seq_len, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
y = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
is_causal=True,
enable_gqa=True) # Flash attention
y = y.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) # re-assemble all head outputs side by side
# output projection
y = self.o_proj(y)
return y
class MLP(nn.Module): ###Inspired from LLamaMLP
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = activation_fn
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class TransformerBlock(nn.Module):
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn):
super(TransformerBlock, self).__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
assert self.head_dim * num_attention_heads == hidden_size, "Hidden size must be divisible by the number of attention heads."
assert self.hidden_size % self.num_key_value_heads == 0, "hidden_size must be divisible by num_key_value_heads"
self.layer_norm_1 = LlamaRMSNorm(self.hidden_size, eps=eps)
self.attn = CausalAttention(hidden_size, num_attention_heads, num_key_value_heads)
# Feedforward layer
self.feed_forward = MLP(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn)
self.layer_norm_2 = LlamaRMSNorm(self.hidden_size, eps=eps)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
# Layer normalization
residual = hidden_states
hidden_states = self.layer_norm_1(hidden_states)
'''
# Query projection
query = self.query_proj(hidden_states)
query = query.view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads,
self.head_dim).transpose(1, 2)
# Key and Value projections with shared num_key_value_heads
key = self.key_proj(hidden_states)
value = self.value_proj(hidden_states)
key = key.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value = value.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads,
self.head_dim).transpose(1, 2)
# Expand keys and values to match num_attention_heads
key = key.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
value = value.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
# Apply rotary embeddings to query and key
cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# Scaled dot-product attention
attention_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, is_causal=True)
# Reshape back to [batch_size, seq_length, hidden_size]
attention_output = attention_output.transpose(1, 2).contiguous().view(hidden_states.size(0), -1,
self.hidden_size)
# Output projection
attention_output = self.out_proj(attention_output)
'''
attention_output = self.attn(hidden_states, position_embeddings=position_embeddings)
# Residual connection
hidden_states = residual + attention_output
# Feedforward layer
residual = hidden_states
# Feed-forward
hidden_states = self.layer_norm_2(hidden_states)
feed_forward_output = self.feed_forward(hidden_states)
hidden_states = residual + feed_forward_output
return hidden_states
class SmollM(nn.Module):
def __init__(self, config):
super(SmollM, self).__init__()
self.vocab_size = config['vocab_size']
self.hidden_size = config['hidden_size']
self.num_hidden_layers = config['num_hidden_layers']
self.num_attention_heads = config['num_attention_heads']
self.num_key_value_heads = config['num_key_value_heads']
self.max_position_embeddings = config['max_position_embeddings']
self.intermediate_size = config['intermediate_size']
self.initializer_range = config['initializer_range']
self.eps = config['rms_norm_eps']
self.head_dim = self.hidden_size // self.num_attention_heads
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
self.layers = nn.ModuleList([
TransformerBlock(
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
eps=self.eps,
activation_fn=F.silu # Activation function specified in config
) for _ in range(self.num_hidden_layers)
])
self.layer_norm = LlamaRMSNorm(self.hidden_size, eps=self.eps)
# Language modeling head
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
# Share weights between embedding and lm_head
self.lm_head.weight = self.embedding.weight
self._init_weights()
def forward(self, input_ids, attention_mask=None):
batch_size, seq_length = input_ids.size()
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
embeddings = self.embedding(input_ids)
hidden_states = embeddings
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
hidden_states = self.layer_norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
|