Tar-7B / tok /ar_dtok /ar_model.py
hanjiaming.0208
init
146dae5
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .. import models
from .generate import generate as ar_generate
def find_multiple(n: int, k: int):
if n % k == 0:
return n
return n + k - (n % k)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, scale_factor=10000):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
scale_factor: the base for the scaling factor, default is 10000
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / scale_factor**omega # Parameterized scaling factor (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
@dataclass
class ModelArgs:
dim: int = 4096
n_layer: int = 32
n_head: int = 32
n_kv_head: Optional[int] = None
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
rope_base: float = 10000
norm_eps: float = 1e-5
initializer_range: float = 0.02
token_dropout_p: float = 0.1
attn_dropout_p: float = 0.0
resid_dropout_p: float = 0.1
ffn_dropout_p: float = 0.1
drop_path_rate: float = 0.0
num_classes: int = 1000
class_dropout_prob: float = 0.1
model_type: str = 'class_cond' # clip_cond, indice_cond
cond_dim: int = 1152
cond_vocab_size: int = 8192
vocab_size: int = 8192
cls_token_num: int = 1
max_batch_size: int = 32
max_seq_len: int = 2048
use_fixed_pe: bool = False
frame_prediction: bool = False
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
@torch.autocast(device_type='cuda', enabled=False)
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
#################################################################################
# Drop Path Implementation #
#################################################################################
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(torch.nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
#################################################################################
# AR Model #
#################################################################################
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if config.ffn_dim_multiplier is not None:
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
def forward(self, x):
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
super().__init__()
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2], f"{input_pos.shape[0]} != {k_val.shape[2]}"
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val.to(k_out.dtype)
v_out[:, :, input_pos] = v_val.to(v_out.dtype)
return k_out, v_out
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
self.dim = config.dim
self.head_dim = config.dim // config.n_head
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
# regularization
self.attn_dropout_p = config.attn_dropout_p
self.resid_dropout = nn.Dropout(config.resid_dropout_p)
def forward(
self, x: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
):
bsz, seqlen, _ = x.shape
kv_size = self.n_kv_head * self.head_dim
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
if self.kv_cache is not None:
keys, values = self.kv_cache.update(input_pos, xk, xv)
else:
keys, values = xk, xv
keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
output = F.scaled_dot_product_attention(
xq, keys, values,
attn_mask=mask,
is_causal=True if mask is None else False, # is_causal=False is for KV cache
dropout_p=self.attn_dropout_p if self.training else 0)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
output = self.resid_dropout(self.wo(output))
return output
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs, drop_path: float):
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(
self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
return out
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
# replace all negative labels with the last class (unconditional class)
labels = torch.where(labels < 0, self.num_classes, labels)
embeddings = self.embedding_table(labels)
return embeddings
class ARModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.n_layer = config.n_layer
self.max_seq_length = config.max_seq_len
self.num_classes = config.num_classes
self.model_type = config.model_type
self.cls_token_num = config.cls_token_num
self.is_sampling = False
self.frame_prediction = config.frame_prediction
if self.model_type == 'class_cond':
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
elif self.model_type == 'clip_cond':
self.clip_proj = nn.Linear(config.cond_dim, config.dim)
elif self.model_type == 'indice_cond':
self.clip_proj = LabelEmbedder(config.cond_vocab_size + 1, config.dim, 0.0)
else:
raise Exception("please check model type")
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.tok_dropout = nn.Dropout(config.token_dropout_p)
# transformer blocks
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
self.layers = torch.nn.ModuleList()
for layer_id in range(config.n_layer):
self.layers.append(TransformerBlock(config, dpr[layer_id]))
# output layer
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.use_fixed_pe:
self.register_buffer('abs_pe', torch.zeros(1, config.max_seq_len + config.cls_token_num - 1, config.dim))
abs_pe = get_1d_sincos_pos_embed_from_grid(embed_dim=config.dim, pos=np.arange(config.max_seq_len + config.cls_token_num - 1))
self.abs_pe.copy_(torch.from_numpy(abs_pe).float().reshape_as(self.abs_pe))
print(f"Using fixed absolute PE")
else:
self.abs_pe = nn.Parameter(torch.randn(1, config.max_seq_len + config.cls_token_num - 1, config.dim) * 0.02)
print(f"Using learned absolute PE")
self.initialize_weights()
def initialize_weights(self):
# Initialize nn.Linear and nn.Embedding
self.apply(self._init_weights)
# Zero-out output layers:
if hasattr(self.output, 'weight') and isinstance(self.output.weight, nn.Parameter):
nn.init.constant_(self.output.weight, 0)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@contextmanager
def sampling(self):
self.is_sampling = True
try:
yield
finally:
self.is_sampling = False
def setup_caches(self, max_batch_size, max_seq_length, dtype):
assert max_seq_length == self.max_seq_length + self.cls_token_num, f'{max_seq_length} != {self.max_seq_length} + {self.cls_token_num=}'
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool))
self.causal_mask = causal_mask.unsqueeze(0).repeat(max_batch_size, 1, 1)
def reset_caches(self):
for b in self.layers:
b.attention.kv_cache = None
def clip_embedding(self, x):
if self.model_type == 'clip_cond':
if self.training and self.config.class_dropout_prob > 0:
drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
x[drop_ids] = 0.
x = self.clip_proj(x.to(self.dtype)) # Linear
elif self.model_type == 'indice_cond':
if self.training and self.config.class_dropout_prob > 0:
drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
x[drop_ids] = self.config.cond_vocab_size
x = self.clip_proj(x, train=self.training) # Embedding
return x
def forward(
self,
idx: Optional[torch.Tensor], # (b, n)
cond_idx: Optional[torch.Tensor], # cond_idx_or_embed
input_pos: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
valid: Optional[torch.Tensor] = None,
):
if idx is not None and cond_idx is not None: # training or naive inference
if self.model_type == 'class_cond':
cond_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
elif self.model_type in ['clip_cond', 'indice_cond']:
cond_embeddings = self.clip_embedding(cond_idx)
token_embeddings = self.tok_embeddings(idx) # (b, n, d)
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) # (b, cls_token_num + n, d)
h = self.tok_dropout(token_embeddings)
else:
if cond_idx is not None: # prefill in inference
if self.model_type == 'class_cond':
token_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
elif self.model_type in ['clip_cond', 'indice_cond']:
token_embeddings = self.clip_embedding(cond_idx)
else: # decode_n_tokens(kv cache) in inference
token_embeddings = self.tok_embeddings(idx)
bs = token_embeddings.shape[0]
mask = self.causal_mask[:bs, None, input_pos]
h = self.tok_dropout(token_embeddings)
if self.is_sampling:
h = h + self.abs_pe[:, input_pos]
else:
h = h + self.abs_pe[:, :h.shape[1]]
# transformer blocks
for layer in self.layers:
h = layer(h, input_pos, mask)
# output layers
h = self.norm(h)
logits = self.output(h)
# if self.training or self.is_sampling:
if cond_idx is not None:
# if self.training:
# logits = logits[:, self.cls_token_num - 1:].contiguous()
logits = logits[:, cond_idx.size(1) - 1:].contiguous()
# if we are given some desired targets also calculate the loss
loss = None
if valid is not None:
loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
elif targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.inference_mode()
def sample(
self,
c,
cfg_scale=2.0,
cfg_interval=-1,
temperature=1.0,
top_k=0,
top_p=1.0,
seq_length=None,
):
seq_length = self.max_seq_length if seq_length is None else seq_length
with self.sampling():
sampled_seqs = ar_generate(
self, c, seq_length,
cfg_scale=cfg_scale, cfg_interval=cfg_interval,
temperature=temperature, top_k=top_k,
top_p=top_p, sample_logits=True,
)
return sampled_seqs
@classmethod
def from_checkpoint(cls, ckpt, load_state_dict=True):
if isinstance(ckpt, str):
assert os.path.exists(ckpt), f"checkpoint {ckpt} does not exist"
ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
else:
assert isinstance(
ckpt, dict
), f"checkpoint must be a dict or a path to a checkpoint"
model = models.make(ckpt["model"], load_sd=load_state_dict)
return model
#################################################################################
# LLAMA-ABS Configs #
#################################################################################
def LLAMA_ABS_XXXL(**kwargs):
return ARModel(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
def LLAMA_ABS_XXL(**kwargs):
return ARModel(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
def LLAMA_ABS_XL(**kwargs):
return ARModel(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
def LLAMA_ABS_LP(**kwargs):
return ARModel(ModelArgs(n_layer=30, n_head=20, dim=1280, **kwargs)) # 632M
def LLAMA_ABS_L(**kwargs):
return ARModel(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
def LLAMA_ABS_B(**kwargs):
return ARModel(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
def LLAMA_ABS_S(**kwargs):
return ARModel(ModelArgs(n_layer=12, n_head=6, dim=384, **kwargs)) # 21.7M
ar_models = {
'llama-abs-S': LLAMA_ABS_S,
'llama-abs-B': LLAMA_ABS_B,
'llama-abs-L': LLAMA_ABS_L,
'llama-abs-LP': LLAMA_ABS_LP,
'llama-abs-XL': LLAMA_ABS_XL,
'llama-abs-XXL': LLAMA_ABS_XXL,
'llama-abs-XXXL': LLAMA_ABS_XXXL,
}
models.models.update(ar_models)