lorocksUMD's picture
Upload 32 files
e6d4b46 verified
import gzip
import html
import io
import logging
import math
import os
from functools import lru_cache
from functools import partial
from types import SimpleNamespace
from typing import Callable, List
from typing import Optional
import einops
import ftfy
import numpy as np
import regex as re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torchaudio
import torchvision.transforms as T
from PIL import Image
from timm.models.layers import DropPath, trunc_normal_
from torchvision import transforms
import matplotlib.pyplot as plt
from iopath.common.file_io import g_pathmgr
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version,
# can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MultiheadAttention(nn.MultiheadAttention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
class ViTAttention(Attention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
assert attn_mask is None
return super().forward(x)
class BlockWithMasking(nn.Module):
def __init__(
self,
dim: int,
attn_target: Callable,
mlp_ratio: int = 4,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
ffn_dropout_rate: float = 0.0,
drop_path: float = 0.0,
layer_scale_type: str = None,
layer_scale_init_value: float = 1e-4,
):
super().__init__()
assert not isinstance(
attn_target, nn.Module
), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
self.attn = attn_target()
if drop_path > 0.0:
self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm_1 = norm_layer(dim)
mlp_hidden_dim = int(mlp_ratio * dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=ffn_dropout_rate,
)
self.norm_2 = norm_layer(dim)
self.layer_scale_type = layer_scale_type
if self.layer_scale_type is not None:
assert self.layer_scale_type in [
"per_channel",
"scalar",
], f"Found Layer scale type {self.layer_scale_type}"
if self.layer_scale_type == "per_channel":
# one gamma value per channel
gamma_shape = [1, 1, dim]
elif self.layer_scale_type == "scalar":
# single gamma value for all channels
gamma_shape = [1, 1, 1]
# two gammas: for each part of the fwd in the encoder
self.layer_scale_gamma1 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
self.layer_scale_gamma2 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
if self.layer_scale_type is None:
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
x = x + self.drop_path(self.mlp(self.norm_2(x)))
else:
x = (
x
+ self.drop_path(self.attn(self.norm_1(x), attn_mask))
* self.layer_scale_gamma1
)
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
return x
_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
class SimpleTransformer(nn.Module):
def __init__(
self,
attn_target: Callable,
embed_dim: int,
num_blocks: int,
block: Callable = BlockWithMasking,
pre_transformer_layer: Callable = None,
post_transformer_layer: Callable = None,
drop_path_rate: float = 0.0,
drop_path_type: str = "progressive",
norm_layer: Callable = _LAYER_NORM,
mlp_ratio: int = 4,
ffn_dropout_rate: float = 0.0,
layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
layer_scale_init_value: float = 1e-4, # from cait; float
weight_init_style: str = "jax", # possible values jax or pytorch
):
"""
Simple Transformer with the following features
1. Supports masked attention
2. Supports DropPath
3. Supports LayerScale
4. Supports Dropout in Attention and FFN
5. Makes few assumptions about the input except that it is a Tensor
"""
super().__init__()
self.pre_transformer_layer = pre_transformer_layer
if drop_path_type == "progressive":
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
elif drop_path_type == "uniform":
dpr = [drop_path_rate for i in range(num_blocks)]
else:
raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
self.blocks = nn.Sequential(
*[
block(
dim=embed_dim,
attn_target=attn_target,
mlp_ratio=mlp_ratio,
ffn_dropout_rate=ffn_dropout_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
layer_scale_type=layer_scale_type,
layer_scale_init_value=layer_scale_init_value,
)
for i in range(num_blocks)
]
)
self.post_transformer_layer = post_transformer_layer
self.weight_init_style = weight_init_style
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
if self.weight_init_style == "jax":
# Based on MAE and official Jax ViT implementation
torch.nn.init.xavier_uniform_(m.weight)
elif self.weight_init_style == "pytorch":
# PyTorch ViT uses trunc_normal_
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
tokens: torch.Tensor,
attn_mask: torch.Tensor = None,
use_checkpoint: bool = False,
checkpoint_every_n: int = 1,
checkpoint_blk_ids: List[int] = None,
):
"""
Inputs
- tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
- attn: mask of shape L x L
Output
- x: data of shape N x L x D (or L x N x D depending on the attention implementation)
"""
if self.pre_transformer_layer:
tokens = self.pre_transformer_layer(tokens)
if use_checkpoint and checkpoint_blk_ids is None:
checkpoint_blk_ids = [
blk_id
for blk_id in range(len(self.blocks))
if blk_id % checkpoint_every_n == 0
]
if checkpoint_blk_ids:
checkpoint_blk_ids = set(checkpoint_blk_ids)
for blk_id, blk in enumerate(self.blocks):
if use_checkpoint and blk_id in checkpoint_blk_ids:
tokens = checkpoint.checkpoint(
blk, tokens, attn_mask, use_reentrant=False
)
else:
tokens = blk(tokens, attn_mask=attn_mask)
if self.post_transformer_layer:
tokens = self.post_transformer_layer(tokens)
return tokens
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
N = pos_embed.shape[1]
if N == target_spatial_size:
return pos_embed
dim = pos_embed.shape[-1]
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
pos_embed = nn.functional.interpolate(
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
0, 3, 1, 2
),
scale_factor=math.sqrt(target_spatial_size / N),
mode="bicubic",
)
if updated:
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return pos_embed
def interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=None,
first_patch_idx=1,
):
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
if npatch_per_img == N:
return pos_embed
# assert (
# patches_layout[-1] == patches_layout[-2]
# ), "Interpolation of pos embed not supported for non-square layouts"
class_emb = pos_embed[:, :first_patch_idx]
pos_embed = pos_embed[:, first_patch_idx:]
if input_shape is None or patches_layout[0] == 1:
# simple 2D pos embedding, no temporal component
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
elif patches_layout[0] > 1:
# pos embed has a temporal component
assert len(input_shape) == 4, "temporal interpolation not supported"
# we only support 2D interpolation in this case
num_frames = patches_layout[0]
num_spatial_tokens = patches_layout[1] * patches_layout[2]
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
# interpolate embedding for zeroth frame
pos_embed = interpolate_pos_encoding_2d(
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
)
else:
raise ValueError("This type of interpolation isn't implemented")
return torch.cat((class_emb, pos_embed), dim=1)
def _get_pos_embedding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape,
first_patch_idx=1,
):
pos_embed = interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=input_shape,
first_patch_idx=first_patch_idx,
)
return pos_embed
class VerboseNNModule(nn.Module):
"""
Wrapper around nn.Module that prints registered buffers and parameter names.
"""
@staticmethod
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
st = (
"("
+ name
+ "): "
+ "tensor("
+ str(tuple(tensor[1].shape))
+ ", requires_grad="
+ str(tensor[1].requires_grad)
+ ")\n"
)
return st
def extra_repr(self) -> str:
named_modules = set()
for p in self.named_modules():
named_modules.update([p[0]])
named_modules = list(named_modules)
string_repr = ""
for p in self.named_parameters():
name = p[0].split(".")[0]
if name not in named_modules:
string_repr += self.get_readable_tensor_repr(name, p)
for p in self.named_buffers():
name = p[0].split(".")[0]
string_repr += self.get_readable_tensor_repr(name, p)
return string_repr
class PatchEmbedGeneric(nn.Module):
"""
PatchEmbed from Hydra
"""
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
super().__init__()
if len(proj_stem) > 1:
self.proj = nn.Sequential(*proj_stem)
else:
# Special case to be able to load pre-trained models that were
# trained with a standard stem
self.proj = proj_stem[0]
self.norm_layer = norm_layer
def get_patch_layout(self, img_size):
with torch.no_grad():
dummy_img = torch.zeros(
[
1,
]
+ img_size
)
dummy_out = self.proj(dummy_img)
embed_dim = dummy_out.shape[1]
patches_layout = tuple(dummy_out.shape[2:])
num_patches = np.prod(patches_layout)
return patches_layout, num_patches, embed_dim
def forward(self, x):
x = self.proj(x)
# B C (T) H W -> B (T)HW C
x = x.flatten(2).transpose(1, 2)
if self.norm_layer is not None:
x = self.norm_layer(x)
return x
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
def __init__(
self,
patches_layout: List,
num_patches: int,
num_cls_tokens: int,
embed_dim: int,
learnable: bool,
) -> None:
super().__init__()
self.num_cls_tokens = num_cls_tokens
self.patches_layout = patches_layout
self.num_patches = num_patches
self.num_tokens = num_cls_tokens + num_patches
self.learnable = learnable
if self.learnable:
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
trunc_normal_(self.pos_embed, std=0.02)
else:
self.register_buffer(
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
)
def get_pos_embedding(self, vision_input, all_vision_tokens):
input_shape = vision_input.shape
pos_embed = _get_pos_embedding(
all_vision_tokens.size(1) - self.num_cls_tokens,
pos_embed=self.pos_embed,
patches_layout=self.patches_layout,
input_shape=input_shape,
first_patch_idx=self.num_cls_tokens,
)
return pos_embed
class RGBDTPreprocessor(VerboseNNModule):
def __init__(
self,
rgbt_stem: PatchEmbedGeneric,
depth_stem: PatchEmbedGeneric,
img_size: List = (3, 224, 224),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
use_type_embed: bool = False,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = rgbt_stem if rgbt_stem is not None else depth_stem
(
self.patches_layout,
self.num_patches,
self.embed_dim,
) = stem.get_patch_layout(img_size)
self.rgbt_stem = rgbt_stem
self.depth_stem = depth_stem
self.use_pos_embed = pos_embed_fn is not None
self.use_type_embed = use_type_embed
self.num_cls_tokens = num_cls_tokens
if self.use_pos_embed:
self.pos_embedding_helper = pos_embed_fn(
patches_layout=self.patches_layout,
num_cls_tokens=num_cls_tokens,
num_patches=self.num_patches,
embed_dim=self.embed_dim,
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
if self.use_type_embed:
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim ** -0.5
if self.use_pos_embed:
nn.init.normal_(self.pos_embedding_helper.pos_embed)
self.pos_embedding_helper.pos_embed *= scale
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
if self.use_type_embed:
nn.init.normal_(self.type_embed)
def get_pos_emb_2(self, input, stem):
patches = stem.proj(input)
target_size = patches.shape[-2:]
original_size = list(self.pos_embedding_helper.patches_layout)[-2:]
orig_ce = self.pos_embedding_helper.pos_embed[:, 0, :]
orig_pe = ((self.pos_embedding_helper.pos_embed[:, 1:, :]
.reshape(1, *original_size, self.embed_dim))
.permute(0, 3, 1, 2))
new_pe = F.interpolate(orig_pe, size=target_size, mode="bicubic")
new_full_pe = torch.cat([orig_ce.unsqueeze(1), new_pe.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)],
dim=1)
return new_full_pe
def tokenize_input_and_cls_pos(self, input, stem, mask):
# tokens is of shape B x L x D
tokens = stem(input)
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
# pos_embed = self.get_pos_emb_2(input, stem)
tokens = tokens + pos_embed
if self.use_type_embed:
tokens = tokens + self.type_embed.expand(B, -1, -1)
return tokens
def forward(self, vision=None, depth=None, patch_mask=None):
if patch_mask is not None:
raise NotImplementedError()
if vision is not None:
vision_tokens = self.tokenize_input_and_cls_pos(
vision, self.rgbt_stem, patch_mask
)
if depth is not None:
depth_tokens = self.tokenize_input_and_cls_pos(
depth, self.depth_stem, patch_mask
)
# aggregate tokens
if vision is not None and depth is not None:
final_tokens = vision_tokens + depth_tokens
else:
final_tokens = vision_tokens if vision is not None else depth_tokens
return_dict = {
"trunk": {
"tokens": final_tokens,
},
"head": {},
}
return return_dict
class AudioPreprocessor(RGBDTPreprocessor):
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
def forward(self, audio=None):
return super().forward(vision=audio)
class ThermalPreprocessor(RGBDTPreprocessor):
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
def forward(self, thermal=None):
return super().forward(vision=thermal)
def build_causal_attention_mask(context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(context_length, context_length, requires_grad=False)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
class TextPreprocessor(VerboseNNModule):
def __init__(
self,
vocab_size: int,
context_length: int,
embed_dim: int,
causal_masking: bool,
supply_seq_len_to_head: bool = True,
num_cls_tokens: int = 0,
init_param_style: str = "openclip",
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Parameter(
torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
)
self.causal_masking = causal_masking
if self.causal_masking:
mask = build_causal_attention_mask(self.context_length)
# register the mask as a buffer, so it can be moved to the right device
self.register_buffer("mask", mask)
self.supply_seq_len_to_head = supply_seq_len_to_head
self.num_cls_tokens = num_cls_tokens
self.embed_dim = embed_dim
if num_cls_tokens > 0:
assert self.causal_masking is False, "Masking + CLS token isn't implemented"
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style="openclip"):
# OpenCLIP style initialization
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim ** -0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def forward(self, text):
# text tokens are of shape B x L x D
text_tokens = self.token_embedding(text)
# concat CLS tokens if any
if self.num_cls_tokens > 0:
B = text_tokens.shape[0]
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
text_tokens = text_tokens + self.pos_embed
return_dict = {
"trunk": {
"tokens": text_tokens,
},
"head": {},
}
# Compute sequence length after adding CLS tokens
if self.supply_seq_len_to_head:
text_lengths = text.argmax(dim=-1)
return_dict["head"] = {
"seq_len": text_lengths,
}
if self.causal_masking:
return_dict["trunk"].update({"attn_mask": self.mask})
return return_dict
class Im2Video(nn.Module):
"""Convert an image into a trivial video."""
def __init__(self, time_dim=2):
super().__init__()
self.time_dim = time_dim
def forward(self, x):
if x.ndim == 4:
# B, C, H, W -> B, C, T, H, W
return x.unsqueeze(self.time_dim)
elif x.ndim == 5:
return x
else:
raise ValueError(f"Dimension incorrect {x.shape}")
class PadIm2Video(Im2Video):
def __init__(self, ntimes, pad_type, time_dim=2):
super().__init__(time_dim=time_dim)
assert ntimes > 0
assert pad_type in ["zero", "repeat"]
self.ntimes = ntimes
self.pad_type = pad_type
def forward(self, x):
x = super().forward(x)
if x.shape[self.time_dim] == 1:
if self.pad_type == "repeat":
new_shape = [1] * len(x.shape)
new_shape[self.time_dim] = self.ntimes
x = x.repeat(new_shape)
elif self.pad_type == "zero":
padarg = [0, 0] * len(x.shape)
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
x = nn.functional.pad(x, padarg)
return x
# Modified from github.com/openai/CLIP
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str, context_length=77):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with g_pathmgr.open(bpe_path, "rb") as fh:
bpe_bytes = io.BytesIO(fh.read())
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
merges = merges[1: 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>",
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.context_length = context_length
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
def __call__(self, texts, context_length=None):
if not context_length:
context_length = self.context_length
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder["<|startoftext|>"]
eot_token = self.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
tokens = tokens[:context_length]
result[i, : len(tokens)] = torch.tensor(tokens)
if len(result) == 1:
return result[0]
return result
class Normalize(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x):
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
class LearnableLogitScaling(nn.Module):
def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
max_logit_scale: float = 100,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)
def forward(self, x):
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
def extra_repr(self):
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
return st
class EinOpsRearrange(nn.Module):
def __init__(self, rearrange_expr: str, **kwargs) -> None:
super().__init__()
self.rearrange_expr = rearrange_expr
self.kwargs = kwargs
def forward(self, x):
assert isinstance(x, torch.Tensor)
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
class IMUPreprocessor(VerboseNNModule):
def __init__(
self,
kernel_size: int,
imu_stem: PatchEmbedGeneric,
embed_dim: int,
img_size: List = (6, 2000),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = imu_stem
self.imu_stem = imu_stem
self.embed_dim = embed_dim
self.use_pos_embed = pos_embed_fn is not None
self.num_cls_tokens = num_cls_tokens
self.kernel_size = kernel_size
self.pos_embed = nn.Parameter(
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim ** -0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def tokenize_input_and_cls_pos(self, input, stem):
# tokens is of shape B x L x D
tokens = stem.norm_layer(stem.proj(input))
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
tokens = tokens + self.pos_embed
return tokens
def forward(self, imu):
# Patchify
imu = imu.unfold(
-1,
self.kernel_size,
self.kernel_size,
).permute(0, 2, 1, 3)
imu = imu.reshape(imu.size(0), imu.size(1), -1)
imu_tokens = self.tokenize_input_and_cls_pos(
imu,
self.imu_stem,
)
return_dict = {
"trunk": {
"tokens": imu_tokens,
},
"head": {},
}
return return_dict
def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class SelectElement(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x):
assert x.ndim >= 3
return x[:, self.index, ...]
class ReshapeSpatial(nn.Module):
def __init__(self, shape) -> None:
super().__init__()
self.h, self.w = shape
def forward(self, x):
assert x.ndim >= 3
return x[:, 1:, ...].reshape(x.shape[0], self.h, self.w, -1), x[:, 0, :]
class ReshapeAudio(nn.Module):
def __init__(self, shape) -> None:
super().__init__()
self.h, self.w = shape
def forward(self, x):
assert x.ndim == 3
return x[:, 1:, :].reshape(-1, 5, self.h, self.w, x.shape[-1]), x[:, 0, :]
class ApplyTwice(nn.Module):
def __init__(self, module) -> None:
super().__init__()
self.module = module
def forward(self, pair):
return self.module(pair[0]), self.module(pair[1])
class SelectEOSAndProject(nn.Module):
"""
Text Pooling used in OpenCLIP
"""
def __init__(self, proj: nn.Module) -> None:
super().__init__()
self.proj = proj
def forward(self, x, seq_len):
assert x.ndim == 3
# x is of shape B x L x D
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), seq_len]
x = self.proj(x)
return x
ModalityType = SimpleNamespace(
VISION="vision",
TEXT="text",
AUDIO="audio",
THERMAL="thermal",
DEPTH="depth",
IMU="imu",
)
class ImageBindModel(nn.Module):
def __init__(
self,
video_frames=2,
kernel_size=(2, 14, 14),
audio_kernel_size=16,
audio_stride=10,
out_embed_dim=768,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_num_mel_bins=128,
audio_target_len=204,
audio_drop_path=0.1,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
depth_embed_dim=384,
depth_kernel_size=16,
depth_num_blocks=12,
depth_num_heads=8,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_kernel_size=16,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_kernel_size=8,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
super().__init__()
self.modality_preprocessors = self._create_modality_preprocessors(
video_frames,
vision_embed_dim,
kernel_size,
text_embed_dim,
audio_embed_dim,
audio_kernel_size,
audio_stride,
audio_num_mel_bins,
audio_target_len,
depth_embed_dim,
depth_kernel_size,
thermal_embed_dim,
thermal_kernel_size,
imu_embed_dim,
)
self.modality_trunks = self._create_modality_trunks(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
text_embed_dim,
text_num_blocks,
text_num_heads,
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
audio_drop_path,
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
depth_drop_path,
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
thermal_drop_path,
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
imu_drop_path,
)
self.modality_heads = self._create_modality_heads(
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
)
self.modality_postprocessors = self._create_modality_postprocessors(
out_embed_dim
)
def _create_modality_preprocessors(
self,
video_frames=2,
vision_embed_dim=1024,
kernel_size=(2, 14, 14),
text_embed_dim=768,
audio_embed_dim=768,
audio_kernel_size=16,
audio_stride=10,
audio_num_mel_bins=128,
audio_target_len=204,
depth_embed_dim=768,
depth_kernel_size=16,
thermal_embed_dim=768,
thermal_kernel_size=16,
imu_embed_dim=512,
):
rgbt_stem = PatchEmbedGeneric(
proj_stem=[
PadIm2Video(pad_type="repeat", ntimes=2),
nn.Conv3d(
in_channels=3,
kernel_size=kernel_size,
out_channels=vision_embed_dim,
stride=kernel_size,
bias=False,
),
]
)
rgbt_preprocessor = RGBDTPreprocessor(
img_size=[3, video_frames, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=rgbt_stem,
depth_stem=None,
)
text_preprocessor = TextPreprocessor(
context_length=77,
vocab_size=49408,
embed_dim=text_embed_dim,
causal_masking=True,
)
audio_stem = PatchEmbedGeneric(
proj_stem=[
nn.Conv2d(
in_channels=1,
kernel_size=audio_kernel_size,
stride=audio_stride,
out_channels=audio_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
)
audio_preprocessor = AudioPreprocessor(
img_size=[1, audio_num_mel_bins, audio_target_len],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
audio_stem=audio_stem,
)
# depth_stem = PatchEmbedGeneric(
# [
# nn.Conv2d(
# kernel_size=depth_kernel_size,
# in_channels=1,
# out_channels=depth_embed_dim,
# stride=depth_kernel_size,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
# )
#
# depth_preprocessor = RGBDTPreprocessor(
# img_size=[1, 224, 224],
# num_cls_tokens=1,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# rgbt_stem=None,
# depth_stem=depth_stem,
# )
#
# thermal_stem = PatchEmbedGeneric(
# [
# nn.Conv2d(
# kernel_size=thermal_kernel_size,
# in_channels=1,
# out_channels=thermal_embed_dim,
# stride=thermal_kernel_size,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
# )
# thermal_preprocessor = ThermalPreprocessor(
# img_size=[1, 224, 224],
# num_cls_tokens=1,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# thermal_stem=thermal_stem,
# )
#
# imu_stem = PatchEmbedGeneric(
# [
# nn.Linear(
# in_features=48,
# out_features=imu_embed_dim,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
# )
#
# imu_preprocessor = IMUPreprocessor(
# img_size=[6, 2000],
# num_cls_tokens=1,
# kernel_size=8,
# embed_dim=imu_embed_dim,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# imu_stem=imu_stem,
# )
modality_preprocessors = {
ModalityType.VISION: rgbt_preprocessor,
ModalityType.TEXT: text_preprocessor,
ModalityType.AUDIO: audio_preprocessor,
# ModalityType.DEPTH: depth_preprocessor,
# ModalityType.THERMAL: thermal_preprocessor,
# ModalityType.IMU: imu_preprocessor,
}
return nn.ModuleDict(modality_preprocessors)
def _create_modality_trunks(
self,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_drop_path=0.0,
depth_embed_dim=768,
depth_num_blocks=12,
depth_num_heads=12,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
def instantiate_trunk(
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
):
return SimpleTransformer(
embed_dim=embed_dim,
num_blocks=num_blocks,
ffn_dropout_rate=0.0,
drop_path_rate=drop_path,
attn_target=partial(
MultiheadAttention,
embed_dim=embed_dim,
num_heads=num_heads,
bias=True,
add_bias_kv=add_bias_kv,
),
pre_transformer_layer=nn.Sequential(
nn.LayerNorm(embed_dim, eps=1e-6)
if pre_transformer_ln
else nn.Identity(),
EinOpsRearrange("b l d -> l b d"),
),
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
)
modality_trunks = {}
modality_trunks[ModalityType.VISION] = instantiate_trunk(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
pre_transformer_ln=True,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.TEXT] = instantiate_trunk(
text_embed_dim,
text_num_blocks,
text_num_heads,
pre_transformer_ln=False,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=audio_drop_path,
)
# modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
# depth_embed_dim,
# depth_num_blocks,
# depth_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=depth_drop_path,
# )
# modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
# thermal_embed_dim,
# thermal_num_blocks,
# thermal_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=thermal_drop_path,
# )
# modality_trunks[ModalityType.IMU] = instantiate_trunk(
# imu_embed_dim,
# imu_num_blocks,
# imu_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=imu_drop_path,
# )
return nn.ModuleDict(modality_trunks)
def _create_modality_heads(
self,
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
):
modality_heads = {}
modality_heads[ModalityType.VISION] = nn.Sequential(
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
proj=nn.Sequential(
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
nn.Linear(text_embed_dim, out_embed_dim, bias=False),
)
)
modality_heads[ModalityType.AUDIO] = nn.Sequential(
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
)
# modality_heads[ModalityType.DEPTH] = nn.Sequential(
# nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
# SelectElement(index=0),
# nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
# )
#
# modality_heads[ModalityType.THERMAL] = nn.Sequential(
# nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
# SelectElement(index=0),
# nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
# )
#
# modality_heads[ModalityType.IMU] = nn.Sequential(
# nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
# SelectElement(index=0),
# nn.Dropout(p=0.5),
# nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
# )
return nn.ModuleDict(modality_heads)
def _create_modality_postprocessors(self, out_embed_dim):
modality_postprocessors = {}
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
Normalize(dim=-1), LearnableLogitScaling(learnable=True)
)
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
)
# modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
# )
# modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
# )
# modality_postprocessors[ModalityType.IMU] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
# )
return nn.ModuleDict(modality_postprocessors)
def forward(self, inputs):
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
modality_value = self.modality_heads[modality_key](
modality_value, **head_inputs
)
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
if reduce_list:
modality_value = modality_value.reshape(B, S, -1)
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value
return outputs
def reconfigure_head(self, k, v):
if k == ModalityType.AUDIO:
return torch.nn.Sequential(v[0], v[2])
elif k == ModalityType.VISION:
return torch.nn.Sequential(v[0], v[2])
else:
return v
def forward_features(self, inputs):
outputs = {}
reconfigured_heads = {k: self.reconfigure_head(k, v) for k, v in self.modality_heads.items()}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
modality_value = reconfigured_heads[modality_key](
modality_value, **head_inputs
)
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
if modality_key == ModalityType.AUDIO:
modality_value = ReshapeAudio((12, 19))(modality_value)
elif modality_key == ModalityType.VISION:
modality_value = ReshapeSpatial((16, 16))(modality_value)
outputs[modality_key] = modality_value
return outputs
def imagebind_huge(output_path, pretrained=False):
model = ImageBindModel(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
)
if pretrained:
path = os.path.join(output_path, 'models/imagebind_huge.pth')
if not os.path.exists(path):
print(f"Downloading imagebind weights to {path} ...")
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
path,
progress=True,
)
model.load_state_dict(torch.load(path), strict=False)
return model
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
waveform -= waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sample_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mel_bins,
dither=0.0,
frame_length=25,
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
)
# Convert to [mel_bins, num_frames] shape
fbank = fbank.transpose(0, 1)
# Pad to target_length
n_frames = fbank.size(1)
p = target_length - n_frames
# if p is too large (say >20%), flash a warning
if abs(p) / n_frames > 0.2:
logging.warning(
"Large gap between audio n_frames(%d) and "
"target_length (%d). Is the audio_target_length "
"setting correct?",
n_frames,
target_length,
)
# cut and pad
if p > 0:
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
elif p < 0:
fbank = fbank[:, 0:target_length]
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
# channel image
fbank = fbank.unsqueeze(0)
return fbank
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def load_and_transform_vision_data(image_paths, device):
if image_paths is None:
return None
image_ouputs = []
for image_path in image_paths:
data_transform = transforms.Compose(
[
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
with open(image_path, "rb") as fopen:
image = Image.open(fopen).convert("RGB")
image = data_transform(image).to(device)
image_ouputs.append(image)
return torch.stack(image_ouputs, dim=0)
def load_and_transform_audio_data(
audio_paths,
device,
num_mel_bins=128,
target_length=204,
sample_rate=16000,
clip_duration=2,
clips_per_video=3,
mean=-4.268,
std=9.138,
):
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
if audio_paths is None:
return None
audio_outputs = []
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
for audio_path in audio_paths:
waveform, sr = torchaudio.load(audio_path)
if sample_rate != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=sample_rate
)
all_clips_timepoints = get_clip_timepoints(
clip_sampler, waveform.size(1) / sample_rate
)
all_clips = []
for clip_timepoints in all_clips_timepoints:
waveform_clip = waveform[
:,
int(clip_timepoints[0] * sample_rate): int(
clip_timepoints[1] * sample_rate
),
]
waveform_melspec = waveform2melspec(
waveform_clip, sample_rate, num_mel_bins, target_length
)
all_clips.append(waveform_melspec)
normalize = transforms.Normalize(mean=mean, std=std)
all_clips = [normalize(ac).to(device) for ac in all_clips]
all_clips = torch.stack(all_clips, dim=0)
audio_outputs.append(all_clips)
return torch.stack(audio_outputs, dim=0)
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image):
image2 = torch.clone(image)
for t, m, s in zip(image2, self.mean, self.std):
t.mul_(s).add_(m)
return image2
norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
class TorchPCA(object):
def __init__(self, n_components):
self.n_components = n_components
def fit(self, X):
self.mean_ = X.mean(dim=0)
unbiased = X - self.mean_.unsqueeze(0)
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
self.components_ = V.T
self.singular_values_ = S
return self
def transform(self, X):
t0 = X - self.mean_.unsqueeze(0)
projected = t0 @ self.components_.T
return projected
def pca(image_feats_list, dim=3, fit_pca=None):
# from sklearn.decomposition import PCA
device = image_feats_list[0].device
def flatten(tensor, target_size=None):
if target_size is not None and fit_pca is None:
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
B, C, H, W = tensor.shape
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
if len(image_feats_list) > 1 and fit_pca is None:
target_size = image_feats_list[0].shape[2]
else:
target_size = None
flattened_feats = []
for feats in image_feats_list:
flattened_feats.append(flatten(feats, target_size))
x = torch.cat(flattened_feats, dim=0)
if fit_pca is None:
# fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy()))
fit_pca = TorchPCA(n_components=dim).fit(x)
reduced_feats = []
for feats in image_feats_list:
# x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
x_red = fit_pca.transform(flatten(feats))
x_red -= x_red.min(dim=0, keepdim=True).values
x_red /= x_red.max(dim=0, keepdim=True).values
B, C, H, W = feats.shape
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
return reduced_feats, fit_pca
def my_load_audio(audio_file):
loaded_waveform, obs_sr = torchaudio.load(audio_file)
loaded_waveform = loaded_waveform[0]
neg_waveform, neg_obs_sr = None, None
from data.AVDatasets import prep_waveform
(waveform,
spectrogram,
audio_length,
total_length,
original_length,
mask,
pos_mask) = prep_waveform(
loaded_waveform,
obs_sr,
10,
128,
-4.268,
9.138,
16000,
True,
False,
False,
neg_waveform,
neg_obs_sr,
False,
)
patch_size = 204
n_tiles = spectrogram.shape[1] // patch_size
assert n_tiles == 5
patches = []
for i in range(n_tiles):
patches.append(spectrogram[:, i * patch_size:(i + 1) * patch_size, :])
patches = torch.cat(patches, dim=0).permute(0, 2, 1).unsqueeze(1)
return patches
class ImageBindImageFeaturizer(nn.Module):
def __init__(self, output_path, model=None):
super().__init__()
if model is not None:
self.model = model
else:
self.model = imagebind_huge(output_path, pretrained=True).cuda()
def forward(self, image, include_cls):
inputs = {
ModalityType.VISION: image,
}
patch_tokens, cls_tokens = self.model.forward_features(inputs)[ModalityType.VISION]
patch_tokens = patch_tokens.permute(0, 3, 1, 2)
if include_cls:
return patch_tokens, cls_tokens
else:
return patch_tokens
class ImageBindAudioFeaturizer(nn.Module):
def __init__(self, output_path, model=None):
super().__init__()
if model is not None:
self.model = model
else:
self.model = imagebind_huge(output_path, pretrained=True).cuda()
def forward(self, spec, include_cls):
patch_size = 204
n_tiles = spec.shape[2] // patch_size
assert n_tiles == 5
patches = []
for i in range(n_tiles):
patches.append(spec[:, :, i * patch_size:(i + 1) * patch_size, :])
patches = torch.cat(patches, dim=1).permute(0, 1, 3, 2).unsqueeze(2)
inputs = {
ModalityType.AUDIO: patches,
}
patch_tokens, cls_token = self.model.forward_features(inputs)[ModalityType.AUDIO]
patch_tokens = patch_tokens.permute(0, 4, 2, 1, 3)
b, c, h, p, w = patch_tokens.shape
patch_tokens = patch_tokens.reshape(b, c, h, w * p)
cls_token = cls_token.reshape(b, p, -1).mean(1)
if include_cls:
return patch_tokens, cls_token
else:
return patch_tokens
if __name__ == "__main__":
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Instantiate model
model = imagebind_huge("../../", pretrained=True)
model.eval()
model.to(device)
audio_inputs = torch.cat([my_load_audio(af).unsqueeze(0) for af in audio_paths], dim=0).cuda()
# Load data
inputs = {
ModalityType.VISION: load_and_transform_vision_data(image_paths, device),
# ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, device, clip_duration=2, clips_per_video=5),
ModalityType.AUDIO: audio_inputs,
}
with torch.no_grad():
embeddings = model.forward_features(inputs)
cls_tokens = model.forward(inputs)
audio_cls_token = embeddings["audio"][1].reshape(3, 5, -1).mean(1)
sims1 = torch.einsum(
"bc,dc->bd",
embeddings["vision"][1],
audio_cls_token)
print(torch.softmax(sims1, dim=1).cpu().numpy())
#
# sims2 = torch.einsum(
# "bc,dc->bd",
# embeddings["vision"].mean(1).mean(1),
# embeddings["audio"].mean(1).mean(1).mean(1)
# )
#
# print(torch.softmax(sims2, dim=1).cpu().numpy())
#
#
# img_num = 0
# img_feats = F.normalize(embeddings["vision"].permute(0, 3, 1, 2), dim=1)
# [red_img_feats], fit_pca = pca([img_feats])
#
# fig, axes = plt.subplots(2, 2, figsize=(4 * 2, 4 * 2))
# axes[0][0].imshow(unnorm(inputs["vision"][0].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
# axes[0][1].imshow(unnorm(inputs["vision"][1].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
# axes[1][0].imshow(red_img_feats[0].permute(1, 2, 0).detach().cpu())
# axes[1][1].imshow(red_img_feats[1].permute(1, 2, 0).detach().cpu())
# plt.tight_layout()
# plt.show()
#
audio_embs = F.normalize(embeddings["audio"][0], dim=-1)
b, n, h, w, c = audio_embs.shape
audio_embs = audio_embs.permute(0, 4, 2, 1, 3).reshape(b, c, h, w * n)
b, n, c, h, w = inputs["audio"].shape
audio_inputs = inputs["audio"].permute(0, 2, 3, 1, 4).reshape(b, c, h, w * n)
print("here")
for img_num in range(3):
[red_audio], fit_pca = pca([audio_embs[img_num].unsqueeze(0)])
fig, axes = plt.subplots(2, 1, figsize=(10 * 1, 4 * 2))
axes[0].imshow(audio_inputs[img_num, 0].detach().cpu())
axes[1].imshow(red_audio[0].permute(1, 2, 0).detach().cpu())
plt.tight_layout()
plt.show()
print("here")