backward compatibility
Browse files
src/f5_tts/model/modules.py
CHANGED
@@ -443,7 +443,7 @@ class AttnProcessor:
|
|
443 |
def __init__(
|
444 |
self,
|
445 |
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
446 |
-
attn_backend: str = "flash_attn"
|
447 |
attn_mask_enabled: bool = True,
|
448 |
):
|
449 |
if attn_backend == "flash_attn":
|
@@ -655,7 +655,7 @@ class DiTBlock(nn.Module):
|
|
655 |
dropout=0.1,
|
656 |
qk_norm=None,
|
657 |
pe_attn_head=None,
|
658 |
-
attn_backend="flash_attn"
|
659 |
attn_mask_enabled=True,
|
660 |
):
|
661 |
super().__init__()
|
|
|
443 |
def __init__(
|
444 |
self,
|
445 |
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
446 |
+
attn_backend: str = "torch", # "torch" or "flash_attn"
|
447 |
attn_mask_enabled: bool = True,
|
448 |
):
|
449 |
if attn_backend == "flash_attn":
|
|
|
655 |
dropout=0.1,
|
656 |
qk_norm=None,
|
657 |
pe_attn_head=None,
|
658 |
+
attn_backend="torch", # "torch" or "flash_attn"
|
659 |
attn_mask_enabled=True,
|
660 |
):
|
661 |
super().__init__()
|