Add flash_attn2 support attn_mask, minor fixes (#1066)
Browse files* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel
---------
Co-authored-by: SWivid <[email protected]>
- src/f5_tts/configs/F5TTS_Base.yaml +2 -0
- src/f5_tts/configs/F5TTS_Small.yaml +2 -0
- src/f5_tts/configs/F5TTS_v1_Base.yaml +2 -0
- src/f5_tts/eval/eval_infer_batch.py +7 -2
- src/f5_tts/eval/utils_eval.py +6 -5
- src/f5_tts/model/backbones/dit.py +4 -0
- src/f5_tts/model/cfm.py +2 -3
- src/f5_tts/model/dataset.py +2 -2
- src/f5_tts/model/modules.py +74 -21
- src/f5_tts/model/utils.py +10 -0
src/f5_tts/configs/F5TTS_Base.yaml
CHANGED
@@ -31,6 +31,8 @@ model:
|
|
31 |
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
pe_attn_head: 1
|
|
|
|
|
34 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
mel_spec:
|
36 |
target_sample_rate: 24000
|
|
|
31 |
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
pe_attn_head: 1
|
34 |
+
attn_backend: torch # torch | flash_attn
|
35 |
+
attn_mask_enabled: False
|
36 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
37 |
mel_spec:
|
38 |
target_sample_rate: 24000
|
src/f5_tts/configs/F5TTS_Small.yaml
CHANGED
@@ -31,6 +31,8 @@ model:
|
|
31 |
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
pe_attn_head: 1
|
|
|
|
|
34 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
mel_spec:
|
36 |
target_sample_rate: 24000
|
|
|
31 |
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
pe_attn_head: 1
|
34 |
+
attn_backend: torch # torch | flash_attn
|
35 |
+
attn_mask_enabled: False
|
36 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
37 |
mel_spec:
|
38 |
target_sample_rate: 24000
|
src/f5_tts/configs/F5TTS_v1_Base.yaml
CHANGED
@@ -32,6 +32,8 @@ model:
|
|
32 |
qk_norm: null # null | rms_norm
|
33 |
conv_layers: 4
|
34 |
pe_attn_head: null
|
|
|
|
|
35 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
36 |
mel_spec:
|
37 |
target_sample_rate: 24000
|
|
|
32 |
qk_norm: null # null | rms_norm
|
33 |
conv_layers: 4
|
34 |
pe_attn_head: null
|
35 |
+
attn_backend: torch # torch | flash_attn
|
36 |
+
attn_mask_enabled: False
|
37 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
38 |
mel_spec:
|
39 |
target_sample_rate: 24000
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -148,10 +148,15 @@ def main():
|
|
148 |
vocab_char_map=vocab_char_map,
|
149 |
).to(device)
|
150 |
|
151 |
-
|
152 |
-
if
|
|
|
|
|
|
|
|
|
153 |
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
154 |
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
|
|
155 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
156 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
157 |
|
|
|
148 |
vocab_char_map=vocab_char_map,
|
149 |
).to(device)
|
150 |
|
151 |
+
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
|
152 |
+
if os.path.exists(ckpt_prefix + ".pt"):
|
153 |
+
ckpt_path = ckpt_prefix + ".pt"
|
154 |
+
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
155 |
+
ckpt_path = ckpt_prefix + ".safetensors"
|
156 |
+
else:
|
157 |
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
158 |
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
159 |
+
|
160 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
161 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
162 |
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -126,8 +126,13 @@ def get_inference_prompt(
|
|
126 |
else:
|
127 |
text_list = text
|
128 |
|
|
|
|
|
|
|
|
|
129 |
# Duration, mel frame length
|
130 |
-
ref_mel_len =
|
|
|
131 |
if use_truth_duration:
|
132 |
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
133 |
if gt_sr != target_sample_rate:
|
@@ -142,10 +147,6 @@ def get_inference_prompt(
|
|
142 |
gen_text_len = len(gt_text.encode("utf-8"))
|
143 |
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
144 |
|
145 |
-
# to mel spectrogram
|
146 |
-
ref_mel = mel_spectrogram(ref_audio)
|
147 |
-
ref_mel = ref_mel.squeeze(0)
|
148 |
-
|
149 |
# deal with batch
|
150 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
151 |
assert min_tokens <= total_mel_len <= max_tokens, (
|
|
|
126 |
else:
|
127 |
text_list = text
|
128 |
|
129 |
+
# to mel spectrogram
|
130 |
+
ref_mel = mel_spectrogram(ref_audio)
|
131 |
+
ref_mel = ref_mel.squeeze(0)
|
132 |
+
|
133 |
# Duration, mel frame length
|
134 |
+
ref_mel_len = ref_mel.shape[-1]
|
135 |
+
|
136 |
if use_truth_duration:
|
137 |
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
138 |
if gt_sr != target_sample_rate:
|
|
|
147 |
gen_text_len = len(gt_text.encode("utf-8"))
|
148 |
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# deal with batch
|
151 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
152 |
assert min_tokens <= total_mel_len <= max_tokens, (
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -116,6 +116,8 @@ class DiT(nn.Module):
|
|
116 |
qk_norm=None,
|
117 |
conv_layers=0,
|
118 |
pe_attn_head=None,
|
|
|
|
|
119 |
long_skip_connection=False,
|
120 |
checkpoint_activations=False,
|
121 |
):
|
@@ -145,6 +147,8 @@ class DiT(nn.Module):
|
|
145 |
dropout=dropout,
|
146 |
qk_norm=qk_norm,
|
147 |
pe_attn_head=pe_attn_head,
|
|
|
|
|
148 |
)
|
149 |
for _ in range(depth)
|
150 |
]
|
|
|
116 |
qk_norm=None,
|
117 |
conv_layers=0,
|
118 |
pe_attn_head=None,
|
119 |
+
attn_backend="torch", # "torch" | "flash_attn"
|
120 |
+
attn_mask_enabled=False,
|
121 |
long_skip_connection=False,
|
122 |
checkpoint_activations=False,
|
123 |
):
|
|
|
147 |
dropout=dropout,
|
148 |
qk_norm=qk_norm,
|
149 |
pe_attn_head=pe_attn_head,
|
150 |
+
attn_backend=attn_backend,
|
151 |
+
attn_mask_enabled=attn_mask_enabled,
|
152 |
)
|
153 |
for _ in range(depth)
|
154 |
]
|
src/f5_tts/model/cfm.py
CHANGED
@@ -275,10 +275,9 @@ class CFM(nn.Module):
|
|
275 |
else:
|
276 |
drop_text = False
|
277 |
|
278 |
-
#
|
279 |
-
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
280 |
pred = self.transformer(
|
281 |
-
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
282 |
)
|
283 |
|
284 |
# flow matching loss
|
|
|
275 |
else:
|
276 |
drop_text = False
|
277 |
|
278 |
+
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
|
|
279 |
pred = self.transformer(
|
280 |
+
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
|
281 |
)
|
282 |
|
283 |
# flow matching loss
|
src/f5_tts/model/dataset.py
CHANGED
@@ -312,7 +312,7 @@ def collate_fn(batch):
|
|
312 |
max_mel_length = mel_lengths.amax()
|
313 |
|
314 |
padded_mel_specs = []
|
315 |
-
for spec in mel_specs:
|
316 |
padding = (0, max_mel_length - spec.size(-1))
|
317 |
padded_spec = F.pad(spec, padding, value=0)
|
318 |
padded_mel_specs.append(padded_spec)
|
@@ -324,7 +324,7 @@ def collate_fn(batch):
|
|
324 |
|
325 |
return dict(
|
326 |
mel=mel_specs,
|
327 |
-
mel_lengths=mel_lengths,
|
328 |
text=text,
|
329 |
text_lengths=text_lengths,
|
330 |
)
|
|
|
312 |
max_mel_length = mel_lengths.amax()
|
313 |
|
314 |
padded_mel_specs = []
|
315 |
+
for spec in mel_specs:
|
316 |
padding = (0, max_mel_length - spec.size(-1))
|
317 |
padded_spec = F.pad(spec, padding, value=0)
|
318 |
padded_mel_specs.append(padded_spec)
|
|
|
324 |
|
325 |
return dict(
|
326 |
mel=mel_specs,
|
327 |
+
mel_lengths=mel_lengths, # records for padding mask
|
328 |
text=text,
|
329 |
text_lengths=text_lengths,
|
330 |
)
|
src/f5_tts/model/modules.py
CHANGED
@@ -6,6 +6,7 @@ nt - text sequence
|
|
6 |
nw - raw wave length
|
7 |
d - dimension
|
8 |
"""
|
|
|
9 |
|
10 |
from __future__ import annotations
|
11 |
|
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
|
|
19 |
from torch import nn
|
20 |
from x_transformers.x_transformers import apply_rotary_pos_emb
|
21 |
|
|
|
|
|
22 |
|
23 |
# raw wav to mel spec
|
24 |
|
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
|
|
175 |
nn.Mish(),
|
176 |
)
|
177 |
|
178 |
-
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
179 |
if mask is not None:
|
180 |
mask = mask[..., None]
|
181 |
x = x.masked_fill(~mask, 0.0)
|
@@ -417,9 +420,9 @@ class Attention(nn.Module):
|
|
417 |
|
418 |
def forward(
|
419 |
self,
|
420 |
-
x: float["b n d"], # noised input x
|
421 |
-
c: float["b n d"] = None, # context c
|
422 |
-
mask: bool["b n"] | None = None,
|
423 |
rope=None, # rotary position embedding for x
|
424 |
c_rope=None, # rotary position embedding for c
|
425 |
) -> torch.Tensor:
|
@@ -431,19 +434,30 @@ class Attention(nn.Module):
|
|
431 |
|
432 |
# Attention processor
|
433 |
|
|
|
|
|
|
|
|
|
434 |
|
435 |
class AttnProcessor:
|
436 |
def __init__(
|
437 |
self,
|
438 |
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
|
|
|
|
439 |
):
|
|
|
|
|
|
|
440 |
self.pe_attn_head = pe_attn_head
|
|
|
|
|
441 |
|
442 |
def __call__(
|
443 |
self,
|
444 |
attn: Attention,
|
445 |
-
x: float["b n d"], # noised input x
|
446 |
-
mask: bool["b n"] | None = None,
|
447 |
rope=None, # rotary position embedding
|
448 |
) -> torch.FloatTensor:
|
449 |
batch_size = x.shape[0]
|
@@ -479,16 +493,40 @@ class AttnProcessor:
|
|
479 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
480 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
481 |
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
-
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
491 |
-
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
492 |
x = x.to(query.dtype)
|
493 |
|
494 |
# linear proj
|
@@ -514,9 +552,9 @@ class JointAttnProcessor:
|
|
514 |
def __call__(
|
515 |
self,
|
516 |
attn: Attention,
|
517 |
-
x: float["b n d"], # noised input x
|
518 |
-
c: float["b nt d"] = None, # context c, here text
|
519 |
-
mask: bool["b n"] | None = None,
|
520 |
rope=None, # rotary position embedding for x
|
521 |
c_rope=None, # rotary position embedding for c
|
522 |
) -> torch.FloatTensor:
|
@@ -608,12 +646,27 @@ class JointAttnProcessor:
|
|
608 |
|
609 |
|
610 |
class DiTBlock(nn.Module):
|
611 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
super().__init__()
|
613 |
|
614 |
self.attn_norm = AdaLayerNorm(dim)
|
615 |
self.attn = Attention(
|
616 |
-
processor=AttnProcessor(
|
|
|
|
|
|
|
|
|
617 |
dim=dim,
|
618 |
heads=heads,
|
619 |
dim_head=dim_head,
|
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
|
|
724 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
725 |
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
726 |
|
727 |
-
def forward(self, timestep: float["b"]):
|
728 |
time_hidden = self.time_embed(timestep)
|
729 |
time_hidden = time_hidden.to(timestep.dtype)
|
730 |
time = self.time_mlp(time_hidden) # b d
|
|
|
6 |
nw - raw wave length
|
7 |
d - dimension
|
8 |
"""
|
9 |
+
# flake8: noqa
|
10 |
|
11 |
from __future__ import annotations
|
12 |
|
|
|
20 |
from torch import nn
|
21 |
from x_transformers.x_transformers import apply_rotary_pos_emb
|
22 |
|
23 |
+
from f5_tts.model.utils import is_package_available
|
24 |
+
|
25 |
|
26 |
# raw wav to mel spec
|
27 |
|
|
|
178 |
nn.Mish(),
|
179 |
)
|
180 |
|
181 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
182 |
if mask is not None:
|
183 |
mask = mask[..., None]
|
184 |
x = x.masked_fill(~mask, 0.0)
|
|
|
420 |
|
421 |
def forward(
|
422 |
self,
|
423 |
+
x: float["b n d"], # noised input x
|
424 |
+
c: float["b n d"] = None, # context c
|
425 |
+
mask: bool["b n"] | None = None,
|
426 |
rope=None, # rotary position embedding for x
|
427 |
c_rope=None, # rotary position embedding for c
|
428 |
) -> torch.Tensor:
|
|
|
434 |
|
435 |
# Attention processor
|
436 |
|
437 |
+
if is_package_available("flash_attn"):
|
438 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
439 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
440 |
+
|
441 |
|
442 |
class AttnProcessor:
|
443 |
def __init__(
|
444 |
self,
|
445 |
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
446 |
+
attn_backend: str = "flash_attn",
|
447 |
+
attn_mask_enabled: bool = True,
|
448 |
):
|
449 |
+
if attn_backend == "flash_attn":
|
450 |
+
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
451 |
+
|
452 |
self.pe_attn_head = pe_attn_head
|
453 |
+
self.attn_backend = attn_backend
|
454 |
+
self.attn_mask_enabled = attn_mask_enabled
|
455 |
|
456 |
def __call__(
|
457 |
self,
|
458 |
attn: Attention,
|
459 |
+
x: float["b n d"], # noised input x
|
460 |
+
mask: bool["b n"] | None = None,
|
461 |
rope=None, # rotary position embedding
|
462 |
) -> torch.FloatTensor:
|
463 |
batch_size = x.shape[0]
|
|
|
493 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
494 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
495 |
|
496 |
+
if self.attn_backend == "torch":
|
497 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
498 |
+
if self.attn_mask_enabled and mask is not None:
|
499 |
+
attn_mask = mask
|
500 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
501 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
502 |
+
else:
|
503 |
+
attn_mask = None
|
504 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
505 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
506 |
+
|
507 |
+
elif self.attn_backend == "flash_attn":
|
508 |
+
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
509 |
+
key = key.transpose(1, 2)
|
510 |
+
value = value.transpose(1, 2)
|
511 |
+
if self.attn_mask_enabled and mask is not None:
|
512 |
+
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
|
513 |
+
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
514 |
+
value, _, _, _, _ = unpad_input(value, mask)
|
515 |
+
x = flash_attn_varlen_func(
|
516 |
+
query,
|
517 |
+
key,
|
518 |
+
value,
|
519 |
+
q_cu_seqlens,
|
520 |
+
k_cu_seqlens,
|
521 |
+
q_max_seqlen_in_batch,
|
522 |
+
k_max_seqlen_in_batch,
|
523 |
+
)
|
524 |
+
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
|
525 |
+
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
526 |
+
else:
|
527 |
+
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
528 |
+
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
529 |
|
|
|
|
|
530 |
x = x.to(query.dtype)
|
531 |
|
532 |
# linear proj
|
|
|
552 |
def __call__(
|
553 |
self,
|
554 |
attn: Attention,
|
555 |
+
x: float["b n d"], # noised input x
|
556 |
+
c: float["b nt d"] = None, # context c, here text
|
557 |
+
mask: bool["b n"] | None = None,
|
558 |
rope=None, # rotary position embedding for x
|
559 |
c_rope=None, # rotary position embedding for c
|
560 |
) -> torch.FloatTensor:
|
|
|
646 |
|
647 |
|
648 |
class DiTBlock(nn.Module):
|
649 |
+
def __init__(
|
650 |
+
self,
|
651 |
+
dim,
|
652 |
+
heads,
|
653 |
+
dim_head,
|
654 |
+
ff_mult=4,
|
655 |
+
dropout=0.1,
|
656 |
+
qk_norm=None,
|
657 |
+
pe_attn_head=None,
|
658 |
+
attn_backend="flash_attn",
|
659 |
+
attn_mask_enabled=True,
|
660 |
+
):
|
661 |
super().__init__()
|
662 |
|
663 |
self.attn_norm = AdaLayerNorm(dim)
|
664 |
self.attn = Attention(
|
665 |
+
processor=AttnProcessor(
|
666 |
+
pe_attn_head=pe_attn_head,
|
667 |
+
attn_backend=attn_backend,
|
668 |
+
attn_mask_enabled=attn_mask_enabled,
|
669 |
+
),
|
670 |
dim=dim,
|
671 |
heads=heads,
|
672 |
dim_head=dim_head,
|
|
|
777 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
778 |
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
779 |
|
780 |
+
def forward(self, timestep: float["b"]):
|
781 |
time_hidden = self.time_embed(timestep)
|
782 |
time_hidden = time_hidden.to(timestep.dtype)
|
783 |
time = self.time_mlp(time_hidden) # b d
|
src/f5_tts/model/utils.py
CHANGED
@@ -35,6 +35,16 @@ def default(v, d):
|
|
35 |
return v if exists(v) else d
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# tensor helpers
|
39 |
|
40 |
|
|
|
35 |
return v if exists(v) else d
|
36 |
|
37 |
|
38 |
+
def is_package_available(package_name: str) -> bool:
|
39 |
+
try:
|
40 |
+
import importlib
|
41 |
+
|
42 |
+
package_exists = importlib.util.find_spec(package_name) is not None
|
43 |
+
return package_exists
|
44 |
+
except Exception:
|
45 |
+
return False
|
46 |
+
|
47 |
+
|
48 |
# tensor helpers
|
49 |
|
50 |
|