| import torch | |
| import torch.nn as nn | |
| class STFTMag(nn.Module): | |
| def __init__(self, | |
| nfft=1024, | |
| hop=256): | |
| super().__init__() | |
| self.nfft = nfft | |
| self.hop = hop | |
| self.register_buffer('window', torch.hann_window(nfft), False) | |
| # x: [B,T] or [T] | |
| def forward(self, x): | |
| stft = torch.stft(x.cpu(), | |
| self.nfft, | |
| self.hop, | |
| window=self.window, | |
| ) # return_complex=False) #[B, F, TT,2] | |
| mag = torch.norm(stft, p=2, dim=-1) # [B, F, TT] | |
| return mag | |