|
import os |
|
import sys |
|
import time |
|
import yt_dlp |
|
import shutil |
|
import librosa |
|
import logging |
|
import argparse |
|
import warnings |
|
import logging.handlers |
|
|
|
from soundfile import read, write |
|
from distutils.util import strtobool |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
from main.configs.config import Config |
|
from main.library.algorithm.separator import Separator |
|
|
|
config = Config() |
|
translations = config.translations |
|
dataset_temp = os.path.join("dataset_temp") |
|
logger = logging.getLogger(__name__) |
|
|
|
if logger.hasHandlers(): logger.handlers.clear() |
|
else: |
|
console_handler = logging.StreamHandler() |
|
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
|
console_handler.setFormatter(console_formatter) |
|
console_handler.setLevel(logging.INFO) |
|
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') |
|
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
|
file_handler.setFormatter(file_formatter) |
|
file_handler.setLevel(logging.DEBUG) |
|
logger.addHandler(console_handler) |
|
logger.addHandler(file_handler) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--input_audio", type=str, required=True) |
|
parser.add_argument("--output_dataset", type=str, default="./dataset") |
|
parser.add_argument("--sample_rate", type=int, default=44100) |
|
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False) |
|
parser.add_argument("--clean_strength", type=float, default=0.7) |
|
parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False) |
|
parser.add_argument("--kim_vocal_version", type=int, default=2) |
|
parser.add_argument("--overlap", type=float, default=0.25) |
|
parser.add_argument("--segments_size", type=int, default=256) |
|
parser.add_argument("--mdx_hop_length", type=int, default=1024) |
|
parser.add_argument("--mdx_batch_size", type=int, default=1) |
|
parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False) |
|
parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False) |
|
parser.add_argument("--skip_start_audios", type=str, default="0") |
|
parser.add_argument("--skip_end_audios", type=str, default="0") |
|
|
|
return parser.parse_args() |
|
|
|
def main(): |
|
pid_path = os.path.join("assets", "create_dataset_pid.txt") |
|
with open(pid_path, "w") as pid_file: |
|
pid_file.write(str(os.getpid())) |
|
|
|
args = parse_arguments() |
|
input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios |
|
log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip} |
|
|
|
if clean_dataset: log_data[translations['clean_strength']] = clean_strength |
|
if skip: |
|
log_data[translations['skip_start']] = skip_start_audios |
|
log_data[translations['skip_end']] = skip_end_audios |
|
|
|
for key, value in log_data.items(): |
|
logger.debug(f"{key}: {value}") |
|
|
|
if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"]) |
|
start_time = time.time() |
|
|
|
try: |
|
paths = [] |
|
|
|
if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True) |
|
urls = input_audio.replace(", ", ",").split(",") |
|
|
|
for url in urls: |
|
path = downloader(url, urls.index(url)) |
|
paths.append(path) |
|
|
|
if skip: |
|
skip_start_audios, skip_end_audios = skip_start_audios.replace(", ", ",").split(","), skip_end_audios.replace(", ", ",").split(",") |
|
|
|
if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths): |
|
logger.warning(translations["skip<audio"]) |
|
sys.exit(1) |
|
elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths): |
|
logger.warning(translations["skip>audio"]) |
|
sys.exit(1) |
|
else: |
|
for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios): |
|
skip_start(audio, skip_start_audio) |
|
skip_end(audio, skip_end_audio) |
|
|
|
separator_paths = [] |
|
|
|
for audio in paths: |
|
vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate) |
|
if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate) |
|
separator_paths.append(vocals) |
|
|
|
paths = separator_paths |
|
|
|
for audio_path in paths: |
|
data, sample_rate = read(audio_path) |
|
data = librosa.to_mono(data.T) |
|
|
|
if clean_dataset: |
|
from main.tools.noisereduce import reduce_noise |
|
data = reduce_noise(y=data, sr=sample_rate, prop_decrease=clean_strength, device=config.device) |
|
|
|
write(audio_path, data, sample_rate) |
|
except Exception as e: |
|
logger.error(f"{translations['create_dataset_error']}: {e}") |
|
import traceback |
|
logger.error(traceback.format_exc()) |
|
finally: |
|
for audio in paths: |
|
shutil.move(audio, output_dataset) |
|
|
|
if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True) |
|
|
|
elapsed_time = time.time() - start_time |
|
if os.path.exists(pid_path): os.remove(pid_path) |
|
logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}")) |
|
|
|
def downloader(url, name): |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
|
|
ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False} |
|
logger.info(f"{translations['starting_download']}: {url}...") |
|
|
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
ydl.extract_info(url) |
|
logger.info(f"{translations['download_success']}: {url}") |
|
|
|
return os.path.join(dataset_temp, f"{name}" + ".wav") |
|
|
|
def skip_start(input_file, seconds): |
|
data, sr = read(input_file) |
|
total_duration = len(data) / sr |
|
|
|
if seconds <= 0: logger.warning(translations["=<0"]) |
|
elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) |
|
else: |
|
logger.info(f"{translations['skip_start']}: {input_file}...") |
|
write(input_file, data[int(seconds * sr):], sr) |
|
|
|
logger.info(translations["skip_start_audio"].format(input_file=input_file)) |
|
|
|
def skip_end(input_file, seconds): |
|
data, sr = read(input_file) |
|
total_duration = len(data) / sr |
|
|
|
if seconds <= 0: logger.warning(translations["=<0"]) |
|
elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) |
|
else: |
|
logger.info(f"{translations['skip_end']}: {input_file}...") |
|
write(input_file, data[:-int(seconds * sr)], sr) |
|
|
|
logger.info(translations["skip_end_audio"].format(input_file=input_file)) |
|
|
|
def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate): |
|
if not os.path.exists(input): |
|
logger.warning(translations["input_not_valid"]) |
|
return None |
|
|
|
if not os.path.exists(output): |
|
logger.warning(translations["output_not_valid"]) |
|
return None |
|
|
|
model = f"Kim_Vocal_{version}.onnx" |
|
output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate) |
|
|
|
for f in output_separator: |
|
path = os.path.join(output, f) |
|
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) |
|
|
|
if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") |
|
elif '_(Vocals)_' in f: |
|
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" |
|
os.rename(path, rename_file) |
|
|
|
return rename_file |
|
|
|
def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate): |
|
if not os.path.exists(input): |
|
logger.warning(translations["input_not_valid"]) |
|
return None |
|
|
|
if not os.path.exists(output): |
|
logger.warning(translations["output_not_valid"]) |
|
return None |
|
|
|
logger.info(f"{translations['dereverb']}: {input}...") |
|
output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate) |
|
|
|
for f in output_dereverb: |
|
path = os.path.join(output, f) |
|
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) |
|
|
|
if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") |
|
elif '_(No Reverb)_' in f: |
|
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" |
|
os.rename(path, rename_file) |
|
|
|
logger.info(f"{translations['dereverb_success']}: {rename_file}") |
|
return rename_file |
|
|
|
def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100): |
|
try: |
|
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}) |
|
separator.load_model(model_filename=model_filename) |
|
return separator.separate(audio_file) |
|
except: |
|
logger.debug(translations["default_setting"]) |
|
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}) |
|
separator.load_model(model_filename=model_filename) |
|
return separator.separate(audio_file) |
|
|
|
if __name__ == "__main__": main() |