|
import math |
|
import random |
|
import torch |
|
|
|
from torch import nn |
|
from typing import Tuple |
|
|
|
class PadCrop(nn.Module): |
|
def __init__(self, n_samples, randomize=True): |
|
super().__init__() |
|
self.n_samples = n_samples |
|
self.randomize = randomize |
|
|
|
def __call__(self, signal): |
|
n, s = signal.shape |
|
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() |
|
end = start + self.n_samples |
|
output = signal.new_zeros([n, self.n_samples]) |
|
output[:, :min(s, self.n_samples)] = signal[:, start:end] |
|
return output |
|
|
|
class PadCrop_Normalized_T(nn.Module): |
|
|
|
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): |
|
|
|
super().__init__() |
|
|
|
self.n_samples = n_samples |
|
self.sample_rate = sample_rate |
|
self.randomize = randomize |
|
|
|
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: |
|
|
|
n_channels, n_samples = source.shape |
|
|
|
|
|
upper_bound = max(0, n_samples - self.n_samples) |
|
|
|
|
|
offset = 0 |
|
if(self.randomize and n_samples > self.n_samples): |
|
offset = random.randint(0, upper_bound) |
|
|
|
|
|
t_start = offset / (upper_bound + self.n_samples) |
|
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) |
|
|
|
|
|
chunk = source.new_zeros([n_channels, self.n_samples]) |
|
|
|
|
|
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] |
|
|
|
|
|
seconds_start = math.floor(offset / self.sample_rate) |
|
seconds_total = math.ceil(n_samples / self.sample_rate) |
|
|
|
|
|
padding_mask = torch.zeros([self.n_samples]) |
|
padding_mask[:min(n_samples, self.n_samples)] = 1 |
|
|
|
|
|
return ( |
|
chunk, |
|
t_start, |
|
t_end, |
|
seconds_start, |
|
seconds_total, |
|
padding_mask |
|
) |
|
|
|
class PhaseFlipper(nn.Module): |
|
"Randomly invert the phase of a signal" |
|
def __init__(self, p=0.5): |
|
super().__init__() |
|
self.p = p |
|
def __call__(self, signal): |
|
return -signal if (random.random() < self.p) else signal |
|
|
|
class Mono(nn.Module): |
|
def __call__(self, signal): |
|
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal |
|
|
|
class Stereo(nn.Module): |
|
def __call__(self, signal): |
|
signal_shape = signal.shape |
|
|
|
if len(signal_shape) == 1: |
|
signal = signal.unsqueeze(0).repeat(2, 1) |
|
elif len(signal_shape) == 2: |
|
if signal_shape[0] == 1: |
|
signal = signal.repeat(2, 1) |
|
elif signal_shape[0] > 2: |
|
signal = signal[:2, :] |
|
|
|
return signal |
|
|