CCockrum commited on
Commit
9bd0a22
·
verified ·
1 Parent(s): 2ae1a70

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +185 -0
inference.py CHANGED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+
3
+ import os
4
+ import gc
5
+ import json
6
+ import shlex
7
+ import sys
8
+ import torch
9
+ import librosa
10
+ import numpy as np
11
+ import subprocess
12
+ import soundfile as sf
13
+ import hashlib
14
+ import random
15
+ import time
16
+ import traceback
17
+ import onnxruntime as ort
18
+ from utils import logger, remove_directory_contents, create_directories
19
+ from mdx_core import MDX, MDXModel
20
+ from effects import add_vocal_effects, add_instrumental_effects
21
+
22
+
23
+ stem_naming = {
24
+ "Vocals": "Instrumental",
25
+ "Other": "Instruments",
26
+ "Instrumental": "Vocals",
27
+ "Drums": "Drumless",
28
+ "Bass": "Bassless",
29
+ }
30
+
31
+
32
+ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False,
33
+ suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2, device_base="cuda"):
34
+
35
+ device = torch.device("cuda:0" if device_base == "cuda" else "cpu")
36
+ processor_num = 0 if device_base == "cuda" else -1
37
+
38
+ if device_base == "cuda":
39
+ vram_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3
40
+ m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2)
41
+ logger.info(f"threads: {m_threads} vram: {vram_gb}")
42
+ else:
43
+ m_threads = 1
44
+
45
+ model_hash = MDX.get_hash(model_path)
46
+ mp = model_params.get(model_hash)
47
+
48
+ model = MDXModel(
49
+ device,
50
+ dim_f=mp["mdx_dim_f_set"],
51
+ dim_t=2 ** mp["mdx_dim_t_set"],
52
+ n_fft=mp["mdx_n_fft_scale_set"],
53
+ stem_name=mp["primary_stem"],
54
+ compensation=mp["compensate"],
55
+ )
56
+
57
+ mdx_sess = MDX(model_path, model, processor=processor_num)
58
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
59
+ peak = max(np.max(wave), abs(np.min(wave)))
60
+ wave /= peak
61
+
62
+ if denoise:
63
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
64
+ wave_processed *= 0.5
65
+ else:
66
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
67
+
68
+ wave_processed *= peak
69
+ stem_name = model.stem_name if suffix is None else suffix
70
+
71
+ main_filepath = None
72
+ if not exclude_main:
73
+ main_filepath = os.path.join(
74
+ output_dir,
75
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
76
+ )
77
+ sf.write(main_filepath, wave_processed.T, sr)
78
+
79
+ invert_filepath = None
80
+ if not exclude_inversion:
81
+ diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
82
+ stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
83
+ invert_filepath = os.path.join(
84
+ output_dir,
85
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
86
+ )
87
+ sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
88
+
89
+ if not keep_orig:
90
+ os.remove(filename)
91
+
92
+ del mdx_sess, wave_processed, wave
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
95
+ return main_filepath, invert_filepath
96
+
97
+
98
+ def run_mdx_beta(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False,
99
+ suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=1, device_base=""):
100
+
101
+ duration = librosa.get_duration(filename=filename)
102
+ if duration >= 60 and duration <= 120:
103
+ m_threads = 8
104
+ elif duration > 120:
105
+ m_threads = 16
106
+
107
+ logger.info(f"threads: {m_threads}")
108
+
109
+ device = torch.device("cpu")
110
+ processor_num = -1
111
+
112
+ model_hash = MDX.get_hash(model_path)
113
+ mp = model_params.get(model_hash)
114
+
115
+ model = MDXModel(
116
+ device,
117
+ dim_f=mp["mdx_dim_f_set"],
118
+ dim_t=2 ** mp["mdx_dim_t_set"],
119
+ n_fft=mp["mdx_n_fft_scale_set"],
120
+ stem_name=mp["primary_stem"],
121
+ compensation=mp["compensate"],
122
+ )
123
+
124
+ mdx_sess = MDX(model_path, model, processor=processor_num)
125
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
126
+ peak = max(np.max(wave), abs(np.min(wave)))
127
+ wave /= peak
128
+
129
+ if denoise:
130
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
131
+ wave_processed *= 0.5
132
+ else:
133
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
134
+
135
+ wave_processed *= peak
136
+ stem_name = model.stem_name if suffix is None else suffix
137
+
138
+ main_filepath = None
139
+ if not exclude_main:
140
+ main_filepath = os.path.join(
141
+ output_dir,
142
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
143
+ )
144
+ sf.write(main_filepath, wave_processed.T, sr)
145
+
146
+ invert_filepath = None
147
+ if not exclude_inversion:
148
+ diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
149
+ stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
150
+ invert_filepath = os.path.join(
151
+ output_dir,
152
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
153
+ )
154
+ sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
155
+
156
+ if not keep_orig:
157
+ os.remove(filename)
158
+
159
+ del mdx_sess, wave_processed, wave
160
+ gc.collect()
161
+ torch.cuda.empty_cache()
162
+ return main_filepath, invert_filepath
163
+
164
+
165
+ def convert_to_stereo_and_wav(audio_path, output_dir):
166
+ wave, sr = librosa.load(audio_path, mono=False, sr=44100)
167
+
168
+ if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav":
169
+ stereo_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_stereo.wav")
170
+ command = shlex.split(f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}")
171
+ subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
172
+ return stereo_path
173
+ return audio_path
174
+
175
+
176
+ def get_hash(filepath):
177
+ with open(filepath, 'rb') as f:
178
+ file_hash = hashlib.blake2b()
179
+ while chunk := f.read(8192):
180
+ file_hash.update(chunk)
181
+ return file_hash.hexdigest()[:18]
182
+
183
+
184
+ def random_sleep():
185
+ time.sleep(round(random.uniform(5.2, 7.9), 1))