Zhikang Niu SWivid commited on
Commit
1d923b1
·
1 Parent(s): 7f3b802

Add flash_attn2 support attn_mask, minor fixes (#1066)

Browse files

* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <[email protected]>

src/f5_tts/configs/F5TTS_Base.yaml CHANGED
@@ -31,6 +31,8 @@ model:
31
  text_mask_padding: False
32
  conv_layers: 4
33
  pe_attn_head: 1
 
 
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
 
31
  text_mask_padding: False
32
  conv_layers: 4
33
  pe_attn_head: 1
34
+ attn_backend: torch # torch | flash_attn
35
+ attn_mask_enabled: False
36
  checkpoint_activations: False # recompute activations and save memory for extra compute
37
  mel_spec:
38
  target_sample_rate: 24000
src/f5_tts/configs/F5TTS_Small.yaml CHANGED
@@ -31,6 +31,8 @@ model:
31
  text_mask_padding: False
32
  conv_layers: 4
33
  pe_attn_head: 1
 
 
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
 
31
  text_mask_padding: False
32
  conv_layers: 4
33
  pe_attn_head: 1
34
+ attn_backend: torch # torch | flash_attn
35
+ attn_mask_enabled: False
36
  checkpoint_activations: False # recompute activations and save memory for extra compute
37
  mel_spec:
38
  target_sample_rate: 24000
src/f5_tts/configs/F5TTS_v1_Base.yaml CHANGED
@@ -32,6 +32,8 @@ model:
32
  qk_norm: null # null | rms_norm
33
  conv_layers: 4
34
  pe_attn_head: null
 
 
35
  checkpoint_activations: False # recompute activations and save memory for extra compute
36
  mel_spec:
37
  target_sample_rate: 24000
 
32
  qk_norm: null # null | rms_norm
33
  conv_layers: 4
34
  pe_attn_head: null
35
+ attn_backend: torch # torch | flash_attn
36
+ attn_mask_enabled: False
37
  checkpoint_activations: False # recompute activations and save memory for extra compute
38
  mel_spec:
39
  target_sample_rate: 24000
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -148,10 +148,15 @@ def main():
148
  vocab_char_map=vocab_char_map,
149
  ).to(device)
150
 
151
- ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
152
- if not os.path.exists(ckpt_path):
 
 
 
 
153
  print("Loading from self-organized training checkpoints rather than released pretrained.")
154
  ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
 
155
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
156
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
157
 
 
148
  vocab_char_map=vocab_char_map,
149
  ).to(device)
150
 
151
+ ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
152
+ if os.path.exists(ckpt_prefix + ".pt"):
153
+ ckpt_path = ckpt_prefix + ".pt"
154
+ elif os.path.exists(ckpt_prefix + ".safetensors"):
155
+ ckpt_path = ckpt_prefix + ".safetensors"
156
+ else:
157
  print("Loading from self-organized training checkpoints rather than released pretrained.")
158
  ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
159
+
160
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
161
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
162
 
src/f5_tts/eval/utils_eval.py CHANGED
@@ -126,8 +126,13 @@ def get_inference_prompt(
126
  else:
127
  text_list = text
128
 
 
 
 
 
129
  # Duration, mel frame length
130
- ref_mel_len = ref_audio.shape[-1] // hop_length
 
131
  if use_truth_duration:
132
  gt_audio, gt_sr = torchaudio.load(gt_wav)
133
  if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
142
  gen_text_len = len(gt_text.encode("utf-8"))
143
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
144
 
145
- # to mel spectrogram
146
- ref_mel = mel_spectrogram(ref_audio)
147
- ref_mel = ref_mel.squeeze(0)
148
-
149
  # deal with batch
150
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
151
  assert min_tokens <= total_mel_len <= max_tokens, (
 
126
  else:
127
  text_list = text
128
 
129
+ # to mel spectrogram
130
+ ref_mel = mel_spectrogram(ref_audio)
131
+ ref_mel = ref_mel.squeeze(0)
132
+
133
  # Duration, mel frame length
134
+ ref_mel_len = ref_mel.shape[-1]
135
+
136
  if use_truth_duration:
137
  gt_audio, gt_sr = torchaudio.load(gt_wav)
138
  if gt_sr != target_sample_rate:
 
147
  gen_text_len = len(gt_text.encode("utf-8"))
148
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
149
 
 
 
 
 
150
  # deal with batch
151
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
152
  assert min_tokens <= total_mel_len <= max_tokens, (
src/f5_tts/model/backbones/dit.py CHANGED
@@ -116,6 +116,8 @@ class DiT(nn.Module):
116
  qk_norm=None,
117
  conv_layers=0,
118
  pe_attn_head=None,
 
 
119
  long_skip_connection=False,
120
  checkpoint_activations=False,
121
  ):
@@ -145,6 +147,8 @@ class DiT(nn.Module):
145
  dropout=dropout,
146
  qk_norm=qk_norm,
147
  pe_attn_head=pe_attn_head,
 
 
148
  )
149
  for _ in range(depth)
150
  ]
 
116
  qk_norm=None,
117
  conv_layers=0,
