Spaces:
Runtime error
Runtime error
update
Browse files- ldm/modules/attention.py +39 -38
ldm/modules/attention.py
CHANGED
|
@@ -26,40 +26,40 @@ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficie
|
|
| 26 |
# import apex
|
| 27 |
# from apex.normalization import FusedRMSNorm as RMSNorm
|
| 28 |
|
| 29 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
else:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
| 64 |
|
| 65 |
def exists(val):
|
|
@@ -282,11 +282,12 @@ class CrossAttention(nn.Module):
|
|
| 282 |
"""
|
| 283 |
## new
|
| 284 |
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
|
| 285 |
-
with sdpa_kernel([self.backend]): # new signature
|
| 286 |
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
| 290 |
|
| 291 |
del q, k, v
|
| 292 |
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
|
|
|
| 26 |
# import apex
|
| 27 |
# from apex.normalization import FusedRMSNorm as RMSNorm
|
| 28 |
|
| 29 |
+
# if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 30 |
+
# SDP_IS_AVAILABLE = True
|
| 31 |
+
# # from torch.backends.cuda import SDPBackend, sdp_kernel
|
| 32 |
+
# from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 33 |
+
|
| 34 |
+
# BACKEND_MAP = {
|
| 35 |
+
# SDPBackend.MATH: {
|
| 36 |
+
# "enable_math": True,
|
| 37 |
+
# "enable_flash": False,
|
| 38 |
+
# "enable_mem_efficient": False,
|
| 39 |
+
# },
|
| 40 |
+
# SDPBackend.FLASH_ATTENTION: {
|
| 41 |
+
# "enable_math": False,
|
| 42 |
+
# "enable_flash": True,
|
| 43 |
+
# "enable_mem_efficient": False,
|
| 44 |
+
# },
|
| 45 |
+
# SDPBackend.EFFICIENT_ATTENTION: {
|
| 46 |
+
# "enable_math": False,
|
| 47 |
+
# "enable_flash": False,
|
| 48 |
+
# "enable_mem_efficient": True,
|
| 49 |
+
# },
|
| 50 |
+
# None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
| 51 |
+
# }
|
| 52 |
+
# else:
|
| 53 |
+
# from contextlib import nullcontext
|
| 54 |
+
|
| 55 |
+
SDP_IS_AVAILABLE = False
|
| 56 |
+
# sdpa_kernel = nullcontext
|
| 57 |
+
# BACKEND_MAP = {}
|
| 58 |
+
# logpy.warn(
|
| 59 |
+
# f"No SDP backend available, likely because you are running in pytorch "
|
| 60 |
+
# f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
| 61 |
+
# f"You might want to consider upgrading."
|
| 62 |
+
# )
|
| 63 |
|
| 64 |
|
| 65 |
def exists(val):
|
|
|
|
| 282 |
"""
|
| 283 |
## new
|
| 284 |
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
|
| 285 |
+
# with sdpa_kernel([self.backend]): # new signature
|
| 286 |
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
| 287 |
+
|
| 288 |
+
out = F.scaled_dot_product_attention(
|
| 289 |
+
q, k, v, attn_mask=mask
|
| 290 |
+
) # scale is dim_head ** -0.5 per default
|
| 291 |
|
| 292 |
del q, k, v
|
| 293 |
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|