audio / demucs /augment.py
PreciousMposa's picture
Upload 107 files
519d358 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Data augmentations.
"""
import random
import torch as th
from torch import nn
class Shift(nn.Module):
"""
Randomly shift audio in time by up to `shift` samples.
"""
def __init__(self, shift=8192, same=False):
super().__init__()
self.shift = shift
self.same = same
def forward(self, wav):
batch, sources, channels, time = wav.size()
length = time - self.shift
if self.shift > 0:
if not self.training:
wav = wav[..., :length]
else:
srcs = 1 if self.same else sources
offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
offsets = offsets.expand(-1, sources, channels, -1)
indexes = th.arange(length, device=wav.device)
wav = wav.gather(3, indexes + offsets)
return wav
class FlipChannels(nn.Module):
"""
Flip left-right channels.
"""
def forward(self, wav):
batch, sources, channels, time = wav.size()
if self.training and wav.size(2) == 2:
left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
left = left.expand(-1, -1, -1, time)
right = 1 - left
wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
return wav
class FlipSign(nn.Module):
"""
Random sign flip.
"""
def forward(self, wav):
batch, sources, channels, time = wav.size()
if self.training:
signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
wav = wav * (2 * signs - 1)
return wav
class Remix(nn.Module):
"""
Shuffle sources to make new mixes.
"""
def __init__(self, proba=1, group_size=4):
"""
Shuffle sources within one batch.
Each batch is divided into groups of size `group_size` and shuffling is done within
each group separatly. This allow to keep the same probability distribution no matter
the number of GPUs. Without this grouping, using more GPUs would lead to a higher
probability of keeping two sources from the same track together which can impact
performance.
"""
super().__init__()
self.proba = proba
self.group_size = group_size
def forward(self, wav):
batch, streams, channels, time = wav.size()
device = wav.device
if self.training and random.random() < self.proba:
group_size = self.group_size or batch
if batch % group_size != 0:
raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
groups = batch // group_size
wav = wav.view(groups, group_size, streams, channels, time)
permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
dim=1)
wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
wav = wav.view(batch, streams, channels, time)
return wav
class Scale(nn.Module):
def __init__(self, proba=1., min=0.25, max=1.25):
super().__init__()
self.proba = proba
self.min = min
self.max = max
def forward(self, wav):
batch, streams, channels, time = wav.size()
device = wav.device
if self.training and random.random() < self.proba:
scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
wav *= scales
return wav