Hev832 commited on
Commit
42ed6fc
·
verified ·
1 Parent(s): 6b5f4a0

Create module_infer.py

Browse files
Files changed (1) hide show
  1. module_infer.py +135 -0
module_infer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class UVRWebUI:
2
+ def __init__(self, uvr: UVRInterface, online_data_path: str) -> None:
3
+ self.uvr = uvr
4
+ self.models_url = self.get_models_url(online_data_path)
5
+ self.define_layout()
6
+
7
+ self.input_temp_dir = "__temp"
8
+ self.export_path = "out"
9
+ if not os.path.exists(self.input_temp_dir):
10
+ os.mkdir(self.input_temp_dir)
11
+
12
+ def get_models_url(self, models_info_path: str) -> Dict[str, Dict]:
13
+ with open(models_info_path, "r") as f:
14
+ online_data = json.loads(f.read())
15
+ models_url = {}
16
+ for arch, download_list_key in zip([VR_ARCH_TYPE, MDX_ARCH_TYPE], ["vr_download_list", "mdx_download_list"]):
17
+ models_url[arch] = {model: NORMAL_REPO+model_path for model, model_path in online_data[download_list_key].items()}
18
+ models_url[DEMUCS_ARCH_TYPE] = online_data["demucs_download_list"]
19
+ return models_url
20
+
21
+ def get_local_models(self, arch: str) -> List[str]:
22
+ model_config = {
23
+ VR_ARCH_TYPE: (VR_MODELS_DIR, ".pth"),
24
+ MDX_ARCH_TYPE: (MDX_MODELS_DIR, ".onnx"),
25
+ DEMUCS_ARCH_TYPE: (DEMUCS_MODELS_DIR, ".yaml"),
26
+ }
27
+ try:
28
+ model_dir, suffix = model_config[arch]
29
+ except KeyError:
30
+ raise ValueError(f"Unkown arch type: {arch}")
31
+ return [os.path.splitext(f)[0] for f in os.listdir(model_dir) if f.endswith(suffix)]
32
+
33
+ def set_arch_setting_value(self, arch: str, setting1, setting2):
34
+ if arch == VR_ARCH_TYPE:
35
+ root.window_size_var.set(setting1)
36
+ root.aggression_setting_var.set(setting2)
37
+ elif arch == MDX_ARCH_TYPE:
38
+ root.mdx_batch_size_var.set(setting1)
39
+ root.compensate_var.set(setting2)
40
+ elif arch == DEMUCS_ARCH_TYPE:
41
+ pass
42
+
43
+ def arch_select_update(self, arch: str) -> List[Dict]:
44
+ choices = self.get_local_models(arch)
45
+ if arch == VR_ARCH_TYPE:
46
+ model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=SELECT_VR_MODEL_MAIN_LABEL)
47
+ setting1_update = self.arch_setting1.update(choices=VR_WINDOW, label=WINDOW_SIZE_MAIN_LABEL, value=root.window_size_var.get())
48
+ setting2_update = self.arch_setting2.update(choices=VR_AGGRESSION, label=AGGRESSION_SETTING_MAIN_LABEL, value=root.aggression_setting_var.get())
49
+ elif arch == MDX_ARCH_TYPE:
50
+ model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_MDX_MODEL_MAIN_LABEL)
51
+ setting1_update = self.arch_setting1.update(choices=BATCH_SIZE, label=BATCHES_MDX_MAIN_LABEL, value=root.mdx_batch_size_var.get())
52
+ setting2_update = self.arch_setting2.update(choices=VOL_COMPENSATION, label=VOL_COMP_MDX_MAIN_LABEL, value=root.compensate_var.get())
53
+ elif arch == DEMUCS_ARCH_TYPE:
54
+ model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_DEMUCS_MODEL_MAIN_LABEL)
55
+ raise gr.Error(f"{DEMUCS_ARCH_TYPE} not implempted")
56
+ else:
57
+ raise gr.Error(f"Unkown arch type: {arch}")
58
+ return [model_update, setting1_update, setting2_update]
59
+
60
+ def model_select_update(self, arch: str, model_name: str) -> List[Union[str, Dict, None]]:
61
+ if model_name == CHOOSE_MODEL:
62
+ return [None for _ in range(4)]
63
+ model, = self.uvr.assemble_model_data(model_name, arch)
64
+ if not model.model_status:
65
+ raise gr.Error(f"Cannot get model data, model hash = {model.model_hash}")
66
+
67
+ stem1_check_update = self.primary_stem_only.update(label=f"{model.primary_stem} Only")
68
+ stem2_check_update = self.secondary_stem_only.update(label=f"{model.secondary_stem} Only")
69
+ stem1_out_update = self.primary_stem_out.update(label=f"Output {model.primary_stem}")
70
+ stem2_out_update = self.secondary_stem_out.update(label=f"Output {model.secondary_stem}")
71
+
72
+ return [stem1_check_update, stem2_check_update, stem1_out_update, stem2_out_update]
73
+
74
+ def checkbox_set_root_value(self, checkbox: gr.Checkbox, root_attr: str):
75
+ checkbox.change(lambda value: root.__getattribute__(root_attr).set(value), inputs=checkbox)
76
+
77
+ def set_checkboxes_exclusive(self, checkboxes: List[gr.Checkbox], pure_callbacks: List[Callable], exclusive_value=True):
78
+ def exclusive_onchange(i, callback_i):
79
+ def new_onchange(*check_values):
80
+ if check_values[i] == exclusive_value:
81
+ return_values = []
82
+ for j, value_j in enumerate(check_values):
83
+ if j != i and value_j == exclusive_value:
84
+ return_values.append(not exclusive_value)
85
+ else:
86
+ return_values.append(value_j)
87
+ else:
88
+ return_values = check_values
89
+ callback_i(check_values[i])
90
+ return return_values
91
+ return new_onchange
92
+
93
+ for i, (checkbox, callback) in enumerate(zip(checkboxes, pure_callbacks)):
94
+ checkbox.change(exclusive_onchange(i, callback), inputs=checkboxes, outputs=checkboxes)
95
+
96
+ def process(self, input_audio, input_filename, model_name, arch, setting1, setting2, progress=gr.Progress()):
97
+ def set_progress_func(step, inference_iterations=0):
98
+ progress_curr = step + inference_iterations
99
+ progress(progress_curr)
100
+
101
+ sampling_rate, audio = input_audio
102
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
103
+ if len(audio.shape) > 1:
104
+ audio = librosa.to_mono(audio.transpose(1, 0))
105
+ input_path = os.path.join(self.input_temp_dir, input_filename)
106
+ soundfile.write(input_path, audio, sampling_rate, format="wav")
107
+
108
+ self.set_arch_setting_value(arch, setting1, setting2)
109
+
110
+ seperator = uvr.process(
111
+ model_name=model_name,
112
+ arch_type=arch,
113
+ audio_file=input_path,
114
+ export_path=self.export_path,
115
+ is_model_sample_mode=root.model_sample_mode_var.get(),
116
+ set_progress_func=set_progress_func,
117
+ )
118
+
119
+ primary_audio = None
120
+ secondary_audio = None
121
+ msg = ""
122
+ if not seperator.is_secondary_stem_only:
123
+ primary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.primary_stem}).wav")
124
+ audio, rate = soundfile.read(primary_stem_path)
125
+ primary_audio = (rate, audio)
126
+ msg += f"{seperator.primary_stem} saved at {primary_stem_path}\n"
127
+ if not seperator.is_primary_stem_only:
128
+ secondary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.secondary_stem}).wav")
129
+ audio, rate = soundfile.read(secondary_stem_path)
130
+ secondary_audio = (rate, audio)
131
+ msg += f"{seperator.secondary_stem} saved at {secondary_stem_path}\n"
132
+
133
+ os.remove(input_path)
134
+
135
+ return primary_audio, secondary_audio, msg