rahul7star commited on
Commit
f692523
·
verified ·
1 Parent(s): 31b7e4b

Update hymm_sp/modules/models_audio.py

Browse files
Files changed (1) hide show
  1. hymm_sp/modules/models_audio.py +8 -1
hymm_sp/modules/models_audio.py CHANGED
@@ -6,7 +6,13 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from diffusers.models import ModelMixin
8
  from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
 
 
 
 
 
 
10
 
11
  from .activation_layers import get_activation_layer
12
  from .norm_layers import get_norm_layer
@@ -169,6 +175,7 @@ class DoubleStreamBlock(nn.Module):
169
  x.view(x.shape[0] * x.shape[1], *x.shape[2:])
170
  for x in [q, k, v]
171
  ]
 
172
  attn = flash_attn_varlen_func(
173
  q,
174
  k,
 
6
  import torch.nn.functional as F
7
  from diffusers.models import ModelMixin
8
  from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ try:
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
11
+ except ImportError:
12
+ print("⚠️ flash_attn not available — using fallback.")
13
+ flash_attn_varlen_func = None # or create a dummy function if needed
14
+
15
+
16
 
17
  from .activation_layers import get_activation_layer
18
  from .norm_layers import get_norm_layer
 
175
  x.view(x.shape[0] * x.shape[1], *x.shape[2:])
176
  for x in [q, k, v]
177
  ]
178
+
179
  attn = flash_attn_varlen_func(
180
  q,
181
  k,