|
|
|
|
|
import logging |
|
import math |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
from accelerate import init_empty_weights |
|
|
|
from .attention import flash_attention |
|
from .tokenizers import HuggingfaceTokenizer |
|
from .xlm_roberta import XLMRoberta |
|
|
|
from utils.safetensors_utils import load_safetensors |
|
|
|
__all__ = [ |
|
"XLMRobertaCLIP", |
|
"clip_xlm_roberta_vit_h_14", |
|
"CLIPModel", |
|
] |
|
|
|
|
|
def pos_interpolate(pos, seq_len): |
|
if pos.size(1) == seq_len: |
|
return pos |
|
else: |
|
src_grid = int(math.sqrt(pos.size(1))) |
|
tar_grid = int(math.sqrt(seq_len)) |
|
n = pos.size(1) - src_grid * src_grid |
|
return torch.cat( |
|
[ |
|
pos[:, :n], |
|
F.interpolate( |
|
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), |
|
size=(tar_grid, tar_grid), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
.flatten(2) |
|
.transpose(1, 2), |
|
], |
|
dim=1, |
|
) |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
|
|
def forward(self, x): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
|
|
def forward(self, x): |
|
return super().forward(x.float()).type_as(x) |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): |
|
assert dim % num_heads == 0 |
|
super().__init__() |
|
self.dim = dim |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.causal = causal |
|
self.attn_dropout = attn_dropout |
|
self.proj_dropout = proj_dropout |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, dim * 3) |
|
self.proj = nn.Linear(dim, dim) |
|
|
|
def forward(self, x): |
|
""" |
|
x: [B, L, C]. |
|
""" |
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim |
|
|
|
|
|
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) |
|
|
|
|
|
p = self.attn_dropout if self.training else 0.0 |
|
|
|
|
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=p, is_causal=self.causal) |
|
|
|
x = x.transpose(1, 2).contiguous() |
|
x = x.reshape(b, s, c) |
|
|
|
|
|
x = self.proj(x) |
|
x = F.dropout(x, self.proj_dropout, self.training) |
|
return x |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
|
def __init__(self, dim, mid_dim): |
|
super().__init__() |
|
self.dim = dim |
|
self.mid_dim = mid_dim |
|
|
|
|
|
self.fc1 = nn.Linear(dim, mid_dim) |
|
self.fc2 = nn.Linear(dim, mid_dim) |
|
self.fc3 = nn.Linear(mid_dim, dim) |
|
|
|
def forward(self, x): |
|
x = F.silu(self.fc1(x)) * self.fc2(x) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
mlp_ratio, |
|
num_heads, |
|
post_norm=False, |
|
causal=False, |
|
activation="quick_gelu", |
|
attn_dropout=0.0, |
|
proj_dropout=0.0, |
|
norm_eps=1e-5, |
|
): |
|
assert activation in ["quick_gelu", "gelu", "swi_glu"] |
|
super().__init__() |
|
self.dim = dim |
|
self.mlp_ratio = mlp_ratio |
|
self.num_heads = num_heads |
|
self.post_norm = post_norm |
|
self.causal = causal |
|
self.norm_eps = norm_eps |
|
|
|
|
|
self.norm1 = LayerNorm(dim, eps=norm_eps) |
|
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) |
|
self.norm2 = LayerNorm(dim, eps=norm_eps) |
|
if activation == "swi_glu": |
|
self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) |
|
else: |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
QuickGELU() if activation == "quick_gelu" else nn.GELU(), |
|
nn.Linear(int(dim * mlp_ratio), dim), |
|
nn.Dropout(proj_dropout), |
|
) |
|
|
|
def forward(self, x): |
|
if self.post_norm: |
|
x = x + self.norm1(self.attn(x)) |
|
x = x + self.norm2(self.mlp(x)) |
|
else: |
|
x = x + self.attn(self.norm1(x)) |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|
|
|
|
class AttentionPool(nn.Module): |
|
|
|
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5): |
|
assert dim % num_heads == 0 |
|
super().__init__() |
|
self.dim = dim |
|
self.mlp_ratio = mlp_ratio |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.proj_dropout = proj_dropout |
|
self.norm_eps = norm_eps |
|
|
|
|
|
gain = 1.0 / math.sqrt(dim) |
|
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) |
|
self.to_q = nn.Linear(dim, dim) |
|
self.to_kv = nn.Linear(dim, dim * 2) |
|
self.proj = nn.Linear(dim, dim) |
|
self.norm = LayerNorm(dim, eps=norm_eps) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
QuickGELU() if activation == "quick_gelu" else nn.GELU(), |
|
nn.Linear(int(dim * mlp_ratio), dim), |
|
nn.Dropout(proj_dropout), |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
x: [B, L, C]. |
|
""" |
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim |
|
|
|
|
|
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) |
|
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) |
|
|
|
|
|
|
|
x = flash_attention(q, k, v, version=2) |
|
x = x.reshape(b, 1, c) |
|
|
|
|
|
x = self.proj(x) |
|
x = F.dropout(x, self.proj_dropout, self.training) |
|
|
|
|
|
x = x + self.mlp(self.norm(x)) |
|
return x[:, 0] |
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
image_size=224, |
|
patch_size=16, |
|
dim=768, |
|
mlp_ratio=4, |
|
out_dim=512, |
|
num_heads=12, |
|
num_layers=12, |
|
pool_type="token", |
|
pre_norm=True, |
|
post_norm=False, |
|
activation="quick_gelu", |
|
attn_dropout=0.0, |
|
proj_dropout=0.0, |
|
embedding_dropout=0.0, |
|
norm_eps=1e-5, |
|
): |
|
if image_size % patch_size != 0: |
|
print("[WARNING] image_size is not divisible by patch_size", flush=True) |
|
assert pool_type in ("token", "token_fc", "attn_pool") |
|
out_dim = out_dim or dim |
|
super().__init__() |
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.num_patches = (image_size // patch_size) ** 2 |
|
self.dim = dim |
|
self.mlp_ratio = mlp_ratio |
|
self.out_dim = out_dim |
|
self.num_heads = num_heads |
|
self.num_layers = num_layers |
|
self.pool_type = pool_type |
|
self.post_norm = post_norm |
|
self.norm_eps = norm_eps |
|
|
|
|
|
gain = 1.0 / math.sqrt(dim) |
|
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm) |
|
if pool_type in ("token", "token_fc"): |
|
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) |
|
self.pos_embedding = nn.Parameter( |
|
gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim) |
|
) |
|
self.dropout = nn.Dropout(embedding_dropout) |
|
|
|
|
|
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None |
|
self.transformer = nn.Sequential( |
|
*[ |
|
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.post_norm = LayerNorm(dim, eps=norm_eps) |
|
|
|
|
|
if pool_type == "token": |
|
self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) |
|
elif pool_type == "token_fc": |
|
self.head = nn.Linear(dim, out_dim) |
|
elif pool_type == "attn_pool": |
|
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps) |
|
|
|
def forward(self, x, interpolation=False, use_31_block=False): |
|
b = x.size(0) |
|
|
|
|
|
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) |
|
if self.pool_type in ("token", "token_fc"): |
|
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) |
|
if interpolation: |
|
e = pos_interpolate(self.pos_embedding, x.size(1)) |
|
else: |
|
e = self.pos_embedding |
|
x = self.dropout(x + e) |
|
if self.pre_norm is not None: |
|
x = self.pre_norm(x) |
|
|
|
|
|
if use_31_block: |
|
x = self.transformer[:-1](x) |
|
return x |
|
else: |
|
x = self.transformer(x) |
|
return x |
|
|
|
|
|
class XLMRobertaWithHead(XLMRoberta): |
|
|
|
def __init__(self, **kwargs): |
|
self.out_dim = kwargs.pop("out_dim") |
|
super().__init__(**kwargs) |
|
|
|
|
|
mid_dim = (self.dim + self.out_dim) // 2 |
|
self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)) |
|
|
|
def forward(self, ids): |
|
|
|
x = super().forward(ids) |
|
|
|
|
|
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) |
|
x = (x * mask).sum(dim=1) / mask.sum(dim=1) |
|
|
|
|
|
x = self.head(x) |
|
return x |
|
|
|
|
|
class XLMRobertaCLIP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
embed_dim=1024, |
|
image_size=224, |
|
patch_size=14, |
|
vision_dim=1280, |
|
vision_mlp_ratio=4, |
|
vision_heads=16, |
|
vision_layers=32, |
|
vision_pool="token", |
|
vision_pre_norm=True, |
|
vision_post_norm=False, |
|
activation="gelu", |
|
vocab_size=250002, |
|
max_text_len=514, |
|
type_size=1, |
|
pad_id=1, |
|
text_dim=1024, |
|
text_heads=16, |
|
text_layers=24, |
|
text_post_norm=True, |
|
text_dropout=0.1, |
|
attn_dropout=0.0, |
|
proj_dropout=0.0, |
|
embedding_dropout=0.0, |
|
norm_eps=1e-5, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.vision_dim = vision_dim |
|
self.vision_mlp_ratio = vision_mlp_ratio |
|
self.vision_heads = vision_heads |
|
self.vision_layers = vision_layers |
|
self.vision_pre_norm = vision_pre_norm |
|
self.vision_post_norm = vision_post_norm |
|
self.activation = activation |
|
self.vocab_size = vocab_size |
|
self.max_text_len = max_text_len |
|
self.type_size = type_size |
|
self.pad_id = pad_id |
|
self.text_dim = text_dim |
|
self.text_heads = text_heads |
|
self.text_layers = text_layers |
|
self.text_post_norm = text_post_norm |
|
self.norm_eps = norm_eps |
|
|
|
|
|
self.visual = VisionTransformer( |
|
image_size=image_size, |
|
patch_size=patch_size, |
|
dim=vision_dim, |
|
mlp_ratio=vision_mlp_ratio, |
|
out_dim=embed_dim, |
|
num_heads=vision_heads, |
|
num_layers=vision_layers, |
|
pool_type=vision_pool, |
|
pre_norm=vision_pre_norm, |
|
post_norm=vision_post_norm, |
|
activation=activation, |
|
attn_dropout=attn_dropout, |
|
proj_dropout=proj_dropout, |
|
embedding_dropout=embedding_dropout, |
|
norm_eps=norm_eps, |
|
) |
|
self.textual = XLMRobertaWithHead( |
|
vocab_size=vocab_size, |
|
max_seq_len=max_text_len, |
|
type_size=type_size, |
|
pad_id=pad_id, |
|
dim=text_dim, |
|
out_dim=embed_dim, |
|
num_heads=text_heads, |
|
num_layers=text_layers, |
|
post_norm=text_post_norm, |
|
dropout=text_dropout, |
|
) |
|
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) |
|
|
|
def forward(self, imgs, txt_ids): |
|
""" |
|
imgs: [B, 3, H, W] of torch.float32. |
|
- mean: [0.48145466, 0.4578275, 0.40821073] |
|
- std: [0.26862954, 0.26130258, 0.27577711] |
|
txt_ids: [B, L] of torch.long. |
|
Encoded by data.CLIPTokenizer. |
|
""" |
|
xi = self.visual(imgs) |
|
xt = self.textual(txt_ids) |
|
return xi, xt |
|
|
|
def param_groups(self): |
|
groups = [ |
|
{"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0}, |
|
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]}, |
|
] |
|
return groups |
|
|
|
|
|
def _clip( |
|
pretrained=False, |
|
pretrained_name=None, |
|
model_cls=XLMRobertaCLIP, |
|
return_transforms=False, |
|
return_tokenizer=False, |
|
tokenizer_padding="eos", |
|
dtype=torch.float32, |
|
device="cpu", |
|
**kwargs, |
|
): |
|
|
|
|
|
model = model_cls(**kwargs) |
|
|
|
|
|
|
|
output = (model,) |
|
|
|
|
|
if return_transforms: |
|
|
|
if "siglip" in pretrained_name.lower(): |
|
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
|
else: |
|
mean = [0.48145466, 0.4578275, 0.40821073] |
|
std = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
|
|
transforms = T.Compose( |
|
[ |
|
T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=mean, std=std), |
|
] |
|
) |
|
output += (transforms,) |
|
return output[0] if len(output) == 1 else output |
|
|
|
|
|
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): |
|
cfg = dict( |
|
embed_dim=1024, |
|
image_size=224, |
|
patch_size=14, |
|
vision_dim=1280, |
|
vision_mlp_ratio=4, |
|
vision_heads=16, |
|
vision_layers=32, |
|
vision_pool="token", |
|
activation="gelu", |
|
vocab_size=250002, |
|
max_text_len=514, |
|
type_size=1, |
|
pad_id=1, |
|
text_dim=1024, |
|
text_heads=16, |
|
text_layers=24, |
|
text_post_norm=True, |
|
text_dropout=0.1, |
|
attn_dropout=0.0, |
|
proj_dropout=0.0, |
|
embedding_dropout=0.0, |
|
) |
|
cfg.update(**kwargs) |
|
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) |
|
|
|
|
|
class CLIPModel: |
|
|
|
def __init__(self, dtype, device, checkpoint_path=None, tokenizer_path=None, weight_path=None): |
|
self.dtype = dtype |
|
self.device = device |
|
self.checkpoint_path = checkpoint_path |
|
self.tokenizer_path = tokenizer_path |
|
self.weight_path = weight_path |
|
|
|
|
|
with init_empty_weights(): |
|
self.model, self.transforms = clip_xlm_roberta_vit_h_14( |
|
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device |
|
) |
|
self.model = self.model.eval().requires_grad_(False) |
|
|
|
logging.info(f"loading {weight_path}") |
|
if os.path.splitext(weight_path)[-1] == ".safetensors": |
|
sd = load_safetensors(weight_path, device=device, disable_mmap=True, dtype=dtype) |
|
else: |
|
sd = torch.load(weight_path, map_location=device, weights_only=True) |
|
info = self.model.load_state_dict(sd, strict=True, assign=True) |
|
self.model = self.model.to(dtype=dtype, device=device) |
|
logging.info(f"weights loaded from {weight_path}: {info}") |
|
|
|
|
|
if tokenizer_path is None: |
|
tokenizer_path = "Wan-AI/Wan2.1-I2V-14B-720P" |
|
subfolder = "xlm-roberta-large" |
|
else: |
|
subfolder = None |
|
|
|
self.tokenizer = HuggingfaceTokenizer( |
|
name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace", subfolder=subfolder |
|
) |
|
|
|
def visual(self, videos): |
|
|
|
size = (self.model.image_size,) * 2 |
|
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos]) |
|
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) |
|
|
|
|
|
|
|
out = self.model.visual(videos, use_31_block=True) |
|
return out |
|
|