SWivid commited on
Commit
68b4ce0
·
1 Parent(s): 9395289
README.md CHANGED
@@ -92,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen
92
  pip install -r requirements_gradio.txt
93
  ```
94
 
95
- After installing the dependencies, launch the app:
96
 
97
  ```bash
98
  python gradio_app.py
 
92
  pip install -r requirements_gradio.txt
93
  ```
94
 
95
+ After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
96
 
97
  ```bash
98
  python gradio_app.py
gradio_app.py CHANGED
@@ -6,12 +6,12 @@ import gradio as gr
6
  import numpy as np
7
  import tempfile
8
  from einops import rearrange
9
- from ema_pytorch import EMA
10
  from vocos import Vocos
11
  from pydub import AudioSegment
12
  from model import CFM, UNetT, DiT, MMDiT
13
  from cached_path import cached_path
14
  from model.utils import (
 
15
  get_tokenizer,
16
  convert_char_to_pinyin,
17
  save_spectrogram,
@@ -51,10 +51,8 @@ fix_duration = None
51
 
52
 
53
  def load_model(exp_name, model_cls, model_cfg, ckpt_step):
54
- checkpoint = torch.load(
55
- str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")),
56
- map_location=device,
57
- )
58
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
59
  model = CFM(
60
  transformer=model_cls(
@@ -71,11 +69,9 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
71
  vocab_char_map=vocab_char_map,
72
  ).to(device)
73
 
74
- ema_model = EMA(model, include_online_model=False).to(device)
75
- ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
76
- ema_model.copy_params_from_ema_to_model()
77
 
78
- return ema_model, model
79
 
80
 
81
  # load models
@@ -84,10 +80,10 @@ F5TTS_model_cfg = dict(
84
  )
85
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
86
 
87
- F5TTS_ema_model, F5TTS_base_model = load_model(
88
  "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
89
  )
90
- E2TTS_ema_model, E2TTS_base_model = load_model(
91
  "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
92
  )
93
 
@@ -107,10 +103,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
107
  ref_audio = f.name
108
  if exp_name == "F5-TTS":
109
  ema_model = F5TTS_ema_model
110
- base_model = F5TTS_base_model
111
  elif exp_name == "E2-TTS":
112
  ema_model = E2TTS_ema_model
113
- base_model = E2TTS_base_model
114
 
115
  if not ref_text.strip():
116
  gr.Info("No reference text provided, transcribing reference audio...")
@@ -151,7 +145,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
151
  # inference
152
  gr.Info(f"Generating audio using {exp_name}")
153
  with torch.inference_mode():
154
- generated, _ = base_model.sample(
155
  cond=audio,
156
  text=final_text_list,
157
  duration=duration,
@@ -243,7 +237,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
243
 
244
 
245
  @click.command()
246
- @click.option("--port", "-p", default=None, help="Port to run the app on")
247
  @click.option("--host", "-H", default=None, help="Host to run the app on")
248
  @click.option(
249
  "--share",
 
6
  import numpy as np
7
  import tempfile
8
  from einops import rearrange
 
9
  from vocos import Vocos
10
  from pydub import AudioSegment
11
  from model import CFM, UNetT, DiT, MMDiT
12
  from cached_path import cached_path
13
  from model.utils import (
14
+ load_checkpoint,
15
  get_tokenizer,
16
  convert_char_to_pinyin,
17
  save_spectrogram,
 
51
 
52
 
53
  def load_model(exp_name, model_cls, model_cfg, ckpt_step):
54
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
55
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
 
 
56
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
57
  model = CFM(
58
  transformer=model_cls(
 
69
  vocab_char_map=vocab_char_map,
70
  ).to(device)
71
 
72
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
 
 
73
 
74
+ return model
75
 
76
 
77
  # load models
 
80
  )
81
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
82
 
83
+ F5TTS_ema_model = load_model(
84
  "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
85
  )
86
+ E2TTS_ema_model = load_model(
87
  "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
88
  )
89
 
 
103
  ref_audio = f.name
104
  if exp_name == "F5-TTS":
105
  ema_model = F5TTS_ema_model
 
106
  elif exp_name == "E2-TTS":
107
  ema_model = E2TTS_ema_model
 
108
 
109
  if not ref_text.strip():
110
  gr.Info("No reference text provided, transcribing reference audio...")
 
145
  # inference
146
  gr.Info(f"Generating audio using {exp_name}")
147
  with torch.inference_mode():
148
+ generated, _ = ema_model.sample(
149
  cond=audio,
150
  text=final_text_list,
151
  duration=duration,
 
237
 
238
 
239
  @click.command()
240
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
241
  @click.option("--host", "-H", default=None, help="Host to run the app on")
242
  @click.option(
243
  "--share",
requirements_gradio.txt CHANGED
@@ -1,3 +1,4 @@
1
  cached_path
2
- pydub
3
- click
 
 
1
  cached_path
2
+ click
3
+ gradio
4
+ pydub
test_infer_single_edit.py CHANGED
@@ -14,7 +14,7 @@ from model.utils import (
14
  save_spectrogram,
15
  )
16
 
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
 
20
  # --------------------- Dataset Settings -------------------- #
 
14
  save_spectrogram,
15
  )
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
 
19
 
20
  # --------------------- Dataset Settings -------------------- #