Spaces:
Running
Running
# model.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers.models.llama.modeling_llama import ( | |
LlamaRotaryEmbedding, | |
LlamaRMSNorm, | |
) | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
"""Applies Rotary Position Embedding to the query and key tensors. | |
Args: | |
q (`torch.Tensor`): The query tensor. | |
k (`torch.Tensor`): The key tensor. | |
cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
sin (`torch.Tensor`): The sine part of the rotary embedding. | |
position_ids (`torch.Tensor`, *optional*): | |
Deprecated and unused. | |
unsqueeze_dim (`int`, *optional*, defaults to 1): | |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
Returns: | |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
""" | |
cos = cos.unsqueeze(unsqueeze_dim) | |
sin = sin.unsqueeze(unsqueeze_dim) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
class CausalAttention(nn.Module): | |
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_attention_heads = num_attention_heads | |
self.num_key_value_heads = num_key_value_heads | |
self.head_dim = hidden_size // num_attention_heads | |
self.num_key_value_groups = num_attention_heads // num_key_value_heads | |
self.scaling = self.head_dim ** -0.5 | |
#self.attention_dropout = attention_dropout | |
self.is_causal = True | |
# Query, Key, Value projections | |
self.q_proj = nn.Linear(hidden_size, self.head_dim * num_attention_heads, bias=False) | |
self.k_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False) | |
self.v_proj = nn.Linear(hidden_size, self.head_dim * num_key_value_heads, bias=False) | |
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
def forward(self, hidden_states, attention_mask=None, position_embeddings=None): | |
batch, seq_len = hidden_states.shape[:-1] | |
hidden_shape = (batch, seq_len, -1, self.head_dim) | |
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
cos, sin = position_embeddings | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
y = F.scaled_dot_product_attention(query_states, | |
key_states, | |
value_states, | |
is_causal=True, | |
enable_gqa=True) # Flash attention | |
y = y.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) # re-assemble all head outputs side by side | |
# output projection | |
y = self.o_proj(y) | |
return y | |
class MLP(nn.Module): ###Inspired from LLamaMLP | |
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.intermediate_size = intermediate_size | |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
self.act_fn = activation_fn | |
def forward(self, x): | |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
return down_proj | |
class TransformerBlock(nn.Module): | |
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn): | |
super(TransformerBlock, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_attention_heads = num_attention_heads | |
self.num_key_value_heads = num_key_value_heads | |
self.head_dim = hidden_size // num_attention_heads | |
assert self.head_dim * num_attention_heads == hidden_size, "Hidden size must be divisible by the number of attention heads." | |
assert self.hidden_size % self.num_key_value_heads == 0, "hidden_size must be divisible by num_key_value_heads" | |
self.layer_norm_1 = LlamaRMSNorm(self.hidden_size, eps=eps) | |
self.attn = CausalAttention(hidden_size, num_attention_heads, num_key_value_heads) | |
# Feedforward layer | |
self.feed_forward = MLP(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, eps, activation_fn) | |
self.layer_norm_2 = LlamaRMSNorm(self.hidden_size, eps=eps) | |
def forward(self, hidden_states, attention_mask=None, position_embeddings=None): | |
# Layer normalization | |
residual = hidden_states | |
hidden_states = self.layer_norm_1(hidden_states) | |
''' | |
# Query projection | |
query = self.query_proj(hidden_states) | |
query = query.view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads, | |
self.head_dim).transpose(1, 2) | |
# Key and Value projections with shared num_key_value_heads | |
key = self.key_proj(hidden_states) | |
value = self.value_proj(hidden_states) | |
key = key.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value = value.view(hidden_states.size(0), hidden_states.size(1), self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
# Expand keys and values to match num_attention_heads | |
key = key.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1) | |
value = value.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1) | |
# Apply rotary embeddings to query and key | |
cos, sin = position_embeddings | |
query, key = apply_rotary_pos_emb(query, key, cos, sin) | |
# Scaled dot-product attention | |
attention_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, is_causal=True) | |
# Reshape back to [batch_size, seq_length, hidden_size] | |
attention_output = attention_output.transpose(1, 2).contiguous().view(hidden_states.size(0), -1, | |
self.hidden_size) | |
# Output projection | |
attention_output = self.out_proj(attention_output) | |
''' | |
attention_output = self.attn(hidden_states, position_embeddings=position_embeddings) | |
# Residual connection | |
hidden_states = residual + attention_output | |
# Feedforward layer | |
residual = hidden_states | |
# Feed-forward | |
hidden_states = self.layer_norm_2(hidden_states) | |
feed_forward_output = self.feed_forward(hidden_states) | |
hidden_states = residual + feed_forward_output | |
return hidden_states | |
class SmollM(nn.Module): | |
def __init__(self, config): | |
super(SmollM, self).__init__() | |
self.vocab_size = config['vocab_size'] | |
self.hidden_size = config['hidden_size'] | |
self.num_hidden_layers = config['num_hidden_layers'] | |
self.num_attention_heads = config['num_attention_heads'] | |
self.num_key_value_heads = config['num_key_value_heads'] | |
self.max_position_embeddings = config['max_position_embeddings'] | |
self.intermediate_size = config['intermediate_size'] | |
self.initializer_range = config['initializer_range'] | |
self.eps = config['rms_norm_eps'] | |
self.head_dim = self.hidden_size // self.num_attention_heads | |
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) | |
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) | |
self.layers = nn.ModuleList([ | |
TransformerBlock( | |
hidden_size=self.hidden_size, | |
num_attention_heads=self.num_attention_heads, | |
num_key_value_heads=self.num_key_value_heads, | |
intermediate_size=self.intermediate_size, | |
eps=self.eps, | |
activation_fn=F.silu # Activation function specified in config | |
) for _ in range(self.num_hidden_layers) | |
]) | |
self.layer_norm = LlamaRMSNorm(self.hidden_size, eps=self.eps) | |
# Language modeling head | |
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) | |
# Share weights between embedding and lm_head | |
self.lm_head.weight = self.embedding.weight | |
self._init_weights() | |
def forward(self, input_ids, attention_mask=None): | |
batch_size, seq_length = input_ids.size() | |
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) | |
embeddings = self.embedding(input_ids) | |
hidden_states = embeddings | |
position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
for layer in self.layers: | |
hidden_states = layer(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings) | |
hidden_states = self.layer_norm(hidden_states) | |
logits = self.lm_head(hidden_states) | |
return logits | |
def _init_weights(self): | |
for module in self.modules(): | |
if isinstance(module, nn.Linear): | |
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Embedding): | |
nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) | |
elif isinstance(module, nn.LayerNorm): | |
nn.init.constant_(module.bias, 0) | |
nn.init.constant_(module.weight, 1.0) | |