Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
1d3fed2
·
1 Parent(s): 3e7ee7c

Refactor attention module to improve xformers integration. Renamed availability flag to HAS_XFORMERS and added safe_memory_efficient_attention function for better handling of attention operations across devices. Updated related assertions and calls to ensure compatibility with systems lacking GPU support.

Browse files
imagedream/ldm/modules/attention.py CHANGED
@@ -12,10 +12,9 @@ from .diffusionmodules.util import checkpoint
12
  try:
13
  import xformers
14
  import xformers.ops
15
-
16
- XFORMERS_IS_AVAILBLE = True
17
- except:
18
- XFORMERS_IS_AVAILBLE = False
19
 
20
  # CrossAttn precision handling
21
  import os
@@ -138,6 +137,20 @@ class SpatialSelfAttention(nn.Module):
138
  return x + h_
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  class MemoryEfficientCrossAttention(nn.Module):
142
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
@@ -195,7 +208,7 @@ class MemoryEfficientCrossAttention(nn.Module):
195
  )
196
 
197
  # actually compute the attention, what we cannot get enough of
198
- out = xformers.ops.memory_efficient_attention(
199
  q, k, v, attn_bias=None, op=self.attention_op
200
  )
201
 
@@ -209,7 +222,7 @@ class MemoryEfficientCrossAttention(nn.Module):
209
  (k_ip, v_ip),
210
  )
211
  # actually compute the attention, what we cannot get enough of
212
- out_ip = xformers.ops.memory_efficient_attention(
213
  q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
  )
215
  out = out + self.ip_weight * out_ip
@@ -239,7 +252,7 @@ class BasicTransformerBlock(nn.Module):
239
  **kwargs
240
  ):
241
  super().__init__()
242
- assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
  attn_cls = MemoryEfficientCrossAttention
244
  self.disable_self_attn = disable_self_attn
245
  self.attn1 = attn_cls(
 
12
  try:
13
  import xformers
14
  import xformers.ops
15
+ HAS_XFORMERS = True
16
+ except ImportError:
17
+ HAS_XFORMERS = False
 
18
 
19
  # CrossAttn precision handling
20
  import os
 
137
  return x + h_
138
 
139
 
140
+ def safe_memory_efficient_attention(q, k, v, attn_bias=None, op=None, p=0.0):
141
+ if q.device.type == "cuda" and HAS_XFORMERS:
142
+ return xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=op, p=p)
143
+ else:
144
+ # Standard attention for CPU
145
+ scale = 1.0 / (q.shape[-1] ** 0.5)
146
+ attn = torch.matmul(q * scale, k.transpose(-2, -1))
147
+ if attn_bias is not None:
148
+ attn = attn + attn_bias
149
+ attn = torch.softmax(attn, dim=-1)
150
+ attn = torch.nn.functional.dropout(attn, p=p)
151
+ return torch.matmul(attn, v)
152
+
153
+
154
  class MemoryEfficientCrossAttention(nn.Module):
155
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
156
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
 
208
  )
209
 
210
  # actually compute the attention, what we cannot get enough of
211
+ out = safe_memory_efficient_attention(
212
  q, k, v, attn_bias=None, op=self.attention_op
213
  )
214
 
 
222
  (k_ip, v_ip),
223
  )
224
  # actually compute the attention, what we cannot get enough of
225
+ out_ip = safe_memory_efficient_attention(
226
  q, k_ip, v_ip, attn_bias=None, op=self.attention_op
227
  )
228
  out = out + self.ip_weight * out_ip
 
252
  **kwargs
253
  ):
254
  super().__init__()
255
+ assert HAS_XFORMERS, "xformers is not available"
256
  attn_cls = MemoryEfficientCrossAttention
257
  self.disable_self_attn = disable_self_attn
258
  self.attn1 = attn_cls(
imagedream/ldm/modules/diffusionmodules/model.py CHANGED
@@ -11,10 +11,9 @@ from ..attention import MemoryEfficientCrossAttention
11
  try:
12
  import xformers
13
  import xformers.ops
14
-
15
- XFORMERS_IS_AVAILBLE = True
16
- except:
17
- XFORMERS_IS_AVAILBLE = False
18
  print("No module 'xformers'. Proceeding without it.")
19
 
20
 
@@ -238,7 +237,7 @@ class MemoryEfficientAttnBlock(nn.Module):
238
  .contiguous(),
239
  (q, k, v),
240
  )
241
- out = xformers.ops.memory_efficient_attention(
242
  q, k, v, attn_bias=None, op=self.attention_op
243
  )
244
 
@@ -262,6 +261,20 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
262
  return x + out
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
266
  assert attn_type in [
267
  "vanilla",
@@ -270,7 +283,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
270
  "linear",
271
  "none",
272
  ], f"attn_type {attn_type} unknown"
273
- if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
274
  attn_type = "vanilla-xformers"
275
  print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
276
  if attn_type == "vanilla":
 
11
  try:
12
  import xformers
13
  import xformers.ops
14
+ HAS_XFORMERS = True
15
+ except ImportError:
16
+ HAS_XFORMERS = False
 
17
  print("No module 'xformers'. Proceeding without it.")
18
 
19
 
 
237
  .contiguous(),
238
  (q, k, v),
239
  )
240
+ out = safe_memory_efficient_attention(
241
  q, k, v, attn_bias=None, op=self.attention_op
242
  )
243
 
 
261
  return x + out
262
 
263
 
264
+ def safe_memory_efficient_attention(q, k, v, attn_bias=None, op=None, p=0.0):
265
+ if q.device.type == "cuda" and HAS_XFORMERS:
266
+ return xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=op, p=p)
267
+ else:
268
+ # Standard attention for CPU
269
+ scale = 1.0 / (q.shape[-1] ** 0.5)
270
+ attn = torch.matmul(q * scale, k.transpose(-2, -1))
271
+ if attn_bias is not None:
272
+ attn = attn + attn_bias
273
+ attn = torch.softmax(attn, dim=-1)
274
+ attn = torch.nn.functional.dropout(attn, p=p)
275
+ return torch.matmul(attn, v)
276
+
277
+
278
  def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
279
  assert attn_type in [
280
  "vanilla",
 
283
  "linear",
284
  "none",
285
  ], f"attn_type {attn_type} unknown"
286
+ if HAS_XFORMERS and attn_type == "vanilla":
287
  attn_type = "vanilla-xformers"
288
  print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
289
  if attn_type == "vanilla":