Spaces:
Running
Running
import random | |
import numpy as np | |
import timm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
import torchvision.transforms as T | |
from PIL import Image | |
from timm.models.layers import to_2tuple, DropPath | |
from timm.models.vision_transformer import Mlp, PatchEmbed, Block | |
import os | |
class Attention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size, cls_token=False): | |
""" | |
grid_size: int of the grid height and width | |
return: | |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
grid_h = np.arange(grid_h_size, dtype=float) | |
grid_w = np.arange(grid_w_size, dtype=float) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_w_size, grid_h_size]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=float) | |
omega /= embed_dim / 2. | |
omega = 1. / 10000 ** omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
# -------------------------------------------------------- | |
# Interpolate position embeddings for high-resolution | |
# References: | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
def interpolate_pos_embed(model, checkpoint_model): | |
if 'pos_embed' in checkpoint_model: | |
pos_embed_checkpoint = checkpoint_model['pos_embed'] | |
embedding_size = pos_embed_checkpoint.shape[-1] | |
num_patches = model.patch_embed.num_patches | |
num_extra_tokens = model.pos_embed.shape[-2] - num_patches | |
# height (== width) for the checkpoint position embedding | |
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) | |
# height (== width) for the new position embedding | |
new_size = int(num_patches ** 0.5) | |
# class_token and dist_token are kept unchanged | |
if orig_size != new_size: | |
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) | |
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] | |
# only the position tokens are interpolated | |
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] | |
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) | |
pos_tokens = torch.nn.functional.interpolate( | |
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) | |
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) | |
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) | |
checkpoint_model['pos_embed'] = new_pos_embed | |
class PatchEmbed(nn.Module): | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
def forward(self, x): | |
x = self.proj(x).flatten(2).transpose(1, 2) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.norm1_a = norm_layer(dim) | |
self.norm1_v = norm_layer(dim) | |
self.attn = Attention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
self.norm2_a = norm_layer(dim) | |
self.norm2_v = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
def forward(self, x, modality=None): | |
if modality == None: | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
elif modality == 'a': | |
x = x + self.drop_path(self.attn(self.norm1_a(x))) | |
x = x + self.drop_path(self.mlp(self.norm2_a(x))) | |
elif modality == 'v': | |
x = x + self.drop_path(self.attn(self.norm1_v(x))) | |
x = x + self.drop_path(self.mlp(self.norm2_v(x))) | |
return x | |
# our main proposed model, for pretraining only, for finetuning, use CAVMAEFT class | |
class CAVMAE(nn.Module): | |
""" CAV-MAE Model | |
""" | |
def __init__(self, img_size=224, audio_length=1024, patch_size=16, in_chans=3, | |
embed_dim=768, modality_specific_depth=11, num_heads=12, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=False): | |
super().__init__() | |
print('A CAV-MAE Model') | |
print('Use norm_pix_loss: ', norm_pix_loss) | |
print('Learnable Positional Embedding: ', tr_pos) | |
# the encoder part | |
# overide the timm package | |
timm.models.vision_transformer.PatchEmbed = PatchEmbed | |
timm.models.vision_transformer.Block = Block | |
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim) | |
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim) | |
self.patch_embed_a.num_patches = int(audio_length * 128 / 256) | |
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches, | |
self.patch_embed_v.num_patches)) | |
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
# audio-branch | |
self.blocks_a = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(modality_specific_depth)]) | |
# visual-branch | |
self.blocks_v = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(modality_specific_depth)]) | |
# unified branch | |
self.blocks_u = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(12 - modality_specific_depth)]) | |
# independent normalization layer for audio, visual, and audio-visual | |
self.norm_a, self.norm_v, self.norm = norm_layer(embed_dim), norm_layer(embed_dim), norm_layer(embed_dim) | |
# the decoder part | |
# Project to lower dimension for the decoder | |
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
# token used for masking | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_modality_a = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_modality_v = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, decoder_embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
self.decoder_pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, decoder_embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
self.decoder_blocks = nn.ModuleList( | |
[Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) | |
for i in range(decoder_depth)]) | |
self.decoder_norm = norm_layer(decoder_embed_dim) | |
# project channel is different for two modality, use two projection head | |
self.decoder_pred_a = nn.Linear(decoder_embed_dim, patch_size ** 2 * 1, bias=True) # decoder to patch | |
self.decoder_pred_v = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch | |
self.norm_pix_loss = norm_pix_loss | |
self.initialize_weights() | |
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape) | |
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape) | |
def initialize_weights(self): | |
# initialize (and freeze) pos_embed by sin-cos embedding, opt the cls token, add by myself | |
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8), | |
cls_token=False) | |
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0)) | |
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5), | |
int(self.patch_embed_v.num_patches ** .5), cls_token=False) | |
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0)) | |
decoder_pos_embed_a = get_2d_sincos_pos_embed(self.decoder_pos_embed_a.shape[-1], 8, | |
int(self.patch_embed_a.num_patches / 8), cls_token=False) | |
self.decoder_pos_embed_a.data.copy_(torch.from_numpy(decoder_pos_embed_a).float().unsqueeze(0)) | |
decoder_pos_embed_v = get_2d_sincos_pos_embed(self.decoder_pos_embed_v.shape[-1], | |
int(self.patch_embed_v.num_patches ** .5), | |
int(self.patch_embed_v.num_patches ** .5), cls_token=False) | |
self.decoder_pos_embed_v.data.copy_(torch.from_numpy(decoder_pos_embed_v).float().unsqueeze(0)) | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed_a.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
w = self.patch_embed_v.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.modality_a, std=.02) | |
torch.nn.init.normal_(self.modality_v, std=.02) | |
torch.nn.init.normal_(self.decoder_modality_a, std=.02) | |
torch.nn.init.normal_(self.decoder_modality_v, std=.02) | |
torch.nn.init.normal_(self.mask_token, std=.02) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def patchify(self, imgs, c, h, w, p=16): | |
""" | |
imgs: (N, 3, H, W) | |
x: (N, L, patch_size**2 *3) | |
""" | |
x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p)) | |
x = torch.einsum('nchpwq->nhwpqc', x) | |
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * c)) | |
return x | |
def unpatchify(self, x, c, h, w, p=16): | |
""" | |
x: (N, L, patch_size**2 *3) | |
imgs: (N, 3, H, W) | |
""" | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) | |
x = torch.einsum('nhwpqc->nchpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) | |
return imgs | |
def random_masking_unstructured(self, x, mask_ratio): | |
""" | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=x.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return x_masked, mask, ids_restore | |
def random_masking_structured(self, x, mask_ratio, t=64, f=8, mode='time'): | |
""" | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
assert L == f * t | |
noise = noise.reshape(N, f, t) # the audio patch is in shape [f,t], not [t,f] | |
if mode == 'time': | |
for i in range(N): | |
mask_t_list = random.sample(range(t), int(t * mask_ratio)) | |
for k in mask_t_list: | |
noise[i, :, k] = 1.1 # large value will be removed | |
elif mode == 'freq': | |
for i in range(N): | |
mask_f_list = random.sample(range(f), int(f * mask_ratio)) | |
for k in mask_f_list: | |
noise[i, k, :] = 1.1 # large value will be removed | |
elif mode == 'tf': | |
for i in range(N): | |
mask_t_list = random.sample(range(t), int(t * mask_ratio * 0.7)) | |
for k in mask_t_list: | |
noise[i, :, k] = 1.1 # large value will be removed | |
for i in range(N): | |
mask_f_list = random.sample(range(f), int(f * mask_ratio * 0.7)) | |
for k in mask_f_list: | |
noise[i, k, :] = 1.1 # large value will be removed | |
noise = noise.reshape(N, L) | |
# sort noise for each sample, only need to manuplate these two ids_shuffle, ids_restore | |
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=x.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return x_masked, mask, ids_restore | |
def forward_encoder(self, a, v, mask_ratio_a, mask_ratio_v, mask_mode='unstructured'): | |
# embed patches | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
# by default, we always use unstructured masking | |
if mask_mode == 'unstructured': | |
a, mask_a, ids_restore_a = self.random_masking_unstructured(a, mask_ratio_a) | |
# in ablation study, we tried time/freq/tf masking. mode in ['freq', 'time', 'tf'] | |
else: | |
a, mask_a, ids_restore_a = self.random_masking_structured(a, mask_ratio_a, t=64, f=8, mode=mask_mode) | |
# visual branch always use unstructured masking | |
v, mask_v, ids_restore_v = self.random_masking_unstructured(v, mask_ratio_v) | |
# audio and visual stream, independent blocks | |
for blk in self.blocks_a: | |
a = blk(a) | |
for blk in self.blocks_v: | |
v = blk(v) | |
x = torch.cat((a, v), dim=1) | |
# unified stream, shared blocks_u, but independent normalization layers | |
for blk in self.blocks_u: | |
x = blk(x) | |
x = self.norm(x) | |
for blk in self.blocks_u: | |
ca = blk(a, 'a') | |
ca = self.norm_a(ca) | |
for blk in self.blocks_u: | |
cv = blk(v, 'v') | |
cv = self.norm_v(cv) | |
return x, mask_a, ids_restore_a, mask_v, ids_restore_v, ca, cv | |
def forward_decoder(self, x, mask_a, ids_restore_a, mask_v, ids_restore_v): | |
x = self.decoder_embed(x) | |
# append mask tokens to sequence | |
# mask_tokens_a in shape [B, #a_mask_token, mask_token_dim], get the number of masked samples from mask_a[0], which is the first example of the batch, all samples should have same number of masked tokens | |
mask_tokens_a = self.mask_token.repeat(x.shape[0], int(mask_a[0].sum()), 1) | |
a_ = torch.cat([x[:, :self.patch_embed_a.num_patches - int(mask_a[0].sum()), :], mask_tokens_a], | |
dim=1) # no cls token | |
a_ = torch.gather(a_, dim=1, index=ids_restore_a.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle | |
# similar for the visual modality | |
mask_tokens_v = self.mask_token.repeat(x.shape[0], int(mask_v[0].sum()), 1) | |
v_ = torch.cat([x[:, self.patch_embed_a.num_patches - int(mask_a[0].sum()):, :], mask_tokens_v], | |
dim=1) # no cls token | |
v_ = torch.gather(v_, dim=1, index=ids_restore_v.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle | |
# concatenate audio and visual tokens | |
x = torch.cat([a_, v_], dim=1) | |
decoder_pos_embed = torch.cat([self.decoder_pos_embed_a, self.decoder_pos_embed_v], dim=1) | |
x = x + decoder_pos_embed | |
# add modality indication tokens | |
x[:, 0:self.patch_embed_a.num_patches, :] = x[:, 0:self.patch_embed_a.num_patches, :] + self.decoder_modality_a | |
x[:, self.patch_embed_a.num_patches:, :] = x[:, self.patch_embed_a.num_patches:, :] + self.decoder_modality_v | |
# apply Transformer blocks | |
for blk in self.decoder_blocks: | |
x = blk(x) | |
x = self.decoder_norm(x) | |
# predictor projection | |
x_a = self.decoder_pred_a(x[:, :self.patch_embed_a.num_patches, :]) | |
x_v = self.decoder_pred_v(x[:, self.patch_embed_a.num_patches:, :]) | |
# return audio and video tokens | |
return x_a, x_v | |
def forward_contrastive(self, audio_rep, video_rep, bidirect_contrast=False): | |
# calculate nce loss for mean-visual representation and mean-audio representation | |
audio_rep = torch.nn.functional.normalize(audio_rep, dim=-1) | |
video_rep = torch.nn.functional.normalize(video_rep, dim=-1) | |
total = torch.mm(audio_rep, torch.transpose(video_rep, 0, 1)) / 0.05 | |
# by default we use single directional | |
if bidirect_contrast == False: | |
nce = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0))) | |
c_acc = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0), | |
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] | |
return nce, c_acc | |
else: | |
nce_1 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0))) | |
nce_2 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total.t(), dim=0))) | |
c_acc_1 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0), | |
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] | |
c_acc_2 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total.t(), dim=0), dim=0), | |
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] | |
nce = (nce_1 + nce_2) / 2 | |
c_acc = (c_acc_1 + c_acc_2) / 2 | |
return nce, c_acc | |
def forward_mae_loss(self, input, pred, mask, modality): | |
if modality == 'a': | |
# for audio, need to adjust the shape | |
input = input.unsqueeze(1) | |
input = input.transpose(2, 3) | |
target = self.patchify(input, 1, int(input.shape[2] / self.patch_embed_a.patch_size[0]), | |
int(input.shape[3] / self.patch_embed_a.patch_size[1]), 16) | |
elif modality == 'v': | |
target = self.patchify(input, 3, int(input.shape[2] / self.patch_embed_v.patch_size[0]), | |
int(input.shape[3] / self.patch_embed_v.patch_size[1]), 16) | |
# patch-wise normalization might minorly improve the classification performance, but will make the model lose inpainting function | |
if self.norm_pix_loss: | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.e-6) ** .5 | |
loss = (pred - target) ** 2 | |
loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
return loss | |
def forward(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mae_loss_weight=1., contrast_loss_weight=0.01, | |
mask_mode='unstructured'): | |
# latent is used for reconstruction (mae), latent_c_{a,v} are used for contrastive learning | |
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs, | |
mask_ratio_a, | |
mask_ratio_v, | |
mask_mode=mask_mode) | |
# if mae loss is used | |
if mae_loss_weight != 0: | |
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) | |
loss_mae_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a') | |
loss_mae_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v') | |
loss_mae = mae_loss_weight * (loss_mae_a + loss_mae_v) | |
else: | |
loss_mae_a, loss_mae_v, loss_mae = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, | |
device=audio.device), torch.tensor( | |
0.0, device=audio.device) | |
# if contrastive loss is used | |
if contrast_loss_weight != 0: | |
# note this is single directional | |
loss_c, c_acc = self.forward_contrastive(latent_c_a.mean(dim=1), latent_c_v.mean(dim=1)) | |
loss_c = contrast_loss_weight * loss_c | |
else: | |
loss_c, c_acc = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device) | |
loss = loss_mae + loss_c | |
return loss, loss_mae, loss_mae_a, loss_mae_v, loss_c, mask_a, mask_v, c_acc | |
# used only for inpainting, ignore if inpainting is not of interest | |
def forward_inpaint(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mask_mode='unstructured'): | |
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs, | |
mask_ratio_a, | |
mask_ratio_v, | |
mask_mode=mask_mode) | |
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) # [N, L, p*p*3] | |
loss_pixel_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a') | |
loss_pixel_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v') | |
return pred_a, pred_v, mask_a, mask_v, loss_pixel_a, loss_pixel_v | |
# used for retrieval, ignore if retrieval is not of interest | |
def forward_feat(self, a, v): | |
# embed patches | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
# the modality-specific stream | |
for blk in self.blocks_a: | |
a = blk(a) | |
for blk in self.blocks_v: | |
v = blk(v) | |
# use modality specific normalization, | |
for blk in self.blocks_u: | |
a = blk(a, 'a') | |
a = self.norm_a(a) | |
for blk in self.blocks_u: | |
v = blk(v, 'v') | |
v = self.norm_v(v) | |
return a, v | |
def forward_audio(self, a): | |
# embed patches | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
# the modality-specific stream | |
for blk in self.blocks_a: | |
a = blk(a) | |
# use modality specific normalization, | |
for blk in self.blocks_u: | |
a = blk(a, 'a') | |
a = self.norm_a(a) | |
return a.reshape(a.shape[0], 128 // 16, 1024 // 16, 768).permute(0, 3, 1, 2) | |
def forward_video(self, v): | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
for blk in self.blocks_v: | |
v = blk(v) | |
for blk in self.blocks_u: | |
v = blk(v, 'v') | |
v = self.norm_v(v) | |
return v.reshape(v.shape[0], 224 // 16, 224 // 16, 768).permute(0, 3, 1, 2) | |
# the finetuned CAV-MAE model | |
class CAVMAEFT(nn.Module): | |
def __init__(self, label_dim, img_size=224, audio_length=1024, patch_size=16, in_chans=3, | |
embed_dim=768, modality_specific_depth=11, num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm, | |
norm_pix_loss=False, tr_pos=True): | |
super().__init__() | |
timm.models.vision_transformer.Block = Block | |
print('Use norm_pix_loss: ', norm_pix_loss) | |
timm.models.vision_transformer.PatchEmbed = PatchEmbed | |
timm.models.vision_transformer.Block = Block | |
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim) | |
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim) | |
self.patch_embed_a.num_patches = int(audio_length * 128 / 256) | |
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches, | |
self.patch_embed_v.num_patches)) | |
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim), | |
requires_grad=tr_pos) # fixed sin-cos embedding | |
self.blocks_a = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(modality_specific_depth)]) | |
self.blocks_v = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(modality_specific_depth)]) | |
self.blocks_u = nn.ModuleList( | |
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in | |
range(12 - modality_specific_depth)]) | |
self.norm_a = norm_layer(embed_dim) | |
self.norm_v = norm_layer(embed_dim) | |
self.norm = norm_layer(embed_dim) | |
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, label_dim)) | |
self.initialize_weights() | |
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape) | |
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape) | |
def get_patch_num(self, input_shape, stride): | |
test_input = torch.zeros(1, 1, input_shape[0], input_shape[1]) | |
test_proj = torch.nn.Conv2d(1, 4, kernel_size=(16, 16), stride=(stride, stride)) | |
test_output = test_proj(test_input) | |
print(test_output.shape) | |
return test_output.shape[2], test_output[3], test_output[2] * test_output[2] | |
def initialize_weights(self): | |
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8), | |
cls_token=False) | |
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0)) | |
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5), | |
int(self.patch_embed_v.num_patches ** .5), cls_token=False) | |
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0)) | |
w = self.patch_embed_a.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
w = self.patch_embed_v.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
torch.nn.init.normal_(self.modality_a, std=.02) | |
torch.nn.init.normal_(self.modality_v, std=.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, a, v, mode): | |
# multi-modal fine-tuning, our default method for fine-tuning | |
if mode == 'multimodal': | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
for blk in self.blocks_a: | |
a = blk(a) | |
for blk in self.blocks_v: | |
v = blk(v) | |
x = torch.cat((a, v), dim=1) | |
for blk in self.blocks_u: | |
x = blk(x) | |
x = self.norm(x) | |
x = x.mean(dim=1) | |
x = self.mlp_head(x) | |
return x | |
# finetune with only audio (and inference with only audio when the model is finetuned with only audio) | |
elif mode == 'audioonly': | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
for blk in self.blocks_a: | |
a = blk(a) | |
# note here uses the 'a' normalization, it is used in both training and inference, so it is fine | |
for blk in self.blocks_u: | |
a = blk(a, 'a') | |
a = self.norm_a(a) | |
x = a.mean(dim=1) | |
x = self.mlp_head(x) | |
return x | |
# finetune with only image (and inference with only audio when the model is finetuned with only image) | |
elif mode == 'videoonly': | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
for blk in self.blocks_v: | |
v = blk(v) | |
# note here uses the 'v' normalization, it is used in both training and inference, so it is fine | |
for blk in self.blocks_u: | |
v = blk(v, 'v') | |
v = self.norm_v(v) | |
x = v.mean(dim=1) | |
x = self.mlp_head(x) | |
return x | |
# used in case that the model is finetuned with both modality, but in inference only audio is given | |
elif mode == 'missingaudioonly': | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
for blk in self.blocks_a: | |
a = blk(a) | |
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization | |
u = a | |
for blk in self.blocks_u: | |
u = blk(u) # note here use unified normalization | |
u = self.norm(u) | |
u = u.mean(dim=1) | |
for blk in self.blocks_u: | |
a = blk(a, 'a') # note here use modality-specific normalization | |
a = self.norm_a(a) | |
a = a.mean(dim=1) | |
# average the output of the two forward passes | |
x = (u + a) / 2 | |
x = self.mlp_head(x) | |
return x | |
# used in case that the model is fine-tuned with both modality, but in inference only image is given | |
elif mode == 'missingvideoonly': | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
for blk in self.blocks_v: | |
v = blk(v) | |
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization | |
u = v | |
for blk in self.blocks_u: | |
u = blk(u) # note here use unified normalization | |
u = self.norm(u) | |
u = u.mean(dim=1) | |
for blk in self.blocks_u: | |
v = blk(v, 'v') # note here use modality-specific normalization | |
v = self.norm_v(v) | |
v = v.mean(dim=1) | |
# average the output of the two forward passes | |
x = (u + v) / 2 | |
x = self.mlp_head(x) | |
return x | |
# for retrieval | |
def forward_feat(self, a, v, mode='av'): | |
# return both audio and visual | |
if mode == 'av': | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
v = self.patch_embed_v(v) | |
v = v + self.pos_embed_v | |
v = v + self.modality_v | |
for blk in self.blocks_a: | |
a = blk(a) | |
for blk in self.blocks_v: | |
v = blk(v) | |
for blk in self.blocks_u: | |
a = blk(a, 'a') | |
a = self.norm_a(a) | |
for blk in self.blocks_u: | |
v = blk(v, 'v') | |
v = self.norm_v(v) | |
return a, v | |
# return only audio | |
if mode == 'a': | |
a = a.unsqueeze(1) | |
a = a.transpose(2, 3) | |
a = self.patch_embed_a(a) | |
a = a + self.pos_embed_a | |
a = a + self.modality_a | |
for blk in self.blocks_a: | |
a = blk(a) | |
for blk in self.blocks_u: | |
a = blk(a, 'a') | |
a = self.norm_a(a) | |
return a | |
def _wav2fbank(filename): | |
waveform, sr = torchaudio.load(filename) | |
waveform = torchaudio.functional.resample( | |
waveform, orig_freq=sr, new_freq=16000 | |
) | |
waveform = waveform - waveform.mean() | |
waveform | |
print(sr) | |
fbank = torchaudio.compliance.kaldi.fbank( | |
waveform, | |
htk_compat=True, | |
sample_frequency=sr, | |
use_energy=False, | |
window_type='hanning', | |
num_mel_bins=128, | |
dither=0.0, | |
frame_shift=10) | |
target_length = 1024 | |
n_frames = fbank.shape[0] | |
p = target_length - n_frames | |
# cut and pad | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
fbank = m(fbank) | |
elif p < 0: | |
fbank = fbank[0:target_length, :] | |
return fbank | |
def pca(image_feats_list, dim=3, fit_pca=None): | |
from sklearn.decomposition import PCA | |
device = image_feats_list[0].device | |
def flatten(tensor, target_size=None): | |
if target_size is not None and fit_pca is None: | |
F.interpolate(tensor, (target_size, target_size), mode="bilinear") | |
B, C, H, W = tensor.shape | |
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() | |
if len(image_feats_list) > 1 and fit_pca is None: | |
target_size = image_feats_list[0].shape[2] | |
else: | |
target_size = None | |
flattened_feats = [] | |
for feats in image_feats_list: | |
flattened_feats.append(flatten(feats, target_size)) | |
x = torch.cat(flattened_feats, dim=0) | |
if fit_pca is None: | |
fit_pca = PCA(n_components=dim).fit(x) | |
reduced_feats = [] | |
for feats in image_feats_list: | |
x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) | |
x_red -= x_red.min(dim=0, keepdim=True).values | |
x_red /= x_red.max(dim=0, keepdim=True).values | |
B, C, H, W = feats.shape | |
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) | |
return reduced_feats, fit_pca | |
class CAVMAEAudioFeaturizer(nn.Module): | |
def __init__(self, output_path, model_name="base", model=None): | |
super().__init__() | |
if model is not None: | |
self.model = model | |
else: | |
if model_name == "base": | |
model_path = os.path.join(output_path, 'models/audio_model.21.pth') | |
else: | |
raise ValueError(f"Unknown model type {model_name}") | |
audio_model = CAVMAE( | |
audio_length=1024, | |
modality_specific_depth=11, | |
norm_pix_loss=True, | |
tr_pos=False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mdl_weight = torch.load(model_path, map_location=device) | |
audio_model = torch.nn.DataParallel(audio_model) | |
audio_model.load_state_dict(mdl_weight, strict=True) | |
self.model = audio_model.module.cuda() | |
def forward(self, audio, include_cls): | |
cls_token = None | |
patch_tokens = self.model.forward_audio(audio.squeeze(1)) | |
if include_cls: | |
return patch_tokens, cls_token | |
else: | |
return patch_tokens | |
class CAVMAEImageFeaturizer(nn.Module): | |
def __init__(self, output_path, model=None, model_name="base"): | |
super().__init__() | |
if model is not None: | |
self.model: CAVMAE = model | |
else: | |
if model_name == "base": | |
model_path = os.path.join(output_path, 'models/audio_model.21.pth') | |
else: | |
raise ValueError(f"Unknown model type {model_name}") | |
audio_model = CAVMAE( | |
audio_length=1024, | |
modality_specific_depth=11, | |
norm_pix_loss=True, | |
tr_pos=False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mdl_weight = torch.load(model_path, map_location=device) | |
audio_model = torch.nn.DataParallel(audio_model) | |
audio_model.load_state_dict(mdl_weight, strict=True) | |
self.model: CAVMAE = audio_model.module.cuda() | |
def forward(self, image, include_cls): | |
cls_token = None | |
patch_tokens = self.model.forward_video(image) | |
if include_cls: | |
return patch_tokens, cls_token | |
else: | |
return patch_tokens | |
if __name__ == "__main__": | |
model_path = os.path.join("../../", 'models/audio_model.21.pth') | |
audio_model = CAVMAE( | |
audio_length=1024, | |
modality_specific_depth=11, | |
norm_pix_loss=True, | |
tr_pos=False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mdl_weight = torch.load(model_path, map_location=device) | |
audio_model = torch.nn.DataParallel(audio_model) | |
audio_model.load_state_dict(mdl_weight, strict=True) | |
model: CAVMAE = audio_model.module.cuda() | |
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"] | |
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"] | |
images = [] | |
audios = [] | |
for image_path in image_paths: | |
image = Image.open(image_path).convert("RGB") | |
preprocess = T.Compose([ | |
T.Resize(224, interpolation=Image.BICUBIC), | |
T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize( | |
mean=[0.4850, 0.4560, 0.4060], | |
std=[0.2290, 0.2240, 0.2250] | |
)]) | |
images.append(preprocess(image).unsqueeze(0).cuda()) | |
for audio_path in audio_paths: | |
a = _wav2fbank(audio_path).cuda().unsqueeze(0) | |
a = (a + 5.081) / (4.4849) | |
audios.append(a) | |
audio_feats, image_feats = model.forward_feat( | |
torch.cat(audios, dim=0), torch.cat(images, dim=0)) | |
audio_feats = F.normalize(audio_feats.mean(1), dim=1) | |
image_feats = F.normalize(image_feats.mean(1), dim=1) | |
sims = torch.einsum("bc,dc->bd", image_feats, audio_feats) | |
print(sims) | |
print("here") | |
# a_feat = F.normalize(a_feat, dim=1) | |
# v_feat = F.normalize(v_feat, dim=1) | |
# [red_v_feat, red_a_feat], fit_pca = pca([v_feat, a_feat]) | |
# | |
# [red_v_feat], fit_pca = pca([v_feat]) | |
# [red_a_feat], fit_pca = pca([a_feat]) | |
# | |
# import matplotlib.pyplot as plt | |
# | |
# fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5)) | |
# ax[0].imshow(red_v_feat[0].permute(1, 2, 0).cpu()) | |
# ax[1].imshow(red_a_feat[0].permute(1, 2, 0).cpu()) | |
# plt.tight_layout() | |
# plt.show() | |
# print("here") | |