SWivid commited on
Commit
3bfc543
·
1 Parent(s): 8c7215c

backward compatibility

Browse files
Files changed (1) hide show
  1. src/f5_tts/model/modules.py +2 -2
src/f5_tts/model/modules.py CHANGED
@@ -443,7 +443,7 @@ class AttnProcessor:
443
  def __init__(
444
  self,
445
  pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
446
- attn_backend: str = "flash_attn",
447
  attn_mask_enabled: bool = True,
448
  ):
449
  if attn_backend == "flash_attn":
@@ -655,7 +655,7 @@ class DiTBlock(nn.Module):
655
  dropout=0.1,
656
  qk_norm=None,
657
  pe_attn_head=None,
658
- attn_backend="flash_attn",
659
  attn_mask_enabled=True,
660
  ):
661
  super().__init__()
 
443
  def __init__(
444
  self,
445
  pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
446
+ attn_backend: str = "torch", # "torch" or "flash_attn"
447
  attn_mask_enabled: bool = True,
448
  ):
449
  if attn_backend == "flash_attn":
 
655
  dropout=0.1,
656
  qk_norm=None,
657
  pe_attn_head=None,
658
+ attn_backend="torch", # "torch" or "flash_attn"
659
  attn_mask_enabled=True,
660
  ):
661
  super().__init__()