Spaces:
Running
Running
import gzip | |
import html | |
import io | |
import logging | |
import math | |
import os | |
from functools import lru_cache | |
from functools import partial | |
from types import SimpleNamespace | |
from typing import Callable, List | |
from typing import Optional | |
import einops | |
import ftfy | |
import numpy as np | |
import regex as re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as checkpoint | |
import torchaudio | |
import torchvision.transforms as T | |
from PIL import Image | |
from timm.models.layers import DropPath, trunc_normal_ | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from iopath.common.file_io import g_pathmgr | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, | |
# can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = ( | |
self.qkv(x) | |
.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
.permute(2, 0, 3, 1, 4) | |
) | |
q, k, v = ( | |
qkv[0], | |
qkv[1], | |
qkv[2], | |
) # make torchscript happy (cannot use tensor as tuple) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Mlp(nn.Module): | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
drop=0.0, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class MultiheadAttention(nn.MultiheadAttention): | |
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] | |
class ViTAttention(Attention): | |
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
assert attn_mask is None | |
return super().forward(x) | |
class BlockWithMasking(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
attn_target: Callable, | |
mlp_ratio: int = 4, | |
act_layer: Callable = nn.GELU, | |
norm_layer: Callable = nn.LayerNorm, | |
ffn_dropout_rate: float = 0.0, | |
drop_path: float = 0.0, | |
layer_scale_type: str = None, | |
layer_scale_init_value: float = 1e-4, | |
): | |
super().__init__() | |
assert not isinstance( | |
attn_target, nn.Module | |
), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" | |
self.attn = attn_target() | |
if drop_path > 0.0: | |
self.drop_path = DropPath(drop_path) | |
else: | |
self.drop_path = nn.Identity() | |
self.norm_1 = norm_layer(dim) | |
mlp_hidden_dim = int(mlp_ratio * dim) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=ffn_dropout_rate, | |
) | |
self.norm_2 = norm_layer(dim) | |
self.layer_scale_type = layer_scale_type | |
if self.layer_scale_type is not None: | |
assert self.layer_scale_type in [ | |
"per_channel", | |
"scalar", | |
], f"Found Layer scale type {self.layer_scale_type}" | |
if self.layer_scale_type == "per_channel": | |
# one gamma value per channel | |
gamma_shape = [1, 1, dim] | |
elif self.layer_scale_type == "scalar": | |
# single gamma value for all channels | |
gamma_shape = [1, 1, 1] | |
# two gammas: for each part of the fwd in the encoder | |
self.layer_scale_gamma1 = nn.Parameter( | |
torch.ones(size=gamma_shape) * layer_scale_init_value, | |
requires_grad=True, | |
) | |
self.layer_scale_gamma2 = nn.Parameter( | |
torch.ones(size=gamma_shape) * layer_scale_init_value, | |
requires_grad=True, | |
) | |
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
if self.layer_scale_type is None: | |
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) | |
x = x + self.drop_path(self.mlp(self.norm_2(x))) | |
else: | |
x = ( | |
x | |
+ self.drop_path(self.attn(self.norm_1(x), attn_mask)) | |
* self.layer_scale_gamma1 | |
) | |
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 | |
return x | |
_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) | |
class SimpleTransformer(nn.Module): | |
def __init__( | |
self, | |
attn_target: Callable, | |
embed_dim: int, | |
num_blocks: int, | |
block: Callable = BlockWithMasking, | |
pre_transformer_layer: Callable = None, | |
post_transformer_layer: Callable = None, | |
drop_path_rate: float = 0.0, | |
drop_path_type: str = "progressive", | |
norm_layer: Callable = _LAYER_NORM, | |
mlp_ratio: int = 4, | |
ffn_dropout_rate: float = 0.0, | |
layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" | |
layer_scale_init_value: float = 1e-4, # from cait; float | |
weight_init_style: str = "jax", # possible values jax or pytorch | |
): | |
""" | |
Simple Transformer with the following features | |
1. Supports masked attention | |
2. Supports DropPath | |
3. Supports LayerScale | |
4. Supports Dropout in Attention and FFN | |
5. Makes few assumptions about the input except that it is a Tensor | |
""" | |
super().__init__() | |
self.pre_transformer_layer = pre_transformer_layer | |
if drop_path_type == "progressive": | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] | |
elif drop_path_type == "uniform": | |
dpr = [drop_path_rate for i in range(num_blocks)] | |
else: | |
raise ValueError(f"Unknown drop_path_type: {drop_path_type}") | |
self.blocks = nn.Sequential( | |
*[ | |
block( | |
dim=embed_dim, | |
attn_target=attn_target, | |
mlp_ratio=mlp_ratio, | |
ffn_dropout_rate=ffn_dropout_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
layer_scale_type=layer_scale_type, | |
layer_scale_init_value=layer_scale_init_value, | |
) | |
for i in range(num_blocks) | |
] | |
) | |
self.post_transformer_layer = post_transformer_layer | |
self.weight_init_style = weight_init_style | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
if self.weight_init_style == "jax": | |
# Based on MAE and official Jax ViT implementation | |
torch.nn.init.xavier_uniform_(m.weight) | |
elif self.weight_init_style == "pytorch": | |
# PyTorch ViT uses trunc_normal_ | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (nn.LayerNorm)): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward( | |
self, | |
tokens: torch.Tensor, | |
attn_mask: torch.Tensor = None, | |
use_checkpoint: bool = False, | |
checkpoint_every_n: int = 1, | |
checkpoint_blk_ids: List[int] = None, | |
): | |
""" | |
Inputs | |
- tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) | |
- attn: mask of shape L x L | |
Output | |
- x: data of shape N x L x D (or L x N x D depending on the attention implementation) | |
""" | |
if self.pre_transformer_layer: | |
tokens = self.pre_transformer_layer(tokens) | |
if use_checkpoint and checkpoint_blk_ids is None: | |
checkpoint_blk_ids = [ | |
blk_id | |
for blk_id in range(len(self.blocks)) | |
if blk_id % checkpoint_every_n == 0 | |
] | |
if checkpoint_blk_ids: | |
checkpoint_blk_ids = set(checkpoint_blk_ids) | |
for blk_id, blk in enumerate(self.blocks): | |
if use_checkpoint and blk_id in checkpoint_blk_ids: | |
tokens = checkpoint.checkpoint( | |
blk, tokens, attn_mask, use_reentrant=False | |
) | |
else: | |
tokens = blk(tokens, attn_mask=attn_mask) | |
if self.post_transformer_layer: | |
tokens = self.post_transformer_layer(tokens) | |
return tokens | |
def get_sinusoid_encoding_table(n_position, d_hid): | |
"""Sinusoid position encoding table""" | |
# TODO: make it with torch instead of numpy | |
def get_position_angle_vec(position): | |
return [ | |
position / np.power(10000, 2 * (hid_j // 2) / d_hid) | |
for hid_j in range(d_hid) | |
] | |
sinusoid_table = np.array( | |
[get_position_angle_vec(pos_i) for pos_i in range(n_position)] | |
) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): | |
N = pos_embed.shape[1] | |
if N == target_spatial_size: | |
return pos_embed | |
dim = pos_embed.shape[-1] | |
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 | |
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) | |
pos_embed = nn.functional.interpolate( | |
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( | |
0, 3, 1, 2 | |
), | |
scale_factor=math.sqrt(target_spatial_size / N), | |
mode="bicubic", | |
) | |
if updated: | |
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) | |
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
return pos_embed | |
def interpolate_pos_encoding( | |
npatch_per_img, | |
pos_embed, | |
patches_layout, | |
input_shape=None, | |
first_patch_idx=1, | |
): | |
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" | |
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists | |
if npatch_per_img == N: | |
return pos_embed | |
# assert ( | |
# patches_layout[-1] == patches_layout[-2] | |
# ), "Interpolation of pos embed not supported for non-square layouts" | |
class_emb = pos_embed[:, :first_patch_idx] | |
pos_embed = pos_embed[:, first_patch_idx:] | |
if input_shape is None or patches_layout[0] == 1: | |
# simple 2D pos embedding, no temporal component | |
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) | |
elif patches_layout[0] > 1: | |
# pos embed has a temporal component | |
assert len(input_shape) == 4, "temporal interpolation not supported" | |
# we only support 2D interpolation in this case | |
num_frames = patches_layout[0] | |
num_spatial_tokens = patches_layout[1] * patches_layout[2] | |
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) | |
# interpolate embedding for zeroth frame | |
pos_embed = interpolate_pos_encoding_2d( | |
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) | |
) | |
else: | |
raise ValueError("This type of interpolation isn't implemented") | |
return torch.cat((class_emb, pos_embed), dim=1) | |
def _get_pos_embedding( | |
npatch_per_img, | |
pos_embed, | |
patches_layout, | |
input_shape, | |
first_patch_idx=1, | |
): | |
pos_embed = interpolate_pos_encoding( | |
npatch_per_img, | |
pos_embed, | |
patches_layout, | |
input_shape=input_shape, | |
first_patch_idx=first_patch_idx, | |
) | |
return pos_embed | |
class VerboseNNModule(nn.Module): | |
""" | |
Wrapper around nn.Module that prints registered buffers and parameter names. | |
""" | |
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: | |
st = ( | |
"(" | |
+ name | |
+ "): " | |
+ "tensor(" | |
+ str(tuple(tensor[1].shape)) | |
+ ", requires_grad=" | |
+ str(tensor[1].requires_grad) | |
+ ")\n" | |
) | |
return st | |
def extra_repr(self) -> str: | |
named_modules = set() | |
for p in self.named_modules(): | |
named_modules.update([p[0]]) | |
named_modules = list(named_modules) | |
string_repr = "" | |
for p in self.named_parameters(): | |
name = p[0].split(".")[0] | |
if name not in named_modules: | |
string_repr += self.get_readable_tensor_repr(name, p) | |
for p in self.named_buffers(): | |
name = p[0].split(".")[0] | |
string_repr += self.get_readable_tensor_repr(name, p) | |
return string_repr | |
class PatchEmbedGeneric(nn.Module): | |
""" | |
PatchEmbed from Hydra | |
""" | |
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): | |
super().__init__() | |
if len(proj_stem) > 1: | |
self.proj = nn.Sequential(*proj_stem) | |
else: | |
# Special case to be able to load pre-trained models that were | |
# trained with a standard stem | |
self.proj = proj_stem[0] | |
self.norm_layer = norm_layer | |
def get_patch_layout(self, img_size): | |
with torch.no_grad(): | |
dummy_img = torch.zeros( | |
[ | |
1, | |
] | |
+ img_size | |
) | |
dummy_out = self.proj(dummy_img) | |
embed_dim = dummy_out.shape[1] | |
patches_layout = tuple(dummy_out.shape[2:]) | |
num_patches = np.prod(patches_layout) | |
return patches_layout, num_patches, embed_dim | |
def forward(self, x): | |
x = self.proj(x) | |
# B C (T) H W -> B (T)HW C | |
x = x.flatten(2).transpose(1, 2) | |
if self.norm_layer is not None: | |
x = self.norm_layer(x) | |
return x | |
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): | |
def __init__( | |
self, | |
patches_layout: List, | |
num_patches: int, | |
num_cls_tokens: int, | |
embed_dim: int, | |
learnable: bool, | |
) -> None: | |
super().__init__() | |
self.num_cls_tokens = num_cls_tokens | |
self.patches_layout = patches_layout | |
self.num_patches = num_patches | |
self.num_tokens = num_cls_tokens + num_patches | |
self.learnable = learnable | |
if self.learnable: | |
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) | |
trunc_normal_(self.pos_embed, std=0.02) | |
else: | |
self.register_buffer( | |
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) | |
) | |
def get_pos_embedding(self, vision_input, all_vision_tokens): | |
input_shape = vision_input.shape | |
pos_embed = _get_pos_embedding( | |
all_vision_tokens.size(1) - self.num_cls_tokens, | |
pos_embed=self.pos_embed, | |
patches_layout=self.patches_layout, | |
input_shape=input_shape, | |
first_patch_idx=self.num_cls_tokens, | |
) | |
return pos_embed | |
class RGBDTPreprocessor(VerboseNNModule): | |
def __init__( | |
self, | |
rgbt_stem: PatchEmbedGeneric, | |
depth_stem: PatchEmbedGeneric, | |
img_size: List = (3, 224, 224), | |
num_cls_tokens: int = 1, | |
pos_embed_fn: Callable = None, | |
use_type_embed: bool = False, | |
init_param_style: str = "openclip", | |
) -> None: | |
super().__init__() | |
stem = rgbt_stem if rgbt_stem is not None else depth_stem | |
( | |
self.patches_layout, | |
self.num_patches, | |
self.embed_dim, | |
) = stem.get_patch_layout(img_size) | |
self.rgbt_stem = rgbt_stem | |
self.depth_stem = depth_stem | |
self.use_pos_embed = pos_embed_fn is not None | |
self.use_type_embed = use_type_embed | |
self.num_cls_tokens = num_cls_tokens | |
if self.use_pos_embed: | |
self.pos_embedding_helper = pos_embed_fn( | |
patches_layout=self.patches_layout, | |
num_cls_tokens=num_cls_tokens, | |
num_patches=self.num_patches, | |
embed_dim=self.embed_dim, | |
) | |
if self.num_cls_tokens > 0: | |
self.cls_token = nn.Parameter( | |
torch.zeros(1, self.num_cls_tokens, self.embed_dim) | |
) | |
if self.use_type_embed: | |
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
self.init_parameters(init_param_style) | |
def init_parameters(self, init_param_style): | |
if init_param_style == "openclip": | |
# OpenCLIP style initialization | |
scale = self.embed_dim ** -0.5 | |
if self.use_pos_embed: | |
nn.init.normal_(self.pos_embedding_helper.pos_embed) | |
self.pos_embedding_helper.pos_embed *= scale | |
if self.num_cls_tokens > 0: | |
nn.init.normal_(self.cls_token) | |
self.cls_token *= scale | |
elif init_param_style == "vit": | |
self.cls_token.data.fill_(0) | |
else: | |
raise ValueError(f"Unknown init {init_param_style}") | |
if self.use_type_embed: | |
nn.init.normal_(self.type_embed) | |
def get_pos_emb_2(self, input, stem): | |
patches = stem.proj(input) | |
target_size = patches.shape[-2:] | |
original_size = list(self.pos_embedding_helper.patches_layout)[-2:] | |
orig_ce = self.pos_embedding_helper.pos_embed[:, 0, :] | |
orig_pe = ((self.pos_embedding_helper.pos_embed[:, 1:, :] | |
.reshape(1, *original_size, self.embed_dim)) | |
.permute(0, 3, 1, 2)) | |
new_pe = F.interpolate(orig_pe, size=target_size, mode="bicubic") | |
new_full_pe = torch.cat([orig_ce.unsqueeze(1), new_pe.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)], | |
dim=1) | |
return new_full_pe | |
def tokenize_input_and_cls_pos(self, input, stem, mask): | |
# tokens is of shape B x L x D | |
tokens = stem(input) | |
assert tokens.ndim == 3 | |
assert tokens.shape[2] == self.embed_dim | |
B = tokens.shape[0] | |
if self.num_cls_tokens > 0: | |
class_tokens = self.cls_token.expand( | |
B, -1, -1 | |
) # stole class_tokens impl from Phil Wang, thanks | |
tokens = torch.cat((class_tokens, tokens), dim=1) | |
if self.use_pos_embed: | |
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) | |
# pos_embed = self.get_pos_emb_2(input, stem) | |
tokens = tokens + pos_embed | |
if self.use_type_embed: | |
tokens = tokens + self.type_embed.expand(B, -1, -1) | |
return tokens | |
def forward(self, vision=None, depth=None, patch_mask=None): | |
if patch_mask is not None: | |
raise NotImplementedError() | |
if vision is not None: | |
vision_tokens = self.tokenize_input_and_cls_pos( | |
vision, self.rgbt_stem, patch_mask | |
) | |
if depth is not None: | |
depth_tokens = self.tokenize_input_and_cls_pos( | |
depth, self.depth_stem, patch_mask | |
) | |
# aggregate tokens | |
if vision is not None and depth is not None: | |
final_tokens = vision_tokens + depth_tokens | |
else: | |
final_tokens = vision_tokens if vision is not None else depth_tokens | |
return_dict = { | |
"trunk": { | |
"tokens": final_tokens, | |
}, | |
"head": {}, | |
} | |
return return_dict | |
class AudioPreprocessor(RGBDTPreprocessor): | |
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: | |
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) | |
def forward(self, audio=None): | |
return super().forward(vision=audio) | |
class ThermalPreprocessor(RGBDTPreprocessor): | |
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: | |
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) | |
def forward(self, thermal=None): | |
return super().forward(vision=thermal) | |
def build_causal_attention_mask(context_length): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(context_length, context_length, requires_grad=False) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
class TextPreprocessor(VerboseNNModule): | |
def __init__( | |
self, | |
vocab_size: int, | |
context_length: int, | |
embed_dim: int, | |
causal_masking: bool, | |
supply_seq_len_to_head: bool = True, | |
num_cls_tokens: int = 0, | |
init_param_style: str = "openclip", | |
) -> None: | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.context_length = context_length | |
self.token_embedding = nn.Embedding(vocab_size, embed_dim) | |
self.pos_embed = nn.Parameter( | |
torch.empty(1, self.context_length + num_cls_tokens, embed_dim) | |
) | |
self.causal_masking = causal_masking | |
if self.causal_masking: | |
mask = build_causal_attention_mask(self.context_length) | |
# register the mask as a buffer, so it can be moved to the right device | |
self.register_buffer("mask", mask) | |
self.supply_seq_len_to_head = supply_seq_len_to_head | |
self.num_cls_tokens = num_cls_tokens | |
self.embed_dim = embed_dim | |
if num_cls_tokens > 0: | |
assert self.causal_masking is False, "Masking + CLS token isn't implemented" | |
self.cls_token = nn.Parameter( | |
torch.zeros(1, self.num_cls_tokens, embed_dim) | |
) | |
self.init_parameters(init_param_style) | |
def init_parameters(self, init_param_style="openclip"): | |
# OpenCLIP style initialization | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.pos_embed, std=0.01) | |
if init_param_style == "openclip": | |
# OpenCLIP style initialization | |
scale = self.embed_dim ** -0.5 | |
if self.num_cls_tokens > 0: | |
nn.init.normal_(self.cls_token) | |
self.cls_token *= scale | |
elif init_param_style == "vit": | |
self.cls_token.data.fill_(0) | |
else: | |
raise ValueError(f"Unknown init {init_param_style}") | |
def forward(self, text): | |
# text tokens are of shape B x L x D | |
text_tokens = self.token_embedding(text) | |
# concat CLS tokens if any | |
if self.num_cls_tokens > 0: | |
B = text_tokens.shape[0] | |
class_tokens = self.cls_token.expand( | |
B, -1, -1 | |
) # stole class_tokens impl from Phil Wang, thanks | |
text_tokens = torch.cat((class_tokens, text_tokens), dim=1) | |
text_tokens = text_tokens + self.pos_embed | |
return_dict = { | |
"trunk": { | |
"tokens": text_tokens, | |
}, | |
"head": {}, | |
} | |
# Compute sequence length after adding CLS tokens | |
if self.supply_seq_len_to_head: | |
text_lengths = text.argmax(dim=-1) | |
return_dict["head"] = { | |
"seq_len": text_lengths, | |
} | |
if self.causal_masking: | |
return_dict["trunk"].update({"attn_mask": self.mask}) | |
return return_dict | |
class Im2Video(nn.Module): | |
"""Convert an image into a trivial video.""" | |
def __init__(self, time_dim=2): | |
super().__init__() | |
self.time_dim = time_dim | |
def forward(self, x): | |
if x.ndim == 4: | |
# B, C, H, W -> B, C, T, H, W | |
return x.unsqueeze(self.time_dim) | |
elif x.ndim == 5: | |
return x | |
else: | |
raise ValueError(f"Dimension incorrect {x.shape}") | |
class PadIm2Video(Im2Video): | |
def __init__(self, ntimes, pad_type, time_dim=2): | |
super().__init__(time_dim=time_dim) | |
assert ntimes > 0 | |
assert pad_type in ["zero", "repeat"] | |
self.ntimes = ntimes | |
self.pad_type = pad_type | |
def forward(self, x): | |
x = super().forward(x) | |
if x.shape[self.time_dim] == 1: | |
if self.pad_type == "repeat": | |
new_shape = [1] * len(x.shape) | |
new_shape[self.time_dim] = self.ntimes | |
x = x.repeat(new_shape) | |
elif self.pad_type == "zero": | |
padarg = [0, 0] * len(x.shape) | |
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] | |
x = nn.functional.pad(x, padarg) | |
return x | |
# Modified from github.com/openai/CLIP | |
def bytes_to_unicode(): | |
""" | |
Returns list of utf-8 byte and a corresponding list of unicode strings. | |
The reversible bpe codes work on unicode strings. | |
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |
This is a signficant percentage of your normal, say, 32K bpe vocab. | |
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |
And avoids mapping to whitespace/control characters the bpe code barfs on. | |
""" | |
bs = ( | |
list(range(ord("!"), ord("~") + 1)) | |
+ list(range(ord("¡"), ord("¬") + 1)) | |
+ list(range(ord("®"), ord("ÿ") + 1)) | |
) | |
cs = bs[:] | |
n = 0 | |
for b in range(2 ** 8): | |
if b not in bs: | |
bs.append(b) | |
cs.append(2 ** 8 + n) | |
n += 1 | |
cs = [chr(n) for n in cs] | |
return dict(zip(bs, cs)) | |
def get_pairs(word): | |
"""Return set of symbol pairs in a word. | |
Word is represented as tuple of symbols (symbols being variable-length strings). | |
""" | |
pairs = set() | |
prev_char = word[0] | |
for char in word[1:]: | |
pairs.add((prev_char, char)) | |
prev_char = char | |
return pairs | |
def basic_clean(text): | |
text = ftfy.fix_text(text) | |
text = html.unescape(html.unescape(text)) | |
return text.strip() | |
def whitespace_clean(text): | |
text = re.sub(r"\s+", " ", text) | |
text = text.strip() | |
return text | |
class SimpleTokenizer(object): | |
def __init__(self, bpe_path: str, context_length=77): | |
self.byte_encoder = bytes_to_unicode() | |
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |
with g_pathmgr.open(bpe_path, "rb") as fh: | |
bpe_bytes = io.BytesIO(fh.read()) | |
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") | |
merges = merges[1: 49152 - 256 - 2 + 1] | |
merges = [tuple(merge.split()) for merge in merges] | |
vocab = list(bytes_to_unicode().values()) | |
vocab = vocab + [v + "</w>" for v in vocab] | |
for merge in merges: | |
vocab.append("".join(merge)) | |
vocab.extend(["<|startoftext|>", "<|endoftext|>"]) | |
self.encoder = dict(zip(vocab, range(len(vocab)))) | |
self.decoder = {v: k for k, v in self.encoder.items()} | |
self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |
self.cache = { | |
"<|startoftext|>": "<|startoftext|>", | |
"<|endoftext|>": "<|endoftext|>", | |
} | |
self.pat = re.compile( | |
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |
re.IGNORECASE, | |
) | |
self.context_length = context_length | |
def bpe(self, token): | |
if token in self.cache: | |
return self.cache[token] | |
word = tuple(token[:-1]) + (token[-1] + "</w>",) | |
pairs = get_pairs(word) | |
if not pairs: | |
return token + "</w>" | |
while True: | |
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |
if bigram not in self.bpe_ranks: | |
break | |
first, second = bigram | |
new_word = [] | |
i = 0 | |
while i < len(word): | |
try: | |
j = word.index(first, i) | |
new_word.extend(word[i:j]) | |
i = j | |
except: | |
new_word.extend(word[i:]) | |
break | |
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |
new_word.append(first + second) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
new_word = tuple(new_word) | |
word = new_word | |
if len(word) == 1: | |
break | |
else: | |
pairs = get_pairs(word) | |
word = " ".join(word) | |
self.cache[token] = word | |
return word | |
def encode(self, text): | |
bpe_tokens = [] | |
text = whitespace_clean(basic_clean(text)).lower() | |
for token in re.findall(self.pat, text): | |
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) | |
bpe_tokens.extend( | |
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") | |
) | |
return bpe_tokens | |
def decode(self, tokens): | |
text = "".join([self.decoder[token] for token in tokens]) | |
text = ( | |
bytearray([self.byte_decoder[c] for c in text]) | |
.decode("utf-8", errors="replace") | |
.replace("</w>", " ") | |
) | |
return text | |
def __call__(self, texts, context_length=None): | |
if not context_length: | |
context_length = self.context_length | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self.encoder["<|startoftext|>"] | |
eot_token = self.encoder["<|endoftext|>"] | |
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] | |
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
for i, tokens in enumerate(all_tokens): | |
tokens = tokens[:context_length] | |
result[i, : len(tokens)] = torch.tensor(tokens) | |
if len(result) == 1: | |
return result[0] | |
return result | |
class Normalize(nn.Module): | |
def __init__(self, dim: int) -> None: | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
return torch.nn.functional.normalize(x, dim=self.dim, p=2) | |
class LearnableLogitScaling(nn.Module): | |
def __init__( | |
self, | |
logit_scale_init: float = 1 / 0.07, | |
learnable: bool = True, | |
max_logit_scale: float = 100, | |
) -> None: | |
super().__init__() | |
self.max_logit_scale = max_logit_scale | |
self.logit_scale_init = logit_scale_init | |
self.learnable = learnable | |
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) | |
if learnable: | |
self.log_logit_scale = nn.Parameter(log_logit_scale) | |
else: | |
self.register_buffer("log_logit_scale", log_logit_scale) | |
def forward(self, x): | |
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x | |
def extra_repr(self): | |
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" | |
return st | |
class EinOpsRearrange(nn.Module): | |
def __init__(self, rearrange_expr: str, **kwargs) -> None: | |
super().__init__() | |
self.rearrange_expr = rearrange_expr | |
self.kwargs = kwargs | |
def forward(self, x): | |
assert isinstance(x, torch.Tensor) | |
return einops.rearrange(x, self.rearrange_expr, **self.kwargs) | |
class IMUPreprocessor(VerboseNNModule): | |
def __init__( | |
self, | |
kernel_size: int, | |
imu_stem: PatchEmbedGeneric, | |
embed_dim: int, | |
img_size: List = (6, 2000), | |
num_cls_tokens: int = 1, | |
pos_embed_fn: Callable = None, | |
init_param_style: str = "openclip", | |
) -> None: | |
super().__init__() | |
stem = imu_stem | |
self.imu_stem = imu_stem | |
self.embed_dim = embed_dim | |
self.use_pos_embed = pos_embed_fn is not None | |
self.num_cls_tokens = num_cls_tokens | |
self.kernel_size = kernel_size | |
self.pos_embed = nn.Parameter( | |
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) | |
) | |
if self.num_cls_tokens > 0: | |
self.cls_token = nn.Parameter( | |
torch.zeros(1, self.num_cls_tokens, self.embed_dim) | |
) | |
self.init_parameters(init_param_style) | |
def init_parameters(self, init_param_style): | |
nn.init.normal_(self.pos_embed, std=0.01) | |
if init_param_style == "openclip": | |
# OpenCLIP style initialization | |
scale = self.embed_dim ** -0.5 | |
if self.num_cls_tokens > 0: | |
nn.init.normal_(self.cls_token) | |
self.cls_token *= scale | |
elif init_param_style == "vit": | |
self.cls_token.data.fill_(0) | |
else: | |
raise ValueError(f"Unknown init {init_param_style}") | |
def tokenize_input_and_cls_pos(self, input, stem): | |
# tokens is of shape B x L x D | |
tokens = stem.norm_layer(stem.proj(input)) | |
assert tokens.ndim == 3 | |
assert tokens.shape[2] == self.embed_dim | |
B = tokens.shape[0] | |
if self.num_cls_tokens > 0: | |
class_tokens = self.cls_token.expand( | |
B, -1, -1 | |
) # stole class_tokens impl from Phil Wang, thanks | |
tokens = torch.cat((class_tokens, tokens), dim=1) | |
if self.use_pos_embed: | |
tokens = tokens + self.pos_embed | |
return tokens | |
def forward(self, imu): | |
# Patchify | |
imu = imu.unfold( | |
-1, | |
self.kernel_size, | |
self.kernel_size, | |
).permute(0, 2, 1, 3) | |
imu = imu.reshape(imu.size(0), imu.size(1), -1) | |
imu_tokens = self.tokenize_input_and_cls_pos( | |
imu, | |
self.imu_stem, | |
) | |
return_dict = { | |
"trunk": { | |
"tokens": imu_tokens, | |
}, | |
"head": {}, | |
} | |
return return_dict | |
def cast_if_src_dtype( | |
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype | |
): | |
updated = False | |
if tensor.dtype == src_dtype: | |
tensor = tensor.to(dtype=tgt_dtype) | |
updated = True | |
return tensor, updated | |
class QuickGELU(nn.Module): | |
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class SelectElement(nn.Module): | |
def __init__(self, index) -> None: | |
super().__init__() | |
self.index = index | |
def forward(self, x): | |
assert x.ndim >= 3 | |
return x[:, self.index, ...] | |
class ReshapeSpatial(nn.Module): | |
def __init__(self, shape) -> None: | |
super().__init__() | |
self.h, self.w = shape | |
def forward(self, x): | |
assert x.ndim >= 3 | |
return x[:, 1:, ...].reshape(x.shape[0], self.h, self.w, -1), x[:, 0, :] | |
class ReshapeAudio(nn.Module): | |
def __init__(self, shape) -> None: | |
super().__init__() | |
self.h, self.w = shape | |
def forward(self, x): | |
assert x.ndim == 3 | |
return x[:, 1:, :].reshape(-1, 5, self.h, self.w, x.shape[-1]), x[:, 0, :] | |
class ApplyTwice(nn.Module): | |
def __init__(self, module) -> None: | |
super().__init__() | |
self.module = module | |
def forward(self, pair): | |
return self.module(pair[0]), self.module(pair[1]) | |
class SelectEOSAndProject(nn.Module): | |
""" | |
Text Pooling used in OpenCLIP | |
""" | |
def __init__(self, proj: nn.Module) -> None: | |
super().__init__() | |
self.proj = proj | |
def forward(self, x, seq_len): | |
assert x.ndim == 3 | |
# x is of shape B x L x D | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), seq_len] | |
x = self.proj(x) | |
return x | |
ModalityType = SimpleNamespace( | |
VISION="vision", | |
TEXT="text", | |
AUDIO="audio", | |
THERMAL="thermal", | |
DEPTH="depth", | |
IMU="imu", | |
) | |
class ImageBindModel(nn.Module): | |
def __init__( | |
self, | |
video_frames=2, | |
kernel_size=(2, 14, 14), | |
audio_kernel_size=16, | |
audio_stride=10, | |
out_embed_dim=768, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
audio_drop_path=0.1, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
depth_embed_dim=384, | |
depth_kernel_size=16, | |
depth_num_blocks=12, | |
depth_num_heads=8, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_kernel_size=8, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
): | |
super().__init__() | |
self.modality_preprocessors = self._create_modality_preprocessors( | |
video_frames, | |
vision_embed_dim, | |
kernel_size, | |
text_embed_dim, | |
audio_embed_dim, | |
audio_kernel_size, | |
audio_stride, | |
audio_num_mel_bins, | |
audio_target_len, | |
depth_embed_dim, | |
depth_kernel_size, | |
thermal_embed_dim, | |
thermal_kernel_size, | |
imu_embed_dim, | |
) | |
self.modality_trunks = self._create_modality_trunks( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
text_embed_dim, | |
text_num_blocks, | |
text_num_heads, | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
audio_drop_path, | |
depth_embed_dim, | |
depth_num_blocks, | |
depth_num_heads, | |
depth_drop_path, | |
thermal_embed_dim, | |
thermal_num_blocks, | |
thermal_num_heads, | |
thermal_drop_path, | |
imu_embed_dim, | |
imu_num_blocks, | |
imu_num_heads, | |
imu_drop_path, | |
) | |
self.modality_heads = self._create_modality_heads( | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
) | |
self.modality_postprocessors = self._create_modality_postprocessors( | |
out_embed_dim | |
) | |
def _create_modality_preprocessors( | |
self, | |
video_frames=2, | |
vision_embed_dim=1024, | |
kernel_size=(2, 14, 14), | |
text_embed_dim=768, | |
audio_embed_dim=768, | |
audio_kernel_size=16, | |
audio_stride=10, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
depth_embed_dim=768, | |
depth_kernel_size=16, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
imu_embed_dim=512, | |
): | |
rgbt_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
PadIm2Video(pad_type="repeat", ntimes=2), | |
nn.Conv3d( | |
in_channels=3, | |
kernel_size=kernel_size, | |
out_channels=vision_embed_dim, | |
stride=kernel_size, | |
bias=False, | |
), | |
] | |
) | |
rgbt_preprocessor = RGBDTPreprocessor( | |
img_size=[3, video_frames, 224, 224], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
rgbt_stem=rgbt_stem, | |
depth_stem=None, | |
) | |
text_preprocessor = TextPreprocessor( | |
context_length=77, | |
vocab_size=49408, | |
embed_dim=text_embed_dim, | |
causal_masking=True, | |
) | |
audio_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
nn.Conv2d( | |
in_channels=1, | |
kernel_size=audio_kernel_size, | |
stride=audio_stride, | |
out_channels=audio_embed_dim, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), | |
) | |
audio_preprocessor = AudioPreprocessor( | |
img_size=[1, audio_num_mel_bins, audio_target_len], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
audio_stem=audio_stem, | |
) | |
# depth_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Conv2d( | |
# kernel_size=depth_kernel_size, | |
# in_channels=1, | |
# out_channels=depth_embed_dim, | |
# stride=depth_kernel_size, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), | |
# ) | |
# | |
# depth_preprocessor = RGBDTPreprocessor( | |
# img_size=[1, 224, 224], | |
# num_cls_tokens=1, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# rgbt_stem=None, | |
# depth_stem=depth_stem, | |
# ) | |
# | |
# thermal_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Conv2d( | |
# kernel_size=thermal_kernel_size, | |
# in_channels=1, | |
# out_channels=thermal_embed_dim, | |
# stride=thermal_kernel_size, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), | |
# ) | |
# thermal_preprocessor = ThermalPreprocessor( | |
# img_size=[1, 224, 224], | |
# num_cls_tokens=1, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# thermal_stem=thermal_stem, | |
# ) | |
# | |
# imu_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Linear( | |
# in_features=48, | |
# out_features=imu_embed_dim, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), | |
# ) | |
# | |
# imu_preprocessor = IMUPreprocessor( | |
# img_size=[6, 2000], | |
# num_cls_tokens=1, | |
# kernel_size=8, | |
# embed_dim=imu_embed_dim, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# imu_stem=imu_stem, | |
# ) | |
modality_preprocessors = { | |
ModalityType.VISION: rgbt_preprocessor, | |
ModalityType.TEXT: text_preprocessor, | |
ModalityType.AUDIO: audio_preprocessor, | |
# ModalityType.DEPTH: depth_preprocessor, | |
# ModalityType.THERMAL: thermal_preprocessor, | |
# ModalityType.IMU: imu_preprocessor, | |
} | |
return nn.ModuleDict(modality_preprocessors) | |
def _create_modality_trunks( | |
self, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_drop_path=0.0, | |
depth_embed_dim=768, | |
depth_num_blocks=12, | |
depth_num_heads=12, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
): | |
def instantiate_trunk( | |
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path | |
): | |
return SimpleTransformer( | |
embed_dim=embed_dim, | |
num_blocks=num_blocks, | |
ffn_dropout_rate=0.0, | |
drop_path_rate=drop_path, | |
attn_target=partial( | |
MultiheadAttention, | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
bias=True, | |
add_bias_kv=add_bias_kv, | |
), | |
pre_transformer_layer=nn.Sequential( | |
nn.LayerNorm(embed_dim, eps=1e-6) | |
if pre_transformer_ln | |
else nn.Identity(), | |
EinOpsRearrange("b l d -> l b d"), | |
), | |
post_transformer_layer=EinOpsRearrange("l b d -> b l d"), | |
) | |
modality_trunks = {} | |
modality_trunks[ModalityType.VISION] = instantiate_trunk( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
pre_transformer_ln=True, | |
add_bias_kv=False, | |
drop_path=0.0, | |
) | |
modality_trunks[ModalityType.TEXT] = instantiate_trunk( | |
text_embed_dim, | |
text_num_blocks, | |
text_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=False, | |
drop_path=0.0, | |
) | |
modality_trunks[ModalityType.AUDIO] = instantiate_trunk( | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=audio_drop_path, | |
) | |
# modality_trunks[ModalityType.DEPTH] = instantiate_trunk( | |
# depth_embed_dim, | |
# depth_num_blocks, | |
# depth_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=depth_drop_path, | |
# ) | |
# modality_trunks[ModalityType.THERMAL] = instantiate_trunk( | |
# thermal_embed_dim, | |
# thermal_num_blocks, | |
# thermal_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=thermal_drop_path, | |
# ) | |
# modality_trunks[ModalityType.IMU] = instantiate_trunk( | |
# imu_embed_dim, | |
# imu_num_blocks, | |
# imu_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=imu_drop_path, | |
# ) | |
return nn.ModuleDict(modality_trunks) | |
def _create_modality_heads( | |
self, | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
): | |
modality_heads = {} | |
modality_heads[ModalityType.VISION] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(vision_embed_dim, out_embed_dim, bias=False), | |
) | |
modality_heads[ModalityType.TEXT] = SelectEOSAndProject( | |
proj=nn.Sequential( | |
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), | |
nn.Linear(text_embed_dim, out_embed_dim, bias=False), | |
) | |
) | |
modality_heads[ModalityType.AUDIO] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(audio_embed_dim, out_embed_dim, bias=False), | |
) | |
# modality_heads[ModalityType.DEPTH] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), | |
# SelectElement(index=0), | |
# nn.Linear(depth_embed_dim, out_embed_dim, bias=False), | |
# ) | |
# | |
# modality_heads[ModalityType.THERMAL] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), | |
# SelectElement(index=0), | |
# nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), | |
# ) | |
# | |
# modality_heads[ModalityType.IMU] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), | |
# SelectElement(index=0), | |
# nn.Dropout(p=0.5), | |
# nn.Linear(imu_embed_dim, out_embed_dim, bias=False), | |
# ) | |
return nn.ModuleDict(modality_heads) | |
def _create_modality_postprocessors(self, out_embed_dim): | |
modality_postprocessors = {} | |
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) | |
modality_postprocessors[ModalityType.TEXT] = nn.Sequential( | |
Normalize(dim=-1), LearnableLogitScaling(learnable=True) | |
) | |
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=20.0, learnable=False), | |
) | |
# modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
# ) | |
# modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=10.0, learnable=False), | |
# ) | |
# modality_postprocessors[ModalityType.IMU] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
# ) | |
return nn.ModuleDict(modality_postprocessors) | |
def forward(self, inputs): | |
outputs = {} | |
for modality_key, modality_value in inputs.items(): | |
reduce_list = ( | |
modality_value.ndim >= 5 | |
) # Audio and Video inputs consist of multiple clips | |
if reduce_list: | |
B, S = modality_value.shape[:2] | |
modality_value = modality_value.reshape( | |
B * S, *modality_value.shape[2:] | |
) | |
if modality_value is not None: | |
modality_value = self.modality_preprocessors[modality_key]( | |
**{modality_key: modality_value} | |
) | |
trunk_inputs = modality_value["trunk"] | |
head_inputs = modality_value["head"] | |
modality_value = self.modality_trunks[modality_key](**trunk_inputs) | |
modality_value = self.modality_heads[modality_key]( | |
modality_value, **head_inputs | |
) | |
modality_value = self.modality_postprocessors[modality_key]( | |
modality_value | |
) | |
if reduce_list: | |
modality_value = modality_value.reshape(B, S, -1) | |
modality_value = modality_value.mean(dim=1) | |
outputs[modality_key] = modality_value | |
return outputs | |
def reconfigure_head(self, k, v): | |
if k == ModalityType.AUDIO: | |
return torch.nn.Sequential(v[0], v[2]) | |
elif k == ModalityType.VISION: | |
return torch.nn.Sequential(v[0], v[2]) | |
else: | |
return v | |
def forward_features(self, inputs): | |
outputs = {} | |
reconfigured_heads = {k: self.reconfigure_head(k, v) for k, v in self.modality_heads.items()} | |
for modality_key, modality_value in inputs.items(): | |
reduce_list = ( | |
modality_value.ndim >= 5 | |
) # Audio and Video inputs consist of multiple clips | |
if reduce_list: | |
B, S = modality_value.shape[:2] | |
modality_value = modality_value.reshape( | |
B * S, *modality_value.shape[2:] | |
) | |
if modality_value is not None: | |
modality_value = self.modality_preprocessors[modality_key]( | |
**{modality_key: modality_value} | |
) | |
trunk_inputs = modality_value["trunk"] | |
head_inputs = modality_value["head"] | |
modality_value = self.modality_trunks[modality_key](**trunk_inputs) | |
modality_value = reconfigured_heads[modality_key]( | |
modality_value, **head_inputs | |
) | |
modality_value = self.modality_postprocessors[modality_key]( | |
modality_value | |
) | |
if modality_key == ModalityType.AUDIO: | |
modality_value = ReshapeAudio((12, 19))(modality_value) | |
elif modality_key == ModalityType.VISION: | |
modality_value = ReshapeSpatial((16, 16))(modality_value) | |
outputs[modality_key] = modality_value | |
return outputs | |
def imagebind_huge(output_path, pretrained=False): | |
model = ImageBindModel( | |
vision_embed_dim=1280, | |
vision_num_blocks=32, | |
vision_num_heads=16, | |
text_embed_dim=1024, | |
text_num_blocks=24, | |
text_num_heads=16, | |
out_embed_dim=1024, | |
audio_drop_path=0.1, | |
imu_drop_path=0.7, | |
) | |
if pretrained: | |
path = os.path.join(output_path, 'models/imagebind_huge.pth') | |
if not os.path.exists(path): | |
print(f"Downloading imagebind weights to {path} ...") | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
torch.hub.download_url_to_file( | |
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", | |
path, | |
progress=True, | |
) | |
model.load_state_dict(torch.load(path), strict=False) | |
return model | |
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds | |
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): | |
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 | |
waveform -= waveform.mean() | |
fbank = torchaudio.compliance.kaldi.fbank( | |
waveform, | |
htk_compat=True, | |
sample_frequency=sample_rate, | |
use_energy=False, | |
window_type="hanning", | |
num_mel_bins=num_mel_bins, | |
dither=0.0, | |
frame_length=25, | |
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, | |
) | |
# Convert to [mel_bins, num_frames] shape | |
fbank = fbank.transpose(0, 1) | |
# Pad to target_length | |
n_frames = fbank.size(1) | |
p = target_length - n_frames | |
# if p is too large (say >20%), flash a warning | |
if abs(p) / n_frames > 0.2: | |
logging.warning( | |
"Large gap between audio n_frames(%d) and " | |
"target_length (%d). Is the audio_target_length " | |
"setting correct?", | |
n_frames, | |
target_length, | |
) | |
# cut and pad | |
if p > 0: | |
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) | |
elif p < 0: | |
fbank = fbank[:, 0:target_length] | |
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 | |
# channel image | |
fbank = fbank.unsqueeze(0) | |
return fbank | |
def get_clip_timepoints(clip_sampler, duration): | |
# Read out all clips in this video | |
all_clips_timepoints = [] | |
is_last_clip = False | |
end = 0.0 | |
while not is_last_clip: | |
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) | |
all_clips_timepoints.append((start, end)) | |
return all_clips_timepoints | |
def load_and_transform_vision_data(image_paths, device): | |
if image_paths is None: | |
return None | |
image_ouputs = [] | |
for image_path in image_paths: | |
data_transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
224, interpolation=transforms.InterpolationMode.BICUBIC | |
), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
with open(image_path, "rb") as fopen: | |
image = Image.open(fopen).convert("RGB") | |
image = data_transform(image).to(device) | |
image_ouputs.append(image) | |
return torch.stack(image_ouputs, dim=0) | |
def load_and_transform_audio_data( | |
audio_paths, | |
device, | |
num_mel_bins=128, | |
target_length=204, | |
sample_rate=16000, | |
clip_duration=2, | |
clips_per_video=3, | |
mean=-4.268, | |
std=9.138, | |
): | |
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler | |
if audio_paths is None: | |
return None | |
audio_outputs = [] | |
clip_sampler = ConstantClipsPerVideoSampler( | |
clip_duration=clip_duration, clips_per_video=clips_per_video | |
) | |
for audio_path in audio_paths: | |
waveform, sr = torchaudio.load(audio_path) | |
if sample_rate != sr: | |
waveform = torchaudio.functional.resample( | |
waveform, orig_freq=sr, new_freq=sample_rate | |
) | |
all_clips_timepoints = get_clip_timepoints( | |
clip_sampler, waveform.size(1) / sample_rate | |
) | |
all_clips = [] | |
for clip_timepoints in all_clips_timepoints: | |
waveform_clip = waveform[ | |
:, | |
int(clip_timepoints[0] * sample_rate): int( | |
clip_timepoints[1] * sample_rate | |
), | |
] | |
waveform_melspec = waveform2melspec( | |
waveform_clip, sample_rate, num_mel_bins, target_length | |
) | |
all_clips.append(waveform_melspec) | |
normalize = transforms.Normalize(mean=mean, std=std) | |
all_clips = [normalize(ac).to(device) for ac in all_clips] | |
all_clips = torch.stack(all_clips, dim=0) | |
audio_outputs.append(all_clips) | |
return torch.stack(audio_outputs, dim=0) | |
class UnNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, image): | |
image2 = torch.clone(image) | |
for t, m, s in zip(image2, self.mean, self.std): | |
t.mul_(s).add_(m) | |
return image2 | |
norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
class TorchPCA(object): | |
def __init__(self, n_components): | |
self.n_components = n_components | |
def fit(self, X): | |
self.mean_ = X.mean(dim=0) | |
unbiased = X - self.mean_.unsqueeze(0) | |
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) | |
self.components_ = V.T | |
self.singular_values_ = S | |
return self | |
def transform(self, X): | |
t0 = X - self.mean_.unsqueeze(0) | |
projected = t0 @ self.components_.T | |
return projected | |
def pca(image_feats_list, dim=3, fit_pca=None): | |
# from sklearn.decomposition import PCA | |
device = image_feats_list[0].device | |
def flatten(tensor, target_size=None): | |
if target_size is not None and fit_pca is None: | |
F.interpolate(tensor, (target_size, target_size), mode="bilinear") | |
B, C, H, W = tensor.shape | |
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() | |
if len(image_feats_list) > 1 and fit_pca is None: | |
target_size = image_feats_list[0].shape[2] | |
else: | |
target_size = None | |
flattened_feats = [] | |
for feats in image_feats_list: | |
flattened_feats.append(flatten(feats, target_size)) | |
x = torch.cat(flattened_feats, dim=0) | |
if fit_pca is None: | |
# fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy())) | |
fit_pca = TorchPCA(n_components=dim).fit(x) | |
reduced_feats = [] | |
for feats in image_feats_list: | |
# x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) | |
x_red = fit_pca.transform(flatten(feats)) | |
x_red -= x_red.min(dim=0, keepdim=True).values | |
x_red /= x_red.max(dim=0, keepdim=True).values | |
B, C, H, W = feats.shape | |
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) | |
return reduced_feats, fit_pca | |
def my_load_audio(audio_file): | |
loaded_waveform, obs_sr = torchaudio.load(audio_file) | |
loaded_waveform = loaded_waveform[0] | |
neg_waveform, neg_obs_sr = None, None | |
from data.AVDatasets import prep_waveform | |
(waveform, | |
spectrogram, | |
audio_length, | |
total_length, | |
original_length, | |
mask, | |
pos_mask) = prep_waveform( | |
loaded_waveform, | |
obs_sr, | |
10, | |
128, | |
-4.268, | |
9.138, | |
16000, | |
True, | |
False, | |
False, | |
neg_waveform, | |
neg_obs_sr, | |
False, | |
) | |
patch_size = 204 | |
n_tiles = spectrogram.shape[1] // patch_size | |
assert n_tiles == 5 | |
patches = [] | |
for i in range(n_tiles): | |
patches.append(spectrogram[:, i * patch_size:(i + 1) * patch_size, :]) | |
patches = torch.cat(patches, dim=0).permute(0, 2, 1).unsqueeze(1) | |
return patches | |
class ImageBindImageFeaturizer(nn.Module): | |
def __init__(self, output_path, model=None): | |
super().__init__() | |
if model is not None: | |
self.model = model | |
else: | |
self.model = imagebind_huge(output_path, pretrained=True).cuda() | |
def forward(self, image, include_cls): | |
inputs = { | |
ModalityType.VISION: image, | |
} | |
patch_tokens, cls_tokens = self.model.forward_features(inputs)[ModalityType.VISION] | |
patch_tokens = patch_tokens.permute(0, 3, 1, 2) | |
if include_cls: | |
return patch_tokens, cls_tokens | |
else: | |
return patch_tokens | |
class ImageBindAudioFeaturizer(nn.Module): | |
def __init__(self, output_path, model=None): | |
super().__init__() | |
if model is not None: | |
self.model = model | |
else: | |
self.model = imagebind_huge(output_path, pretrained=True).cuda() | |
def forward(self, spec, include_cls): | |
patch_size = 204 | |
n_tiles = spec.shape[2] // patch_size | |
assert n_tiles == 5 | |
patches = [] | |
for i in range(n_tiles): | |
patches.append(spec[:, :, i * patch_size:(i + 1) * patch_size, :]) | |
patches = torch.cat(patches, dim=1).permute(0, 1, 3, 2).unsqueeze(2) | |
inputs = { | |
ModalityType.AUDIO: patches, | |
} | |
patch_tokens, cls_token = self.model.forward_features(inputs)[ModalityType.AUDIO] | |
patch_tokens = patch_tokens.permute(0, 4, 2, 1, 3) | |
b, c, h, p, w = patch_tokens.shape | |
patch_tokens = patch_tokens.reshape(b, c, h, w * p) | |
cls_token = cls_token.reshape(b, p, -1).mean(1) | |
if include_cls: | |
return patch_tokens, cls_token | |
else: | |
return patch_tokens | |
if __name__ == "__main__": | |
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"] | |
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"] | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Instantiate model | |
model = imagebind_huge("../../", pretrained=True) | |
model.eval() | |
model.to(device) | |
audio_inputs = torch.cat([my_load_audio(af).unsqueeze(0) for af in audio_paths], dim=0).cuda() | |
# Load data | |
inputs = { | |
ModalityType.VISION: load_and_transform_vision_data(image_paths, device), | |
# ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, device, clip_duration=2, clips_per_video=5), | |
ModalityType.AUDIO: audio_inputs, | |
} | |
with torch.no_grad(): | |
embeddings = model.forward_features(inputs) | |
cls_tokens = model.forward(inputs) | |
audio_cls_token = embeddings["audio"][1].reshape(3, 5, -1).mean(1) | |
sims1 = torch.einsum( | |
"bc,dc->bd", | |
embeddings["vision"][1], | |
audio_cls_token) | |
print(torch.softmax(sims1, dim=1).cpu().numpy()) | |
# | |
# sims2 = torch.einsum( | |
# "bc,dc->bd", | |
# embeddings["vision"].mean(1).mean(1), | |
# embeddings["audio"].mean(1).mean(1).mean(1) | |
# ) | |
# | |
# print(torch.softmax(sims2, dim=1).cpu().numpy()) | |
# | |
# | |
# img_num = 0 | |
# img_feats = F.normalize(embeddings["vision"].permute(0, 3, 1, 2), dim=1) | |
# [red_img_feats], fit_pca = pca([img_feats]) | |
# | |
# fig, axes = plt.subplots(2, 2, figsize=(4 * 2, 4 * 2)) | |
# axes[0][0].imshow(unnorm(inputs["vision"][0].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu()) | |
# axes[0][1].imshow(unnorm(inputs["vision"][1].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu()) | |
# axes[1][0].imshow(red_img_feats[0].permute(1, 2, 0).detach().cpu()) | |
# axes[1][1].imshow(red_img_feats[1].permute(1, 2, 0).detach().cpu()) | |
# plt.tight_layout() | |
# plt.show() | |
# | |
audio_embs = F.normalize(embeddings["audio"][0], dim=-1) | |
b, n, h, w, c = audio_embs.shape | |
audio_embs = audio_embs.permute(0, 4, 2, 1, 3).reshape(b, c, h, w * n) | |
b, n, c, h, w = inputs["audio"].shape | |
audio_inputs = inputs["audio"].permute(0, 2, 3, 1, 4).reshape(b, c, h, w * n) | |
print("here") | |
for img_num in range(3): | |
[red_audio], fit_pca = pca([audio_embs[img_num].unsqueeze(0)]) | |
fig, axes = plt.subplots(2, 1, figsize=(10 * 1, 4 * 2)) | |
axes[0].imshow(audio_inputs[img_num, 0].detach().cpu()) | |
axes[1].imshow(red_audio[0].permute(1, 2, 0).detach().cpu()) | |
plt.tight_layout() | |
plt.show() | |
print("here") | |