James Zhou
[init]
9867d34
import logging
import math
from typing import Any, Mapping
import einops
import numpy as np
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F
from .motionformer import MotionFormer
from .ast_model import AST
from .utils import Config
class Synchformer(nn.Module):
def __init__(self):
super().__init__()
self.vfeat_extractor = MotionFormer(
extract_features=True,
factorize_space_time=True,
agg_space_module="TransformerEncoderLayer",
agg_time_module="torch.nn.Identity",
add_global_repr=False,
)
self.afeat_extractor = AST(
extract_features=True,
max_spec_t=66,
factorize_freq_time=True,
agg_freq_module="TransformerEncoderLayer",
agg_time_module="torch.nn.Identity",
add_global_repr=False,
)
# # bridging the s3d latent dim (1024) into what is specified in the config
# # to match e.g. the transformer dim
self.vproj = nn.Linear(in_features=768, out_features=768)
self.aproj = nn.Linear(in_features=768, out_features=768)
self.transformer = GlobalTransformer(
tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768
)
def forward(self, vis):
B, S, Tv, C, H, W = vis.shape
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
# feat extractors return a tuple of segment-level and global features (ignored for sync)
# (B, S, tv, D), e.g. (B, 7, 8, 768)
vis = self.vfeat_extractor(vis)
return vis
def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
vis = self.vproj(vis)
aud = self.aproj(aud)
B, S, tv, D = vis.shape
B, S, ta, D = aud.shape
vis = vis.view(B, S * tv, D) # (B, S*tv, D)
aud = aud.view(B, S * ta, D) # (B, S*ta, D)
# print(vis.shape, aud.shape)
# self.transformer will concatenate the vis and aud in one sequence with aux tokens,
# ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens
logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer
return logits
def extract_vfeats(self, vis):
B, S, Tv, C, H, W = vis.shape
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
# feat extractors return a tuple of segment-level and global features (ignored for sync)
# (B, S, tv, D), e.g. (B, 7, 8, 768)
vis = self.vfeat_extractor(vis)
return vis
def extract_afeats(self, aud):
B, S, _, Fa, Ta = aud.shape
aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F)
# (B, S, ta, D), e.g. (B, 7, 6, 768)
aud, _ = self.afeat_extractor(aud)
return aud
def compute_loss(self, logits, targets, loss_fn: str = None):
loss = None
if targets is not None:
if loss_fn is None or loss_fn == "cross_entropy":
# logits: (B, cls) and targets: (B,)
loss = F.cross_entropy(logits, targets)
else:
raise NotImplementedError(f"Loss {loss_fn} not implemented")
return loss
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
# discard all entries except vfeat_extractor
# sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
return super().load_state_dict(sd, strict)
class RandInitPositionalEncoding(nn.Module):
"""Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors."""
def __init__(self, block_shape: list, n_embd: int):
super().__init__()
self.block_shape = block_shape
self.n_embd = n_embd
self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd))
def forward(self, token_embeddings):
return token_embeddings + self.pos_emb
class GlobalTransformer(torch.nn.Module):
"""Same as in SparseSync but without the selector transformers and the head"""
def __init__(
self,
tok_pdrop=0.0,
embd_pdrop=0.1,
resid_pdrop=0.1,
attn_pdrop=0.1,
n_layer=3,
n_head=8,
n_embd=768,
pos_emb_block_shape=[
198,
],
n_off_head_out=21,
) -> None:
super().__init__()
self.config = Config(
embd_pdrop=embd_pdrop,
resid_pdrop=resid_pdrop,
attn_pdrop=attn_pdrop,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
)
# input norm
self.vis_in_lnorm = torch.nn.LayerNorm(n_embd)
self.aud_in_lnorm = torch.nn.LayerNorm(n_embd)
# aux tokens
self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
# whole token dropout
self.tok_pdrop = tok_pdrop
self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
# maybe add pos emb
self.pos_emb_cfg = RandInitPositionalEncoding(
block_shape=pos_emb_block_shape,
n_embd=n_embd,
)
# the stem
self.drop = torch.nn.Dropout(embd_pdrop)
self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)])
# pre-output norm
self.ln_f = torch.nn.LayerNorm(n_embd)
# maybe add a head
self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out)
def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
B, Sv, D = v.shape
B, Sa, D = a.shape
# broadcasting special tokens to the batch size
off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
# norm
v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
# maybe whole token dropout
if self.tok_pdrop > 0:
v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
# (B, 1+Sv+1+Sa, D)
x = torch.cat((off_tok, v, mod_tok, a), dim=1)
# maybe add pos emb
if hasattr(self, "pos_emb_cfg"):
x = self.pos_emb_cfg(x)
# dropout -> stem -> norm
x = self.drop(x)
x = self.blocks(x)
x = self.ln_f(x)
# maybe add heads
if attempt_to_apply_heads and hasattr(self, "off_head"):
x = self.off_head(x[:, 0, :])
return x
class SelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# # causal mask to ensure that attention is only applied to the left in the input sequence
# mask = torch.tril(torch.ones(config.block_size,
# config.block_size))
# if hasattr(config, "n_unmasked"):
# mask[:config.n_unmasked, :config.n_unmasked] = 1
# self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
"""an unassuming Transformer block"""
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = SelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
def make_class_grid(
leftmost_val,
rightmost_val,
grid_size,
add_extreme_offset: bool = False,
seg_size_vframes: int = None,
nseg: int = None,
step_size_seg: float = None,
vfps: float = None,
):
assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()"
grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
if add_extreme_offset:
assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}"
seg_size_sec = seg_size_vframes / vfps
trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
extreme_value = trim_size_in_seg * seg_size_sec
grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
return grid
# from synchformer
def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0):
difference = max_spec_t - audio.shape[-1] # safe for batched input
# pad or truncate, depending on difference
if difference > 0:
# pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
pad_dims = (0, difference)
audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value)
elif difference < 0:
print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).")
audio = audio[..., :max_spec_t] # safe for batched input
return audio
def encode_audio_with_sync(
synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram
) -> torch.Tensor:
b, t = x.shape
# partition the video
segment_size = 10240
step_size = 10240 // 2
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size : i * step_size + segment_size])
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
x = mel(x)
x = torch.log(x + 1e-6)
x = pad_or_truncate(x, 66)
mean = -4.2677393
std = 4.5689974
x = (x - mean) / (2 * std)
# x: B * S * 128 * 66
x = synchformer.extract_afeats(x.unsqueeze(2))
return x
def read_audio(filename, expected_length=int(16000 * 4)):
waveform, sr = torchaudio.load(filename)
waveform = waveform.mean(dim=0)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler[sr](waveform)
waveform = waveform[:expected_length]
if waveform.shape[0] != expected_length:
raise ValueError(f"Audio {filename} is too short")
waveform = waveform.squeeze()
return waveform
if __name__ == "__main__":
synchformer = Synchformer().cuda().eval()
# mmaudio provided synchformer ckpt
synchformer.load_state_dict(
torch.load(
os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
weights_only=True,
map_location="cpu",
)
)
sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=16000,
win_length=400,
hop_length=160,
n_fft=1024,
n_mels=128,
)