118
  pe_attn_head=None,
119
+ attn_backend="torch", # "torch" | "flash_attn"
120
+ attn_mask_enabled=False,
121
  long_skip_connection=False,
122
  checkpoint_activations=False,
123
  ):
 
147
  dropout=dropout,
148
  qk_norm=qk_norm,
149
  pe_attn_head=pe_attn_head,
150
+ attn_backend=attn_backend,
151
+ attn_mask_enabled=attn_mask_enabled,
152
  )
153
  for _ in range(depth)
154
  ]
src/f5_tts/model/cfm.py CHANGED
@@ -275,10 +275,9 @@ class CFM(nn.Module):
275
  else:
276
  drop_text = False
277
 
278
- # if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
279
- # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
280
  pred = self.transformer(
281
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
282
  )
283
 
284
  # flow matching loss
 
275
  else:
276
  drop_text = False
277
 
278
+ # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
 
279
  pred = self.transformer(
280
+ x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
281
  )
282
 
283
  # flow matching loss
src/f5_tts/model/dataset.py CHANGED
@@ -312,7 +312,7 @@ def collate_fn(batch):
312
  max_mel_length = mel_lengths.amax()
313
 
314
  padded_mel_specs = []
315
- for spec in mel_specs: # TODO. maybe records mask for attention here
316
  padding = (0, max_mel_length - spec.size(-1))
317
  padded_spec = F.pad(spec, padding, value=0)
318
  padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
324
 
325
  return dict(
326
  mel=mel_specs,
327
- mel_lengths=mel_lengths,
328
  text=text,
329
  text_lengths=text_lengths,
330
  )
 
312
  max_mel_length = mel_lengths.amax()
313
 
314
  padded_mel_specs = []
315
+ for spec in mel_specs:
316
  padding = (0, max_mel_length - spec.size(-1))
317
  padded_spec = F.pad(spec, padding, value=0)
318
  padded_mel_specs.append(padded_spec)
 
324
 
325
  return dict(
326
  mel=mel_specs,
327
+ mel_lengths=mel_lengths, # records for padding mask
328
  text=text,
329
  text_lengths=text_lengths,
330
  )
src/f5_tts/model/modules.py CHANGED
@@ -6,6 +6,7 @@ nt - text sequence
6
  nw - raw wave length
7
  d - dimension
8
  """
 
9
 
10
  from __future__ import annotations
11
 
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
19
  from torch import nn
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
 
 
22
 
23
  # raw wav to mel spec
24
 
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
175
  nn.Mish(),
176
  )
177
 
178
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
179
  if mask is not None:
180
  mask = mask[..., None]
181
  x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
417
 
418
  def forward(
419
  self,
420
- x: float["b n d"], # noised input x # noqa: F722
421
- c: float["b n d"] = None, # context c # noqa: F722
422
- mask: bool["b n"] | None = None, # noqa: F722
423
  rope=None, # rotary position embedding for x
424
  c_rope=None, # rotary position embedding for c
425
  ) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
431
 
432
  # Attention processor
433
 
 
 
 
 
434
 
435
  class AttnProcessor:
436
  def __init__(
437
  self,
438
  pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
 
 
439
  ):
 
 
 
440
  self.pe_attn_head = pe_attn_head
 
 
441
 
442
  def __call__(
443
  self,
444
  attn: Attention,
445
- x: float["b n d"], # noised input x # noqa: F722
446
- mask: bool["b n"] | None = None, # noqa: F722
447
  rope=None, # rotary position embedding
448
  ) -> torch.FloatTensor:
449
  batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
479
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
480
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
481
 
482
- # mask. e.g. inference got a batch with different target durations, mask out the padding
483
- if mask is not None:
484
- attn_mask = mask
485
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
486
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
487
- else:
488
- attn_mask = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
491
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
492
  x = x.to(query.dtype)
493
 
494
  # linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
514
  def __call__(
515
  self,
516
  attn: Attention,
517
- x: float["b n d"], # noised input x # noqa: F722
518
- c: float["b nt d"] = None, # context c, here text # noqa: F722
519
- mask: bool["b n"] | None = None, # noqa: F722
520
  rope=None, # rotary position embedding for x
521
  c_rope=None, # rotary position embedding for c
522
  ) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
608
 
609
 
610
  class DiTBlock(nn.Module):
611
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
 
 
 
 
 
 
 
 
 
 
 
612
  super().__init__()
613
 
614
  self.attn_norm = AdaLayerNorm(dim)
615
  self.attn = Attention(
616
- processor=AttnProcessor(pe_attn_head=pe_attn_head),
 
 
 
 
617
  dim=dim,
618
  heads=heads,
619
  dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
724
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
725
  self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
726
 
727
- def forward(self, timestep: float["b"]): # noqa: F821
728
  time_hidden = self.time_embed(timestep)
729
  time_hidden = time_hidden.to(timestep.dtype)
730
  time = self.time_mlp(time_hidden) # b d
 
6
  nw - raw wave length
7
  d - dimension
8
  """
9
+ # flake8: noqa
10
 
11
  from __future__ import annotations
