Update infer_cli.py
Browse files- infer_cli.py +47 -34
infer_cli.py
CHANGED
@@ -1,66 +1,79 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import sys
|
4 |
-
|
5 |
-
now_dir = os.getcwd()
|
6 |
-
sys.path.append(now_dir)
|
7 |
from dotenv import load_dotenv
|
8 |
from scipy.io import wavfile
|
9 |
|
10 |
from config import Config
|
11 |
from modules import VC
|
12 |
|
13 |
-
####
|
14 |
-
# USAGE
|
15 |
-
#
|
16 |
-
# In your Terminal or CMD or whatever
|
17 |
-
|
18 |
|
19 |
-
def arg_parse() ->
|
20 |
parser = argparse.ArgumentParser()
|
21 |
parser.add_argument("--f0up_key", type=int, default=0)
|
22 |
-
parser.add_argument("--input_path", type=str, help="input path")
|
23 |
parser.add_argument("--index_path", type=str, help="index path")
|
24 |
parser.add_argument("--f0method", type=str, default="harvest", help="harvest or pm")
|
25 |
-
parser.add_argument("--opt_path", type=str, help="
|
26 |
-
parser.add_argument("--model_name", type=str, help="
|
27 |
parser.add_argument("--index_rate", type=float, default=0.66, help="index rate")
|
28 |
-
parser.add_argument("--device", type=str, help="device")
|
29 |
-
parser.add_argument("--is_half", type=bool, help="use half
|
30 |
parser.add_argument("--filter_radius", type=int, default=3, help="filter radius")
|
31 |
-
parser.add_argument("--resample_sr", type=int, default=0, help="resample
|
32 |
-
parser.add_argument("--rms_mix_rate", type=float, default=1, help="
|
33 |
-
parser.add_argument("--protect", type=float, default=0.33, help="protect")
|
34 |
|
35 |
args = parser.parse_args()
|
36 |
-
sys.argv = sys.argv[:1]
|
37 |
-
|
38 |
return args
|
39 |
|
40 |
|
41 |
def main():
|
42 |
load_dotenv()
|
43 |
args = arg_parse()
|
|
|
44 |
config = Config()
|
45 |
config.device = args.device if args.device else config.device
|
46 |
config.is_half = args.is_half if args.is_half else config.is_half
|
|
|
47 |
vc = VC(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
vc.get_vc(args.model_name)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import sys
|
|
|
|
|
|
|
4 |
from dotenv import load_dotenv
|
5 |
from scipy.io import wavfile
|
6 |
|
7 |
from config import Config
|
8 |
from modules import VC
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
def arg_parse() -> argparse.Namespace:
|
12 |
parser = argparse.ArgumentParser()
|
13 |
parser.add_argument("--f0up_key", type=int, default=0)
|
14 |
+
parser.add_argument("--input_path", type=str, help="input path", required=True)
|
15 |
parser.add_argument("--index_path", type=str, help="index path")
|
16 |
parser.add_argument("--f0method", type=str, default="harvest", help="harvest or pm")
|
17 |
+
parser.add_argument("--opt_path", type=str, help="output path", required=True)
|
18 |
+
parser.add_argument("--model_name", type=str, help="model name (stored in assets/weight_root)", required=True)
|
19 |
parser.add_argument("--index_rate", type=float, default=0.66, help="index rate")
|
20 |
+
parser.add_argument("--device", type=str, help="device (e.g., cuda or cpu)")
|
21 |
+
parser.add_argument("--is_half", type=bool, help="use half precision (True or False)", default=False)
|
22 |
parser.add_argument("--filter_radius", type=int, default=3, help="filter radius")
|
23 |
+
parser.add_argument("--resample_sr", type=int, default=0, help="resample sampling rate")
|
24 |
+
parser.add_argument("--rms_mix_rate", type=float, default=1, help="RMS mix rate")
|
25 |
+
parser.add_argument("--protect", type=float, default=0.33, help="protect value")
|
26 |
|
27 |
args = parser.parse_args()
|
|
|
|
|
28 |
return args
|
29 |
|
30 |
|
31 |
def main():
|
32 |
load_dotenv()
|
33 |
args = arg_parse()
|
34 |
+
|
35 |
config = Config()
|
36 |
config.device = args.device if args.device else config.device
|
37 |
config.is_half = args.is_half if args.is_half else config.is_half
|
38 |
+
|
39 |
vc = VC(config)
|
40 |
+
|
41 |
+
# Add a check for the model name
|
42 |
+
if not args.model_name:
|
43 |
+
print("Error: Model name must be provided.")
|
44 |
+
sys.exit(1)
|
45 |
+
|
46 |
vc.get_vc(args.model_name)
|
47 |
+
|
48 |
+
# Process the audio file
|
49 |
+
try:
|
50 |
+
sid, wav_opt = vc.vc_single(
|
51 |
+
0,
|
52 |
+
args.input_path,
|
53 |
+
args.f0up_key,
|
54 |
+
None,
|
55 |
+
args.f0method,
|
56 |
+
args.index_path,
|
57 |
+
None,
|
58 |
+
args.index_rate,
|
59 |
+
args.filter_radius,
|
60 |
+
args.resample_sr,
|
61 |
+
args.rms_mix_rate,
|
62 |
+
args.protect,
|
63 |
+
)
|
64 |
+
|
65 |
+
if sid is None:
|
66 |
+
print("Warning: sid is None. Skipping sid-related operations.")
|
67 |
+
else:
|
68 |
+
print(f"Processed sid: {sid}")
|
69 |
+
|
70 |
+
# Save the processed audio
|
71 |
+
wavfile.write(args.opt_path, wav_opt[0], wav_opt[1])
|
72 |
+
print(f"Output saved to: {args.opt_path}")
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
print(f"An error occurred during processing: {e}")
|
76 |
+
sys.exit(1)
|
77 |
|
78 |
|
79 |
if __name__ == "__main__":
|