Spaces:
Build error
Build error
Update hymm_sp/modules/models_audio.py
Browse files
hymm_sp/modules/models_audio.py
CHANGED
@@ -6,7 +6,13 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
from diffusers.models import ModelMixin
|
8 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
from .activation_layers import get_activation_layer
|
12 |
from .norm_layers import get_norm_layer
|
@@ -169,6 +175,7 @@ class DoubleStreamBlock(nn.Module):
|
|
169 |
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
170 |
for x in [q, k, v]
|
171 |
]
|
|
|
172 |
attn = flash_attn_varlen_func(
|
173 |
q,
|
174 |
k,
|
|
|
6 |
import torch.nn.functional as F
|
7 |
from diffusers.models import ModelMixin
|
8 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
try:
|
10 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
11 |
+
except ImportError:
|
12 |
+
print("⚠️ flash_attn not available — using fallback.")
|
13 |
+
flash_attn_varlen_func = None # or create a dummy function if needed
|
14 |
+
|
15 |
+
|
16 |
|
17 |
from .activation_layers import get_activation_layer
|
18 |
from .norm_layers import get_norm_layer
|
|
|
175 |
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
176 |
for x in [q, k, v]
|
177 |
]
|
178 |
+
|
179 |
attn = flash_attn_varlen_func(
|
180 |
q,
|
181 |
k,
|