Spaces:
Sleeping
Sleeping
Update infer/utils_infer.py
Browse files- infer/utils_infer.py +9 -2
infer/utils_infer.py
CHANGED
@@ -101,6 +101,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
101 |
repo_id = "charactr/vocos-mel-24khz"
|
102 |
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
103 |
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
|
|
104 |
vocoder = Vocos.from_hparams(config_path)
|
105 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
106 |
from vocos.feature_extractors import EncodecFeatures
|
@@ -111,13 +112,18 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
111 |
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
112 |
}
|
113 |
state_dict.update(encodec_parameters)
|
|
|
114 |
vocoder.load_state_dict(state_dict)
|
115 |
-
|
|
|
|
|
|
|
116 |
elif vocoder_name == "bigvgan":
|
117 |
try:
|
118 |
from third_party.BigVGAN import bigvgan
|
119 |
except ImportError:
|
120 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
|
|
121 |
if is_local:
|
122 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
123 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
@@ -126,7 +132,8 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
126 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
127 |
|
128 |
vocoder.remove_weight_norm()
|
129 |
-
vocoder = vocoder.eval().to(device)
|
|
|
130 |
return vocoder
|
131 |
|
132 |
|
|
|
101 |
repo_id = "charactr/vocos-mel-24khz"
|
102 |
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
103 |
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
104 |
+
|
105 |
vocoder = Vocos.from_hparams(config_path)
|
106 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
107 |
from vocos.feature_extractors import EncodecFeatures
|
|
|
112 |
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
113 |
}
|
114 |
state_dict.update(encodec_parameters)
|
115 |
+
|
116 |
vocoder.load_state_dict(state_dict)
|
117 |
+
|
118 |
+
# Convert vocoder to bfloat16 if using a compatible device
|
119 |
+
vocoder = vocoder.eval().to(device).to(torch.bfloat16)
|
120 |
+
|
121 |
elif vocoder_name == "bigvgan":
|
122 |
try:
|
123 |
from third_party.BigVGAN import bigvgan
|
124 |
except ImportError:
|
125 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
126 |
+
|
127 |
if is_local:
|
128 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
129 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
|
|
132 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
133 |
|
134 |
vocoder.remove_weight_norm()
|
135 |
+
vocoder = vocoder.eval().to(device).to(torch.bfloat16) # Convert to bfloat16
|
136 |
+
|
137 |
return vocoder
|
138 |
|
139 |
|