|
import logging |
|
|
|
import torch |
|
from transformers.modeling_outputs import BaseModelOutputWithPooling |
|
|
|
from .modeling_ast import ASTForAudioClassification, ASTConfig |
|
from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer |
|
from .utils import check_if_file_exists_else_download |
|
|
|
|
|
class AST(torch.nn.Module): |
|
def __init__( |
|
self, |
|
extract_features: bool = False, |
|
ckpt_path: str = None, |
|
feat_type: str = None, |
|
max_spec_t: int = None, |
|
factorize_freq_time: bool = None, |
|
agg_freq_module: str = None, |
|
agg_time_module: str = None, |
|
add_global_repr: bool = True, |
|
agg_segments_module: str = None, |
|
max_segments: int = None, |
|
) -> None: |
|
""" |
|
extract_features: if True, then the model will return the features instead of head's output |
|
ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub. |
|
feat_type: if extract_features is True, this parameter specifies the type of features to return |
|
max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec |
|
factorize_freq_time: if True, then the model will use a factorized freq/time aggregation |
|
agg_freq_module: if specified, then the model will use this module for freq aggregation |
|
agg_time_module: if specified, then the model will use this module for time aggregation |
|
add_global_repr: if True, adds a global representation to the features (aggregation on segments) |
|
agg_segments_module: if specified, then the model will use this module for segments aggregation |
|
max_segments: if specified, the initialization of PE in the global agg module will use this value. |
|
This should correspond to the max number of segments per video (if None, 16 is used) |
|
""" |
|
super().__init__() |
|
self.extract_features = extract_features |
|
self.ckpt_path = ckpt_path |
|
self.max_spec_t = max_spec_t |
|
self.max_segments = max_segments |
|
|
|
|
|
|
|
|
|
|
|
if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593": |
|
revision = "c1c0c66" |
|
self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision) |
|
full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision) |
|
logging.info(f"Loaded AST from {ckpt_path}") |
|
else: |
|
self.config = ASTConfig() |
|
self.config.num_labels = 527 |
|
full_model = ASTForAudioClassification(self.config) |
|
logging.info("Initialized AST from scratch with the AST AudioSet config") |
|
|
|
was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt") |
|
|
|
|
|
self.ast = full_model.audio_spectrogram_transformer |
|
|
|
if self.extract_features: |
|
|
|
self.feat_type = "last_hidden_state" if feat_type is None else feat_type |
|
|
|
self.factorize_freq_time = factorize_freq_time |
|
|
|
transf_enc_layer_kwargs = dict( |
|
d_model=self.config.hidden_size, |
|
nhead=self.config.num_attention_heads, |
|
dim_feedforward=self.config.intermediate_size, |
|
activation=torch.nn.GELU(), |
|
batch_first=True, |
|
dropout=self.config.attention_probs_dropout_prob, |
|
layer_norm_eps=1e-6, |
|
norm_first=True, |
|
) |
|
if factorize_freq_time: |
|
self.feat_type = "last_hidden_state" |
|
|
|
if agg_freq_module == "TransformerEncoderLayer": |
|
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs) |
|
elif agg_freq_module == "AveragePooling": |
|
self.freq_attn_agg = AveragePooling( |
|
avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D" |
|
) |
|
|
|
if agg_time_module == "TransformerEncoderLayer": |
|
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) |
|
elif agg_time_module == "AveragePooling": |
|
self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") |
|
elif "Identity" in agg_time_module: |
|
self.temp_attn_agg = torch.nn.Identity() |
|
|
|
self.add_global_repr = add_global_repr |
|
if add_global_repr: |
|
if agg_segments_module == "TransformerEncoderLayer": |
|
|
|
|
|
pos_max_len = max_segments if max_segments is not None else 16 |
|
self.global_attn_agg = TemporalTransformerEncoderLayer( |
|
add_pos_emb=True, |
|
pos_emb_drop=self.config.hidden_dropout_prob, |
|
pos_max_len=pos_max_len, |
|
**transf_enc_layer_kwargs, |
|
) |
|
elif agg_segments_module == "AveragePooling": |
|
self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") |
|
else: |
|
self.classifier = full_model.classifier |
|
|
|
|
|
self.device = full_model.device |
|
|
|
|
|
self.patch_position_emb() |
|
|
|
if was_pt_on_avclip: |
|
|
|
|
|
check_if_file_exists_else_download(self.ckpt_path) |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
ckpt_weights = dict() |
|
for k, v in ckpt["state_dict"].items(): |
|
if k.startswith(("module.a_encoder.", "a_encoder.")): |
|
k = k.replace("module.", "").replace("a_encoder.", "") |
|
ckpt_weights[k] = v |
|
_load_status = self.load_state_dict(ckpt_weights, strict=False) |
|
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: |
|
logging.warning( |
|
f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n" |
|
f"Missing keys ({len(_load_status.missing_keys)}): " |
|
f"{_load_status.missing_keys}, \n" |
|
f"Unexpected keys ({len(_load_status.unexpected_keys)}): " |
|
f"{_load_status.unexpected_keys} \n" |
|
f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." |
|
) |
|
else: |
|
logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.") |
|
|
|
|
|
logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") |
|
|
|
def forward( |
|
self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs |
|
) -> torch.Tensor: |
|
""" |
|
x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins, |
|
ast_kwargs: additional arguments for the AST model |
|
cont_mask: (B, S, T, F) where 0s are the values to be masked out |
|
if `for_loop=True`, we use a for loop to extract features for each segment separately. |
|
if `for_loop=False`, we extract features for all segments at once. |
|
Using the for loop is slower but more memory efficient, while using all segments at once |
|
is faster but more memory inefficient. |
|
Using for loop allows to control the memory footprint by varying the number of videos in a |
|
batch (batch size) rather than the number of segments in a video. |
|
""" |
|
B, S, T, F = x.shape |
|
|
|
if for_loop: |
|
assert cont_mask is None, "cont_mask is not supported with for_loop=True" |
|
orig_shape_s = (B, 1, T, F) |
|
|
|
|
|
x = torch.cat( |
|
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1 |
|
) |
|
else: |
|
orig_shape = (B, S, T, F) |
|
x = x.view(B * S, T, F) |
|
if cont_mask is not None: |
|
cont_mask = cont_mask.reshape(B * S, T, F) |
|
|
|
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs) |
|
|
|
x = x.view(B, S, *x.shape[1:]) |
|
|
|
|
|
global_x = None |
|
if self.extract_features and self.add_global_repr: |
|
assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}" |
|
global_x = self.global_attn_agg(x) |
|
|
|
return x, global_x |
|
|
|
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs): |
|
"""x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out""" |
|
|
|
|
|
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs) |
|
|
|
if self.extract_features: |
|
x = self.get_features_by_type(x) |
|
if self.factorize_freq_time: |
|
x = self.restore_freq_temp_dims(x, orig_shape) |
|
if cont_mask is not None: |
|
|
|
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) |
|
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) |
|
|
|
x_mask = x_mask[:, 0, :, :] |
|
else: |
|
x_mask = None |
|
x = self.freq_attn_agg(x, x_mask) |
|
x = self.temp_attn_agg(x) |
|
else: |
|
x = x["pooler_output"] |
|
x = self.classifier(x) |
|
return x |
|
|
|
def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor: |
|
if self.feat_type == "pooler_output": |
|
return x["pooler_output"] |
|
elif self.feat_type == "CLS": |
|
return x["last_hidden_state"][:, 0, :] |
|
elif self.feat_type == "last_hidden_state": |
|
return x["last_hidden_state"] |
|
elif self.feat_type == "last_hidden_state_no_AUX": |
|
return x["last_hidden_state"][:, 2:, :] |
|
else: |
|
raise ValueError(f"Unknown feature type: {self.feat_type}") |
|
|
|
def restore_freq_temp_dims(self, feats, orig_shape: tuple): |
|
""" |
|
feats are of shape (B*S, T, D) |
|
where T = 2 + f * t (if feat_type == 'last_hidden_state') |
|
where T = f * t (if feat_type == 'last_hidden_state_no_AUX') |
|
Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching. |
|
From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats: |
|
`feats.transpose(1, 2).view(B*S, D, f, t)` |
|
|
|
(Similar function is defined in for RGB features in `motionformer.py`) |
|
""" |
|
B, S, T, F = orig_shape |
|
D = self.config.hidden_size |
|
|
|
|
|
f, t = self.ast.embeddings.get_shape(self.config) |
|
|
|
if self.feat_type == "last_hidden_state": |
|
feats = feats[:, 2:, :] |
|
|
|
feats = feats.permute(0, 2, 1) |
|
feats = feats.view(B * S, D, f, t) |
|
|
|
return feats |
|
|
|
def patch_position_emb(self): |
|
if self.max_spec_t is not None: |
|
self.config.max_length = self.max_spec_t |
|
f, t = self.ast.embeddings.get_shape(self.config) |
|
shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() |
|
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device) |
|
|
|
def to(self, device): |
|
"""AST.device fails with AttributeError. This is a workaround.""" |
|
self.device = torch.device(device) |
|
return super().to(device) |
|
|
|
|
|
class FrequencyTransformerEncoderLayer(BaseEncoderLayer): |
|
"""This layer is used to aggregate the features along the frequency axis. |
|
It follows the same logic as spatio-temporal aggregation in visual feature extractor. |
|
Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: |
|
"""x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out""" |
|
BS, D, f, t = x.shape |
|
|
|
|
|
x = x.permute(0, 3, 2, 1) |
|
x = x.reshape(BS * t, f, D) |
|
|
|
if x_mask is not None: |
|
x_mask = x_mask.permute(0, 2, 1) |
|
x_mask = x_mask.reshape(BS * t, f) |
|
|
|
|
|
x = super().forward(x=x, x_mask=x_mask) |
|
|
|
|
|
x = x.view(BS, t, D) |
|
|
|
return x |
|
|