Spaces:
Runtime error
Runtime error
File size: 7,033 Bytes
42ed6fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
class UVRWebUI:
def __init__(self, uvr: UVRInterface, online_data_path: str) -> None:
self.uvr = uvr
self.models_url = self.get_models_url(online_data_path)
self.define_layout()
self.input_temp_dir = "__temp"
self.export_path = "out"
if not os.path.exists(self.input_temp_dir):
os.mkdir(self.input_temp_dir)
def get_models_url(self, models_info_path: str) -> Dict[str, Dict]:
with open(models_info_path, "r") as f:
online_data = json.loads(f.read())
models_url = {}
for arch, download_list_key in zip([VR_ARCH_TYPE, MDX_ARCH_TYPE], ["vr_download_list", "mdx_download_list"]):
models_url[arch] = {model: NORMAL_REPO+model_path for model, model_path in online_data[download_list_key].items()}
models_url[DEMUCS_ARCH_TYPE] = online_data["demucs_download_list"]
return models_url
def get_local_models(self, arch: str) -> List[str]:
model_config = {
VR_ARCH_TYPE: (VR_MODELS_DIR, ".pth"),
MDX_ARCH_TYPE: (MDX_MODELS_DIR, ".onnx"),
DEMUCS_ARCH_TYPE: (DEMUCS_MODELS_DIR, ".yaml"),
}
try:
model_dir, suffix = model_config[arch]
except KeyError:
raise ValueError(f"Unkown arch type: {arch}")
return [os.path.splitext(f)[0] for f in os.listdir(model_dir) if f.endswith(suffix)]
def set_arch_setting_value(self, arch: str, setting1, setting2):
if arch == VR_ARCH_TYPE:
root.window_size_var.set(setting1)
root.aggression_setting_var.set(setting2)
elif arch == MDX_ARCH_TYPE:
root.mdx_batch_size_var.set(setting1)
root.compensate_var.set(setting2)
elif arch == DEMUCS_ARCH_TYPE:
pass
def arch_select_update(self, arch: str) -> List[Dict]:
choices = self.get_local_models(arch)
if arch == VR_ARCH_TYPE:
model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=SELECT_VR_MODEL_MAIN_LABEL)
setting1_update = self.arch_setting1.update(choices=VR_WINDOW, label=WINDOW_SIZE_MAIN_LABEL, value=root.window_size_var.get())
setting2_update = self.arch_setting2.update(choices=VR_AGGRESSION, label=AGGRESSION_SETTING_MAIN_LABEL, value=root.aggression_setting_var.get())
elif arch == MDX_ARCH_TYPE:
model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_MDX_MODEL_MAIN_LABEL)
setting1_update = self.arch_setting1.update(choices=BATCH_SIZE, label=BATCHES_MDX_MAIN_LABEL, value=root.mdx_batch_size_var.get())
setting2_update = self.arch_setting2.update(choices=VOL_COMPENSATION, label=VOL_COMP_MDX_MAIN_LABEL, value=root.compensate_var.get())
elif arch == DEMUCS_ARCH_TYPE:
model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_DEMUCS_MODEL_MAIN_LABEL)
raise gr.Error(f"{DEMUCS_ARCH_TYPE} not implempted")
else:
raise gr.Error(f"Unkown arch type: {arch}")
return [model_update, setting1_update, setting2_update]
def model_select_update(self, arch: str, model_name: str) -> List[Union[str, Dict, None]]:
if model_name == CHOOSE_MODEL:
return [None for _ in range(4)]
model, = self.uvr.assemble_model_data(model_name, arch)
if not model.model_status:
raise gr.Error(f"Cannot get model data, model hash = {model.model_hash}")
stem1_check_update = self.primary_stem_only.update(label=f"{model.primary_stem} Only")
stem2_check_update = self.secondary_stem_only.update(label=f"{model.secondary_stem} Only")
stem1_out_update = self.primary_stem_out.update(label=f"Output {model.primary_stem}")
stem2_out_update = self.secondary_stem_out.update(label=f"Output {model.secondary_stem}")
return [stem1_check_update, stem2_check_update, stem1_out_update, stem2_out_update]
def checkbox_set_root_value(self, checkbox: gr.Checkbox, root_attr: str):
checkbox.change(lambda value: root.__getattribute__(root_attr).set(value), inputs=checkbox)
def set_checkboxes_exclusive(self, checkboxes: List[gr.Checkbox], pure_callbacks: List[Callable], exclusive_value=True):
def exclusive_onchange(i, callback_i):
def new_onchange(*check_values):
if check_values[i] == exclusive_value:
return_values = []
for j, value_j in enumerate(check_values):
if j != i and value_j == exclusive_value:
return_values.append(not exclusive_value)
else:
return_values.append(value_j)
else:
return_values = check_values
callback_i(check_values[i])
return return_values
return new_onchange
for i, (checkbox, callback) in enumerate(zip(checkboxes, pure_callbacks)):
checkbox.change(exclusive_onchange(i, callback), inputs=checkboxes, outputs=checkboxes)
def process(self, input_audio, input_filename, model_name, arch, setting1, setting2, progress=gr.Progress()):
def set_progress_func(step, inference_iterations=0):
progress_curr = step + inference_iterations
progress(progress_curr)
sampling_rate, audio = input_audio
audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.transpose(1, 0))
input_path = os.path.join(self.input_temp_dir, input_filename)
soundfile.write(input_path, audio, sampling_rate, format="wav")
self.set_arch_setting_value(arch, setting1, setting2)
seperator = uvr.process(
model_name=model_name,
arch_type=arch,
audio_file=input_path,
export_path=self.export_path,
is_model_sample_mode=root.model_sample_mode_var.get(),
set_progress_func=set_progress_func,
)
primary_audio = None
secondary_audio = None
msg = ""
if not seperator.is_secondary_stem_only:
primary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.primary_stem}).wav")
audio, rate = soundfile.read(primary_stem_path)
primary_audio = (rate, audio)
msg += f"{seperator.primary_stem} saved at {primary_stem_path}\n"
if not seperator.is_primary_stem_only:
secondary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.secondary_stem}).wav")
audio, rate = soundfile.read(secondary_stem_path)
secondary_audio = (rate, audio)
msg += f"{seperator.secondary_stem} saved at {secondary_stem_path}\n"
os.remove(input_path)
return primary_audio, secondary_audio, msg
|