andrew3d commited on
Commit
d6b0d62
Β·
verified Β·
1 Parent(s): 003c7b3

Update app.py: add xformers stub to avoid missing dependency

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py CHANGED
@@ -39,6 +39,7 @@ import numpy as np
39
  import tempfile
40
 
41
  import zipfile
 
42
 
43
  # ---------------------------------------------------------------------------
44
  # NOTE
@@ -74,6 +75,50 @@ def _ensure_hi3dgen_available():
74
  # Make sure the hi3dgen package is available before importing it
75
  _ensure_hi3dgen_available()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  from hi3dgen.pipelines import Hi3DGenPipeline
78
  import trimesh
79
  MAX_SEED = np.iinfo(np.int32).max
 
39
  import tempfile
40
 
41
  import zipfile
42
+ import types
43
 
44
  # ---------------------------------------------------------------------------
45
  # NOTE
 
75
  # Make sure the hi3dgen package is available before importing it
76
  _ensure_hi3dgen_available()
77
 
78
+ # ---------------------------------------------------------------------------
79
+ # xformers stub
80
+ #
81
+ # Some modules in the Hi3DGen pipeline import `xformers.ops.memory_efficient_attention`
82
+ # to compute multi-head attention. The official `xformers` library is not
83
+ # installed in this Space (and requires GPU-only build), so we provide a
84
+ # minimal in-memory stub that exposes a compatible API backed by PyTorch's
85
+ # built-in scaled dot-product attention. This stub is lightweight and
86
+ # CPU-friendly. It registers both the `xformers` and `xformers.ops` modules
87
+ # in sys.modules so that subsequent imports succeed.
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def _ensure_xformers_stub():
91
+ import sys
92
+ # If xformers is already available, do nothing.
93
+ if 'xformers.ops' in sys.modules:
94
+ return
95
+ import torch.nn.functional as F
96
+ # Create a new module object for xformers and its ops submodule
97
+ xformers_mod = types.ModuleType('xformers')
98
+ ops_mod = types.ModuleType('xformers.ops')
99
+
100
+ def memory_efficient_attention(query, key, value, attn_bias=None):
101
+ """
102
+ Fallback implementation of memory_efficient_attention for CPU environments.
103
+ This wraps torch.nn.functional.scaled_dot_product_attention.
104
+ """
105
+ # PyTorch expects the attention mask (bias) to be additive with shape
106
+ # broadcastable to (batch, num_heads, seq_len_query, seq_len_key). If
107
+ # attn_bias is provided and is non-zero, pass it through; otherwise
108
+ # supply None to avoid unnecessary allocations.
109
+ return F.scaled_dot_product_attention(query, key, value, attn_bias)
110
+
111
+ # Populate the ops module with our fallback function
112
+ ops_mod.memory_efficient_attention = memory_efficient_attention
113
+ # Expose ops as an attribute of xformers
114
+ xformers_mod.ops = ops_mod
115
+ # Register modules
116
+ sys.modules['xformers'] = xformers_mod
117
+ sys.modules['xformers.ops'] = ops_mod
118
+
119
+ # Ensure the xformers stub is registered before importing Hi3DGen
120
+ _ensure_xformers_stub()
121
+
122
  from hi3dgen.pipelines import Hi3DGenPipeline
123
  import trimesh
124
  MAX_SEED = np.iinfo(np.int32).max