audio / demucs /api.py
PreciousMposa's picture
Upload 107 files
519d358 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""API methods for demucs
Classes
-------
`demucs.api.Separator`: The base separator class
Functions
---------
`demucs.api.save_audio`: Save an audio
`demucs.api.list_models`: Get models list
Examples
--------
See the end of this module (if __name__ == "__main__")
"""
import subprocess
import torch as th
import torchaudio as ta
from dora.log import fatal
from pathlib import Path
from typing import Optional, Callable, Dict, Tuple, Union
from .apply import apply_model, _replace_dict
from .audio import AudioFile, convert_audio, save_audio
from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
class LoadAudioError(Exception):
pass
class LoadModelError(Exception):
pass
class _NotProvided:
pass
NotProvided = _NotProvided()
class Separator:
def __init__(
self,
model: str = "htdemucs",
repo: Optional[Path] = None,
device: str = "cuda" if th.cuda.is_available() else "cpu",
shifts: int = 1,
overlap: float = 0.25,
split: bool = True,
segment: Optional[int] = None,
jobs: int = 0,
progress: bool = False,
callback: Optional[Callable[[dict], None]] = None,
callback_arg: Optional[dict] = None,
):
"""
`class Separator`
=================
Parameters
----------
model: Pretrained model name or signature. Default is htdemucs.
repo: Folder containing all pre-trained models for use.
segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
not specified, will use the command line option.
shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
apply the oppositve shift to the output. This is repeated `shifts` time and all \
predictions are averaged. This effectively makes the model time equivariant and \
improves SDR by up to 0.2 points. If not specified, will use the command line option.
split: If True, the input will be broken down into small chunks (length set by `segment`) \
and predictions will be performed individually on each and concatenated. Useful for \
model with large memory footprint like Tasnet. If not specified, will use the command \
line option.
overlap: The overlap between the splits. If not specified, will use the command line \
option.
device (torch.device, str, or None): If provided, device on which to execute the \
computation, otherwise `wav.device` is assumed. When `device` is different from \
`wav.device`, only local computations will be on `device`, while the entire tracks \
will be stored on `wav.device`. If not specified, will use the command line option.
jobs: Number of jobs. This can increase memory usage but will be much faster when \
multiple cores are available. If not specified, will use the command line option.
callback: A function will be called when the separation of a chunk starts or finished. \
The argument passed to the function will be a dict. For more information, please see \
the Callback section.
callback_arg: A dict containing private parameters to be passed to callback function. For \
more information, please see the Callback section.
progress: If true, show a progress bar.
Callback
--------
The function will be called with only one positional parameter whose type is `dict`. The
`callback_arg` will be combined with information of current separation progress. The
progress information will override the values in `callback_arg` if same key has been used.
To abort the separation, raise `KeyboardInterrupt`.
Progress information contains several keys (These keys will always exist):
- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
- `shift_idx`: The index of shifts. Starts from 0.
- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
- `state`: Could be `"start"` or `"end"`.
- `audio_length`: Length of the audio (in "frame" of the tensor).
- `models`: Count of submodels in the model.
"""
self._name = model
self._repo = repo
self._load_model()
self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
segment=segment, jobs=jobs, progress=progress, callback=callback,
callback_arg=callback_arg)
def update_parameter(
self,
device: Union[str, _NotProvided] = NotProvided,
shifts: Union[int, _NotProvided] = NotProvided,
overlap: Union[float, _NotProvided] = NotProvided,
split: Union[bool, _NotProvided] = NotProvided,
segment: Optional[Union[int, _NotProvided]] = NotProvided,
jobs: Union[int, _NotProvided] = NotProvided,
progress: Union[bool, _NotProvided] = NotProvided,
callback: Optional[
Union[Callable[[dict], None], _NotProvided]
] = NotProvided,
callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
):
"""
Update the parameters of separation.
Parameters
----------
segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
not specified, will use the command line option.
shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
apply the oppositve shift to the output. This is repeated `shifts` time and all \
predictions are averaged. This effectively makes the model time equivariant and \
improves SDR by up to 0.2 points. If not specified, will use the command line option.
split: If True, the input will be broken down into small chunks (length set by `segment`) \
and predictions will be performed individually on each and concatenated. Useful for \
model with large memory footprint like Tasnet. If not specified, will use the command \
line option.
overlap: The overlap between the splits. If not specified, will use the command line \
option.
device (torch.device, str, or None): If provided, device on which to execute the \
computation, otherwise `wav.device` is assumed. When `device` is different from \
`wav.device`, only local computations will be on `device`, while the entire tracks \
will be stored on `wav.device`. If not specified, will use the command line option.
jobs: Number of jobs. This can increase memory usage but will be much faster when \
multiple cores are available. If not specified, will use the command line option.
callback: A function will be called when the separation of a chunk starts or finished. \
The argument passed to the function will be a dict. For more information, please see \
the Callback section.
callback_arg: A dict containing private parameters to be passed to callback function. For \
more information, please see the Callback section.
progress: If true, show a progress bar.
Callback
--------
The function will be called with only one positional parameter whose type is `dict`. The
`callback_arg` will be combined with information of current separation progress. The
progress information will override the values in `callback_arg` if same key has been used.
To abort the separation, raise `KeyboardInterrupt`.
Progress information contains several keys (These keys will always exist):
- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
- `shift_idx`: The index of shifts. Starts from 0.
- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
- `state`: Could be `"start"` or `"end"`.
- `audio_length`: Length of the audio (in "frame" of the tensor).
- `models`: Count of submodels in the model.
"""
if not isinstance(device, _NotProvided):
self._device = device
if not isinstance(shifts, _NotProvided):
self._shifts = shifts
if not isinstance(overlap, _NotProvided):
self._overlap = overlap
if not isinstance(split, _NotProvided):
self._split = split
if not isinstance(segment, _NotProvided):
self._segment = segment
if not isinstance(jobs, _NotProvided):
self._jobs = jobs
if not isinstance(progress, _NotProvided):
self._progress = progress
if not isinstance(callback, _NotProvided):
self._callback = callback
if not isinstance(callback_arg, _NotProvided):
self._callback_arg = callback_arg
def _load_model(self):
self._model = get_model(name=self._name, repo=self._repo)
if self._model is None:
raise LoadModelError("Failed to load model")
self._audio_channels = self._model.audio_channels
self._samplerate = self._model.samplerate
def _load_audio(self, track: Path):
errors = {}
wav = None
try:
wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
channels=self._audio_channels)
except FileNotFoundError:
errors["ffmpeg"] = "FFmpeg is not installed."
except subprocess.CalledProcessError:
errors["ffmpeg"] = "FFmpeg could not read the file."
if wav is None:
try:
wav, sr = ta.load(str(track))
except RuntimeError as err:
errors["torchaudio"] = err.args[0]
else:
wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
if wav is None:
raise LoadAudioError(
"\n".join(
"When trying to load using {}, got the following error: {}".format(
backend, error
)
for backend, error in errors.items()
)
)
return wav
def separate_tensor(
self, wav: th.Tensor, sr: Optional[int] = None
) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
"""
Separate a loaded tensor.
Parameters
----------
wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
while the second is the waveform of each channel. Type should be float32. \
e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
model.
Returns
-------
A tuple, whose first element is the original wave and second element is a dict, whose keys
are the name of stems and values are separated waves. The original wave will have already
been resampled.
Notes
-----
Use this function with cautiousness. This function does not provide data verifying.
"""
if sr is not None and sr != self.samplerate:
wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
ref = wav.mean(0)
wav -= ref.mean()
wav /= ref.std() + 1e-8
out = apply_model(
self._model,
wav[None],
segment=self._segment,
shifts=self._shifts,
split=self._split,
overlap=self._overlap,
device=self._device,
num_workers=self._jobs,
callback=self._callback,
callback_arg=_replace_dict(
self._callback_arg, ("audio_length", wav.shape[1])
),
progress=self._progress,
)
if out is None:
raise KeyboardInterrupt
out *= ref.std() + 1e-8
out += ref.mean()
wav *= ref.std() + 1e-8
wav += ref.mean()
return (wav, dict(zip(self._model.sources, out[0])))
def separate_audio_file(self, file: Path):
"""
Separate an audio file. The method will automatically read the file.
Parameters
----------
wav: Path of the file to be separated.
Returns
-------
A tuple, whose first element is the original wave and second element is a dict, whose keys
are the name of stems and values are separated waves. The original wave will have already
been resampled.
"""
return self.separate_tensor(self._load_audio(file), self.samplerate)
@property
def samplerate(self):
return self._samplerate
@property
def audio_channels(self):
return self._audio_channels
@property
def model(self):
return self._model
def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
"""
List the available models. Please remember that not all the returned models can be
successfully loaded.
Parameters
----------
repo: The repo whose models are to be listed.
Returns
-------
A dict with two keys ("single" for single models and "bag" for bag of models). The values are
lists whose components are strs.
"""
model_repo: ModelOnlyRepo
if repo is None:
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
model_repo = RemoteRepo(models)
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
else:
if not repo.is_dir():
fatal(f"{repo} must exist and be a directory.")
model_repo = LocalRepo(repo)
bag_repo = BagOnlyRepo(repo, model_repo)
return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
if __name__ == "__main__":
# Test API functions
# two-stem not supported
from .separate import get_parser
args = get_parser().parse_args()
separator = Separator(
model=args.name,
repo=args.repo,
device=args.device,
shifts=args.shifts,
overlap=args.overlap,
split=args.split,
segment=args.segment,
jobs=args.jobs,
callback=print
)
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
for file in args.tracks:
separated = separator.separate_audio_file(file)[1]
if args.mp3:
ext = "mp3"
elif args.flac:
ext = "flac"
else:
ext = "wav"
kwargs = {
"samplerate": separator.samplerate,
"bitrate": args.mp3_bitrate,
"clip": args.clip_mode,
"as_float": args.float32,
"bits_per_sample": 24 if args.int24 else 16,
}
for stem, source in separated.items():
stem = out / args.filename.format(
track=Path(file).name.rsplit(".", 1)[0],
trackext=Path(file).name.rsplit(".", 1)[-1],
stem=stem,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(source, str(stem), **kwargs)