|
|
|
|
|
""" |
|
@File : audio.py |
|
@Time : 2023/8/8 下午7:18 |
|
@Author : waytan |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2023, Tencent |
|
@Desc : Audio |
|
""" |
|
import json |
|
import subprocess as sp |
|
import typing as tp |
|
from pathlib import Path |
|
|
|
import lameenc |
|
import julius |
|
import torch |
|
import numpy as np |
|
import torchaudio as ta |
|
|
|
from .utils import temp_filenames |
|
|
|
|
|
def _read_info(path): |
|
stdout_data = sp.check_output([ |
|
'ffprobe', "-loglevel", "panic", |
|
str(path), '-print_format', 'json', '-show_format', '-show_streams' |
|
]) |
|
return json.loads(stdout_data.decode('utf-8')) |
|
|
|
|
|
class AudioFile: |
|
""" |
|
Allows to read audio from any format supported by ffmpeg, as well as resampling or |
|
converting to mono on the fly. See :method:`read` for more details. |
|
""" |
|
def __init__(self, path: Path): |
|
self.path = Path(path) |
|
self._info = None |
|
|
|
def __repr__(self): |
|
features = [("path", self.path)] |
|
features.append(("samplerate", self.samplerate())) |
|
features.append(("channels", self.channels())) |
|
features.append(("streams", len(self))) |
|
features_str = ", ".join(f"{name}={value}" for name, value in features) |
|
return f"AudioFile({features_str})" |
|
|
|
@property |
|
def info(self): |
|
if self._info is None: |
|
self._info = _read_info(self.path) |
|
return self._info |
|
|
|
@property |
|
def duration(self): |
|
return float(self.info['format']['duration']) |
|
|
|
@property |
|
def _audio_streams(self): |
|
return [ |
|
index for index, stream in enumerate(self.info["streams"]) |
|
if stream["codec_type"] == "audio" |
|
] |
|
|
|
def __len__(self): |
|
return len(self._audio_streams) |
|
|
|
def channels(self, stream=0): |
|
return int(self.info['streams'][self._audio_streams[stream]]['channels']) |
|
|
|
def samplerate(self, stream=0): |
|
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) |
|
|
|
def read(self, |
|
seek_time=None, |
|
duration=None, |
|
streams=slice(None), |
|
samplerate=None, |
|
channels=None): |
|
""" |
|
Slightly more efficient implementation than stempeg, |
|
in particular, this will extract all stems at once |
|
rather than having to loop over one file multiple times |
|
for each stream. |
|
|
|
Args: |
|
seek_time (float): seek time in seconds or None if no seeking is needed. |
|
duration (float): duration in seconds to extract or None to extract until the end. |
|
streams (slice, int or list): streams to extract, can be a single int, a list or |
|
a slice. If it is a slice or list, the output will be of size [S, C, T] |
|
with S the number of streams, C the number of channels and T the number of samples. |
|
If it is an int, the output will be [C, T]. |
|
samplerate (int): if provided, will resample on the fly. If None, no resampling will |
|
be done. Original sampling rate can be obtained with :method:`samplerate`. |
|
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that |
|
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. |
|
See https://sound.stackexchange.com/a/42710. |
|
Our definition of mono is simply the average of the two channels. Any other |
|
value will be ignored. |
|
""" |
|
streams = np.array(range(len(self)))[streams] |
|
single = not isinstance(streams, np.ndarray) |
|
if single: |
|
streams = [streams] |
|
|
|
if duration is None: |
|
target_size = None |
|
query_duration = None |
|
else: |
|
target_size = int((samplerate or self.samplerate()) * duration) |
|
query_duration = float((target_size + 1) / (samplerate or self.samplerate())) |
|
|
|
with temp_filenames(len(streams)) as filenames: |
|
command = ['ffmpeg', '-y'] |
|
command += ['-loglevel', 'panic'] |
|
if seek_time: |
|
command += ['-ss', str(seek_time)] |
|
command += ['-i', str(self.path)] |
|
for stream, filename in zip(streams, filenames): |
|
command += ['-map', f'0:{self._audio_streams[stream]}'] |
|
if query_duration is not None: |
|
command += ['-t', str(query_duration)] |
|
command += ['-threads', '1'] |
|
command += ['-f', 'f32le'] |
|
if samplerate is not None: |
|
command += ['-ar', str(samplerate)] |
|
command += [filename] |
|
|
|
sp.run(command, check=True) |
|
wavs = [] |
|
for filename in filenames: |
|
wav = np.fromfile(filename, dtype=np.float32) |
|
wav = torch.from_numpy(wav) |
|
wav = wav.view(-1, self.channels()).t() |
|
if channels is not None: |
|
wav = convert_audio_channels(wav, channels) |
|
if target_size is not None: |
|
wav = wav[..., :target_size] |
|
wavs.append(wav) |
|
wav = torch.stack(wavs, dim=0) |
|
if single: |
|
wav = wav[0] |
|
return wav |
|
|
|
|
|
def convert_audio_channels(wav, channels=2): |
|
"""Convert audio to the given number of channels.""" |
|
*shape, src_channels, length = wav.shape |
|
if src_channels == channels: |
|
pass |
|
elif channels == 1: |
|
|
|
|
|
|
|
wav = wav.mean(dim=-2, keepdim=True) |
|
elif src_channels == 1: |
|
|
|
|
|
|
|
wav = wav.expand(*shape, channels, length) |
|
elif src_channels >= channels: |
|
|
|
|
|
|
|
wav = wav[..., :channels, :] |
|
else: |
|
|
|
raise ValueError('The audio file has less channels than requested but is not mono.') |
|
return wav |
|
|
|
|
|
def convert_audio(wav, from_samplerate, to_samplerate, channels): |
|
"""Convert audio from a given samplerate to a target one and target number of channels.""" |
|
wav = convert_audio_channels(wav, channels) |
|
return julius.resample_frac(wav, from_samplerate, to_samplerate) |
|
|
|
|
|
def i16_pcm(wav): |
|
"""Convert audio to 16 bits integer PCM format.""" |
|
if wav.dtype.is_floating_point: |
|
return (wav.clamp_(-1, 1) * (2**15 - 1)).short() |
|
else: |
|
return wav |
|
|
|
|
|
def f32_pcm(wav): |
|
"""Convert audio to float 32 bits PCM format.""" |
|
if wav.dtype.is_floating_point: |
|
return wav |
|
else: |
|
return wav.float() / (2**15 - 1) |
|
|
|
|
|
def as_dtype_pcm(wav): |
|
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" |
|
if wav.dtype.is_floating_point: |
|
return f32_pcm(wav) |
|
else: |
|
return i16_pcm(wav) |
|
|
|
|
|
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False): |
|
"""Save given audio as mp3. This should work on all OSes.""" |
|
c, _ = wav.shape |
|
wav = i16_pcm(wav) |
|
encoder = lameenc.Encoder() |
|
encoder.set_bit_rate(bitrate) |
|
encoder.set_in_sample_rate(samplerate) |
|
encoder.set_channels(c) |
|
encoder.set_quality(2) |
|
if not verbose: |
|
encoder.silence() |
|
wav = wav.data.cpu() |
|
wav = wav.transpose(0, 1).numpy() |
|
mp3_data = encoder.encode(wav.tobytes()) |
|
mp3_data += encoder.flush() |
|
with open(path, "wb") as f: |
|
f.write(mp3_data) |
|
|
|
|
|
def prevent_clip(wav, mode='rescale'): |
|
""" |
|
different strategies for avoiding raw clipping. |
|
""" |
|
if mode is None or mode == 'none': |
|
return wav |
|
assert wav.dtype.is_floating_point, "too late for clipping" |
|
if mode == 'rescale': |
|
wav = wav / max(1.01 * wav.abs().max(), 1) |
|
elif mode == 'clamp': |
|
wav = wav.clamp(-0.99, 0.99) |
|
elif mode == 'tanh': |
|
wav = torch.tanh(wav) |
|
else: |
|
raise ValueError(f"Invalid mode {mode}") |
|
return wav |
|
|
|
|
|
def save_audio(wav: torch.Tensor, |
|
path: tp.Union[str, Path], |
|
samplerate: int, |
|
bitrate: int = 320, |
|
clip: tp.Union[str] = 'rescale', |
|
bits_per_sample: tp.Union[int] = 16, |
|
as_float: bool = False): |
|
"""Save audio file, automatically preventing clipping if necessary |
|
based on the given `clip` strategy. If the path ends in `.mp3`, this |
|
will save as mp3 with the given `bitrate`. |
|
""" |
|
wav = prevent_clip(wav, mode=clip) |
|
path = Path(path) |
|
suffix = path.suffix.lower() |
|
if suffix == ".mp3": |
|
encode_mp3(wav, path, samplerate, bitrate, verbose=True) |
|
elif suffix == ".wav": |
|
if as_float: |
|
bits_per_sample = 32 |
|
encoding = 'PCM_F' |
|
else: |
|
encoding = 'PCM_S' |
|
ta.save(str(path), wav, sample_rate=samplerate, |
|
encoding=encoding, bits_per_sample=bits_per_sample) |
|
elif suffix == ".flac": |
|
ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample) |
|
else: |
|
raise ValueError(f"Invalid suffix for path: {suffix}") |
|
|
|
|
|
def load_track(track, audio_channels, samplerate): |
|
errors = {} |
|
wav = None |
|
|
|
try: |
|
wav = AudioFile(track).read( |
|
streams=0, |
|
samplerate=samplerate, |
|
channels=audio_channels) |
|
except sp.CalledProcessError: |
|
errors['ffmpeg'] = 'FFmpeg could not read the file.' |
|
|
|
if wav is None: |
|
try: |
|
wav, sr = ta.load(str(track)) |
|
except RuntimeError as err: |
|
errors['torchaudio'] = err.args[0] |
|
else: |
|
wav = convert_audio(wav, sr, samplerate, audio_channels) |
|
|
|
return wav, errors |