|
|
|
|
|
""" |
|
@File : apply.py |
|
@Time : 2023/8/8 下午4:22 |
|
@Author : waytan |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2023, Tencent |
|
@Desc : Apply |
|
""" |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
import torch |
|
import os |
|
import random |
|
import typing as tp |
|
|
|
import torch as th |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import tqdm |
|
|
|
from .htdemucs import HTDemucs |
|
from .audio import load_track, save_audio |
|
from .utils import center_trim, DummyPoolExecutor |
|
|
|
Model = tp.Union[HTDemucs] |
|
|
|
|
|
class BagOfModels(nn.Module): |
|
def __init__(self, models: tp.List[Model], |
|
weights: tp.Optional[tp.List[tp.List[float]]] = None, |
|
segment: tp.Optional[float] = None): |
|
""" |
|
Represents a bag of models with specific weights. |
|
You should call `apply_model` rather than calling directly the forward here for |
|
optimal performance. |
|
|
|
Args: |
|
models (list[nn.Module]): list of Demucs/HDemucs models. |
|
weights (list[list[float]]): list of weights. If None, assumed to |
|
be all ones, otherwise it should be a list of N list (N number of models), |
|
each containing S floats (S number of sources). |
|
segment (None or float): overrides the `segment` attribute of each model |
|
(this is performed inplace, be careful is you reuse the models passed). |
|
""" |
|
super().__init__() |
|
assert len(models) > 0 |
|
first = models[0] |
|
for other in models: |
|
assert other.sources == first.sources |
|
assert other.samplerate == first.samplerate |
|
assert other.audio_channels == first.audio_channels |
|
if segment is not None: |
|
other.segment = segment |
|
|
|
self.audio_channels = first.audio_channels |
|
self.samplerate = first.samplerate |
|
self.sources = first.sources |
|
self.models = nn.ModuleList(models) |
|
|
|
if weights is None: |
|
weights = [[1. for _ in first.sources] for _ in models] |
|
else: |
|
assert len(weights) == len(models) |
|
for weight in weights: |
|
assert len(weight) == len(first.sources) |
|
self.weights = weights |
|
|
|
@property |
|
def max_allowed_segment(self) -> float: |
|
max_allowed_segment = float('inf') |
|
for model in self.models: |
|
if isinstance(model, HTDemucs): |
|
max_allowed_segment = min(max_allowed_segment, float(model.segment)) |
|
return max_allowed_segment |
|
|
|
def forward(self, x): |
|
raise NotImplementedError("Call `apply_model` on this.") |
|
|
|
def separate(self, source_file, output_dir, stem=None, device=None): |
|
wav, _ = load_track(source_file, self.audio_channels, self.samplerate) |
|
ref = wav.mean(0) |
|
wav -= ref.mean() |
|
wav /= ref.std() |
|
sources = apply_model(self, wav[None], device=device, shifts=1, split=True, overlap=0.25, |
|
progress=True, num_workers=0, segment=None)[0] |
|
sources *= ref.std() |
|
sources += ref.mean() |
|
|
|
output_paths = [] |
|
name, ext = os.path.splitext(os.path.split(source_file)[-1]) |
|
if ext != ".flac": |
|
ext = ".flac" |
|
kwargs = { |
|
'samplerate': self.samplerate, |
|
'bitrate': 320, |
|
'clip': "rescale", |
|
'as_float': False, |
|
'bits_per_sample': 16, |
|
} |
|
if stem is None: |
|
for source, stem in zip(sources, self.sources): |
|
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}") |
|
save_audio(source, output_stem_path, **kwargs) |
|
output_paths.append(output_stem_path) |
|
else: |
|
sources = list(sources) |
|
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}") |
|
save_audio(sources.pop(self.sources.index(stem)), output_stem_path, **kwargs) |
|
other_stem = torch.zeros_like(sources[0]) |
|
for i in sources: |
|
other_stem += i |
|
output_no_stem_path = os.path.join(output_dir, f"{name}_no_{stem}{ext}") |
|
save_audio(other_stem, output_no_stem_path, **kwargs) |
|
output_paths = [output_stem_path, output_no_stem_path] |
|
|
|
return output_paths |
|
|
|
|
|
class TensorChunk: |
|
def __init__(self, tensor, offset=0, length=None): |
|
total_length = tensor.shape[-1] |
|
assert offset >= 0 |
|
assert offset < total_length |
|
|
|
if length is None: |
|
length = total_length - offset |
|
else: |
|
length = min(total_length - offset, length) |
|
|
|
if isinstance(tensor, TensorChunk): |
|
self.tensor = tensor.tensor |
|
self.offset = offset + tensor.offset |
|
else: |
|
self.tensor = tensor |
|
self.offset = offset |
|
self.length = length |
|
self.device = tensor.device |
|
|
|
@property |
|
def shape(self): |
|
shape = list(self.tensor.shape) |
|
shape[-1] = self.length |
|
return shape |
|
|
|
def padded(self, target_length): |
|
delta = target_length - self.length |
|
total_length = self.tensor.shape[-1] |
|
assert delta >= 0 |
|
|
|
start = self.offset - delta // 2 |
|
end = start + target_length |
|
|
|
correct_start = max(0, start) |
|
correct_end = min(total_length, end) |
|
|
|
pad_left = correct_start - start |
|
pad_right = end - correct_end |
|
|
|
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
|
assert out.shape[-1] == target_length |
|
return out |
|
|
|
|
|
def tensor_chunk(tensor_or_chunk): |
|
if isinstance(tensor_or_chunk, TensorChunk): |
|
return tensor_or_chunk |
|
else: |
|
assert isinstance(tensor_or_chunk, th.Tensor) |
|
return TensorChunk(tensor_or_chunk) |
|
|
|
|
|
def apply_model(model: tp.Union[BagOfModels, Model], |
|
mix: tp.Union[th.Tensor, TensorChunk], |
|
shifts: int = 1, split: bool = True, |
|
overlap: float = 0.25, transition_power: float = 1., |
|
progress: bool = False, device=None, |
|
num_workers: int = 0, segment: tp.Optional[float] = None, |
|
pool=None) -> th.Tensor: |
|
""" |
|
Apply model to a given mixture. |
|
|
|
Args: |
|
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
|
and apply the oppositve shift to the output. This is repeated `shifts` time and |
|
all predictions are averaged. This effectively makes the model time equivariant |
|
and improves SDR by up to 0.2 points. |
|
split (bool): if True, the input will be broken down in 8 seconds extracts |
|
and predictions will be performed individually on each and concatenated. |
|
Useful for model with large memory footprint like Tasnet. |
|
progress (bool): if True, show a progress bar (requires split=True) |
|
device (torch.device, str, or None): if provided, device on which to |
|
execute the computation, otherwise `mix.device` is assumed. |
|
When `device` is different from `mix.device`, only local computations will |
|
be on `device`, while the entire tracks will be stored on `mix.device`. |
|
num_workers (int): if non zero, device is 'cpu', how many threads to |
|
use in parallel. |
|
segment (float or None): override the model segment parameter. |
|
""" |
|
if device is None: |
|
device = mix.device |
|
else: |
|
device = th.device(device) |
|
if pool is None: |
|
if num_workers > 0 and device.type == 'cpu': |
|
pool = ThreadPoolExecutor(num_workers) |
|
else: |
|
pool = DummyPoolExecutor() |
|
kwargs: tp.Dict[str, tp.Any] = { |
|
'shifts': shifts, |
|
'split': split, |
|
'overlap': overlap, |
|
'transition_power': transition_power, |
|
'progress': progress, |
|
'device': device, |
|
'pool': pool, |
|
'segment': segment, |
|
} |
|
out: tp.Union[float, th.Tensor] |
|
if isinstance(model, BagOfModels): |
|
|
|
|
|
|
|
estimates: tp.Union[float, th.Tensor] = 0. |
|
totals = [0.] * len(model.sources) |
|
for sub_model, model_weights in zip(model.models, model.weights): |
|
original_model_device = next(iter(sub_model.parameters())).device |
|
sub_model.to(device) |
|
|
|
out = apply_model(sub_model, mix, **kwargs) |
|
sub_model.to(original_model_device) |
|
for k, inst_weight in enumerate(model_weights): |
|
out[:, k, :, :] *= inst_weight |
|
totals[k] += inst_weight |
|
estimates += out |
|
del out |
|
|
|
assert isinstance(estimates, th.Tensor) |
|
for k in range(estimates.shape[1]): |
|
estimates[:, k, :, :] /= totals[k] |
|
return estimates |
|
|
|
model.to(device) |
|
model.eval() |
|
assert transition_power >= 1, "transition_power < 1 leads to weird behavior." |
|
batch, channels, length = mix.shape |
|
if shifts: |
|
kwargs['shifts'] = 0 |
|
max_shift = int(0.5 * model.samplerate) |
|
mix = tensor_chunk(mix) |
|
assert isinstance(mix, TensorChunk) |
|
padded_mix = mix.padded(length + 2 * max_shift) |
|
out = 0. |
|
for _ in range(shifts): |
|
offset = random.randint(0, max_shift) |
|
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
|
shifted_out = apply_model(model, shifted, **kwargs) |
|
out += shifted_out[..., max_shift - offset:] |
|
out /= shifts |
|
assert isinstance(out, th.Tensor) |
|
return out |
|
elif split: |
|
kwargs['split'] = False |
|
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) |
|
sum_weight = th.zeros(length, device=mix.device) |
|
if segment is None: |
|
segment = model.segment |
|
assert segment is not None and segment > 0. |
|
segment_length: int = int(model.samplerate * segment) |
|
stride = int((1 - overlap) * segment_length) |
|
offsets = range(0, length, stride) |
|
scale = float(format(stride / model.samplerate, ".2f")) |
|
|
|
|
|
|
|
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device), |
|
th.arange(segment_length - segment_length // 2, 0, -1, device=device)]) |
|
assert len(weight) == segment_length |
|
|
|
|
|
weight = (weight / weight.max())**transition_power |
|
futures = [] |
|
for offset in offsets: |
|
chunk = TensorChunk(mix, offset, segment_length) |
|
future = pool.submit(apply_model, model, chunk, **kwargs) |
|
futures.append((future, offset)) |
|
offset += segment_length |
|
if progress: |
|
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') |
|
for future, offset in futures: |
|
chunk_out = future.result() |
|
chunk_length = chunk_out.shape[-1] |
|
out[..., offset:offset + segment_length] += ( |
|
weight[:chunk_length] * chunk_out).to(mix.device) |
|
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device) |
|
assert sum_weight.min() > 0 |
|
out /= sum_weight |
|
assert isinstance(out, th.Tensor) |
|
return out |
|
else: |
|
valid_length: int |
|
if isinstance(model, HTDemucs) and segment is not None: |
|
valid_length = int(segment * model.samplerate) |
|
elif hasattr(model, 'valid_length'): |
|
valid_length = model.valid_length(length) |
|
else: |
|
valid_length = length |
|
mix = tensor_chunk(mix) |
|
assert isinstance(mix, TensorChunk) |
|
padded_mix = mix.padded(valid_length).to(device) |
|
with th.no_grad(): |
|
out = model(padded_mix) |
|
assert isinstance(out, th.Tensor) |
|
return center_trim(out, length) |
|
|