File size: 16,352 Bytes
519d358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""API methods for demucs



Classes

-------

`demucs.api.Separator`: The base separator class



Functions

---------

`demucs.api.save_audio`: Save an audio

`demucs.api.list_models`: Get models list



Examples

--------

See the end of this module (if __name__ == "__main__")

"""

import subprocess

import torch as th
import torchaudio as ta

from dora.log import fatal
from pathlib import Path
from typing import Optional, Callable, Dict, Tuple, Union

from .apply import apply_model, _replace_dict
from .audio import AudioFile, convert_audio, save_audio
from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo


class LoadAudioError(Exception):
    pass


class LoadModelError(Exception):
    pass


class _NotProvided:
    pass


NotProvided = _NotProvided()


class Separator:
    def __init__(

        self,

        model: str = "htdemucs",

        repo: Optional[Path] = None,

        device: str = "cuda" if th.cuda.is_available() else "cpu",

        shifts: int = 1,

        overlap: float = 0.25,

        split: bool = True,

        segment: Optional[int] = None,

        jobs: int = 0,

        progress: bool = False,

        callback: Optional[Callable[[dict], None]] = None,

        callback_arg: Optional[dict] = None,

    ):
        """

        `class Separator`

        =================



        Parameters

        ----------

        model: Pretrained model name or signature. Default is htdemucs.

        repo: Folder containing all pre-trained models for use.

        segment: Length (in seconds) of each segment (only available if `split` is `True`). If \

            not specified, will use the command line option.

        shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \

            apply the oppositve shift to the output. This is repeated `shifts` time and all \

            predictions are averaged. This effectively makes the model time equivariant and \

            improves SDR by up to 0.2 points. If not specified, will use the command line option.

        split: If True, the input will be broken down into small chunks (length set by `segment`) \

            and predictions will be performed individually on each and concatenated. Useful for \

            model with large memory footprint like Tasnet. If not specified, will use the command \

            line option.

        overlap: The overlap between the splits. If not specified, will use the command line \

            option.

        device (torch.device, str, or None): If provided, device on which to execute the \

            computation, otherwise `wav.device` is assumed. When `device` is different from \

            `wav.device`, only local computations will be on `device`, while the entire tracks \

            will be stored on `wav.device`. If not specified, will use the command line option.

        jobs: Number of jobs. This can increase memory usage but will be much faster when \

            multiple cores are available. If not specified, will use the command line option.

        callback: A function will be called when the separation of a chunk starts or finished. \

            The argument passed to the function will be a dict. For more information, please see \

            the Callback section.

        callback_arg: A dict containing private parameters to be passed to callback function. For \

            more information, please see the Callback section.

        progress: If true, show a progress bar.



        Callback

        --------

        The function will be called with only one positional parameter whose type is `dict`. The

        `callback_arg` will be combined with information of current separation progress. The

        progress information will override the values in `callback_arg` if same key has been used.

        To abort the separation, raise `KeyboardInterrupt`.



        Progress information contains several keys (These keys will always exist):

        - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.

        - `shift_idx`: The index of shifts. Starts from 0.

        - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't

            mean that it is at the 441000 second of the audio, but the "frame" of the tensor.

        - `state`: Could be `"start"` or `"end"`.

        - `audio_length`: Length of the audio (in "frame" of the tensor).

        - `models`: Count of submodels in the model.

        """
        self._name = model
        self._repo = repo
        self._load_model()
        self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
                              segment=segment, jobs=jobs, progress=progress, callback=callback,
                              callback_arg=callback_arg)

    def update_parameter(

        self,

        device: Union[str, _NotProvided] = NotProvided,

        shifts: Union[int, _NotProvided] = NotProvided,

        overlap: Union[float, _NotProvided] = NotProvided,

        split: Union[bool, _NotProvided] = NotProvided,

        segment: Optional[Union[int, _NotProvided]] = NotProvided,

        jobs: Union[int, _NotProvided] = NotProvided,

        progress: Union[bool, _NotProvided] = NotProvided,

        callback: Optional[

            Union[Callable[[dict], None], _NotProvided]

        ] = NotProvided,

        callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,

    ):
        """

        Update the parameters of separation.



        Parameters

        ----------

        segment: Length (in seconds) of each segment (only available if `split` is `True`). If \

            not specified, will use the command line option.

        shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \

            apply the oppositve shift to the output. This is repeated `shifts` time and all \

            predictions are averaged. This effectively makes the model time equivariant and \

            improves SDR by up to 0.2 points. If not specified, will use the command line option.

        split: If True, the input will be broken down into small chunks (length set by `segment`) \

            and predictions will be performed individually on each and concatenated. Useful for \

            model with large memory footprint like Tasnet. If not specified, will use the command \

            line option.

        overlap: The overlap between the splits. If not specified, will use the command line \

            option.

        device (torch.device, str, or None): If provided, device on which to execute the \

            computation, otherwise `wav.device` is assumed. When `device` is different from \

            `wav.device`, only local computations will be on `device`, while the entire tracks \

            will be stored on `wav.device`. If not specified, will use the command line option.

        jobs: Number of jobs. This can increase memory usage but will be much faster when \

            multiple cores are available. If not specified, will use the command line option.

        callback: A function will be called when the separation of a chunk starts or finished. \

            The argument passed to the function will be a dict. For more information, please see \

            the Callback section.

        callback_arg: A dict containing private parameters to be passed to callback function. For \

            more information, please see the Callback section.

        progress: If true, show a progress bar.



        Callback

        --------

        The function will be called with only one positional parameter whose type is `dict`. The

        `callback_arg` will be combined with information of current separation progress. The

        progress information will override the values in `callback_arg` if same key has been used.

        To abort the separation, raise `KeyboardInterrupt`.



        Progress information contains several keys (These keys will always exist):

        - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.

        - `shift_idx`: The index of shifts. Starts from 0.

        - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't

            mean that it is at the 441000 second of the audio, but the "frame" of the tensor.

        - `state`: Could be `"start"` or `"end"`.

        - `audio_length`: Length of the audio (in "frame" of the tensor).

        - `models`: Count of submodels in the model.

        """
        if not isinstance(device, _NotProvided):
            self._device = device
        if not isinstance(shifts, _NotProvided):
            self._shifts = shifts
        if not isinstance(overlap, _NotProvided):
            self._overlap = overlap
        if not isinstance(split, _NotProvided):
            self._split = split
        if not isinstance(segment, _NotProvided):
            self._segment = segment
        if not isinstance(jobs, _NotProvided):
            self._jobs = jobs
        if not isinstance(progress, _NotProvided):
            self._progress = progress
        if not isinstance(callback, _NotProvided):
            self._callback = callback
        if not isinstance(callback_arg, _NotProvided):
            self._callback_arg = callback_arg

    def _load_model(self):
        self._model = get_model(name=self._name, repo=self._repo)
        if self._model is None:
            raise LoadModelError("Failed to load model")
        self._audio_channels = self._model.audio_channels
        self._samplerate = self._model.samplerate

    def _load_audio(self, track: Path):
        errors = {}
        wav = None

        try:
            wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
                                        channels=self._audio_channels)
        except FileNotFoundError:
            errors["ffmpeg"] = "FFmpeg is not installed."
        except subprocess.CalledProcessError:
            errors["ffmpeg"] = "FFmpeg could not read the file."

        if wav is None:
            try:
                wav, sr = ta.load(str(track))
            except RuntimeError as err:
                errors["torchaudio"] = err.args[0]
            else:
                wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)

        if wav is None:
            raise LoadAudioError(
                "\n".join(
                    "When trying to load using {}, got the following error: {}".format(
                        backend, error
                    )
                    for backend, error in errors.items()
                )
            )
        return wav

    def separate_tensor(

        self, wav: th.Tensor, sr: Optional[int] = None

    ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
        """

        Separate a loaded tensor.



        Parameters

        ----------

        wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \

            while the second is the waveform of each channel. Type should be float32. \

            e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.

        sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \

            model.



        Returns

        -------

        A tuple, whose first element is the original wave and second element is a dict, whose keys

        are the name of stems and values are separated waves. The original wave will have already

        been resampled.



        Notes

        -----

        Use this function with cautiousness. This function does not provide data verifying.

        """
        if sr is not None and sr != self.samplerate:
            wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
        ref = wav.mean(0)
        wav -= ref.mean()
        wav /= ref.std() + 1e-8
        out = apply_model(
                self._model,
                wav[None],
                segment=self._segment,
                shifts=self._shifts,
                split=self._split,
                overlap=self._overlap,
                device=self._device,
                num_workers=self._jobs,
                callback=self._callback,
                callback_arg=_replace_dict(
                    self._callback_arg, ("audio_length", wav.shape[1])
                ),
                progress=self._progress,
            )
        if out is None:
            raise KeyboardInterrupt
        out *= ref.std() + 1e-8
        out += ref.mean()
        wav *= ref.std() + 1e-8
        wav += ref.mean()
        return (wav, dict(zip(self._model.sources, out[0])))

    def separate_audio_file(self, file: Path):
        """

        Separate an audio file. The method will automatically read the file.



        Parameters

        ----------

        wav: Path of the file to be separated.



        Returns

        -------

        A tuple, whose first element is the original wave and second element is a dict, whose keys

        are the name of stems and values are separated waves. The original wave will have already

        been resampled.

        """
        return self.separate_tensor(self._load_audio(file), self.samplerate)

    @property
    def samplerate(self):
        return self._samplerate

    @property
    def audio_channels(self):
        return self._audio_channels

    @property
    def model(self):
        return self._model


def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
    """

    List the available models. Please remember that not all the returned models can be

    successfully loaded.



    Parameters

    ----------

    repo: The repo whose models are to be listed.



    Returns

    -------

    A dict with two keys ("single" for single models and "bag" for bag of models). The values are

    lists whose components are strs.

    """
    model_repo: ModelOnlyRepo
    if repo is None:
        models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
        model_repo = RemoteRepo(models)
        bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
    else:
        if not repo.is_dir():
            fatal(f"{repo} must exist and be a directory.")
        model_repo = LocalRepo(repo)
        bag_repo = BagOnlyRepo(repo, model_repo)
    return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}


if __name__ == "__main__":
    # Test API functions
    # two-stem not supported

    from .separate import get_parser

    args = get_parser().parse_args()
    separator = Separator(
        model=args.name,
        repo=args.repo,
        device=args.device,
        shifts=args.shifts,
        overlap=args.overlap,
        split=args.split,
        segment=args.segment,
        jobs=args.jobs,
        callback=print
    )
    out = args.out / args.name
    out.mkdir(parents=True, exist_ok=True)
    for file in args.tracks:
        separated = separator.separate_audio_file(file)[1]
        if args.mp3:
            ext = "mp3"
        elif args.flac:
            ext = "flac"
        else:
            ext = "wav"
        kwargs = {
            "samplerate": separator.samplerate,
            "bitrate": args.mp3_bitrate,
            "clip": args.clip_mode,
            "as_float": args.float32,
            "bits_per_sample": 24 if args.int24 else 16,
        }
        for stem, source in separated.items():
            stem = out / args.filename.format(
                track=Path(file).name.rsplit(".", 1)[0],
                trackext=Path(file).name.rsplit(".", 1)[-1],
                stem=stem,
                ext=ext,
            )
            stem.parent.mkdir(parents=True, exist_ok=True)
            save_audio(source, str(stem), **kwargs)