Bredvige commited on
Commit
d510ab9
·
verified ·
1 Parent(s): 7cd70b4

Update infer_cli.py

Browse files
Files changed (1) hide show
  1. infer_cli.py +47 -34
infer_cli.py CHANGED
@@ -1,66 +1,79 @@
1
  import argparse
2
  import os
3
  import sys
4
-
5
- now_dir = os.getcwd()
6
- sys.path.append(now_dir)
7
  from dotenv import load_dotenv
8
  from scipy.io import wavfile
9
 
10
  from config import Config
11
  from modules import VC
12
 
13
- ####
14
- # USAGE
15
- #
16
- # In your Terminal or CMD or whatever
17
-
18
 
19
- def arg_parse() -> tuple:
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument("--f0up_key", type=int, default=0)
22
- parser.add_argument("--input_path", type=str, help="input path")
23
  parser.add_argument("--index_path", type=str, help="index path")
24
  parser.add_argument("--f0method", type=str, default="harvest", help="harvest or pm")
25
- parser.add_argument("--opt_path", type=str, help="opt path")
26
- parser.add_argument("--model_name", type=str, help="store in assets/weight_root")
27
  parser.add_argument("--index_rate", type=float, default=0.66, help="index rate")
28
- parser.add_argument("--device", type=str, help="device")
29
- parser.add_argument("--is_half", type=bool, help="use half -> True")
30
  parser.add_argument("--filter_radius", type=int, default=3, help="filter radius")
31
- parser.add_argument("--resample_sr", type=int, default=0, help="resample sr")
32
- parser.add_argument("--rms_mix_rate", type=float, default=1, help="rms mix rate")
33
- parser.add_argument("--protect", type=float, default=0.33, help="protect")
34
 
35
  args = parser.parse_args()
36
- sys.argv = sys.argv[:1]
37
-
38
  return args
39
 
40
 
41
  def main():
42
  load_dotenv()
43
  args = arg_parse()
 
44
  config = Config()
45
  config.device = args.device if args.device else config.device
46
  config.is_half = args.is_half if args.is_half else config.is_half
 
47
  vc = VC(config)
 
 
 
 
 
 
48
  vc.get_vc(args.model_name)
49
- _, wav_opt = vc.vc_single(
50
- 0,
51
- args.input_path,
52
- args.f0up_key,
53
- None,
54
- args.f0method,
55
- args.index_path,
56
- None,
57
- args.index_rate,
58
- args.filter_radius,
59
- args.resample_sr,
60
- args.rms_mix_rate,
61
- args.protect,
62
- )
63
- wavfile.write(args.opt_path, wav_opt[0], wav_opt[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  if __name__ == "__main__":
 
1
  import argparse
2
  import os
3
  import sys
 
 
 
4
  from dotenv import load_dotenv
5
  from scipy.io import wavfile
6
 
7
  from config import Config
8
  from modules import VC
9
 
 
 
 
 
 
10
 
11
+ def arg_parse() -> argparse.Namespace:
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument("--f0up_key", type=int, default=0)
14
+ parser.add_argument("--input_path", type=str, help="input path", required=True)
15
  parser.add_argument("--index_path", type=str, help="index path")
16
  parser.add_argument("--f0method", type=str, default="harvest", help="harvest or pm")
17
+ parser.add_argument("--opt_path", type=str, help="output path", required=True)
18
+ parser.add_argument("--model_name", type=str, help="model name (stored in assets/weight_root)", required=True)
19
  parser.add_argument("--index_rate", type=float, default=0.66, help="index rate")
20
+ parser.add_argument("--device", type=str, help="device (e.g., cuda or cpu)")
21
+ parser.add_argument("--is_half", type=bool, help="use half precision (True or False)", default=False)
22
  parser.add_argument("--filter_radius", type=int, default=3, help="filter radius")
23
+ parser.add_argument("--resample_sr", type=int, default=0, help="resample sampling rate")
24
+ parser.add_argument("--rms_mix_rate", type=float, default=1, help="RMS mix rate")
25
+ parser.add_argument("--protect", type=float, default=0.33, help="protect value")
26
 
27
  args = parser.parse_args()
 
 
28
  return args
29
 
30
 
31
  def main():
32
  load_dotenv()
33
  args = arg_parse()
34
+
35
  config = Config()
36
  config.device = args.device if args.device else config.device
37
  config.is_half = args.is_half if args.is_half else config.is_half
38
+
39
  vc = VC(config)
40
+
41
+ # Add a check for the model name
42
+ if not args.model_name:
43
+ print("Error: Model name must be provided.")
44
+ sys.exit(1)
45
+
46
  vc.get_vc(args.model_name)
47
+
48
+ # Process the audio file
49
+ try:
50
+ sid, wav_opt = vc.vc_single(
51
+ 0,
52
+ args.input_path,
53
+ args.f0up_key,
54
+ None,
55
+ args.f0method,
56
+ args.index_path,
57
+ None,
58
+ args.index_rate,
59
+ args.filter_radius,
60
+ args.resample_sr,
61
+ args.rms_mix_rate,
62
+ args.protect,
63
+ )
64
+
65
+ if sid is None:
66
+ print("Warning: sid is None. Skipping sid-related operations.")
67
+ else:
68
+ print(f"Processed sid: {sid}")
69
+
70
+ # Save the processed audio
71
+ wavfile.write(args.opt_path, wav_opt[0], wav_opt[1])
72
+ print(f"Output saved to: {args.opt_path}")
73
+
74
+ except Exception as e:
75
+ print(f"An error occurred during processing: {e}")
76
+ sys.exit(1)
77
 
78
 
79
  if __name__ == "__main__":