Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import torch | |
from transformers.modeling_outputs import BaseModelOutputWithPooling | |
from .modeling_ast import ASTForAudioClassification, ASTConfig | |
from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer | |
from .utils import check_if_file_exists_else_download | |
class AST(torch.nn.Module): | |
def __init__( | |
self, | |
extract_features: bool = False, | |
ckpt_path: str = None, | |
feat_type: str = None, | |
max_spec_t: int = None, | |
factorize_freq_time: bool = None, | |
agg_freq_module: str = None, | |
agg_time_module: str = None, | |
add_global_repr: bool = True, | |
agg_segments_module: str = None, | |
max_segments: int = None, | |
) -> None: | |
""" | |
extract_features: if True, then the model will return the features instead of head's output | |
ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub. | |
feat_type: if extract_features is True, this parameter specifies the type of features to return | |
max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec | |
factorize_freq_time: if True, then the model will use a factorized freq/time aggregation | |
agg_freq_module: if specified, then the model will use this module for freq aggregation | |
agg_time_module: if specified, then the model will use this module for time aggregation | |
add_global_repr: if True, adds a global representation to the features (aggregation on segments) | |
agg_segments_module: if specified, then the model will use this module for segments aggregation | |
max_segments: if specified, the initialization of PE in the global agg module will use this value. | |
This should correspond to the max number of segments per video (if None, 16 is used) | |
""" | |
super().__init__() | |
self.extract_features = extract_features | |
self.ckpt_path = ckpt_path | |
self.max_spec_t = max_spec_t | |
self.max_segments = max_segments | |
# depending on whether the feat extractor was pre-trained contrastively or not, we need to | |
# load the state dict differently. | |
# if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model | |
if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593": | |
revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4) | |
self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision) | |
full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision) | |
logging.info(f"Loaded AST from {ckpt_path}") | |
else: | |
self.config = ASTConfig() | |
self.config.num_labels = 527 # 2 by default, audioset has 527 labels | |
full_model = ASTForAudioClassification(self.config) | |
logging.info("Initialized AST from scratch with the AST AudioSet config") | |
was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt") | |
# feature extractor | |
self.ast = full_model.audio_spectrogram_transformer | |
if self.extract_features: | |
# assign `feat_type` (use default if not specified) | |
self.feat_type = "last_hidden_state" if feat_type is None else feat_type | |
# define adapters if needed | |
self.factorize_freq_time = factorize_freq_time | |
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) | |
transf_enc_layer_kwargs = dict( | |
d_model=self.config.hidden_size, | |
nhead=self.config.num_attention_heads, | |
dim_feedforward=self.config.intermediate_size, | |
activation=torch.nn.GELU(), | |
batch_first=True, | |
dropout=self.config.attention_probs_dropout_prob, | |
layer_norm_eps=1e-6, | |
norm_first=True, | |
) | |
if factorize_freq_time: | |
self.feat_type = "last_hidden_state" # this feat_type supports factorization | |
# frequency aggreration | |
if agg_freq_module == "TransformerEncoderLayer": | |
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs) | |
elif agg_freq_module == "AveragePooling": | |
self.freq_attn_agg = AveragePooling( | |
avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D" | |
) | |
# time aggreration | |
if agg_time_module == "TransformerEncoderLayer": | |
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) | |
elif agg_time_module == "AveragePooling": | |
self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") | |
elif "Identity" in agg_time_module: | |
self.temp_attn_agg = torch.nn.Identity() | |
# define a global aggregation layer (aggregarate over segments) | |
self.add_global_repr = add_global_repr | |
if add_global_repr: | |
if agg_segments_module == "TransformerEncoderLayer": | |
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) | |
# we need to add pos emb (PE) because previously we added the same PE for each segment | |
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 | |
self.global_attn_agg = TemporalTransformerEncoderLayer( | |
add_pos_emb=True, | |
pos_emb_drop=self.config.hidden_dropout_prob, | |
pos_max_len=pos_max_len, | |
**transf_enc_layer_kwargs, | |
) | |
elif agg_segments_module == "AveragePooling": | |
self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") | |
else: | |
self.classifier = full_model.classifier | |
# AST.device fails with AttributeError. This is a workaround | |
self.device = full_model.device | |
# pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74) | |
self.patch_position_emb() | |
if was_pt_on_avclip: | |
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) | |
# and keep only the state_dict of the feat extractor | |
check_if_file_exists_else_download(self.ckpt_path) | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
ckpt_weights = dict() | |
for k, v in ckpt["state_dict"].items(): | |
if k.startswith(("module.a_encoder.", "a_encoder.")): | |
k = k.replace("module.", "").replace("a_encoder.", "") | |
ckpt_weights[k] = v | |
_load_status = self.load_state_dict(ckpt_weights, strict=False) | |
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: | |
logging.warning( | |
f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n" | |
f"Missing keys ({len(_load_status.missing_keys)}): " | |
f"{_load_status.missing_keys}, \n" | |
f"Unexpected keys ({len(_load_status.unexpected_keys)}): " | |
f"{_load_status.unexpected_keys} \n" | |
f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." | |
) | |
else: | |
logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.") | |
# print the number of parameters | |
logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") | |
def forward( | |
self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs | |
) -> torch.Tensor: | |
""" | |
x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins, | |
ast_kwargs: additional arguments for the AST model | |
cont_mask: (B, S, T, F) where 0s are the values to be masked out | |
if `for_loop=True`, we use a for loop to extract features for each segment separately. | |
if `for_loop=False`, we extract features for all segments at once. | |
Using the for loop is slower but more memory efficient, while using all segments at once | |
is faster but more memory inefficient. | |
Using for loop allows to control the memory footprint by varying the number of videos in a | |
batch (batch size) rather than the number of segments in a video. | |
""" | |
B, S, T, F = x.shape | |
if for_loop: | |
assert cont_mask is None, "cont_mask is not supported with for_loop=True" | |
orig_shape_s = (B, 1, T, F) | |
# NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F). | |
# (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1. | |
x = torch.cat( | |
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1 | |
) | |
else: | |
orig_shape = (B, S, T, F) | |
x = x.view(B * S, T, F) | |
if cont_mask is not None: | |
cont_mask = cont_mask.reshape(B * S, T, F) | |
# AST expects a tensor of shape (B*S, T, F). | |
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs) | |
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) | |
x = x.view(B, S, *x.shape[1:]) | |
# x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity` | |
global_x = None | |
if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError | |
assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}" | |
global_x = self.global_attn_agg(x) # (B, D) | |
return x, global_x # x is (B, S, ...), global_x is (B, D) or None | |
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs): | |
"""x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out""" | |
# 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>] | |
# x_mask is (B, T) where 0s are the values to be masked out | |
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs) | |
if self.extract_features: | |
x = self.get_features_by_type(x) | |
if self.factorize_freq_time: | |
x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D) | |
if cont_mask is not None: | |
# duplicating the mask for the latent dimension (D) to be compatible with the next func | |
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) | |
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D) | |
# again removing the latent | |
x_mask = x_mask[:, 0, :, :] | |
else: | |
x_mask = None | |
x = self.freq_attn_agg(x, x_mask) # (BS, t, D) | |
x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity | |
else: | |
x = x["pooler_output"] | |
x = self.classifier(x) | |
return x | |
def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor: | |
if self.feat_type == "pooler_output": | |
return x["pooler_output"] # (B, D) | |
elif self.feat_type == "CLS": | |
return x["last_hidden_state"][:, 0, :] # (B, D) | |
elif self.feat_type == "last_hidden_state": | |
return x["last_hidden_state"] # (B, 2+T, D) | |
elif self.feat_type == "last_hidden_state_no_AUX": | |
return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens | |
else: | |
raise ValueError(f"Unknown feature type: {self.feat_type}") | |
def restore_freq_temp_dims(self, feats, orig_shape: tuple): | |
""" | |
feats are of shape (B*S, T, D) | |
where T = 2 + f * t (if feat_type == 'last_hidden_state') | |
where T = f * t (if feat_type == 'last_hidden_state_no_AUX') | |
Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching. | |
From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats: | |
`feats.transpose(1, 2).view(B*S, D, f, t)` | |
(Similar function is defined in for RGB features in `motionformer.py`) | |
""" | |
B, S, T, F = orig_shape | |
D = self.config.hidden_size | |
# num patches in each dimension | |
f, t = self.ast.embeddings.get_shape(self.config) | |
if self.feat_type == "last_hidden_state": | |
feats = feats[:, 2:, :] # removing CLS and distill tokens | |
feats = feats.permute(0, 2, 1) # (B*S, D, T) | |
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t) | |
return feats | |
def patch_position_emb(self): | |
if self.max_spec_t is not None: | |
self.config.max_length = self.max_spec_t | |
f, t = self.ast.embeddings.get_shape(self.config) | |
shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens | |
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device) | |
def to(self, device): | |
"""AST.device fails with AttributeError. This is a workaround.""" | |
self.device = torch.device(device) | |
return super().to(device) | |
class FrequencyTransformerEncoderLayer(BaseEncoderLayer): | |
"""This layer is used to aggregate the features along the frequency axis. | |
It follows the same logic as spatio-temporal aggregation in visual feature extractor. | |
Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: | |
"""x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out""" | |
BS, D, f, t = x.shape | |
# time as a batch dimension | |
x = x.permute(0, 3, 2, 1) # (B*S, t, f, D) | |
x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory | |
# similar to mask | |
if x_mask is not None: | |
x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f) | |
x_mask = x_mask.reshape(BS * t, f) | |
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation | |
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) | |
# reshape back to (B*S, t, D) | |
x = x.view(BS, t, D) | |
return x # (B*S, t, D) | |