Spaces:
Running
Running
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- app.py +1 -0
- inference-cli.py +4 -3
- model/cfm.py +4 -2
- model/modules.py +1 -0
- model/trainer.py +7 -2
- model/utils.py +11 -11
- requirements.txt +1 -0
- speech_edit.py +4 -3
app.py
CHANGED
|
@@ -173,6 +173,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
| 173 |
sway_sampling_coef=sway_sampling_coef,
|
| 174 |
)
|
| 175 |
|
|
|
|
| 176 |
generated = generated[:, ref_audio_len:, :]
|
| 177 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
| 178 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
|
|
|
| 173 |
sway_sampling_coef=sway_sampling_coef,
|
| 174 |
)
|
| 175 |
|
| 176 |
+
generated = generated.to(torch.float32)
|
| 177 |
generated = generated[:, ref_audio_len:, :]
|
| 178 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
| 179 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
inference-cli.py
CHANGED
|
@@ -145,9 +145,9 @@ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
|
|
| 145 |
else:
|
| 146 |
tokenizer="custom"
|
| 147 |
|
| 148 |
-
print("\nvocab : ",vocab_file,tokenizer)
|
| 149 |
-
print("tokenizer : ",tokenizer)
|
| 150 |
-
print("model : ",ckpt_path,"\n")
|
| 151 |
|
| 152 |
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
| 153 |
model = CFM(
|
|
@@ -265,6 +265,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_voca
|
|
| 265 |
sway_sampling_coef=sway_sampling_coef,
|
| 266 |
)
|
| 267 |
|
|
|
|
| 268 |
generated = generated[:, ref_audio_len:, :]
|
| 269 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
| 270 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
|
|
|
| 145 |
else:
|
| 146 |
tokenizer="custom"
|
| 147 |
|
| 148 |
+
print("\nvocab : ", vocab_file,tokenizer)
|
| 149 |
+
print("tokenizer : ", tokenizer)
|
| 150 |
+
print("model : ", ckpt_path,"\n")
|
| 151 |
|
| 152 |
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
| 153 |
model = CFM(
|
|
|
|
| 265 |
sway_sampling_coef=sway_sampling_coef,
|
| 266 |
)
|
| 267 |
|
| 268 |
+
generated = generated.to(torch.float32)
|
| 269 |
generated = generated[:, ref_audio_len:, :]
|
| 270 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
| 271 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
model/cfm.py
CHANGED
|
@@ -99,6 +99,8 @@ class CFM(nn.Module):
|
|
| 99 |
):
|
| 100 |
self.eval()
|
| 101 |
|
|
|
|
|
|
|
| 102 |
# raw wave
|
| 103 |
|
| 104 |
if cond.ndim == 2:
|
|
@@ -175,7 +177,7 @@ class CFM(nn.Module):
|
|
| 175 |
for dur in duration:
|
| 176 |
if exists(seed):
|
| 177 |
torch.manual_seed(seed)
|
| 178 |
-
y0.append(torch.randn(dur, self.num_channels, device = self.device))
|
| 179 |
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
| 180 |
|
| 181 |
t_start = 0
|
|
@@ -186,7 +188,7 @@ class CFM(nn.Module):
|
|
| 186 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
| 187 |
steps = int(steps * (1 - t_start))
|
| 188 |
|
| 189 |
-
t = torch.linspace(t_start, 1, steps, device = self.device)
|
| 190 |
if sway_sampling_coef is not None:
|
| 191 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
| 192 |
|
|
|
|
| 99 |
):
|
| 100 |
self.eval()
|
| 101 |
|
| 102 |
+
cond = cond.half()
|
| 103 |
+
|
| 104 |
# raw wave
|
| 105 |
|
| 106 |
if cond.ndim == 2:
|
|
|
|
| 177 |
for dur in duration:
|
| 178 |
if exists(seed):
|
| 179 |
torch.manual_seed(seed)
|
| 180 |
+
y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
|
| 181 |
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
| 182 |
|
| 183 |
t_start = 0
|
|
|
|
| 188 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
| 189 |
steps = int(steps * (1 - t_start))
|
| 190 |
|
| 191 |
+
t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
|
| 192 |
if sway_sampling_coef is not None:
|
| 193 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
| 194 |
|
model/modules.py
CHANGED
|
@@ -571,5 +571,6 @@ class TimestepEmbedding(nn.Module):
|
|
| 571 |
|
| 572 |
def forward(self, timestep: float['b']):
|
| 573 |
time_hidden = self.time_embed(timestep)
|
|
|
|
| 574 |
time = self.time_mlp(time_hidden) # b d
|
| 575 |
return time
|
|
|
|
| 571 |
|
| 572 |
def forward(self, timestep: float['b']):
|
| 573 |
time_hidden = self.time_embed(timestep)
|
| 574 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 575 |
time = self.time_mlp(time_hidden) # b d
|
| 576 |
return time
|
model/trainer.py
CHANGED
|
@@ -45,7 +45,8 @@ class Trainer:
|
|
| 45 |
wandb_resume_id: str = None,
|
| 46 |
last_per_steps = None,
|
| 47 |
accelerate_kwargs: dict = dict(),
|
| 48 |
-
ema_kwargs: dict = dict()
|
|
|
|
| 49 |
):
|
| 50 |
|
| 51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
|
@@ -107,7 +108,11 @@ class Trainer:
|
|
| 107 |
|
| 108 |
self.duration_predictor = duration_predictor
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
self.model, self.optimizer = self.accelerator.prepare(
|
| 112 |
self.model, self.optimizer
|
| 113 |
)
|
|
|
|
| 45 |
wandb_resume_id: str = None,
|
| 46 |
last_per_steps = None,
|
| 47 |
accelerate_kwargs: dict = dict(),
|
| 48 |
+
ema_kwargs: dict = dict(),
|
| 49 |
+
bnb_optimizer: bool = False,
|
| 50 |
):
|
| 51 |
|
| 52 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
|
|
|
| 108 |
|
| 109 |
self.duration_predictor = duration_predictor
|
| 110 |
|
| 111 |
+
if bnb_optimizer:
|
| 112 |
+
import bitsandbytes as bnb
|
| 113 |
+
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
| 114 |
+
else:
|
| 115 |
+
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
| 116 |
self.model, self.optimizer = self.accelerator.prepare(
|
| 117 |
self.model, self.optimizer
|
| 118 |
)
|
model/utils.py
CHANGED
|
@@ -557,23 +557,23 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
| 557 |
# load model checkpoint for inference
|
| 558 |
|
| 559 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
| 560 |
-
|
| 561 |
|
| 562 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 563 |
if ckpt_type == "safetensors":
|
| 564 |
from safetensors.torch import load_file
|
| 565 |
-
checkpoint = load_file(ckpt_path
|
| 566 |
else:
|
| 567 |
-
checkpoint = torch.load(ckpt_path, weights_only=True
|
| 568 |
|
| 569 |
-
if use_ema
|
| 570 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
| 571 |
if ckpt_type == "safetensors":
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
ema_model.copy_params_from_ema_to_model()
|
| 576 |
else:
|
|
|
|
|
|
|
| 577 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 578 |
-
|
| 579 |
-
return model
|
|
|
|
| 557 |
# load model checkpoint for inference
|
| 558 |
|
| 559 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
| 560 |
+
model = model.half()
|
| 561 |
|
| 562 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 563 |
if ckpt_type == "safetensors":
|
| 564 |
from safetensors.torch import load_file
|
| 565 |
+
checkpoint = load_file(ckpt_path)
|
| 566 |
else:
|
| 567 |
+
checkpoint = torch.load(ckpt_path, weights_only=True)
|
| 568 |
|
| 569 |
+
if use_ema:
|
|
|
|
| 570 |
if ckpt_type == "safetensors":
|
| 571 |
+
checkpoint = {'ema_model_state_dict': checkpoint}
|
| 572 |
+
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
| 573 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
| 574 |
else:
|
| 575 |
+
if ckpt_type == "safetensors":
|
| 576 |
+
checkpoint = {'model_state_dict': checkpoint}
|
| 577 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 578 |
+
|
| 579 |
+
return model.to(device)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
accelerate>=0.33.0
|
|
|
|
| 2 |
cached_path
|
| 3 |
click
|
| 4 |
datasets
|
|
|
|
| 1 |
accelerate>=0.33.0
|
| 2 |
+
bitsandbytes>0.37.0
|
| 3 |
cached_path
|
| 4 |
click
|
| 5 |
datasets
|
speech_edit.py
CHANGED
|
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
|
|
| 49 |
model_cls = UNetT
|
| 50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
| 51 |
|
| 52 |
-
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.
|
| 53 |
output_dir = "tests"
|
| 54 |
|
| 55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
|
@@ -172,12 +172,13 @@ with torch.inference_mode():
|
|
| 172 |
print(f"Generated mel: {generated.shape}")
|
| 173 |
|
| 174 |
# Final result
|
|
|
|
| 175 |
generated = generated[:, ref_audio_len:, :]
|
| 176 |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
| 177 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
| 178 |
if rms < target_rms:
|
| 179 |
generated_wave = generated_wave * rms / target_rms
|
| 180 |
|
| 181 |
-
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/
|
| 182 |
-
torchaudio.save(f"{output_dir}/
|
| 183 |
print(f"Generated wav: {generated_wave.shape}")
|
|
|
|
| 49 |
model_cls = UNetT
|
| 50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
| 51 |
|
| 52 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
| 53 |
output_dir = "tests"
|
| 54 |
|
| 55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
|
|
|
| 172 |
print(f"Generated mel: {generated.shape}")
|
| 173 |
|
| 174 |
# Final result
|
| 175 |
+
generated = generated.to(torch.float32)
|
| 176 |
generated = generated[:, ref_audio_len:, :]
|
| 177 |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
| 178 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
| 179 |
if rms < target_rms:
|
| 180 |
generated_wave = generated_wave * rms / target_rms
|
| 181 |
|
| 182 |
+
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
| 183 |
+
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
|
| 184 |
print(f"Generated wav: {generated_wave.shape}")
|