Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0 | |
| # You can find the license at LICENSES/LICENSE_AURALOSS.txt | |
| import torch | |
| import numpy as np | |
| from typing import List, Any | |
| import scipy.signal | |
| def apply_reduction(losses, reduction="none"): | |
| """Apply reduction to collection of losses.""" | |
| if reduction == "mean": | |
| losses = losses.mean() | |
| elif reduction == "sum": | |
| losses = losses.sum() | |
| return losses | |
| def compute_direction(w, x, y, z): | |
| # 计算各个声道的权重 | |
| phi = torch.atan2(y, x) | |
| theta = torch.atan2(torch.sqrt(x**2 + y**2), z) | |
| return phi.unsqueeze(1), theta.unsqueeze(1) | |
| def get_window(win_type: str, win_length: int): | |
| """Return a window function. | |
| Args: | |
| win_type (str): Window type. Can either be one of the window function provided in PyTorch | |
| ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] | |
| or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). | |
| win_length (int): Window length | |
| Returns: | |
| win: The window as a 1D torch tensor | |
| """ | |
| try: | |
| win = getattr(torch, win_type)(win_length) | |
| except: | |
| win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length)) | |
| return win | |
| class SumAndDifference(torch.nn.Module): | |
| """Sum and difference signal extraction module.""" | |
| def __init__(self): | |
| """Initialize sum and difference extraction module.""" | |
| super(SumAndDifference, self).__init__() | |
| def forward(self, x): | |
| """Calculate forward propagation. | |
| Args: | |
| x (Tensor): Predicted signal (B, #channels, #samples). | |
| Returns: | |
| Tensor: Sum signal. | |
| Tensor: Difference signal. | |
| """ | |
| if not (x.size(1) == 2): # inputs must be stereo | |
| raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).") | |
| sum_sig = self.sum(x).unsqueeze(1) | |
| diff_sig = self.diff(x).unsqueeze(1) | |
| return sum_sig, diff_sig | |
| def sum(x): | |
| return x[:, 0, :] + x[:, 1, :] | |
| def diff(x): | |
| return x[:, 0, :] - x[:, 1, :] | |
| class FIRFilter(torch.nn.Module): | |
| """FIR pre-emphasis filtering module. | |
| Args: | |
| filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" | |
| coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 | |
| ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 | |
| plot (bool): Plot the magnitude respond of the filter. Default: False | |
| Based upon the perceptual loss pre-empahsis filters proposed by | |
| [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). | |
| A-weighting filter - "aw" | |
| First-order highpass - "hp" | |
| Folded differentiator - "fd" | |
| Note that the default coefficeint value of 0.85 is optimized for | |
| a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. | |
| """ | |
| def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): | |
| """Initilize FIR pre-emphasis filtering module.""" | |
| super(FIRFilter, self).__init__() | |
| self.filter_type = filter_type | |
| self.coef = coef | |
| self.fs = fs | |
| self.ntaps = ntaps | |
| self.plot = plot | |
| import scipy.signal | |
| if ntaps % 2 == 0: | |
| raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") | |
| if filter_type == "hp": | |
| self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) | |
| self.fir.weight.requires_grad = False | |
| self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) | |
| elif filter_type == "fd": | |
| self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) | |
| self.fir.weight.requires_grad = False | |
| self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) | |
| elif filter_type == "aw": | |
| # Definition of analog A-weighting filter according to IEC/CD 1672. | |
| f1 = 20.598997 | |
| f2 = 107.65265 | |
| f3 = 737.86223 | |
| f4 = 12194.217 | |
| A1000 = 1.9997 | |
| NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] | |
| DENs = np.polymul( | |
| [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], | |
| [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], | |
| ) | |
| DENs = np.polymul( | |
| np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] | |
| ) | |
| # convert analog filter to digital filter | |
| b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) | |
| # compute the digital filter frequency response | |
| w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) | |
| # then we fit to 101 tap FIR filter with least squares | |
| taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) | |
| # now implement this digital FIR filter as a Conv1d layer | |
| self.fir = torch.nn.Conv1d( | |
| 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 | |
| ) | |
| self.fir.weight.requires_grad = False | |
| self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) | |
| if plot: | |
| from .plotting import compare_filters | |
| compare_filters(b, a, taps, fs=fs) | |
| def forward(self, input, target): | |
| """Calculate forward propagation. | |
| Args: | |
| input (Tensor): Predicted signal (B, #channels, #samples). | |
| target (Tensor): Groundtruth signal (B, #channels, #samples). | |
| Returns: | |
| Tensor: Filtered signal. | |
| """ | |
| input = torch.nn.functional.conv1d( | |
| input, self.fir.weight.data, padding=self.ntaps // 2 | |
| ) | |
| target = torch.nn.functional.conv1d( | |
| target, self.fir.weight.data, padding=self.ntaps // 2 | |
| ) | |
| return input, target | |
| class SpectralConvergenceLoss(torch.nn.Module): | |
| """Spectral convergence loss module. | |
| See [Arik et al., 2018](https://arxiv.org/abs/1808.06719). | |
| """ | |
| def __init__(self): | |
| super(SpectralConvergenceLoss, self).__init__() | |
| def forward(self, x_mag, y_mag): | |
| return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean() | |
| class STFTMagnitudeLoss(torch.nn.Module): | |
| """STFT magnitude loss module. | |
| See [Arik et al., 2018](https://arxiv.org/abs/1808.06719) | |
| and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1) | |
| Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the | |
| compression strength (larger value results in more compression), and `log_eps` can be used | |
| to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive | |
| output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression. | |
| Args: | |
| log (bool, optional): Log-scale the STFT magnitudes, | |
| or use linear scale. Default: True | |
| log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm. | |
| Default: 0.0 | |
| log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm. | |
| Default: 1.0 | |
| distance (str, optional): Distance function ["L1", "L2"]. Default: "L1" | |
| reduction (str, optional): Reduction of the loss elements. Default: "mean" | |
| """ | |
| def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"): | |
| super(STFTMagnitudeLoss, self).__init__() | |
| self.log = log | |
| self.log_eps = log_eps | |
| self.log_fac = log_fac | |
| if distance == "L1": | |
| self.distance = torch.nn.L1Loss(reduction=reduction) | |
| elif distance == "L2": | |
| self.distance = torch.nn.MSELoss(reduction=reduction) | |
| else: | |
| raise ValueError(f"Invalid distance: '{distance}'.") | |
| def forward(self, x_mag, y_mag): | |
| if self.log: | |
| x_mag = torch.log(self.log_fac * x_mag + self.log_eps) | |
| y_mag = torch.log(self.log_fac * y_mag + self.log_eps) | |
| return self.distance(x_mag, y_mag) | |
| class STFTLoss(torch.nn.Module): | |
| """STFT loss module. | |
| See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472). | |
| Args: | |
| fft_size (int, optional): FFT size in samples. Default: 1024 | |
| hop_size (int, optional): Hop size of the FFT in samples. Default: 256 | |
| win_length (int, optional): Length of the FFT analysis window. Default: 1024 | |
| window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch | |
| ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] | |
| or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). | |
| Default: 'hann_window' | |
| w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 | |
| w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 | |
| w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 | |
| w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 | |
| sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None | |
| scale (str, optional): Optional frequency scaling method, options include: | |
| ['mel', 'chroma'] | |
| Default: None | |
| n_bins (int, optional): Number of scaling frequency bins. Default: None. | |
| perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False | |
| scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False | |
| eps (float, optional): Small epsilon value for stablity. Default: 1e-8 | |
| output (str, optional): Format of the loss returned. | |
| 'loss' : Return only the raw, aggregate loss term. | |
| 'full' : Return the raw loss, plus intermediate loss terms. | |
| Default: 'loss' | |
| reduction (str, optional): Specifies the reduction to apply to the output: | |
| 'none': no reduction will be applied, | |
| 'mean': the sum of the output will be divided by the number of elements in the output, | |
| 'sum': the output will be summed. | |
| Default: 'mean' | |
| mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. | |
| device (str, optional): Place the filterbanks on specified device. Default: None | |
| Returns: | |
| loss: | |
| Aggreate loss term. Only returned if output='loss'. By default. | |
| loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss: | |
| Aggregate and intermediate loss terms. Only returned if output='full'. | |
| """ | |
| def __init__( | |
| self, | |
| fft_size: int = 1024, | |
| hop_size: int = 256, | |
| win_length: int = 1024, | |
| window: str = "hann_window", | |
| w_sc: float = 1.0, | |
| w_log_mag: float = 1.0, | |
| w_lin_mag: float = 0.0, | |
| w_phs: float = 0.0, | |
| sample_rate: float = None, | |
| scale: str = None, | |
| n_bins: int = None, | |
| perceptual_weighting: bool = False, | |
| scale_invariance: bool = False, | |
| eps: float = 1e-8, | |
| output: str = "loss", | |
| reduction: str = "mean", | |
| mag_distance: str = "L1", | |
| device: Any = None, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.fft_size = fft_size | |
| self.hop_size = hop_size | |
| self.win_length = win_length | |
| self.window = get_window(window, win_length) | |
| self.w_sc = w_sc | |
| self.w_log_mag = w_log_mag | |
| self.w_lin_mag = w_lin_mag | |
| self.w_phs = w_phs | |
| self.sample_rate = sample_rate | |
| self.scale = scale | |
| self.n_bins = n_bins | |
| self.perceptual_weighting = perceptual_weighting | |
| self.scale_invariance = scale_invariance | |
| self.eps = eps | |
| self.output = output | |
| self.reduction = reduction | |
| self.mag_distance = mag_distance | |
| self.device = device | |
| self.phs_used = bool(self.w_phs) | |
| self.spectralconv = SpectralConvergenceLoss() | |
| self.logstft = STFTMagnitudeLoss( | |
| log=True, | |
| reduction=reduction, | |
| distance=mag_distance, | |
| **kwargs | |
| ) | |
| self.linstft = STFTMagnitudeLoss( | |
| log=False, | |
| reduction=reduction, | |
| distance=mag_distance, | |
| **kwargs | |
| ) | |
| # setup mel filterbank | |
| if scale is not None: | |
| try: | |
| import librosa.filters | |
| except Exception as e: | |
| print(e) | |
| print("Try `pip install auraloss[all]`.") | |
| if self.scale == "mel": | |
| assert sample_rate != None # Must set sample rate to use mel scale | |
| assert n_bins <= fft_size # Must be more FFT bins than Mel bins | |
| fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) | |
| fb = torch.tensor(fb).unsqueeze(0) | |
| elif self.scale == "chroma": | |
| assert sample_rate != None # Must set sample rate to use chroma scale | |
| assert n_bins <= fft_size # Must be more FFT bins than chroma bins | |
| fb = librosa.filters.chroma( | |
| sr=sample_rate, n_fft=fft_size, n_chroma=n_bins | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." | |
| ) | |
| self.register_buffer("fb", fb) | |
| if scale is not None and device is not None: | |
| self.fb = self.fb.to(self.device) # move filterbank to device | |
| if self.perceptual_weighting: | |
| if sample_rate is None: | |
| raise ValueError( | |
| f"`sample_rate` must be supplied when `perceptual_weighting = True`." | |
| ) | |
| self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) | |
| def stft(self, x): | |
| """Perform STFT. | |
| Args: | |
| x (Tensor): Input signal tensor (B, T). | |
| Returns: | |
| Tensor: x_mag, x_phs | |
| Magnitude and phase spectra (B, fft_size // 2 + 1, frames). | |
| """ | |
| x_stft = torch.stft( | |
| x, | |
| self.fft_size, | |
| self.hop_size, | |
| self.win_length, | |
| self.window, | |
| return_complex=True, | |
| ) | |
| x_mag = torch.sqrt( | |
| torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) | |
| ) | |
| # torch.angle is expensive, so it is only evaluated if the values are used in the loss | |
| if self.phs_used: | |
| x_phs = torch.angle(x_stft) | |
| else: | |
| x_phs = None | |
| return x_mag, x_phs | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| bs, chs, seq_len = input.size() | |
| if self.perceptual_weighting: # apply optional A-weighting via FIR filter | |
| # since FIRFilter only support mono audio we will move channels to batch dim | |
| input = input.view(bs * chs, 1, -1) | |
| target = target.view(bs * chs, 1, -1) | |
| # now apply the filter to both | |
| self.prefilter.to(input.device) | |
| input, target = self.prefilter(input, target) | |
| # now move the channels back | |
| input = input.view(bs, chs, -1) | |
| target = target.view(bs, chs, -1) | |
| # compute the magnitude and phase spectra of input and target | |
| self.window = self.window.to(input.device) | |
| x_mag, x_phs = self.stft(input.view(-1, input.size(-1))) | |
| y_mag, y_phs = self.stft(target.view(-1, target.size(-1))) | |
| # apply relevant transforms | |
| if self.scale is not None: | |
| self.fb = self.fb.to(input.device) | |
| x_mag = torch.matmul(self.fb, x_mag) | |
| y_mag = torch.matmul(self.fb, y_mag) | |
| # normalize scales | |
| if self.scale_invariance: | |
| alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1])) | |
| y_mag = y_mag * alpha.unsqueeze(-1) | |
| # compute loss terms | |
| sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0 | |
| log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0 | |
| lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0 | |
| phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0 | |
| # combine loss terms | |
| loss = ( | |
| (self.w_sc * sc_mag_loss) | |
| + (self.w_log_mag * log_mag_loss) | |
| + (self.w_lin_mag * lin_mag_loss) | |
| + (self.w_phs * phs_loss) | |
| ) | |
| loss = apply_reduction(loss, reduction=self.reduction) | |
| if self.output == "loss": | |
| return loss | |
| elif self.output == "full": | |
| return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss | |
| class MultiResolutionSTFTLoss(torch.nn.Module): | |
| """Multi resolution STFT loss module. | |
| See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480) | |
| Args: | |
| fft_sizes (list): List of FFT sizes. | |
| hop_sizes (list): List of hop sizes. | |
| win_lengths (list): List of window lengths. | |
| window (str, optional): Window to apply before FFT, options include: | |
| 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] | |
| Default: 'hann_window' | |
| w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 | |
| w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 | |
| w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 | |
| w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 | |
| sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None | |
| scale (str, optional): Optional frequency scaling method, options include: | |
| ['mel', 'chroma'] | |
| Default: None | |
| n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None. | |
| scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False | |
| """ | |
| def __init__( | |
| self, | |
| fft_sizes: List[int] = [1024, 2048, 512], | |
| hop_sizes: List[int] = [120, 240, 50], | |
| win_lengths: List[int] = [600, 1200, 240], | |
| window: str = "hann_window", | |
| w_sc: float = 1.0, | |
| w_log_mag: float = 1.0, | |
| w_lin_mag: float = 0.0, | |
| w_phs: float = 0.0, | |
| sample_rate: float = None, | |
| scale: str = None, | |
| n_bins: int = None, | |
| perceptual_weighting: bool = False, | |
| scale_invariance: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all | |
| self.fft_sizes = fft_sizes | |
| self.hop_sizes = hop_sizes | |
| self.win_lengths = win_lengths | |
| self.stft_losses = torch.nn.ModuleList() | |
| for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): | |
| self.stft_losses += [ | |
| STFTLoss( | |
| fs, | |
| ss, | |
| wl, | |
| window, | |
| w_sc, | |
| w_log_mag, | |
| w_lin_mag, | |
| w_phs, | |
| sample_rate, | |
| scale, | |
| n_bins, | |
| perceptual_weighting, | |
| scale_invariance, | |
| **kwargs, | |
| ) | |
| ] | |
| def forward(self, x, y): | |
| mrstft_loss = 0.0 | |
| sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], [] | |
| # import ipdb | |
| # ipdb.set_trace() | |
| for f in self.stft_losses: | |
| if f.output == "full": # extract just first term | |
| tmp_loss = f(x, y) | |
| mrstft_loss += tmp_loss[0] | |
| sc_mag_loss.append(tmp_loss[1]) | |
| log_mag_loss.append(tmp_loss[2]) | |
| lin_mag_loss.append(tmp_loss[3]) | |
| phs_loss.append(tmp_loss[4]) | |
| else: | |
| mrstft_loss += f(x, y) | |
| mrstft_loss /= len(self.stft_losses) | |
| if f.output == "loss": | |
| return mrstft_loss | |
| else: | |
| return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss | |
| class SumAndDifferenceSTFTLoss(torch.nn.Module): | |
| """Sum and difference sttereo STFT loss module. | |
| See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) | |
| Args: | |
| fft_sizes (List[int]): List of FFT sizes. | |
| hop_sizes (List[int]): List of hop sizes. | |
| win_lengths (List[int]): List of window lengths. | |
| window (str, optional): Window function type. | |
| w_sum (float, optional): Weight of the sum loss component. Default: 1.0 | |
| w_diff (float, optional): Weight of the difference loss component. Default: 1.0 | |
| perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False | |
| mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False | |
| n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 | |
| sample_rate (float, optional): Audio sample rate. Default: None | |
| output (str, optional): Format of the loss returned. | |
| 'loss' : Return only the raw, aggregate loss term. | |
| 'full' : Return the raw loss, plus intermediate loss terms. | |
| Default: 'loss' | |
| """ | |
| def __init__( | |
| self, | |
| fft_sizes: List[int], | |
| hop_sizes: List[int], | |
| win_lengths: List[int], | |
| window: str = "hann_window", | |
| w_sum: float = 1.0, | |
| w_diff: float = 1.0, | |
| output: str = "loss", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.sd = SumAndDifference() | |
| self.w_sum = w_sum | |
| self.w_diff = w_diff | |
| self.output = output | |
| self.mrstft = MultiResolutionSTFTLoss( | |
| fft_sizes, | |
| hop_sizes, | |
| win_lengths, | |
| window, | |
| **kwargs, | |
| ) | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| """This loss function assumes batched input of stereo audio in the time domain. | |
| Args: | |
| input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). | |
| target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). | |
| Returns: | |
| loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. | |
| loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): | |
| Aggregate and intermediate loss terms. Only returned if output='full'. | |
| """ | |
| assert input.shape == target.shape # must have same shape | |
| bs, chs, seq_len = input.size() | |
| # compute sum and difference signals for both | |
| input_sum, input_diff = self.sd(input) | |
| target_sum, target_diff = self.sd(target) | |
| # compute error in STFT domain | |
| sum_loss = self.mrstft(input_sum, target_sum) | |
| diff_loss = self.mrstft(input_diff, target_diff) | |
| loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 | |
| if self.output == "loss": | |
| return loss | |
| elif self.output == "full": | |
| return loss, sum_loss, diff_loss | |
| class SpatialSTFTLoss(torch.nn.Module): | |
| """Sum and difference sttereo STFT loss module. | |
| See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) | |
| Args: | |
| fft_sizes (List[int]): List of FFT sizes. | |
| hop_sizes (List[int]): List of hop sizes. | |
| win_lengths (List[int]): List of window lengths. | |
| window (str, optional): Window function type. | |
| w_sum (float, optional): Weight of the sum loss component. Default: 1.0 | |
| w_diff (float, optional): Weight of the difference loss component. Default: 1.0 | |
| perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False | |
| mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False | |
| n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 | |
| sample_rate (float, optional): Audio sample rate. Default: None | |
| output (str, optional): Format of the loss returned. | |
| 'loss' : Return only the raw, aggregate loss term. | |
| 'full' : Return the raw loss, plus intermediate loss terms. | |
| Default: 'loss' | |
| """ | |
| def __init__( | |
| self, | |
| fft_sizes: List[int], | |
| hop_sizes: List[int], | |
| win_lengths: List[int], | |
| window: str = "hann_window", | |
| w_phi: float = 1.0, | |
| w_theta: float = 1.0, | |
| output: str = "loss", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.w_phi = w_phi | |
| self.w_theta = w_theta | |
| self.output = output | |
| self.mrstft = MultiResolutionSTFTLoss( | |
| fft_sizes, | |
| hop_sizes, | |
| win_lengths, | |
| window, | |
| **kwargs, | |
| ) | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| """This loss function assumes batched input of stereo audio in the time domain. | |
| Args: | |
| input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). | |
| target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). | |
| Returns: | |
| loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. | |
| loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): | |
| Aggregate and intermediate loss terms. Only returned if output='full'. | |
| """ | |
| assert input.shape == target.shape # must have same shape | |
| bs, chs, seq_len = input.size() | |
| w_o, x_o, y_o, z_o = input[:, 0], input[:, 1], input[:, 2], input[:, 3] | |
| w_r, x_r, y_r, z_r = target[:, 0], target[:, 1], target[:, 2], target[:, 3] | |
| phi_o, theta_o = compute_direction(w_o, x_o, y_o, z_o) | |
| phi_r, theta_r = compute_direction(w_r, x_r, y_r, z_r) | |
| # compute error in STFT domain | |
| phi_loss = self.mrstft(phi_o, phi_r) | |
| theta_loss = self.mrstft(theta_o, theta_r) | |
| loss = ((self.w_phi * phi_loss) + (self.w_theta * theta_loss)) / 2 | |
| if self.output == "loss": | |
| return loss | |
| elif self.output == "full": | |
| return loss, sum_loss, diff_loss |