James Zhou
[init]
9867d34
import math
import warnings
from pathlib import Path
import argbind
import numpy as np
import torch
from audiotools import AudioSignal
from audiotools.core import util
from tqdm import tqdm
from . import load_model
warnings.filterwarnings("ignore", category=UserWarning)
@argbind.bind(group="encode", positional=True, without_prefix=True)
@torch.inference_mode()
@torch.no_grad()
def encode(
input: str,
output: str = "",
weights_path: str = "",
model_tag: str = "latest",
model_bitrate: str = "8kbps",
n_quantizers: int = None,
device: str = "cuda",
model_type: str = "44khz",
win_duration: float = 5.0,
verbose: bool = False,
):
"""Encode audio files in input path to .dac format.
Parameters
----------
input : str
Path to input audio file or directory
output : str, optional
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
weights_path : str, optional
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
model_tag and model_type.
model_tag : str, optional
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
model_bitrate: str
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
n_quantizers : int, optional
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
device : str, optional
Device to use, by default "cuda"
model_type : str, optional
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
"""
generator = load_model(
model_type=model_type,
model_bitrate=model_bitrate,
tag=model_tag,
load_path=weights_path,
)
generator.to(device)
generator.eval()
kwargs = {"n_quantizers": n_quantizers}
# Find all audio files in input path
input = Path(input)
audio_files = util.find_audio(input)
output = Path(output)
output.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
# Load file
signal = AudioSignal(audio_files[i])
# Encode audio to .dac format
artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
# Compute output path
relative_path = audio_files[i].relative_to(input)
output_dir = output / relative_path.parent
if not relative_path.name:
output_dir = output
relative_path = audio_files[i]
output_name = relative_path.with_suffix(".dac").name
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
artifact.save(output_path)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
encode()