Spaces:
Runtime error
Runtime error
import logging | |
import math | |
from typing import Any, Mapping | |
import einops | |
import numpy as np | |
import torch | |
import torchaudio | |
from torch import nn | |
from torch.nn import functional as F | |
from .motionformer import MotionFormer | |
from .ast_model import AST | |
from .utils import Config | |
class Synchformer(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.vfeat_extractor = MotionFormer( | |
extract_features=True, | |
factorize_space_time=True, | |
agg_space_module="TransformerEncoderLayer", | |
agg_time_module="torch.nn.Identity", | |
add_global_repr=False, | |
) | |
self.afeat_extractor = AST( | |
extract_features=True, | |
max_spec_t=66, | |
factorize_freq_time=True, | |
agg_freq_module="TransformerEncoderLayer", | |
agg_time_module="torch.nn.Identity", | |
add_global_repr=False, | |
) | |
# # bridging the s3d latent dim (1024) into what is specified in the config | |
# # to match e.g. the transformer dim | |
self.vproj = nn.Linear(in_features=768, out_features=768) | |
self.aproj = nn.Linear(in_features=768, out_features=768) | |
self.transformer = GlobalTransformer( | |
tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 | |
) | |
def forward(self, vis): | |
B, S, Tv, C, H, W = vis.shape | |
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) | |
# feat extractors return a tuple of segment-level and global features (ignored for sync) | |
# (B, S, tv, D), e.g. (B, 7, 8, 768) | |
vis = self.vfeat_extractor(vis) | |
return vis | |
def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): | |
vis = self.vproj(vis) | |
aud = self.aproj(aud) | |
B, S, tv, D = vis.shape | |
B, S, ta, D = aud.shape | |
vis = vis.view(B, S * tv, D) # (B, S*tv, D) | |
aud = aud.view(B, S * ta, D) # (B, S*ta, D) | |
# print(vis.shape, aud.shape) | |
# self.transformer will concatenate the vis and aud in one sequence with aux tokens, | |
# ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens | |
logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer | |
return logits | |
def extract_vfeats(self, vis): | |
B, S, Tv, C, H, W = vis.shape | |
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) | |
# feat extractors return a tuple of segment-level and global features (ignored for sync) | |
# (B, S, tv, D), e.g. (B, 7, 8, 768) | |
vis = self.vfeat_extractor(vis) | |
return vis | |
def extract_afeats(self, aud): | |
B, S, _, Fa, Ta = aud.shape | |
aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) | |
# (B, S, ta, D), e.g. (B, 7, 6, 768) | |
aud, _ = self.afeat_extractor(aud) | |
return aud | |
def compute_loss(self, logits, targets, loss_fn: str = None): | |
loss = None | |
if targets is not None: | |
if loss_fn is None or loss_fn == "cross_entropy": | |
# logits: (B, cls) and targets: (B,) | |
loss = F.cross_entropy(logits, targets) | |
else: | |
raise NotImplementedError(f"Loss {loss_fn} not implemented") | |
return loss | |
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): | |
# discard all entries except vfeat_extractor | |
# sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} | |
return super().load_state_dict(sd, strict) | |
class RandInitPositionalEncoding(nn.Module): | |
"""Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" | |
def __init__(self, block_shape: list, n_embd: int): | |
super().__init__() | |
self.block_shape = block_shape | |
self.n_embd = n_embd | |
self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) | |
def forward(self, token_embeddings): | |
return token_embeddings + self.pos_emb | |
class GlobalTransformer(torch.nn.Module): | |
"""Same as in SparseSync but without the selector transformers and the head""" | |
def __init__( | |
self, | |
tok_pdrop=0.0, | |
embd_pdrop=0.1, | |
resid_pdrop=0.1, | |
attn_pdrop=0.1, | |
n_layer=3, | |
n_head=8, | |
n_embd=768, | |
pos_emb_block_shape=[ | |
198, | |
], | |
n_off_head_out=21, | |
) -> None: | |
super().__init__() | |
self.config = Config( | |
embd_pdrop=embd_pdrop, | |
resid_pdrop=resid_pdrop, | |
attn_pdrop=attn_pdrop, | |
n_layer=n_layer, | |
n_head=n_head, | |
n_embd=n_embd, | |
) | |
# input norm | |
self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) | |
self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) | |
# aux tokens | |
self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) | |
self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) | |
# whole token dropout | |
self.tok_pdrop = tok_pdrop | |
self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) | |
self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) | |
# maybe add pos emb | |
self.pos_emb_cfg = RandInitPositionalEncoding( | |
block_shape=pos_emb_block_shape, | |
n_embd=n_embd, | |
) | |
# the stem | |
self.drop = torch.nn.Dropout(embd_pdrop) | |
self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) | |
# pre-output norm | |
self.ln_f = torch.nn.LayerNorm(n_embd) | |
# maybe add a head | |
self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) | |
def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): | |
B, Sv, D = v.shape | |
B, Sa, D = a.shape | |
# broadcasting special tokens to the batch size | |
off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) | |
mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) | |
# norm | |
v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) | |
# maybe whole token dropout | |
if self.tok_pdrop > 0: | |
v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) | |
# (B, 1+Sv+1+Sa, D) | |
x = torch.cat((off_tok, v, mod_tok, a), dim=1) | |
# maybe add pos emb | |
if hasattr(self, "pos_emb_cfg"): | |
x = self.pos_emb_cfg(x) | |
# dropout -> stem -> norm | |
x = self.drop(x) | |
x = self.blocks(x) | |
x = self.ln_f(x) | |
# maybe add heads | |
if attempt_to_apply_heads and hasattr(self, "off_head"): | |
x = self.off_head(x[:, 0, :]) | |
return x | |
class SelfAttention(nn.Module): | |
""" | |
A vanilla multi-head masked self-attention layer with a projection at the end. | |
It is possible to use torch.nn.MultiheadAttention here but I am including an | |
explicit implementation here to show that there is nothing too scary here. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
assert config.n_embd % config.n_head == 0 | |
# key, query, value projections for all heads | |
self.key = nn.Linear(config.n_embd, config.n_embd) | |
self.query = nn.Linear(config.n_embd, config.n_embd) | |
self.value = nn.Linear(config.n_embd, config.n_embd) | |
# regularization | |
self.attn_drop = nn.Dropout(config.attn_pdrop) | |
self.resid_drop = nn.Dropout(config.resid_pdrop) | |
# output projection | |
self.proj = nn.Linear(config.n_embd, config.n_embd) | |
# # causal mask to ensure that attention is only applied to the left in the input sequence | |
# mask = torch.tril(torch.ones(config.block_size, | |
# config.block_size)) | |
# if hasattr(config, "n_unmasked"): | |
# mask[:config.n_unmasked, :config.n_unmasked] = 1 | |
# self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) | |
self.n_head = config.n_head | |
def forward(self, x): | |
B, T, C = x.size() | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
# att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) | |
att = F.softmax(att, dim=-1) | |
y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
# output projection | |
y = self.resid_drop(self.proj(y)) | |
return y | |
class Block(nn.Module): | |
"""an unassuming Transformer block""" | |
def __init__(self, config): | |
super().__init__() | |
self.ln1 = nn.LayerNorm(config.n_embd) | |
self.ln2 = nn.LayerNorm(config.n_embd) | |
self.attn = SelfAttention(config) | |
self.mlp = nn.Sequential( | |
nn.Linear(config.n_embd, 4 * config.n_embd), | |
nn.GELU(), # nice | |
nn.Linear(4 * config.n_embd, config.n_embd), | |
nn.Dropout(config.resid_pdrop), | |
) | |
def forward(self, x): | |
x = x + self.attn(self.ln1(x)) | |
x = x + self.mlp(self.ln2(x)) | |
return x | |
def make_class_grid( | |
leftmost_val, | |
rightmost_val, | |
grid_size, | |
add_extreme_offset: bool = False, | |
seg_size_vframes: int = None, | |
nseg: int = None, | |
step_size_seg: float = None, | |
vfps: float = None, | |
): | |
assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" | |
grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() | |
if add_extreme_offset: | |
assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" | |
seg_size_sec = seg_size_vframes / vfps | |
trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) | |
extreme_value = trim_size_in_seg * seg_size_sec | |
grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid | |
return grid | |
# from synchformer | |
def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): | |
difference = max_spec_t - audio.shape[-1] # safe for batched input | |
# pad or truncate, depending on difference | |
if difference > 0: | |
# pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input | |
pad_dims = (0, difference) | |
audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) | |
elif difference < 0: | |
print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") | |
audio = audio[..., :max_spec_t] # safe for batched input | |
return audio | |
def encode_audio_with_sync( | |
synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram | |
) -> torch.Tensor: | |
b, t = x.shape | |
# partition the video | |
segment_size = 10240 | |
step_size = 10240 // 2 | |
num_segments = (t - segment_size) // step_size + 1 | |
segments = [] | |
for i in range(num_segments): | |
segments.append(x[:, i * step_size : i * step_size + segment_size]) | |
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) | |
x = mel(x) | |
x = torch.log(x + 1e-6) | |
x = pad_or_truncate(x, 66) | |
mean = -4.2677393 | |
std = 4.5689974 | |
x = (x - mean) / (2 * std) | |
# x: B * S * 128 * 66 | |
x = synchformer.extract_afeats(x.unsqueeze(2)) | |
return x | |
def read_audio(filename, expected_length=int(16000 * 4)): | |
waveform, sr = torchaudio.load(filename) | |
waveform = waveform.mean(dim=0) | |
if sr != 16000: | |
resampler = torchaudio.transforms.Resample(sr, 16000) | |
waveform = resampler[sr](waveform) | |
waveform = waveform[:expected_length] | |
if waveform.shape[0] != expected_length: | |
raise ValueError(f"Audio {filename} is too short") | |
waveform = waveform.squeeze() | |
return waveform | |
if __name__ == "__main__": | |
synchformer = Synchformer().cuda().eval() | |
# mmaudio provided synchformer ckpt | |
synchformer.load_state_dict( | |
torch.load( | |
os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), | |
weights_only=True, | |
map_location="cpu", | |
) | |
) | |
sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( | |
sample_rate=16000, | |
win_length=400, | |
hop_length=160, | |
n_fft=1024, | |
n_mels=128, | |
) | |