Jmica commited on
Commit
929b5ae
·
1 Parent(s): d1d8139

run pre-commit

Browse files
src/f5_tts/api.py CHANGED
@@ -56,14 +56,20 @@ class F5TTS:
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
59
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path))
 
 
60
  elif mel_spec_type == "bigvgan":
61
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path))
 
 
62
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
  model_cls = DiT
64
  elif model_type == "E2-TTS":
65
  if not ckpt_file:
66
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path))
 
 
67
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
  model_cls = UNetT
69
  else:
 
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
59
+ ckpt_file = str(
60
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
61
+ )
62
  elif mel_spec_type == "bigvgan":
63
+ ckpt_file = str(
64
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
65
+ )
66
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
67
  model_cls = DiT
68
  elif model_type == "E2-TTS":
69
  if not ckpt_file:
70
+ ckpt_file = str(
71
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
72
+ )
73
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
74
  model_cls = UNetT
75
  else:
src/f5_tts/infer/utils_infer.py CHANGED
@@ -96,8 +96,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
96
  print(f"Load vocos from local path {local_path}")
97
  repo_id = "charactr/vocos-mel-24khz"
98
  revision = None
99
- config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision)
100
- model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision)
 
 
 
 
101
  vocoder = Vocos.from_hparams(config_path=config_path)
102
  state_dict = torch.load(model_path, map_location="cpu")
103
  vocoder.load_state_dict(state_dict)
 
96
  print(f"Load vocos from local path {local_path}")
97
  repo_id = "charactr/vocos-mel-24khz"
98
  revision = None
99
+ config_path = hf_hub_download(
100
+ repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
101
+ )
102
+ model_path = hf_hub_download(
103
+ repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
104
+ )
105
  vocoder = Vocos.from_hparams(config_path=config_path)
106
  state_dict = torch.load(model_path, map_location="cpu")
107
  vocoder.load_state_dict(state_dict)