Gregniuki commited on
Commit
d1df1d2
·
verified ·
1 Parent(s): 0d50905

Update infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. infer/utils_infer.py +9 -2
infer/utils_infer.py CHANGED
@@ -101,6 +101,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
101
  repo_id = "charactr/vocos-mel-24khz"
102
  config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
103
  model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
 
104
  vocoder = Vocos.from_hparams(config_path)
105
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
106
  from vocos.feature_extractors import EncodecFeatures
@@ -111,13 +112,18 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
111
  for key, value in vocoder.feature_extractor.encodec.state_dict().items()
112
  }
113
  state_dict.update(encodec_parameters)
 
114
  vocoder.load_state_dict(state_dict)
115
- vocoder = vocoder.eval().to(device)
 
 
 
116
  elif vocoder_name == "bigvgan":
117
  try:
118
  from third_party.BigVGAN import bigvgan
119
  except ImportError:
120
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
 
121
  if is_local:
122
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
123
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
@@ -126,7 +132,8 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
126
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
127
 
128
  vocoder.remove_weight_norm()
129
- vocoder = vocoder.eval().to(device)
 
130
  return vocoder
131
 
132
 
 
101
  repo_id = "charactr/vocos-mel-24khz"
102
  config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
103
  model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
104
+
105
  vocoder = Vocos.from_hparams(config_path)
106
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
107
  from vocos.feature_extractors import EncodecFeatures
 
112
  for key, value in vocoder.feature_extractor.encodec.state_dict().items()
113
  }
114
  state_dict.update(encodec_parameters)
115
+
116
  vocoder.load_state_dict(state_dict)
117
+
118
+ # Convert vocoder to bfloat16 if using a compatible device
119
+ vocoder = vocoder.eval().to(device).to(torch.bfloat16)
120
+
121
  elif vocoder_name == "bigvgan":
122
  try:
123
  from third_party.BigVGAN import bigvgan
124
  except ImportError:
125
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
126
+
127
  if is_local:
128
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
129
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
 
132
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
133
 
134
  vocoder.remove_weight_norm()
135
+ vocoder = vocoder.eval().to(device).to(torch.bfloat16) # Convert to bfloat16
136
+
137
  return vocoder
138
 
139