import gzip import html import io import logging import math import os from functools import lru_cache from functools import partial from types import SimpleNamespace from typing import Callable, List from typing import Optional import einops import ftfy import numpy as np import regex as re import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import torchaudio import torchvision.transforms as T from PIL import Image from timm.models.layers import DropPath, trunc_normal_ from torchvision import transforms import matplotlib.pyplot as plt from iopath.common.file_io import g_pathmgr class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, # can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class MultiheadAttention(nn.MultiheadAttention): def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] class ViTAttention(Attention): def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): assert attn_mask is None return super().forward(x) class BlockWithMasking(nn.Module): def __init__( self, dim: int, attn_target: Callable, mlp_ratio: int = 4, act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, ffn_dropout_rate: float = 0.0, drop_path: float = 0.0, layer_scale_type: str = None, layer_scale_init_value: float = 1e-4, ): super().__init__() assert not isinstance( attn_target, nn.Module ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" self.attn = attn_target() if drop_path > 0.0: self.drop_path = DropPath(drop_path) else: self.drop_path = nn.Identity() self.norm_1 = norm_layer(dim) mlp_hidden_dim = int(mlp_ratio * dim) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=ffn_dropout_rate, ) self.norm_2 = norm_layer(dim) self.layer_scale_type = layer_scale_type if self.layer_scale_type is not None: assert self.layer_scale_type in [ "per_channel", "scalar", ], f"Found Layer scale type {self.layer_scale_type}" if self.layer_scale_type == "per_channel": # one gamma value per channel gamma_shape = [1, 1, dim] elif self.layer_scale_type == "scalar": # single gamma value for all channels gamma_shape = [1, 1, 1] # two gammas: for each part of the fwd in the encoder self.layer_scale_gamma1 = nn.Parameter( torch.ones(size=gamma_shape) * layer_scale_init_value, requires_grad=True, ) self.layer_scale_gamma2 = nn.Parameter( torch.ones(size=gamma_shape) * layer_scale_init_value, requires_grad=True, ) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): if self.layer_scale_type is None: x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) x = x + self.drop_path(self.mlp(self.norm_2(x))) else: x = ( x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) * self.layer_scale_gamma1 ) x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 return x _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) class SimpleTransformer(nn.Module): def __init__( self, attn_target: Callable, embed_dim: int, num_blocks: int, block: Callable = BlockWithMasking, pre_transformer_layer: Callable = None, post_transformer_layer: Callable = None, drop_path_rate: float = 0.0, drop_path_type: str = "progressive", norm_layer: Callable = _LAYER_NORM, mlp_ratio: int = 4, ffn_dropout_rate: float = 0.0, layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" layer_scale_init_value: float = 1e-4, # from cait; float weight_init_style: str = "jax", # possible values jax or pytorch ): """ Simple Transformer with the following features 1. Supports masked attention 2. Supports DropPath 3. Supports LayerScale 4. Supports Dropout in Attention and FFN 5. Makes few assumptions about the input except that it is a Tensor """ super().__init__() self.pre_transformer_layer = pre_transformer_layer if drop_path_type == "progressive": dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] elif drop_path_type == "uniform": dpr = [drop_path_rate for i in range(num_blocks)] else: raise ValueError(f"Unknown drop_path_type: {drop_path_type}") self.blocks = nn.Sequential( *[ block( dim=embed_dim, attn_target=attn_target, mlp_ratio=mlp_ratio, ffn_dropout_rate=ffn_dropout_rate, drop_path=dpr[i], norm_layer=norm_layer, layer_scale_type=layer_scale_type, layer_scale_init_value=layer_scale_init_value, ) for i in range(num_blocks) ] ) self.post_transformer_layer = post_transformer_layer self.weight_init_style = weight_init_style self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): if self.weight_init_style == "jax": # Based on MAE and official Jax ViT implementation torch.nn.init.xavier_uniform_(m.weight) elif self.weight_init_style == "pytorch": # PyTorch ViT uses trunc_normal_ trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward( self, tokens: torch.Tensor, attn_mask: torch.Tensor = None, use_checkpoint: bool = False, checkpoint_every_n: int = 1, checkpoint_blk_ids: List[int] = None, ): """ Inputs - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) - attn: mask of shape L x L Output - x: data of shape N x L x D (or L x N x D depending on the attention implementation) """ if self.pre_transformer_layer: tokens = self.pre_transformer_layer(tokens) if use_checkpoint and checkpoint_blk_ids is None: checkpoint_blk_ids = [ blk_id for blk_id in range(len(self.blocks)) if blk_id % checkpoint_every_n == 0 ] if checkpoint_blk_ids: checkpoint_blk_ids = set(checkpoint_blk_ids) for blk_id, blk in enumerate(self.blocks): if use_checkpoint and blk_id in checkpoint_blk_ids: tokens = checkpoint.checkpoint( blk, tokens, attn_mask, use_reentrant=False ) else: tokens = blk(tokens, attn_mask=attn_mask) if self.post_transformer_layer: tokens = self.post_transformer_layer(tokens) return tokens def get_sinusoid_encoding_table(n_position, d_hid): """Sinusoid position encoding table""" # TODO: make it with torch instead of numpy def get_position_angle_vec(position): return [ position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ] sinusoid_table = np.array( [get_position_angle_vec(pos_i) for pos_i in range(n_position)] ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): N = pos_embed.shape[1] if N == target_spatial_size: return pos_embed dim = pos_embed.shape[-1] # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) pos_embed = nn.functional.interpolate( pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( 0, 3, 1, 2 ), scale_factor=math.sqrt(target_spatial_size / N), mode="bicubic", ) if updated: pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return pos_embed def interpolate_pos_encoding( npatch_per_img, pos_embed, patches_layout, input_shape=None, first_patch_idx=1, ): assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists if npatch_per_img == N: return pos_embed # assert ( # patches_layout[-1] == patches_layout[-2] # ), "Interpolation of pos embed not supported for non-square layouts" class_emb = pos_embed[:, :first_patch_idx] pos_embed = pos_embed[:, first_patch_idx:] if input_shape is None or patches_layout[0] == 1: # simple 2D pos embedding, no temporal component pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) elif patches_layout[0] > 1: # pos embed has a temporal component assert len(input_shape) == 4, "temporal interpolation not supported" # we only support 2D interpolation in this case num_frames = patches_layout[0] num_spatial_tokens = patches_layout[1] * patches_layout[2] pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) # interpolate embedding for zeroth frame pos_embed = interpolate_pos_encoding_2d( npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) ) else: raise ValueError("This type of interpolation isn't implemented") return torch.cat((class_emb, pos_embed), dim=1) def _get_pos_embedding( npatch_per_img, pos_embed, patches_layout, input_shape, first_patch_idx=1, ): pos_embed = interpolate_pos_encoding( npatch_per_img, pos_embed, patches_layout, input_shape=input_shape, first_patch_idx=first_patch_idx, ) return pos_embed class VerboseNNModule(nn.Module): """ Wrapper around nn.Module that prints registered buffers and parameter names. """ @staticmethod def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: st = ( "(" + name + "): " + "tensor(" + str(tuple(tensor[1].shape)) + ", requires_grad=" + str(tensor[1].requires_grad) + ")\n" ) return st def extra_repr(self) -> str: named_modules = set() for p in self.named_modules(): named_modules.update([p[0]]) named_modules = list(named_modules) string_repr = "" for p in self.named_parameters(): name = p[0].split(".")[0] if name not in named_modules: string_repr += self.get_readable_tensor_repr(name, p) for p in self.named_buffers(): name = p[0].split(".")[0] string_repr += self.get_readable_tensor_repr(name, p) return string_repr class PatchEmbedGeneric(nn.Module): """ PatchEmbed from Hydra """ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): super().__init__() if len(proj_stem) > 1: self.proj = nn.Sequential(*proj_stem) else: # Special case to be able to load pre-trained models that were # trained with a standard stem self.proj = proj_stem[0] self.norm_layer = norm_layer def get_patch_layout(self, img_size): with torch.no_grad(): dummy_img = torch.zeros( [ 1, ] + img_size ) dummy_out = self.proj(dummy_img) embed_dim = dummy_out.shape[1] patches_layout = tuple(dummy_out.shape[2:]) num_patches = np.prod(patches_layout) return patches_layout, num_patches, embed_dim def forward(self, x): x = self.proj(x) # B C (T) H W -> B (T)HW C x = x.flatten(2).transpose(1, 2) if self.norm_layer is not None: x = self.norm_layer(x) return x class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): def __init__( self, patches_layout: List, num_patches: int, num_cls_tokens: int, embed_dim: int, learnable: bool, ) -> None: super().__init__() self.num_cls_tokens = num_cls_tokens self.patches_layout = patches_layout self.num_patches = num_patches self.num_tokens = num_cls_tokens + num_patches self.learnable = learnable if self.learnable: self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) trunc_normal_(self.pos_embed, std=0.02) else: self.register_buffer( "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) ) def get_pos_embedding(self, vision_input, all_vision_tokens): input_shape = vision_input.shape pos_embed = _get_pos_embedding( all_vision_tokens.size(1) - self.num_cls_tokens, pos_embed=self.pos_embed, patches_layout=self.patches_layout, input_shape=input_shape, first_patch_idx=self.num_cls_tokens, ) return pos_embed class RGBDTPreprocessor(VerboseNNModule): def __init__( self, rgbt_stem: PatchEmbedGeneric, depth_stem: PatchEmbedGeneric, img_size: List = (3, 224, 224), num_cls_tokens: int = 1, pos_embed_fn: Callable = None, use_type_embed: bool = False, init_param_style: str = "openclip", ) -> None: super().__init__() stem = rgbt_stem if rgbt_stem is not None else depth_stem ( self.patches_layout, self.num_patches, self.embed_dim, ) = stem.get_patch_layout(img_size) self.rgbt_stem = rgbt_stem self.depth_stem = depth_stem self.use_pos_embed = pos_embed_fn is not None self.use_type_embed = use_type_embed self.num_cls_tokens = num_cls_tokens if self.use_pos_embed: self.pos_embedding_helper = pos_embed_fn( patches_layout=self.patches_layout, num_cls_tokens=num_cls_tokens, num_patches=self.num_patches, embed_dim=self.embed_dim, ) if self.num_cls_tokens > 0: self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, self.embed_dim) ) if self.use_type_embed: self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style): if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim ** -0.5 if self.use_pos_embed: nn.init.normal_(self.pos_embedding_helper.pos_embed) self.pos_embedding_helper.pos_embed *= scale if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") if self.use_type_embed: nn.init.normal_(self.type_embed) def get_pos_emb_2(self, input, stem): patches = stem.proj(input) target_size = patches.shape[-2:] original_size = list(self.pos_embedding_helper.patches_layout)[-2:] orig_ce = self.pos_embedding_helper.pos_embed[:, 0, :] orig_pe = ((self.pos_embedding_helper.pos_embed[:, 1:, :] .reshape(1, *original_size, self.embed_dim)) .permute(0, 3, 1, 2)) new_pe = F.interpolate(orig_pe, size=target_size, mode="bicubic") new_full_pe = torch.cat([orig_ce.unsqueeze(1), new_pe.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)], dim=1) return new_full_pe def tokenize_input_and_cls_pos(self, input, stem, mask): # tokens is of shape B x L x D tokens = stem(input) assert tokens.ndim == 3 assert tokens.shape[2] == self.embed_dim B = tokens.shape[0] if self.num_cls_tokens > 0: class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks tokens = torch.cat((class_tokens, tokens), dim=1) if self.use_pos_embed: pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) # pos_embed = self.get_pos_emb_2(input, stem) tokens = tokens + pos_embed if self.use_type_embed: tokens = tokens + self.type_embed.expand(B, -1, -1) return tokens def forward(self, vision=None, depth=None, patch_mask=None): if patch_mask is not None: raise NotImplementedError() if vision is not None: vision_tokens = self.tokenize_input_and_cls_pos( vision, self.rgbt_stem, patch_mask ) if depth is not None: depth_tokens = self.tokenize_input_and_cls_pos( depth, self.depth_stem, patch_mask ) # aggregate tokens if vision is not None and depth is not None: final_tokens = vision_tokens + depth_tokens else: final_tokens = vision_tokens if vision is not None else depth_tokens return_dict = { "trunk": { "tokens": final_tokens, }, "head": {}, } return return_dict class AudioPreprocessor(RGBDTPreprocessor): def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) def forward(self, audio=None): return super().forward(vision=audio) class ThermalPreprocessor(RGBDTPreprocessor): def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) def forward(self, thermal=None): return super().forward(vision=thermal) def build_causal_attention_mask(context_length): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(context_length, context_length, requires_grad=False) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask class TextPreprocessor(VerboseNNModule): def __init__( self, vocab_size: int, context_length: int, embed_dim: int, causal_masking: bool, supply_seq_len_to_head: bool = True, num_cls_tokens: int = 0, init_param_style: str = "openclip", ) -> None: super().__init__() self.vocab_size = vocab_size self.context_length = context_length self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.pos_embed = nn.Parameter( torch.empty(1, self.context_length + num_cls_tokens, embed_dim) ) self.causal_masking = causal_masking if self.causal_masking: mask = build_causal_attention_mask(self.context_length) # register the mask as a buffer, so it can be moved to the right device self.register_buffer("mask", mask) self.supply_seq_len_to_head = supply_seq_len_to_head self.num_cls_tokens = num_cls_tokens self.embed_dim = embed_dim if num_cls_tokens > 0: assert self.causal_masking is False, "Masking + CLS token isn't implemented" self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, embed_dim) ) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style="openclip"): # OpenCLIP style initialization nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.pos_embed, std=0.01) if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim ** -0.5 if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") def forward(self, text): # text tokens are of shape B x L x D text_tokens = self.token_embedding(text) # concat CLS tokens if any if self.num_cls_tokens > 0: B = text_tokens.shape[0] class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks text_tokens = torch.cat((class_tokens, text_tokens), dim=1) text_tokens = text_tokens + self.pos_embed return_dict = { "trunk": { "tokens": text_tokens, }, "head": {}, } # Compute sequence length after adding CLS tokens if self.supply_seq_len_to_head: text_lengths = text.argmax(dim=-1) return_dict["head"] = { "seq_len": text_lengths, } if self.causal_masking: return_dict["trunk"].update({"attn_mask": self.mask}) return return_dict class Im2Video(nn.Module): """Convert an image into a trivial video.""" def __init__(self, time_dim=2): super().__init__() self.time_dim = time_dim def forward(self, x): if x.ndim == 4: # B, C, H, W -> B, C, T, H, W return x.unsqueeze(self.time_dim) elif x.ndim == 5: return x else: raise ValueError(f"Dimension incorrect {x.shape}") class PadIm2Video(Im2Video): def __init__(self, ntimes, pad_type, time_dim=2): super().__init__(time_dim=time_dim) assert ntimes > 0 assert pad_type in ["zero", "repeat"] self.ntimes = ntimes self.pad_type = pad_type def forward(self, x): x = super().forward(x) if x.shape[self.time_dim] == 1: if self.pad_type == "repeat": new_shape = [1] * len(x.shape) new_shape[self.time_dim] = self.ntimes x = x.repeat(new_shape) elif self.pad_type == "zero": padarg = [0, 0] * len(x.shape) padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] x = nn.functional.pad(x, padarg) return x # Modified from github.com/openai/CLIP @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = ( list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 for b in range(2 ** 8): if b not in bs: bs.append(b) cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str, context_length=77): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} with g_pathmgr.open(bpe_path, "rb") as fh: bpe_bytes = io.BytesIO(fh.read()) merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") merges = merges[1: 49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v + "" for v in vocab] for merge in merges: vocab.append("".join(merge)) vocab.extend(["<|startoftext|>", "<|endoftext|>"]) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = { "<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>", } self.pat = re.compile( r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE, ) self.context_length = context_length def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + (token[-1] + "",) pairs = get_pairs(word) if not pairs: return token + "" while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) bpe_tokens.extend( self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") ) return bpe_tokens def decode(self, tokens): text = "".join([self.decoder[token] for token in tokens]) text = ( bytearray([self.byte_decoder[c] for c in text]) .decode("utf-8", errors="replace") .replace("", " ") ) return text def __call__(self, texts, context_length=None): if not context_length: context_length = self.context_length if isinstance(texts, str): texts = [texts] sot_token = self.encoder["<|startoftext|>"] eot_token = self.encoder["<|endoftext|>"] all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): tokens = tokens[:context_length] result[i, : len(tokens)] = torch.tensor(tokens) if len(result) == 1: return result[0] return result class Normalize(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.dim = dim def forward(self, x): return torch.nn.functional.normalize(x, dim=self.dim, p=2) class LearnableLogitScaling(nn.Module): def __init__( self, logit_scale_init: float = 1 / 0.07, learnable: bool = True, max_logit_scale: float = 100, ) -> None: super().__init__() self.max_logit_scale = max_logit_scale self.logit_scale_init = logit_scale_init self.learnable = learnable log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) if learnable: self.log_logit_scale = nn.Parameter(log_logit_scale) else: self.register_buffer("log_logit_scale", log_logit_scale) def forward(self, x): return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x def extra_repr(self): st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" return st class EinOpsRearrange(nn.Module): def __init__(self, rearrange_expr: str, **kwargs) -> None: super().__init__() self.rearrange_expr = rearrange_expr self.kwargs = kwargs def forward(self, x): assert isinstance(x, torch.Tensor) return einops.rearrange(x, self.rearrange_expr, **self.kwargs) class IMUPreprocessor(VerboseNNModule): def __init__( self, kernel_size: int, imu_stem: PatchEmbedGeneric, embed_dim: int, img_size: List = (6, 2000), num_cls_tokens: int = 1, pos_embed_fn: Callable = None, init_param_style: str = "openclip", ) -> None: super().__init__() stem = imu_stem self.imu_stem = imu_stem self.embed_dim = embed_dim self.use_pos_embed = pos_embed_fn is not None self.num_cls_tokens = num_cls_tokens self.kernel_size = kernel_size self.pos_embed = nn.Parameter( torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) ) if self.num_cls_tokens > 0: self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, self.embed_dim) ) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style): nn.init.normal_(self.pos_embed, std=0.01) if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim ** -0.5 if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") def tokenize_input_and_cls_pos(self, input, stem): # tokens is of shape B x L x D tokens = stem.norm_layer(stem.proj(input)) assert tokens.ndim == 3 assert tokens.shape[2] == self.embed_dim B = tokens.shape[0] if self.num_cls_tokens > 0: class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks tokens = torch.cat((class_tokens, tokens), dim=1) if self.use_pos_embed: tokens = tokens + self.pos_embed return tokens def forward(self, imu): # Patchify imu = imu.unfold( -1, self.kernel_size, self.kernel_size, ).permute(0, 2, 1, 3) imu = imu.reshape(imu.size(0), imu.size(1), -1) imu_tokens = self.tokenize_input_and_cls_pos( imu, self.imu_stem, ) return_dict = { "trunk": { "tokens": imu_tokens, }, "head": {}, } return return_dict def cast_if_src_dtype( tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype ): updated = False if tensor.dtype == src_dtype: tensor = tensor.to(dtype=tgt_dtype) updated = True return tensor, updated class QuickGELU(nn.Module): # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class SelectElement(nn.Module): def __init__(self, index) -> None: super().__init__() self.index = index def forward(self, x): assert x.ndim >= 3 return x[:, self.index, ...] class ReshapeSpatial(nn.Module): def __init__(self, shape) -> None: super().__init__() self.h, self.w = shape def forward(self, x): assert x.ndim >= 3 return x[:, 1:, ...].reshape(x.shape[0], self.h, self.w, -1), x[:, 0, :] class ReshapeAudio(nn.Module): def __init__(self, shape) -> None: super().__init__() self.h, self.w = shape def forward(self, x): assert x.ndim == 3 return x[:, 1:, :].reshape(-1, 5, self.h, self.w, x.shape[-1]), x[:, 0, :] class ApplyTwice(nn.Module): def __init__(self, module) -> None: super().__init__() self.module = module def forward(self, pair): return self.module(pair[0]), self.module(pair[1]) class SelectEOSAndProject(nn.Module): """ Text Pooling used in OpenCLIP """ def __init__(self, proj: nn.Module) -> None: super().__init__() self.proj = proj def forward(self, x, seq_len): assert x.ndim == 3 # x is of shape B x L x D # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), seq_len] x = self.proj(x) return x ModalityType = SimpleNamespace( VISION="vision", TEXT="text", AUDIO="audio", THERMAL="thermal", DEPTH="depth", IMU="imu", ) class ImageBindModel(nn.Module): def __init__( self, video_frames=2, kernel_size=(2, 14, 14), audio_kernel_size=16, audio_stride=10, out_embed_dim=768, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_num_mel_bins=128, audio_target_len=204, audio_drop_path=0.1, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, depth_embed_dim=384, depth_kernel_size=16, depth_num_blocks=12, depth_num_heads=8, depth_drop_path=0.0, thermal_embed_dim=768, thermal_kernel_size=16, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_kernel_size=8, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, ): super().__init__() self.modality_preprocessors = self._create_modality_preprocessors( video_frames, vision_embed_dim, kernel_size, text_embed_dim, audio_embed_dim, audio_kernel_size, audio_stride, audio_num_mel_bins, audio_target_len, depth_embed_dim, depth_kernel_size, thermal_embed_dim, thermal_kernel_size, imu_embed_dim, ) self.modality_trunks = self._create_modality_trunks( vision_embed_dim, vision_num_blocks, vision_num_heads, text_embed_dim, text_num_blocks, text_num_heads, audio_embed_dim, audio_num_blocks, audio_num_heads, audio_drop_path, depth_embed_dim, depth_num_blocks, depth_num_heads, depth_drop_path, thermal_embed_dim, thermal_num_blocks, thermal_num_heads, thermal_drop_path, imu_embed_dim, imu_num_blocks, imu_num_heads, imu_drop_path, ) self.modality_heads = self._create_modality_heads( out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ) self.modality_postprocessors = self._create_modality_postprocessors( out_embed_dim ) def _create_modality_preprocessors( self, video_frames=2, vision_embed_dim=1024, kernel_size=(2, 14, 14), text_embed_dim=768, audio_embed_dim=768, audio_kernel_size=16, audio_stride=10, audio_num_mel_bins=128, audio_target_len=204, depth_embed_dim=768, depth_kernel_size=16, thermal_embed_dim=768, thermal_kernel_size=16, imu_embed_dim=512, ): rgbt_stem = PatchEmbedGeneric( proj_stem=[ PadIm2Video(pad_type="repeat", ntimes=2), nn.Conv3d( in_channels=3, kernel_size=kernel_size, out_channels=vision_embed_dim, stride=kernel_size, bias=False, ), ] ) rgbt_preprocessor = RGBDTPreprocessor( img_size=[3, video_frames, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=rgbt_stem, depth_stem=None, ) text_preprocessor = TextPreprocessor( context_length=77, vocab_size=49408, embed_dim=text_embed_dim, causal_masking=True, ) audio_stem = PatchEmbedGeneric( proj_stem=[ nn.Conv2d( in_channels=1, kernel_size=audio_kernel_size, stride=audio_stride, out_channels=audio_embed_dim, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), ) audio_preprocessor = AudioPreprocessor( img_size=[1, audio_num_mel_bins, audio_target_len], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), audio_stem=audio_stem, ) # depth_stem = PatchEmbedGeneric( # [ # nn.Conv2d( # kernel_size=depth_kernel_size, # in_channels=1, # out_channels=depth_embed_dim, # stride=depth_kernel_size, # bias=False, # ), # ], # norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), # ) # # depth_preprocessor = RGBDTPreprocessor( # img_size=[1, 224, 224], # num_cls_tokens=1, # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), # rgbt_stem=None, # depth_stem=depth_stem, # ) # # thermal_stem = PatchEmbedGeneric( # [ # nn.Conv2d( # kernel_size=thermal_kernel_size, # in_channels=1, # out_channels=thermal_embed_dim, # stride=thermal_kernel_size, # bias=False, # ), # ], # norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), # ) # thermal_preprocessor = ThermalPreprocessor( # img_size=[1, 224, 224], # num_cls_tokens=1, # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), # thermal_stem=thermal_stem, # ) # # imu_stem = PatchEmbedGeneric( # [ # nn.Linear( # in_features=48, # out_features=imu_embed_dim, # bias=False, # ), # ], # norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), # ) # # imu_preprocessor = IMUPreprocessor( # img_size=[6, 2000], # num_cls_tokens=1, # kernel_size=8, # embed_dim=imu_embed_dim, # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), # imu_stem=imu_stem, # ) modality_preprocessors = { ModalityType.VISION: rgbt_preprocessor, ModalityType.TEXT: text_preprocessor, ModalityType.AUDIO: audio_preprocessor, # ModalityType.DEPTH: depth_preprocessor, # ModalityType.THERMAL: thermal_preprocessor, # ModalityType.IMU: imu_preprocessor, } return nn.ModuleDict(modality_preprocessors) def _create_modality_trunks( self, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_drop_path=0.0, depth_embed_dim=768, depth_num_blocks=12, depth_num_heads=12, depth_drop_path=0.0, thermal_embed_dim=768, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, ): def instantiate_trunk( embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path ): return SimpleTransformer( embed_dim=embed_dim, num_blocks=num_blocks, ffn_dropout_rate=0.0, drop_path_rate=drop_path, attn_target=partial( MultiheadAttention, embed_dim=embed_dim, num_heads=num_heads, bias=True, add_bias_kv=add_bias_kv, ), pre_transformer_layer=nn.Sequential( nn.LayerNorm(embed_dim, eps=1e-6) if pre_transformer_ln else nn.Identity(), EinOpsRearrange("b l d -> l b d"), ), post_transformer_layer=EinOpsRearrange("l b d -> b l d"), ) modality_trunks = {} modality_trunks[ModalityType.VISION] = instantiate_trunk( vision_embed_dim, vision_num_blocks, vision_num_heads, pre_transformer_ln=True, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.TEXT] = instantiate_trunk( text_embed_dim, text_num_blocks, text_num_heads, pre_transformer_ln=False, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.AUDIO] = instantiate_trunk( audio_embed_dim, audio_num_blocks, audio_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=audio_drop_path, ) # modality_trunks[ModalityType.DEPTH] = instantiate_trunk( # depth_embed_dim, # depth_num_blocks, # depth_num_heads, # pre_transformer_ln=False, # add_bias_kv=True, # drop_path=depth_drop_path, # ) # modality_trunks[ModalityType.THERMAL] = instantiate_trunk( # thermal_embed_dim, # thermal_num_blocks, # thermal_num_heads, # pre_transformer_ln=False, # add_bias_kv=True, # drop_path=thermal_drop_path, # ) # modality_trunks[ModalityType.IMU] = instantiate_trunk( # imu_embed_dim, # imu_num_blocks, # imu_num_heads, # pre_transformer_ln=False, # add_bias_kv=True, # drop_path=imu_drop_path, # ) return nn.ModuleDict(modality_trunks) def _create_modality_heads( self, out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ): modality_heads = {} modality_heads[ModalityType.VISION] = nn.Sequential( nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(vision_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.TEXT] = SelectEOSAndProject( proj=nn.Sequential( nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), nn.Linear(text_embed_dim, out_embed_dim, bias=False), ) ) modality_heads[ModalityType.AUDIO] = nn.Sequential( nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(audio_embed_dim, out_embed_dim, bias=False), ) # modality_heads[ModalityType.DEPTH] = nn.Sequential( # nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), # SelectElement(index=0), # nn.Linear(depth_embed_dim, out_embed_dim, bias=False), # ) # # modality_heads[ModalityType.THERMAL] = nn.Sequential( # nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), # SelectElement(index=0), # nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), # ) # # modality_heads[ModalityType.IMU] = nn.Sequential( # nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), # SelectElement(index=0), # nn.Dropout(p=0.5), # nn.Linear(imu_embed_dim, out_embed_dim, bias=False), # ) return nn.ModuleDict(modality_heads) def _create_modality_postprocessors(self, out_embed_dim): modality_postprocessors = {} modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) modality_postprocessors[ModalityType.TEXT] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(learnable=True) ) modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=20.0, learnable=False), ) # modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( # Normalize(dim=-1), # LearnableLogitScaling(logit_scale_init=5.0, learnable=False), # ) # modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( # Normalize(dim=-1), # LearnableLogitScaling(logit_scale_init=10.0, learnable=False), # ) # modality_postprocessors[ModalityType.IMU] = nn.Sequential( # Normalize(dim=-1), # LearnableLogitScaling(logit_scale_init=5.0, learnable=False), # ) return nn.ModuleDict(modality_postprocessors) def forward(self, inputs): outputs = {} for modality_key, modality_value in inputs.items(): reduce_list = ( modality_value.ndim >= 5 ) # Audio and Video inputs consist of multiple clips if reduce_list: B, S = modality_value.shape[:2] modality_value = modality_value.reshape( B * S, *modality_value.shape[2:] ) if modality_value is not None: modality_value = self.modality_preprocessors[modality_key]( **{modality_key: modality_value} ) trunk_inputs = modality_value["trunk"] head_inputs = modality_value["head"] modality_value = self.modality_trunks[modality_key](**trunk_inputs) modality_value = self.modality_heads[modality_key]( modality_value, **head_inputs ) modality_value = self.modality_postprocessors[modality_key]( modality_value ) if reduce_list: modality_value = modality_value.reshape(B, S, -1) modality_value = modality_value.mean(dim=1) outputs[modality_key] = modality_value return outputs def reconfigure_head(self, k, v): if k == ModalityType.AUDIO: return torch.nn.Sequential(v[0], v[2]) elif k == ModalityType.VISION: return torch.nn.Sequential(v[0], v[2]) else: return v def forward_features(self, inputs): outputs = {} reconfigured_heads = {k: self.reconfigure_head(k, v) for k, v in self.modality_heads.items()} for modality_key, modality_value in inputs.items(): reduce_list = ( modality_value.ndim >= 5 ) # Audio and Video inputs consist of multiple clips if reduce_list: B, S = modality_value.shape[:2] modality_value = modality_value.reshape( B * S, *modality_value.shape[2:] ) if modality_value is not None: modality_value = self.modality_preprocessors[modality_key]( **{modality_key: modality_value} ) trunk_inputs = modality_value["trunk"] head_inputs = modality_value["head"] modality_value = self.modality_trunks[modality_key](**trunk_inputs) modality_value = reconfigured_heads[modality_key]( modality_value, **head_inputs ) modality_value = self.modality_postprocessors[modality_key]( modality_value ) if modality_key == ModalityType.AUDIO: modality_value = ReshapeAudio((12, 19))(modality_value) elif modality_key == ModalityType.VISION: modality_value = ReshapeSpatial((16, 16))(modality_value) outputs[modality_key] = modality_value return outputs def imagebind_huge(output_path, pretrained=False): model = ImageBindModel( vision_embed_dim=1280, vision_num_blocks=32, vision_num_heads=16, text_embed_dim=1024, text_num_blocks=24, text_num_heads=16, out_embed_dim=1024, audio_drop_path=0.1, imu_drop_path=0.7, ) if pretrained: path = os.path.join(output_path, 'models/imagebind_huge.pth') if not os.path.exists(path): print(f"Downloading imagebind weights to {path} ...") os.makedirs(os.path.dirname(path), exist_ok=True) torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", path, progress=True, ) model.load_state_dict(torch.load(path), strict=False) return model DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 waveform -= waveform.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform, htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type="hanning", num_mel_bins=num_mel_bins, dither=0.0, frame_length=25, frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, ) # Convert to [mel_bins, num_frames] shape fbank = fbank.transpose(0, 1) # Pad to target_length n_frames = fbank.size(1) p = target_length - n_frames # if p is too large (say >20%), flash a warning if abs(p) / n_frames > 0.2: logging.warning( "Large gap between audio n_frames(%d) and " "target_length (%d). Is the audio_target_length " "setting correct?", n_frames, target_length, ) # cut and pad if p > 0: fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) elif p < 0: fbank = fbank[:, 0:target_length] # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 # channel image fbank = fbank.unsqueeze(0) return fbank def get_clip_timepoints(clip_sampler, duration): # Read out all clips in this video all_clips_timepoints = [] is_last_clip = False end = 0.0 while not is_last_clip: start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints def load_and_transform_vision_data(image_paths, device): if image_paths is None: return None image_ouputs = [] for image_path in image_paths: data_transform = transforms.Compose( [ transforms.Resize( 224, interpolation=transforms.InterpolationMode.BICUBIC ), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) with open(image_path, "rb") as fopen: image = Image.open(fopen).convert("RGB") image = data_transform(image).to(device) image_ouputs.append(image) return torch.stack(image_ouputs, dim=0) def load_and_transform_audio_data( audio_paths, device, num_mel_bins=128, target_length=204, sample_rate=16000, clip_duration=2, clips_per_video=3, mean=-4.268, std=9.138, ): from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler if audio_paths is None: return None audio_outputs = [] clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) for audio_path in audio_paths: waveform, sr = torchaudio.load(audio_path) if sample_rate != sr: waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=sample_rate ) all_clips_timepoints = get_clip_timepoints( clip_sampler, waveform.size(1) / sample_rate ) all_clips = [] for clip_timepoints in all_clips_timepoints: waveform_clip = waveform[ :, int(clip_timepoints[0] * sample_rate): int( clip_timepoints[1] * sample_rate ), ] waveform_melspec = waveform2melspec( waveform_clip, sample_rate, num_mel_bins, target_length ) all_clips.append(waveform_melspec) normalize = transforms.Normalize(mean=mean, std=std) all_clips = [normalize(ac).to(device) for ac in all_clips] all_clips = torch.stack(all_clips, dim=0) audio_outputs.append(all_clips) return torch.stack(audio_outputs, dim=0) class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2 norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class TorchPCA(object): def __init__(self, n_components): self.n_components = n_components def fit(self, X): self.mean_ = X.mean(dim=0) unbiased = X - self.mean_.unsqueeze(0) U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) self.components_ = V.T self.singular_values_ = S return self def transform(self, X): t0 = X - self.mean_.unsqueeze(0) projected = t0 @ self.components_.T return projected def pca(image_feats_list, dim=3, fit_pca=None): # from sklearn.decomposition import PCA device = image_feats_list[0].device def flatten(tensor, target_size=None): if target_size is not None and fit_pca is None: F.interpolate(tensor, (target_size, target_size), mode="bilinear") B, C, H, W = tensor.shape return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() if len(image_feats_list) > 1 and fit_pca is None: target_size = image_feats_list[0].shape[2] else: target_size = None flattened_feats = [] for feats in image_feats_list: flattened_feats.append(flatten(feats, target_size)) x = torch.cat(flattened_feats, dim=0) if fit_pca is None: # fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy())) fit_pca = TorchPCA(n_components=dim).fit(x) reduced_feats = [] for feats in image_feats_list: # x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) x_red = fit_pca.transform(flatten(feats)) x_red -= x_red.min(dim=0, keepdim=True).values x_red /= x_red.max(dim=0, keepdim=True).values B, C, H, W = feats.shape reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) return reduced_feats, fit_pca def my_load_audio(audio_file): loaded_waveform, obs_sr = torchaudio.load(audio_file) loaded_waveform = loaded_waveform[0] neg_waveform, neg_obs_sr = None, None from data.AVDatasets import prep_waveform (waveform, spectrogram, audio_length, total_length, original_length, mask, pos_mask) = prep_waveform( loaded_waveform, obs_sr, 10, 128, -4.268, 9.138, 16000, True, False, False, neg_waveform, neg_obs_sr, False, ) patch_size = 204 n_tiles = spectrogram.shape[1] // patch_size assert n_tiles == 5 patches = [] for i in range(n_tiles): patches.append(spectrogram[:, i * patch_size:(i + 1) * patch_size, :]) patches = torch.cat(patches, dim=0).permute(0, 2, 1).unsqueeze(1) return patches class ImageBindImageFeaturizer(nn.Module): def __init__(self, output_path, model=None): super().__init__() if model is not None: self.model = model else: self.model = imagebind_huge(output_path, pretrained=True).cuda() def forward(self, image, include_cls): inputs = { ModalityType.VISION: image, } patch_tokens, cls_tokens = self.model.forward_features(inputs)[ModalityType.VISION] patch_tokens = patch_tokens.permute(0, 3, 1, 2) if include_cls: return patch_tokens, cls_tokens else: return patch_tokens class ImageBindAudioFeaturizer(nn.Module): def __init__(self, output_path, model=None): super().__init__() if model is not None: self.model = model else: self.model = imagebind_huge(output_path, pretrained=True).cuda() def forward(self, spec, include_cls): patch_size = 204 n_tiles = spec.shape[2] // patch_size assert n_tiles == 5 patches = [] for i in range(n_tiles): patches.append(spec[:, :, i * patch_size:(i + 1) * patch_size, :]) patches = torch.cat(patches, dim=1).permute(0, 1, 3, 2).unsqueeze(2) inputs = { ModalityType.AUDIO: patches, } patch_tokens, cls_token = self.model.forward_features(inputs)[ModalityType.AUDIO] patch_tokens = patch_tokens.permute(0, 4, 2, 1, 3) b, c, h, p, w = patch_tokens.shape patch_tokens = patch_tokens.reshape(b, c, h, w * p) cls_token = cls_token.reshape(b, p, -1).mean(1) if include_cls: return patch_tokens, cls_token else: return patch_tokens if __name__ == "__main__": image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"] audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"] device = "cuda:0" if torch.cuda.is_available() else "cpu" # Instantiate model model = imagebind_huge("../../", pretrained=True) model.eval() model.to(device) audio_inputs = torch.cat([my_load_audio(af).unsqueeze(0) for af in audio_paths], dim=0).cuda() # Load data inputs = { ModalityType.VISION: load_and_transform_vision_data(image_paths, device), # ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, device, clip_duration=2, clips_per_video=5), ModalityType.AUDIO: audio_inputs, } with torch.no_grad(): embeddings = model.forward_features(inputs) cls_tokens = model.forward(inputs) audio_cls_token = embeddings["audio"][1].reshape(3, 5, -1).mean(1) sims1 = torch.einsum( "bc,dc->bd", embeddings["vision"][1], audio_cls_token) print(torch.softmax(sims1, dim=1).cpu().numpy()) # # sims2 = torch.einsum( # "bc,dc->bd", # embeddings["vision"].mean(1).mean(1), # embeddings["audio"].mean(1).mean(1).mean(1) # ) # # print(torch.softmax(sims2, dim=1).cpu().numpy()) # # # img_num = 0 # img_feats = F.normalize(embeddings["vision"].permute(0, 3, 1, 2), dim=1) # [red_img_feats], fit_pca = pca([img_feats]) # # fig, axes = plt.subplots(2, 2, figsize=(4 * 2, 4 * 2)) # axes[0][0].imshow(unnorm(inputs["vision"][0].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu()) # axes[0][1].imshow(unnorm(inputs["vision"][1].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu()) # axes[1][0].imshow(red_img_feats[0].permute(1, 2, 0).detach().cpu()) # axes[1][1].imshow(red_img_feats[1].permute(1, 2, 0).detach().cpu()) # plt.tight_layout() # plt.show() # audio_embs = F.normalize(embeddings["audio"][0], dim=-1) b, n, h, w, c = audio_embs.shape audio_embs = audio_embs.permute(0, 4, 2, 1, 3).reshape(b, c, h, w * n) b, n, c, h, w = inputs["audio"].shape audio_inputs = inputs["audio"].permute(0, 2, 3, 1, 4).reshape(b, c, h, w * n) print("here") for img_num in range(3): [red_audio], fit_pca = pca([audio_embs[img_num].unsqueeze(0)]) fig, axes = plt.subplots(2, 1, figsize=(10 * 1, 4 * 2)) axes[0].imshow(audio_inputs[img_num, 0].detach().cpu()) axes[1].imshow(red_audio[0].permute(1, 2, 0).detach().cpu()) plt.tight_layout() plt.show() print("here")