Spaces:
Runtime error
Runtime error
update
Browse files- ldm/modules/attention.py +2 -2
ldm/modules/attention.py
CHANGED
|
@@ -211,7 +211,7 @@ class CrossAttention(nn.Module):
|
|
| 211 |
dim_head=64,
|
| 212 |
dropout=0.0,
|
| 213 |
# backend=None,
|
| 214 |
-
backend=SDPBackend.FLASH_ATTENTION, # FA implemented by torch.
|
| 215 |
**kwargs,
|
| 216 |
):
|
| 217 |
super().__init__()
|
|
@@ -228,7 +228,7 @@ class CrossAttention(nn.Module):
|
|
| 228 |
self.to_out = nn.Sequential(
|
| 229 |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
| 230 |
)
|
| 231 |
-
self.backend = backend
|
| 232 |
|
| 233 |
def forward(
|
| 234 |
self,
|
|
|
|
| 211 |
dim_head=64,
|
| 212 |
dropout=0.0,
|
| 213 |
# backend=None,
|
| 214 |
+
# backend=SDPBackend.FLASH_ATTENTION, # FA implemented by torch.
|
| 215 |
**kwargs,
|
| 216 |
):
|
| 217 |
super().__init__()
|
|
|
|
| 228 |
self.to_out = nn.Sequential(
|
| 229 |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
| 230 |
)
|
| 231 |
+
# self.backend = backend
|
| 232 |
|
| 233 |
def forward(
|
| 234 |
self,
|