lorocksUMD's picture
Upload 32 files
e6d4b46 verified
import random
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchvision.transforms as T
from PIL import Image
from timm.models.layers import to_2tuple, DropPath
from timm.models.vision_transformer import Mlp, PatchEmbed, Block
import os
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_h_size, dtype=float)
grid_w = np.arange(grid_w_size, dtype=float)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_w_size, grid_h_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm1_a = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm2_a = norm_layer(dim)
self.norm2_v = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, modality=None):
if modality == None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
elif modality == 'a':
x = x + self.drop_path(self.attn(self.norm1_a(x)))
x = x + self.drop_path(self.mlp(self.norm2_a(x)))
elif modality == 'v':
x = x + self.drop_path(self.attn(self.norm1_v(x)))
x = x + self.drop_path(self.mlp(self.norm2_v(x)))
return x
# our main proposed model, for pretraining only, for finetuning, use CAVMAEFT class
class CAVMAE(nn.Module):
""" CAV-MAE Model
"""
def __init__(self, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
embed_dim=768, modality_specific_depth=11, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=False):
super().__init__()
print('A CAV-MAE Model')
print('Use norm_pix_loss: ', norm_pix_loss)
print('Learnable Positional Embedding: ', tr_pos)
# the encoder part
# overide the timm package
timm.models.vision_transformer.PatchEmbed = PatchEmbed
timm.models.vision_transformer.Block = Block
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
self.patch_embed_v.num_patches))
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
# audio-branch
self.blocks_a = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(modality_specific_depth)])
# visual-branch
self.blocks_v = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(modality_specific_depth)])
# unified branch
self.blocks_u = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(12 - modality_specific_depth)])
# independent normalization layer for audio, visual, and audio-visual
self.norm_a, self.norm_v, self.norm = norm_layer(embed_dim), norm_layer(embed_dim), norm_layer(embed_dim)
# the decoder part
# Project to lower dimension for the decoder
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
# token used for masking
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_modality_a = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_modality_v = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, decoder_embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
self.decoder_pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, decoder_embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList(
[Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
# project channel is different for two modality, use two projection head
self.decoder_pred_a = nn.Linear(decoder_embed_dim, patch_size ** 2 * 1, bias=True) # decoder to patch
self.decoder_pred_v = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
def initialize_weights(self):
# initialize (and freeze) pos_embed by sin-cos embedding, opt the cls token, add by myself
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
cls_token=False)
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
decoder_pos_embed_a = get_2d_sincos_pos_embed(self.decoder_pos_embed_a.shape[-1], 8,
int(self.patch_embed_a.num_patches / 8), cls_token=False)
self.decoder_pos_embed_a.data.copy_(torch.from_numpy(decoder_pos_embed_a).float().unsqueeze(0))
decoder_pos_embed_v = get_2d_sincos_pos_embed(self.decoder_pos_embed_v.shape[-1],
int(self.patch_embed_v.num_patches ** .5),
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
self.decoder_pos_embed_v.data.copy_(torch.from_numpy(decoder_pos_embed_v).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed_a.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
w = self.patch_embed_v.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.modality_a, std=.02)
torch.nn.init.normal_(self.modality_v, std=.02)
torch.nn.init.normal_(self.decoder_modality_a, std=.02)
torch.nn.init.normal_(self.decoder_modality_v, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs, c, h, w, p=16):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * c))
return x
def unpatchify(self, x, c, h, w, p=16):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def random_masking_unstructured(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def random_masking_structured(self, x, mask_ratio, t=64, f=8, mode='time'):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
assert L == f * t
noise = noise.reshape(N, f, t) # the audio patch is in shape [f,t], not [t,f]
if mode == 'time':
for i in range(N):
mask_t_list = random.sample(range(t), int(t * mask_ratio))
for k in mask_t_list:
noise[i, :, k] = 1.1 # large value will be removed
elif mode == 'freq':
for i in range(N):
mask_f_list = random.sample(range(f), int(f * mask_ratio))
for k in mask_f_list:
noise[i, k, :] = 1.1 # large value will be removed
elif mode == 'tf':
for i in range(N):
mask_t_list = random.sample(range(t), int(t * mask_ratio * 0.7))
for k in mask_t_list:
noise[i, :, k] = 1.1 # large value will be removed
for i in range(N):
mask_f_list = random.sample(range(f), int(f * mask_ratio * 0.7))
for k in mask_f_list:
noise[i, k, :] = 1.1 # large value will be removed
noise = noise.reshape(N, L)
# sort noise for each sample, only need to manuplate these two ids_shuffle, ids_restore
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, a, v, mask_ratio_a, mask_ratio_v, mask_mode='unstructured'):
# embed patches
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
# by default, we always use unstructured masking
if mask_mode == 'unstructured':
a, mask_a, ids_restore_a = self.random_masking_unstructured(a, mask_ratio_a)
# in ablation study, we tried time/freq/tf masking. mode in ['freq', 'time', 'tf']
else:
a, mask_a, ids_restore_a = self.random_masking_structured(a, mask_ratio_a, t=64, f=8, mode=mask_mode)
# visual branch always use unstructured masking
v, mask_v, ids_restore_v = self.random_masking_unstructured(v, mask_ratio_v)
# audio and visual stream, independent blocks
for blk in self.blocks_a:
a = blk(a)
for blk in self.blocks_v:
v = blk(v)
x = torch.cat((a, v), dim=1)
# unified stream, shared blocks_u, but independent normalization layers
for blk in self.blocks_u:
x = blk(x)
x = self.norm(x)
for blk in self.blocks_u:
ca = blk(a, 'a')
ca = self.norm_a(ca)
for blk in self.blocks_u:
cv = blk(v, 'v')
cv = self.norm_v(cv)
return x, mask_a, ids_restore_a, mask_v, ids_restore_v, ca, cv
def forward_decoder(self, x, mask_a, ids_restore_a, mask_v, ids_restore_v):
x = self.decoder_embed(x)
# append mask tokens to sequence
# mask_tokens_a in shape [B, #a_mask_token, mask_token_dim], get the number of masked samples from mask_a[0], which is the first example of the batch, all samples should have same number of masked tokens
mask_tokens_a = self.mask_token.repeat(x.shape[0], int(mask_a[0].sum()), 1)
a_ = torch.cat([x[:, :self.patch_embed_a.num_patches - int(mask_a[0].sum()), :], mask_tokens_a],
dim=1) # no cls token
a_ = torch.gather(a_, dim=1, index=ids_restore_a.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# similar for the visual modality
mask_tokens_v = self.mask_token.repeat(x.shape[0], int(mask_v[0].sum()), 1)
v_ = torch.cat([x[:, self.patch_embed_a.num_patches - int(mask_a[0].sum()):, :], mask_tokens_v],
dim=1) # no cls token
v_ = torch.gather(v_, dim=1, index=ids_restore_v.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# concatenate audio and visual tokens
x = torch.cat([a_, v_], dim=1)
decoder_pos_embed = torch.cat([self.decoder_pos_embed_a, self.decoder_pos_embed_v], dim=1)
x = x + decoder_pos_embed
# add modality indication tokens
x[:, 0:self.patch_embed_a.num_patches, :] = x[:, 0:self.patch_embed_a.num_patches, :] + self.decoder_modality_a
x[:, self.patch_embed_a.num_patches:, :] = x[:, self.patch_embed_a.num_patches:, :] + self.decoder_modality_v
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x_a = self.decoder_pred_a(x[:, :self.patch_embed_a.num_patches, :])
x_v = self.decoder_pred_v(x[:, self.patch_embed_a.num_patches:, :])
# return audio and video tokens
return x_a, x_v
def forward_contrastive(self, audio_rep, video_rep, bidirect_contrast=False):
# calculate nce loss for mean-visual representation and mean-audio representation
audio_rep = torch.nn.functional.normalize(audio_rep, dim=-1)
video_rep = torch.nn.functional.normalize(video_rep, dim=-1)
total = torch.mm(audio_rep, torch.transpose(video_rep, 0, 1)) / 0.05
# by default we use single directional
if bidirect_contrast == False:
nce = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
c_acc = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
return nce, c_acc
else:
nce_1 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
nce_2 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total.t(), dim=0)))
c_acc_1 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
c_acc_2 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total.t(), dim=0), dim=0),
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
nce = (nce_1 + nce_2) / 2
c_acc = (c_acc_1 + c_acc_2) / 2
return nce, c_acc
def forward_mae_loss(self, input, pred, mask, modality):
if modality == 'a':
# for audio, need to adjust the shape
input = input.unsqueeze(1)
input = input.transpose(2, 3)
target = self.patchify(input, 1, int(input.shape[2] / self.patch_embed_a.patch_size[0]),
int(input.shape[3] / self.patch_embed_a.patch_size[1]), 16)
elif modality == 'v':
target = self.patchify(input, 3, int(input.shape[2] / self.patch_embed_v.patch_size[0]),
int(input.shape[3] / self.patch_embed_v.patch_size[1]), 16)
# patch-wise normalization might minorly improve the classification performance, but will make the model lose inpainting function
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6) ** .5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mae_loss_weight=1., contrast_loss_weight=0.01,
mask_mode='unstructured'):
# latent is used for reconstruction (mae), latent_c_{a,v} are used for contrastive learning
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
mask_ratio_a,
mask_ratio_v,
mask_mode=mask_mode)
# if mae loss is used
if mae_loss_weight != 0:
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v)
loss_mae_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
loss_mae_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
loss_mae = mae_loss_weight * (loss_mae_a + loss_mae_v)
else:
loss_mae_a, loss_mae_v, loss_mae = torch.tensor(0.0, device=audio.device), torch.tensor(0.0,
device=audio.device), torch.tensor(
0.0, device=audio.device)
# if contrastive loss is used
if contrast_loss_weight != 0:
# note this is single directional
loss_c, c_acc = self.forward_contrastive(latent_c_a.mean(dim=1), latent_c_v.mean(dim=1))
loss_c = contrast_loss_weight * loss_c
else:
loss_c, c_acc = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device)
loss = loss_mae + loss_c
return loss, loss_mae, loss_mae_a, loss_mae_v, loss_c, mask_a, mask_v, c_acc
# used only for inpainting, ignore if inpainting is not of interest
def forward_inpaint(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mask_mode='unstructured'):
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
mask_ratio_a,
mask_ratio_v,
mask_mode=mask_mode)
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) # [N, L, p*p*3]
loss_pixel_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
loss_pixel_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
return pred_a, pred_v, mask_a, mask_v, loss_pixel_a, loss_pixel_v
# used for retrieval, ignore if retrieval is not of interest
def forward_feat(self, a, v):
# embed patches
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
# the modality-specific stream
for blk in self.blocks_a:
a = blk(a)
for blk in self.blocks_v:
v = blk(v)
# use modality specific normalization,
for blk in self.blocks_u:
a = blk(a, 'a')
a = self.norm_a(a)
for blk in self.blocks_u:
v = blk(v, 'v')
v = self.norm_v(v)
return a, v
def forward_audio(self, a):
# embed patches
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
# the modality-specific stream
for blk in self.blocks_a:
a = blk(a)
# use modality specific normalization,
for blk in self.blocks_u:
a = blk(a, 'a')
a = self.norm_a(a)
return a.reshape(a.shape[0], 128 // 16, 1024 // 16, 768).permute(0, 3, 1, 2)
def forward_video(self, v):
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
for blk in self.blocks_v:
v = blk(v)
for blk in self.blocks_u:
v = blk(v, 'v')
v = self.norm_v(v)
return v.reshape(v.shape[0], 224 // 16, 224 // 16, 768).permute(0, 3, 1, 2)
# the finetuned CAV-MAE model
class CAVMAEFT(nn.Module):
def __init__(self, label_dim, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
embed_dim=768, modality_specific_depth=11, num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm,
norm_pix_loss=False, tr_pos=True):
super().__init__()
timm.models.vision_transformer.Block = Block
print('Use norm_pix_loss: ', norm_pix_loss)
timm.models.vision_transformer.PatchEmbed = PatchEmbed
timm.models.vision_transformer.Block = Block
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
self.patch_embed_v.num_patches))
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
requires_grad=tr_pos) # fixed sin-cos embedding
self.blocks_a = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(modality_specific_depth)])
self.blocks_v = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(modality_specific_depth)])
self.blocks_u = nn.ModuleList(
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
range(12 - modality_specific_depth)])
self.norm_a = norm_layer(embed_dim)
self.norm_v = norm_layer(embed_dim)
self.norm = norm_layer(embed_dim)
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, label_dim))
self.initialize_weights()
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
def get_patch_num(self, input_shape, stride):
test_input = torch.zeros(1, 1, input_shape[0], input_shape[1])
test_proj = torch.nn.Conv2d(1, 4, kernel_size=(16, 16), stride=(stride, stride))
test_output = test_proj(test_input)
print(test_output.shape)
return test_output.shape[2], test_output[3], test_output[2] * test_output[2]
def initialize_weights(self):
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
cls_token=False)
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
w = self.patch_embed_a.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
w = self.patch_embed_v.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.modality_a, std=.02)
torch.nn.init.normal_(self.modality_v, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, a, v, mode):
# multi-modal fine-tuning, our default method for fine-tuning
if mode == 'multimodal':
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
for blk in self.blocks_a:
a = blk(a)
for blk in self.blocks_v:
v = blk(v)
x = torch.cat((a, v), dim=1)
for blk in self.blocks_u:
x = blk(x)
x = self.norm(x)
x = x.mean(dim=1)
x = self.mlp_head(x)
return x
# finetune with only audio (and inference with only audio when the model is finetuned with only audio)
elif mode == 'audioonly':
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
for blk in self.blocks_a:
a = blk(a)
# note here uses the 'a' normalization, it is used in both training and inference, so it is fine
for blk in self.blocks_u:
a = blk(a, 'a')
a = self.norm_a(a)
x = a.mean(dim=1)
x = self.mlp_head(x)
return x
# finetune with only image (and inference with only audio when the model is finetuned with only image)
elif mode == 'videoonly':
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
for blk in self.blocks_v:
v = blk(v)
# note here uses the 'v' normalization, it is used in both training and inference, so it is fine
for blk in self.blocks_u:
v = blk(v, 'v')
v = self.norm_v(v)
x = v.mean(dim=1)
x = self.mlp_head(x)
return x
# used in case that the model is finetuned with both modality, but in inference only audio is given
elif mode == 'missingaudioonly':
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
for blk in self.blocks_a:
a = blk(a)
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
u = a
for blk in self.blocks_u:
u = blk(u) # note here use unified normalization
u = self.norm(u)
u = u.mean(dim=1)
for blk in self.blocks_u:
a = blk(a, 'a') # note here use modality-specific normalization
a = self.norm_a(a)
a = a.mean(dim=1)
# average the output of the two forward passes
x = (u + a) / 2
x = self.mlp_head(x)
return x
# used in case that the model is fine-tuned with both modality, but in inference only image is given
elif mode == 'missingvideoonly':
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
for blk in self.blocks_v:
v = blk(v)
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
u = v
for blk in self.blocks_u:
u = blk(u) # note here use unified normalization
u = self.norm(u)
u = u.mean(dim=1)
for blk in self.blocks_u:
v = blk(v, 'v') # note here use modality-specific normalization
v = self.norm_v(v)
v = v.mean(dim=1)
# average the output of the two forward passes
x = (u + v) / 2
x = self.mlp_head(x)
return x
# for retrieval
def forward_feat(self, a, v, mode='av'):
# return both audio and visual
if mode == 'av':
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
v = self.patch_embed_v(v)
v = v + self.pos_embed_v
v = v + self.modality_v
for blk in self.blocks_a:
a = blk(a)
for blk in self.blocks_v:
v = blk(v)
for blk in self.blocks_u:
a = blk(a, 'a')
a = self.norm_a(a)
for blk in self.blocks_u:
v = blk(v, 'v')
v = self.norm_v(v)
return a, v
# return only audio
if mode == 'a':
a = a.unsqueeze(1)
a = a.transpose(2, 3)
a = self.patch_embed_a(a)
a = a + self.pos_embed_a
a = a + self.modality_a
for blk in self.blocks_a:
a = blk(a)
for blk in self.blocks_u:
a = blk(a, 'a')
a = self.norm_a(a)
return a
def _wav2fbank(filename):
waveform, sr = torchaudio.load(filename)
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=16000
)
waveform = waveform - waveform.mean()
waveform
print(sr)
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sr,
use_energy=False,
window_type='hanning',
num_mel_bins=128,
dither=0.0,
frame_shift=10)
target_length = 1024
n_frames = fbank.shape[0]
p = target_length - n_frames
# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]
return fbank
def pca(image_feats_list, dim=3, fit_pca=None):
from sklearn.decomposition import PCA
device = image_feats_list[0].device
def flatten(tensor, target_size=None):
if target_size is not None and fit_pca is None:
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
B, C, H, W = tensor.shape
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
if len(image_feats_list) > 1 and fit_pca is None:
target_size = image_feats_list[0].shape[2]
else:
target_size = None
flattened_feats = []
for feats in image_feats_list:
flattened_feats.append(flatten(feats, target_size))
x = torch.cat(flattened_feats, dim=0)
if fit_pca is None:
fit_pca = PCA(n_components=dim).fit(x)
reduced_feats = []
for feats in image_feats_list:
x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
x_red -= x_red.min(dim=0, keepdim=True).values
x_red /= x_red.max(dim=0, keepdim=True).values
B, C, H, W = feats.shape
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
return reduced_feats, fit_pca
class CAVMAEAudioFeaturizer(nn.Module):
def __init__(self, output_path, model_name="base", model=None):
super().__init__()
if model is not None:
self.model = model
else:
if model_name == "base":
model_path = os.path.join(output_path, 'models/audio_model.21.pth')
else:
raise ValueError(f"Unknown model type {model_name}")
audio_model = CAVMAE(
audio_length=1024,
modality_specific_depth=11,
norm_pix_loss=True,
tr_pos=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl_weight = torch.load(model_path, map_location=device)
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(mdl_weight, strict=True)
self.model = audio_model.module.cuda()
def forward(self, audio, include_cls):
cls_token = None
patch_tokens = self.model.forward_audio(audio.squeeze(1))
if include_cls:
return patch_tokens, cls_token
else:
return patch_tokens
class CAVMAEImageFeaturizer(nn.Module):
def __init__(self, output_path, model=None, model_name="base"):
super().__init__()
if model is not None:
self.model: CAVMAE = model
else:
if model_name == "base":
model_path = os.path.join(output_path, 'models/audio_model.21.pth')
else:
raise ValueError(f"Unknown model type {model_name}")
audio_model = CAVMAE(
audio_length=1024,
modality_specific_depth=11,
norm_pix_loss=True,
tr_pos=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl_weight = torch.load(model_path, map_location=device)
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(mdl_weight, strict=True)
self.model: CAVMAE = audio_model.module.cuda()
def forward(self, image, include_cls):
cls_token = None
patch_tokens = self.model.forward_video(image)
if include_cls:
return patch_tokens, cls_token
else:
return patch_tokens
if __name__ == "__main__":
model_path = os.path.join("../../", 'models/audio_model.21.pth')
audio_model = CAVMAE(
audio_length=1024,
modality_specific_depth=11,
norm_pix_loss=True,
tr_pos=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl_weight = torch.load(model_path, map_location=device)
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(mdl_weight, strict=True)
model: CAVMAE = audio_model.module.cuda()
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
images = []
audios = []
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
preprocess = T.Compose([
T.Resize(224, interpolation=Image.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.4850, 0.4560, 0.4060],
std=[0.2290, 0.2240, 0.2250]
)])
images.append(preprocess(image).unsqueeze(0).cuda())
for audio_path in audio_paths:
a = _wav2fbank(audio_path).cuda().unsqueeze(0)
a = (a + 5.081) / (4.4849)
audios.append(a)
audio_feats, image_feats = model.forward_feat(
torch.cat(audios, dim=0), torch.cat(images, dim=0))
audio_feats = F.normalize(audio_feats.mean(1), dim=1)
image_feats = F.normalize(image_feats.mean(1), dim=1)
sims = torch.einsum("bc,dc->bd", image_feats, audio_feats)
print(sims)
print("here")
# a_feat = F.normalize(a_feat, dim=1)
# v_feat = F.normalize(v_feat, dim=1)
# [red_v_feat, red_a_feat], fit_pca = pca([v_feat, a_feat])
#
# [red_v_feat], fit_pca = pca([v_feat])
# [red_a_feat], fit_pca = pca([a_feat])
#
# import matplotlib.pyplot as plt
#
# fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5))
# ax[0].imshow(red_v_feat[0].permute(1, 2, 0).cpu())
# ax[1].imshow(red_a_feat[0].permute(1, 2, 0).cpu())
# plt.tight_layout()
# plt.show()
# print("here")