12
 
 
20
  from torch import nn
21
  from x_transformers.x_transformers import apply_rotary_pos_emb
22
 
23
+ from f5_tts.model.utils import is_package_available
24
+
25
 
26
  # raw wav to mel spec
27
 
 
178
  nn.Mish(),
179
  )
180
 
181
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
182
  if mask is not None:
183
  mask = mask[..., None]
184
  x = x.masked_fill(~mask, 0.0)
 
420
 
421
  def forward(
422
  self,
423
+ x: float["b n d"], # noised input x
424
+ c: float["b n d"] = None, # context c
425
+ mask: bool["b n"] | None = None,
426
  rope=None, # rotary position embedding for x
427
  c_rope=None, # rotary position embedding for c
428
  ) -> torch.Tensor:
 
434
 
435
  # Attention processor
436
 
437
+ if is_package_available("flash_attn"):
438
+ from flash_attn.bert_padding import pad_input, unpad_input
439
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
440
+
441
 
442
  class AttnProcessor:
443
  def __init__(
444
  self,
445
  pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
446
+ attn_backend: str = "flash_attn",
447
+ attn_mask_enabled: bool = True,
448
  ):
449
+ if attn_backend == "flash_attn":
450
+ assert is_package_available("flash_attn"), "Please install flash-attn first."
451
+
452
  self.pe_attn_head = pe_attn_head
453
+ self.attn_backend = attn_backend
454
+ self.attn_mask_enabled = attn_mask_enabled
455
 
456
  def __call__(
457
  self,
458
  attn: Attention,
459
+ x: float["b n d"], # noised input x
460
+ mask: bool["b n"] | None = None,
461
  rope=None, # rotary position embedding
462
  ) -> torch.FloatTensor:
463
  batch_size = x.shape[0]
 
493
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
494
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
495
 
496
+ if self.attn_backend == "torch":
497
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
498
+ if self.attn_mask_enabled and mask is not None:
499
+ attn_mask = mask
500
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
501
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
502
+ else:
503
+ attn_mask = None
504
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
505
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
506
+
507
+ elif self.attn_backend == "flash_attn":
508
+ query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
509
+ key = key.transpose(1, 2)
510
+ value = value.transpose(1, 2)
511
+ if self.attn_mask_enabled and mask is not None:
512
+ query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
513
+ key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
514
+ value, _, _, _, _ = unpad_input(value, mask)
515
+ x = flash_attn_varlen_func(
516
+ query,
517
+ key,
518
+ value,
519
+ q_cu_seqlens,
520
+ k_cu_seqlens,
521
+ q_max_seqlen_in_batch,
522
+ k_max_seqlen_in_batch,
523
+ )
524
+ x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
525
+ x = x.reshape(batch_size, -1, attn.heads * head_dim)
526
+ else:
527
+ x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
528
+ x = x.reshape(batch_size, -1, attn.heads * head_dim)
529
 
 
 
530
  x = x.to(query.dtype)
531
 
532
  # linear proj
 
552
  def __call__(
553
  self,
554
  attn: Attention,
555
+ x: float["b n d"], # noised input x
556
+ c: float["b nt d"] = None, # context c, here text
557
+ mask: bool["b n"] | None = None,
558
  rope=None, # rotary position embedding for x
559
  c_rope=None, # rotary position embedding for c
560
  ) -> torch.FloatTensor:
 
646
 
647
 
648
  class DiTBlock(nn.Module):
649
+ def __init__(
650
+ self,
651
+ dim,
652
+ heads,
653
+ dim_head,
654
+ ff_mult=4,
655
+ dropout=0.1,
656
+ qk_norm=None,
657
+ pe_attn_head=None,
658
+ attn_backend="flash_attn",
659
+ attn_mask_enabled=True,
660
+ ):
661
  super().__init__()
662
 
663
  self.attn_norm = AdaLayerNorm(dim)
664
  self.attn = Attention(
665
+ processor=AttnProcessor(
666
+ pe_attn_head=pe_attn_head,
667
+ attn_backend=attn_backend,
668
+ attn_mask_enabled=attn_mask_enabled,
669
+ ),
670
  dim=dim,
671
  heads=heads,
672
  dim_head=dim_head,
 
777
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
778
  self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
779
 
780
+ def forward(self, timestep: float["b"]):
781
  time_hidden = self.time_embed(timestep)
782
  time_hidden = time_hidden.to(timestep.dtype)
783
  time = self.time_mlp(time_hidden) # b d
src/f5_tts/model/utils.py CHANGED
@@ -35,6 +35,16 @@ def default(v, d):
35
  return v if exists(v) else d
36
 
37
 
 
 
 
 
 
 
 
 
 
 
38
  # tensor helpers
39
 
40
 
 
35
  return v if exists(v) else d
36
 
37
 
38
+ def is_package_available(package_name: str) -> bool:
39
+ try:
40
+ import importlib
41
+
42
+ package_exists = importlib.util.find_spec(package_name) is not None
43
+ return package_exists
44
+ except Exception:
45
+ return False
46
+
47
+
48
  # tensor helpers
49
 
50