Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Processor, HubertModel, HubertConfig | |
| from transformers.pytorch_utils import Conv1D | |
| class HubertAudioTransform(): | |
| def __init__(self): | |
| self.processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") | |
| def __call__(self, audio): | |
| return self.processor(audio, return_tensors="pt", sampling_rate=16000).input_values.squeeze(0) | |
| def copy_conv(l): | |
| new_l = Conv1D() | |
| class Hubert(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| model1 = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") | |
| config = model1.config | |
| del model1 | |
| config.layer_norm_eps = 1e-4 | |
| self.model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", config=config) | |
| self.config = dict() | |
| def forward(self, audio, include_cls): | |
| outputs = self.model(audio) | |
| # outputs = deepspeed.checkpointing.checkpoint(self.model, audio) | |
| patch_tokens = outputs.last_hidden_state.permute(0, 2, 1).unsqueeze(2) | |
| # return patch_tokens | |
| if include_cls: | |
| return patch_tokens, None | |
| else: | |
| return patch_tokens | |
| def get_last_params(self): | |
| return self.model.encoder.layers[-1].parameters() | |
| if __name__ == "__main__": | |
| import librosa | |
| from shared import pca, remove_axes | |
| import matplotlib.pyplot as plt | |
| from pytorch_lightning import seed_everything | |
| audio, _ = librosa.load("../../samples/example.wav", sr=16000) | |
| audio = torch.from_numpy(audio).unsqueeze(0).to("cuda") | |
| model = Hubert().to("cuda") | |
| embeddings = model.forward(audio, include_cls=False) | |
| print(embeddings.shape) | |
| seed_everything(0) | |
| with torch.no_grad(): | |
| [pca_feats], _ = pca([embeddings]) | |
| pca_feats = torch.broadcast_to( | |
| pca_feats, (pca_feats.shape[0], pca_feats.shape[1], 25, pca_feats.shape[3])) | |
| fig, axes = plt.subplots(2, 1, figsize=(10, 7)) | |
| axes[1].imshow(pca_feats.cpu().squeeze(0).permute(1, 2, 0)) | |
| remove_axes(axes) | |
| plt.tight_layout() | |
| plt.show() | |
| print("here") | |