|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from .. import models |
|
from .generate import generate as ar_generate |
|
|
|
|
|
def find_multiple(n: int, k: int): |
|
if n % k == 0: |
|
return n |
|
return n + k - (n % k) |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, scale_factor=10000): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
scale_factor: the base for the scaling factor, default is 10000 |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
omega /= embed_dim / 2. |
|
omega = 1. / scale_factor**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
@dataclass |
|
class ModelArgs: |
|
dim: int = 4096 |
|
n_layer: int = 32 |
|
n_head: int = 32 |
|
|
|
n_kv_head: Optional[int] = None |
|
multiple_of: int = 256 |
|
ffn_dim_multiplier: Optional[float] = None |
|
rope_base: float = 10000 |
|
norm_eps: float = 1e-5 |
|
initializer_range: float = 0.02 |
|
|
|
token_dropout_p: float = 0.1 |
|
attn_dropout_p: float = 0.0 |
|
resid_dropout_p: float = 0.1 |
|
ffn_dropout_p: float = 0.1 |
|
drop_path_rate: float = 0.0 |
|
|
|
num_classes: int = 1000 |
|
class_dropout_prob: float = 0.1 |
|
model_type: str = 'class_cond' |
|
cond_dim: int = 1152 |
|
cond_vocab_size: int = 8192 |
|
|
|
vocab_size: int = 8192 |
|
cls_token_num: int = 1 |
|
|
|
max_batch_size: int = 32 |
|
max_seq_len: int = 2048 |
|
|
|
use_fixed_pe: bool = False |
|
|
|
frame_prediction: bool = False |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
@torch.autocast(device_type='cuda', enabled=False) |
|
def _norm(self, x): |
|
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, in_features, hidden_features, out_features): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False) |
|
self.act = nn.GELU(approximate='tanh') |
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
'survival rate' as the argument. |
|
|
|
""" |
|
if drop_prob == 0. or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
class DropPath(torch.nn.Module): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
""" |
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
def extra_repr(self): |
|
return f'drop_prob={round(self.drop_prob,3):0.3f}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
hidden_dim = 4 * config.dim |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
|
if config.ffn_dim_multiplier is not None: |
|
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) |
|
hidden_dim = find_multiple(hidden_dim, config.multiple_of) |
|
|
|
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) |
|
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) |
|
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) |
|
|
|
def forward(self, x): |
|
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
|
class KVCache(nn.Module): |
|
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype): |
|
super().__init__() |
|
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) |
|
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) |
|
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
|
|
|
def update(self, input_pos, k_val, v_val): |
|
|
|
assert input_pos.shape[0] == k_val.shape[2], f"{input_pos.shape[0]} != {k_val.shape[2]}" |
|
k_out = self.k_cache |
|
v_out = self.v_cache |
|
k_out[:, :, input_pos] = k_val.to(k_out.dtype) |
|
v_out[:, :, input_pos] = v_val.to(v_out.dtype) |
|
|
|
return k_out, v_out |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
assert config.dim % config.n_head == 0 |
|
self.dim = config.dim |
|
self.head_dim = config.dim // config.n_head |
|
self.n_head = config.n_head |
|
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head |
|
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim |
|
|
|
|
|
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False) |
|
self.wo = nn.Linear(config.dim, config.dim, bias=False) |
|
self.kv_cache = None |
|
|
|
|
|
self.attn_dropout_p = config.attn_dropout_p |
|
self.resid_dropout = nn.Dropout(config.resid_dropout_p) |
|
|
|
def forward( |
|
self, x: torch.Tensor, |
|
input_pos: Optional[torch.Tensor] = None, |
|
mask: Optional[torch.Tensor] = None |
|
): |
|
bsz, seqlen, _ = x.shape |
|
kv_size = self.n_kv_head * self.head_dim |
|
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
|
|
|
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) |
|
xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
|
xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
|
|
|
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) |
|
|
|
if self.kv_cache is not None: |
|
keys, values = self.kv_cache.update(input_pos, xk, xv) |
|
else: |
|
keys, values = xk, xv |
|
keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
|
values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
|
|
|
output = F.scaled_dot_product_attention( |
|
xq, keys, values, |
|
attn_mask=mask, |
|
is_causal=True if mask is None else False, |
|
dropout_p=self.attn_dropout_p if self.training else 0) |
|
|
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
|
|
|
output = self.resid_dropout(self.wo(output)) |
|
return output |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, config: ModelArgs, drop_path: float): |
|
super().__init__() |
|
self.attention = Attention(config) |
|
self.feed_forward = FeedForward(config) |
|
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward( |
|
self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None): |
|
h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask)) |
|
out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) |
|
return out |
|
|
|
|
|
class LabelEmbedder(nn.Module): |
|
""" |
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
|
""" |
|
def __init__(self, num_classes, hidden_size, dropout_prob): |
|
super().__init__() |
|
use_cfg_embedding = dropout_prob > 0 |
|
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) |
|
self.num_classes = num_classes |
|
self.dropout_prob = dropout_prob |
|
|
|
def token_drop(self, labels, force_drop_ids=None): |
|
""" |
|
Drops labels to enable classifier-free guidance. |
|
""" |
|
if force_drop_ids is None: |
|
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob |
|
else: |
|
drop_ids = force_drop_ids == 1 |
|
labels = torch.where(drop_ids, self.num_classes, labels) |
|
return labels |
|
|
|
def forward(self, labels, train, force_drop_ids=None): |
|
use_dropout = self.dropout_prob > 0 |
|
if (train and use_dropout) or (force_drop_ids is not None): |
|
labels = self.token_drop(labels, force_drop_ids) |
|
|
|
|
|
labels = torch.where(labels < 0, self.num_classes, labels) |
|
embeddings = self.embedding_table(labels) |
|
return embeddings |
|
|
|
|
|
class ARModel(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
self.config = config |
|
self.vocab_size = config.vocab_size |
|
self.n_layer = config.n_layer |
|
self.max_seq_length = config.max_seq_len |
|
self.num_classes = config.num_classes |
|
self.model_type = config.model_type |
|
self.cls_token_num = config.cls_token_num |
|
self.is_sampling = False |
|
self.frame_prediction = config.frame_prediction |
|
|
|
if self.model_type == 'class_cond': |
|
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob) |
|
elif self.model_type == 'clip_cond': |
|
self.clip_proj = nn.Linear(config.cond_dim, config.dim) |
|
elif self.model_type == 'indice_cond': |
|
self.clip_proj = LabelEmbedder(config.cond_vocab_size + 1, config.dim, 0.0) |
|
else: |
|
raise Exception("please check model type") |
|
|
|
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
|
self.tok_dropout = nn.Dropout(config.token_dropout_p) |
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)] |
|
self.layers = torch.nn.ModuleList() |
|
for layer_id in range(config.n_layer): |
|
self.layers.append(TransformerBlock(config, dpr[layer_id])) |
|
|
|
|
|
self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
|
if config.use_fixed_pe: |
|
self.register_buffer('abs_pe', torch.zeros(1, config.max_seq_len + config.cls_token_num - 1, config.dim)) |
|
abs_pe = get_1d_sincos_pos_embed_from_grid(embed_dim=config.dim, pos=np.arange(config.max_seq_len + config.cls_token_num - 1)) |
|
self.abs_pe.copy_(torch.from_numpy(abs_pe).float().reshape_as(self.abs_pe)) |
|
print(f"Using fixed absolute PE") |
|
else: |
|
self.abs_pe = nn.Parameter(torch.randn(1, config.max_seq_len + config.cls_token_num - 1, config.dim) * 0.02) |
|
print(f"Using learned absolute PE") |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
if hasattr(self.output, 'weight') and isinstance(self.output.weight, nn.Parameter): |
|
nn.init.constant_(self.output.weight, 0) |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
|
|
@contextmanager |
|
def sampling(self): |
|
self.is_sampling = True |
|
try: |
|
yield |
|
finally: |
|
self.is_sampling = False |
|
|
|
|
|
def setup_caches(self, max_batch_size, max_seq_length, dtype): |
|
assert max_seq_length == self.max_seq_length + self.cls_token_num, f'{max_seq_length} != {self.max_seq_length} + {self.cls_token_num=}' |
|
|
|
head_dim = self.config.dim // self.config.n_head |
|
max_seq_length = find_multiple(max_seq_length, 8) |
|
|
|
for b in self.layers: |
|
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype) |
|
|
|
causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool)) |
|
self.causal_mask = causal_mask.unsqueeze(0).repeat(max_batch_size, 1, 1) |
|
|
|
|
|
def reset_caches(self): |
|
for b in self.layers: |
|
b.attention.kv_cache = None |
|
|
|
def clip_embedding(self, x): |
|
if self.model_type == 'clip_cond': |
|
if self.training and self.config.class_dropout_prob > 0: |
|
drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob |
|
x[drop_ids] = 0. |
|
x = self.clip_proj(x.to(self.dtype)) |
|
elif self.model_type == 'indice_cond': |
|
if self.training and self.config.class_dropout_prob > 0: |
|
drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob |
|
x[drop_ids] = self.config.cond_vocab_size |
|
x = self.clip_proj(x, train=self.training) |
|
return x |
|
|
|
def forward( |
|
self, |
|
idx: Optional[torch.Tensor], |
|
cond_idx: Optional[torch.Tensor], |
|
input_pos: Optional[torch.Tensor] = None, |
|
targets: Optional[torch.Tensor] = None, |
|
mask: Optional[torch.Tensor] = None, |
|
valid: Optional[torch.Tensor] = None, |
|
): |
|
if idx is not None and cond_idx is not None: |
|
if self.model_type == 'class_cond': |
|
cond_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num] |
|
elif self.model_type in ['clip_cond', 'indice_cond']: |
|
cond_embeddings = self.clip_embedding(cond_idx) |
|
token_embeddings = self.tok_embeddings(idx) |
|
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) |
|
h = self.tok_dropout(token_embeddings) |
|
else: |
|
if cond_idx is not None: |
|
if self.model_type == 'class_cond': |
|
token_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num] |
|
elif self.model_type in ['clip_cond', 'indice_cond']: |
|
token_embeddings = self.clip_embedding(cond_idx) |
|
else: |
|
token_embeddings = self.tok_embeddings(idx) |
|
|
|
bs = token_embeddings.shape[0] |
|
mask = self.causal_mask[:bs, None, input_pos] |
|
h = self.tok_dropout(token_embeddings) |
|
|
|
if self.is_sampling: |
|
h = h + self.abs_pe[:, input_pos] |
|
else: |
|
h = h + self.abs_pe[:, :h.shape[1]] |
|
|
|
|
|
for layer in self.layers: |
|
h = layer(h, input_pos, mask) |
|
|
|
|
|
h = self.norm(h) |
|
logits = self.output(h) |
|
|
|
if cond_idx is not None: |
|
|
|
|
|
logits = logits[:, cond_idx.size(1) - 1:].contiguous() |
|
|
|
|
|
loss = None |
|
if valid is not None: |
|
loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') |
|
valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1) |
|
loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1) |
|
elif targets is not None: |
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
return logits, loss |
|
|
|
|
|
@torch.inference_mode() |
|
def sample( |
|
self, |
|
c, |
|
cfg_scale=2.0, |
|
cfg_interval=-1, |
|
temperature=1.0, |
|
top_k=0, |
|
top_p=1.0, |
|
seq_length=None, |
|
): |
|
seq_length = self.max_seq_length if seq_length is None else seq_length |
|
with self.sampling(): |
|
sampled_seqs = ar_generate( |
|
self, c, seq_length, |
|
cfg_scale=cfg_scale, cfg_interval=cfg_interval, |
|
temperature=temperature, top_k=top_k, |
|
top_p=top_p, sample_logits=True, |
|
) |
|
return sampled_seqs |
|
|
|
|
|
@classmethod |
|
def from_checkpoint(cls, ckpt, load_state_dict=True): |
|
if isinstance(ckpt, str): |
|
assert os.path.exists(ckpt), f"checkpoint {ckpt} does not exist" |
|
ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) |
|
else: |
|
assert isinstance( |
|
ckpt, dict |
|
), f"checkpoint must be a dict or a path to a checkpoint" |
|
model = models.make(ckpt["model"], load_sd=load_state_dict) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
def LLAMA_ABS_XXXL(**kwargs): |
|
return ARModel(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) |
|
|
|
def LLAMA_ABS_XXL(**kwargs): |
|
return ARModel(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) |
|
|
|
def LLAMA_ABS_XL(**kwargs): |
|
return ARModel(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) |
|
|
|
def LLAMA_ABS_LP(**kwargs): |
|
return ARModel(ModelArgs(n_layer=30, n_head=20, dim=1280, **kwargs)) |
|
|
|
def LLAMA_ABS_L(**kwargs): |
|
return ARModel(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) |
|
|
|
def LLAMA_ABS_B(**kwargs): |
|
return ARModel(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) |
|
|
|
def LLAMA_ABS_S(**kwargs): |
|
return ARModel(ModelArgs(n_layer=12, n_head=6, dim=384, **kwargs)) |
|
|
|
ar_models = { |
|
'llama-abs-S': LLAMA_ABS_S, |
|
'llama-abs-B': LLAMA_ABS_B, |
|
'llama-abs-L': LLAMA_ABS_L, |
|
'llama-abs-LP': LLAMA_ABS_LP, |
|
'llama-abs-XL': LLAMA_ABS_XL, |
|
'llama-abs-XXL': LLAMA_ABS_XXL, |
|
'llama-abs-XXXL': LLAMA_ABS_XXXL, |
|
} |
|
|
|
models.models.update(ar_models) |
|
|