waytan22's picture
Upload folder using huggingface_hub
e730386 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : apply.py
@Time : 2023/8/8 下午4:22
@Author : waytan
@Contact : [email protected]
@License : (C)Copyright 2023, Tencent
@Desc : Apply
"""
from concurrent.futures import ThreadPoolExecutor
import torch
import os
import random
import typing as tp
import torch as th
from torch import nn
from torch.nn import functional as F
import tqdm
from .htdemucs import HTDemucs
from .audio import load_track, save_audio
from .utils import center_trim, DummyPoolExecutor
Model = tp.Union[HTDemucs]
class BagOfModels(nn.Module):
def __init__(self, models: tp.List[Model],
weights: tp.Optional[tp.List[tp.List[float]]] = None,
segment: tp.Optional[float] = None):
"""
Represents a bag of models with specific weights.
You should call `apply_model` rather than calling directly the forward here for
optimal performance.
Args:
models (list[nn.Module]): list of Demucs/HDemucs models.
weights (list[list[float]]): list of weights. If None, assumed to
be all ones, otherwise it should be a list of N list (N number of models),
each containing S floats (S number of sources).
segment (None or float): overrides the `segment` attribute of each model
(this is performed inplace, be careful is you reuse the models passed).
"""
super().__init__()
assert len(models) > 0
first = models[0]
for other in models:
assert other.sources == first.sources
assert other.samplerate == first.samplerate
assert other.audio_channels == first.audio_channels
if segment is not None:
other.segment = segment
self.audio_channels = first.audio_channels
self.samplerate = first.samplerate
self.sources = first.sources
self.models = nn.ModuleList(models)
if weights is None:
weights = [[1. for _ in first.sources] for _ in models]
else:
assert len(weights) == len(models)
for weight in weights:
assert len(weight) == len(first.sources)
self.weights = weights
@property
def max_allowed_segment(self) -> float:
max_allowed_segment = float('inf')
for model in self.models:
if isinstance(model, HTDemucs):
max_allowed_segment = min(max_allowed_segment, float(model.segment))
return max_allowed_segment
def forward(self, x):
raise NotImplementedError("Call `apply_model` on this.")
def separate(self, source_file, output_dir, stem=None, device=None):
wav, _ = load_track(source_file, self.audio_channels, self.samplerate)
ref = wav.mean(0)
wav -= ref.mean()
wav /= ref.std()
sources = apply_model(self, wav[None], device=device, shifts=1, split=True, overlap=0.25,
progress=True, num_workers=0, segment=None)[0]
sources *= ref.std()
sources += ref.mean()
output_paths = []
name, ext = os.path.splitext(os.path.split(source_file)[-1])
if ext != ".flac":
ext = ".flac"
kwargs = {
'samplerate': self.samplerate,
'bitrate': 320,
'clip': "rescale",
'as_float': False,
'bits_per_sample': 16,
}
if stem is None:
for source, stem in zip(sources, self.sources):
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
save_audio(source, output_stem_path, **kwargs)
output_paths.append(output_stem_path)
else:
sources = list(sources)
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
save_audio(sources.pop(self.sources.index(stem)), output_stem_path, **kwargs)
other_stem = torch.zeros_like(sources[0])
for i in sources:
other_stem += i
output_no_stem_path = os.path.join(output_dir, f"{name}_no_{stem}{ext}")
save_audio(other_stem, output_no_stem_path, **kwargs)
output_paths = [output_stem_path, output_no_stem_path]
return output_paths
class TensorChunk:
def __init__(self, tensor, offset=0, length=None):
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
if length is None:
length = total_length - offset
else:
length = min(total_length - offset, length)
if isinstance(tensor, TensorChunk):
self.tensor = tensor.tensor
self.offset = offset + tensor.offset
else:
self.tensor = tensor
self.offset = offset
self.length = length
self.device = tensor.device
@property
def shape(self):
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape
def padded(self, target_length):
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
start = self.offset - delta // 2
end = start + target_length
correct_start = max(0, start)
correct_end = min(total_length, end)
pad_left = correct_start - start
pad_right = end - correct_end
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
else:
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model: tp.Union[BagOfModels, Model],
mix: tp.Union[th.Tensor, TensorChunk],
shifts: int = 1, split: bool = True,
overlap: float = 0.25, transition_power: float = 1.,
progress: bool = False, device=None,
num_workers: int = 0, segment: tp.Optional[float] = None,
pool=None) -> th.Tensor:
"""
Apply model to a given mixture.
Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
num_workers (int): if non zero, device is 'cpu', how many threads to
use in parallel.
segment (float or None): override the model segment parameter.
"""
if device is None:
device = mix.device
else:
device = th.device(device)
if pool is None:
if num_workers > 0 and device.type == 'cpu':
pool = ThreadPoolExecutor(num_workers)
else:
pool = DummyPoolExecutor()
kwargs: tp.Dict[str, tp.Any] = {
'shifts': shifts,
'split': split,
'overlap': overlap,
'transition_power': transition_power,
'progress': progress,
'device': device,
'pool': pool,
'segment': segment,
}
out: tp.Union[float, th.Tensor]
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
# are different for each model.
estimates: tp.Union[float, th.Tensor] = 0.
totals = [0.] * len(model.sources)
for sub_model, model_weights in zip(model.models, model.weights):
original_model_device = next(iter(sub_model.parameters())).device
sub_model.to(device)
out = apply_model(sub_model, mix, **kwargs)
sub_model.to(original_model_device)
for k, inst_weight in enumerate(model_weights):
out[:, k, :, :] *= inst_weight
totals[k] += inst_weight
estimates += out
del out
assert isinstance(estimates, th.Tensor)
for k in range(estimates.shape[1]):
estimates[:, k, :, :] /= totals[k]
return estimates
model.to(device)
model.eval()
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
batch, channels, length = mix.shape
if shifts:
kwargs['shifts'] = 0
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
assert isinstance(mix, TensorChunk)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0.
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted, **kwargs)
out += shifted_out[..., max_shift - offset:]
out /= shifts
assert isinstance(out, th.Tensor)
return out
elif split:
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
sum_weight = th.zeros(length, device=mix.device)
if segment is None:
segment = model.segment
assert segment is not None and segment > 0.
segment_length: int = int(model.samplerate * segment)
stride = int((1 - overlap) * segment_length)
offsets = range(0, length, stride)
scale = float(format(stride / model.samplerate, ".2f"))
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
assert len(weight) == segment_length
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
futures = []
for offset in offsets:
chunk = TensorChunk(mix, offset, segment_length)
future = pool.submit(apply_model, model, chunk, **kwargs)
futures.append((future, offset))
offset += segment_length
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
chunk_out = future.result()
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment_length] += (
weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
assert sum_weight.min() > 0
out /= sum_weight
assert isinstance(out, th.Tensor)
return out
else:
valid_length: int
if isinstance(model, HTDemucs) and segment is not None:
valid_length = int(segment * model.samplerate)
elif hasattr(model, 'valid_length'):
valid_length = model.valid_length(length) # type: ignore
else:
valid_length = length
mix = tensor_chunk(mix)
assert isinstance(mix, TensorChunk)
padded_mix = mix.padded(valid_length).to(device)
with th.no_grad():
out = model(padded_mix)
assert isinstance(out, th.Tensor)
return center_trim(out, length)