Implement pure in-memory flash_attn mock to fix __spec__ error
Browse files
app.py
CHANGED
@@ -4,8 +4,10 @@ import sys
|
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
import traceback
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
|
10 |
print("=" * 50)
|
11 |
print("InternVL2 IMAGE & TEXT ANALYSIS")
|
@@ -29,74 +31,180 @@ if torch.cuda.is_available():
|
|
29 |
else:
|
30 |
print("CUDA is not available. This application requires GPU acceleration.")
|
31 |
|
32 |
-
#
|
33 |
-
def
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
def
|
50 |
-
|
51 |
-
|
52 |
-
def flash_attn_func(self, *args, **kwargs):
|
53 |
-
raise NotImplementedError("This is a mock flash_attn implementation")
|
54 |
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
# Create
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
#
|
72 |
try:
|
73 |
import flash_attn
|
74 |
-
print(f"Mock flash_attn
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Create a function to load the model
|
79 |
def load_model():
|
80 |
try:
|
81 |
print("\nLoading InternVL2 model...")
|
82 |
|
83 |
-
# Setup flash_attn mock
|
84 |
-
setup_flash_attn_mock()
|
85 |
-
|
86 |
# Load the model and tokenizer
|
87 |
model_path = "OpenGVLab/InternVL2-8B"
|
88 |
|
89 |
# Print downloading status
|
90 |
print("Downloading model shards. This may take some time...")
|
91 |
|
92 |
-
# Load the model
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Load tokenizer
|
102 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -221,11 +329,11 @@ def create_interface():
|
|
221 |
with gr.Column(scale=1):
|
222 |
output = gr.Textbox(label="Analysis Results", lines=15)
|
223 |
|
224 |
-
# Example images -
|
225 |
gr.Examples(
|
226 |
examples=[
|
227 |
-
["https://
|
228 |
-
["https://raw.githubusercontent.com/
|
229 |
],
|
230 |
inputs=[input_image, custom_prompt],
|
231 |
)
|
|
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
import traceback
|
7 |
+
import types
|
8 |
+
import importlib.util
|
9 |
+
import importlib.machinery
|
10 |
+
import importlib.abc
|
11 |
|
12 |
print("=" * 50)
|
13 |
print("InternVL2 IMAGE & TEXT ANALYSIS")
|
|
|
31 |
else:
|
32 |
print("CUDA is not available. This application requires GPU acceleration.")
|
33 |
|
34 |
+
# In-memory mock implementation
|
35 |
+
def create_in_memory_flash_attn_mock():
|
36 |
+
"""Create a completely in-memory flash_attn mock with all required attributes"""
|
37 |
+
print("Setting up in-memory flash_attn mock...")
|
38 |
|
39 |
+
# Create a dummy module finder and loader for the mock
|
40 |
+
class DummyFinder(importlib.abc.MetaPathFinder):
|
41 |
+
def find_spec(self, fullname, path, target=None):
|
42 |
+
if fullname == 'flash_attn' or fullname.startswith('flash_attn.'):
|
43 |
+
return self.create_spec(fullname)
|
44 |
+
elif fullname == 'flash_attn_2_cuda':
|
45 |
+
return self.create_spec(fullname)
|
46 |
+
return None
|
47 |
+
|
48 |
+
def create_spec(self, fullname):
|
49 |
+
# Create a spec
|
50 |
+
loader = DummyLoader(fullname)
|
51 |
+
spec = importlib.machinery.ModuleSpec(
|
52 |
+
name=fullname,
|
53 |
+
loader=loader,
|
54 |
+
is_package=fullname.count('.') == 0 or fullname.split('.')[-1] == ''
|
55 |
+
)
|
56 |
+
return spec
|
57 |
|
58 |
+
class DummyLoader(importlib.abc.Loader):
|
59 |
+
def __init__(self, fullname):
|
60 |
+
self.fullname = fullname
|
61 |
+
|
62 |
+
def create_module(self, spec):
|
63 |
+
module = types.ModuleType(spec.name)
|
|
|
|
|
|
|
64 |
|
65 |
+
# Set default attributes for any module
|
66 |
+
module.__spec__ = spec
|
67 |
+
module.__loader__ = self
|
68 |
+
module.__file__ = f"<{spec.name}>"
|
69 |
+
module.__path__ = []
|
70 |
+
module.__package__ = spec.name.rpartition('.')[0] if '.' in spec.name else ''
|
71 |
|
72 |
+
if spec.name == 'flash_attn':
|
73 |
+
# Add flash_attn-specific attributes
|
74 |
+
module.__version__ = "0.0.0-mocked"
|
75 |
+
|
76 |
+
# Add flash_attn functions
|
77 |
+
module.flash_attn_func = lambda *args, **kwargs: None
|
78 |
+
module.flash_attn_kvpacked_func = lambda *args, **kwargs: None
|
79 |
+
module.flash_attn_qkvpacked_func = lambda *args, **kwargs: None
|
80 |
+
|
81 |
+
return module
|
82 |
+
|
83 |
+
def exec_module(self, module):
|
84 |
+
# Nothing to execute
|
85 |
+
pass
|
86 |
+
|
87 |
+
# Remove any existing modules to avoid conflicts
|
88 |
+
for name in list(sys.modules.keys()):
|
89 |
+
if name == 'flash_attn' or name.startswith('flash_attn.') or name == 'flash_attn_2_cuda':
|
90 |
+
del sys.modules[name]
|
91 |
+
|
92 |
+
# Register our finder at the beginning of meta_path
|
93 |
+
sys.meta_path.insert(0, DummyFinder())
|
94 |
+
|
95 |
+
# Pre-create and configure the flash_attn module
|
96 |
+
spec = importlib.machinery.ModuleSpec(
|
97 |
+
name='flash_attn',
|
98 |
+
loader=DummyLoader('flash_attn'),
|
99 |
+
is_package=True
|
100 |
+
)
|
101 |
+
flash_attn = importlib.util.module_from_spec(spec)
|
102 |
+
sys.modules['flash_attn'] = flash_attn
|
103 |
+
|
104 |
+
# Add attributes used by transformers checks
|
105 |
+
flash_attn.__version__ = "0.0.0-mocked"
|
106 |
+
|
107 |
+
# Create common submodules
|
108 |
+
for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']:
|
109 |
+
parts = submodule.split('.')
|
110 |
+
parent_name = '.'.join(parts[:-1])
|
111 |
+
child_name = parts[-1]
|
112 |
+
parent = sys.modules[parent_name]
|
113 |
+
|
114 |
+
# Create submodule spec
|
115 |
+
subspec = importlib.machinery.ModuleSpec(
|
116 |
+
name=submodule,
|
117 |
+
loader=DummyLoader(submodule),
|
118 |
+
is_package=False,
|
119 |
+
parent=parent
|
120 |
+
)
|
121 |
+
|
122 |
+
# Create and register submodule
|
123 |
+
module = importlib.util.module_from_spec(subspec)
|
124 |
+
setattr(parent, child_name, module)
|
125 |
+
sys.modules[submodule] = module
|
126 |
|
127 |
+
# Create flash_attn_2_cuda module
|
128 |
+
cuda_spec = importlib.machinery.ModuleSpec(
|
129 |
+
name='flash_attn_2_cuda',
|
130 |
+
loader=DummyLoader('flash_attn_2_cuda'),
|
131 |
+
is_package=False
|
132 |
+
)
|
133 |
+
cuda_module = importlib.util.module_from_spec(cuda_spec)
|
134 |
+
sys.modules['flash_attn_2_cuda'] = cuda_module
|
135 |
|
136 |
+
# Set environment variables to disable flash attention
|
137 |
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
138 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1" # Avoid online checks
|
139 |
|
140 |
+
# Verify the mock was created successfully
|
141 |
try:
|
142 |
import flash_attn
|
143 |
+
print(f"✓ Mock flash_attn loaded successfully: {flash_attn.__version__}")
|
144 |
+
print(f"✓ flash_attn.__spec__ exists: {flash_attn.__spec__ is not None}")
|
145 |
+
|
146 |
+
# Let's explicitly check for __spec__ in importlib.util.find_spec
|
147 |
+
spec = importlib.util.find_spec("flash_attn")
|
148 |
+
print(f"✓ importlib.util.find_spec returns: {spec is not None}")
|
149 |
+
|
150 |
+
# Check that parent/child relationships work
|
151 |
+
import flash_attn.flash_attn_interface
|
152 |
+
print("✓ flash_attn.flash_attn_interface loaded")
|
153 |
+
|
154 |
+
# Check CUDA module
|
155 |
+
import flash_attn_2_cuda
|
156 |
+
print("✓ flash_attn_2_cuda loaded")
|
157 |
+
except Exception as e:
|
158 |
+
print(f"WARNING: Error verifying flash_attn mock: {e}")
|
159 |
+
traceback.print_exc()
|
160 |
+
|
161 |
+
# Now set up the mock BEFORE importing transformers
|
162 |
+
create_in_memory_flash_attn_mock()
|
163 |
+
|
164 |
+
# Import transformers AFTER setting up mock
|
165 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
166 |
+
from transformers.generation import GenerationConfig
|
167 |
|
168 |
# Create a function to load the model
|
169 |
def load_model():
|
170 |
try:
|
171 |
print("\nLoading InternVL2 model...")
|
172 |
|
|
|
|
|
|
|
173 |
# Load the model and tokenizer
|
174 |
model_path = "OpenGVLab/InternVL2-8B"
|
175 |
|
176 |
# Print downloading status
|
177 |
print("Downloading model shards. This may take some time...")
|
178 |
|
179 |
+
# Load the model - with careful error handling
|
180 |
+
try:
|
181 |
+
model = AutoModelForCausalLM.from_pretrained(
|
182 |
+
model_path,
|
183 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
184 |
+
low_cpu_mem_usage=True,
|
185 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
186 |
+
trust_remote_code=True
|
187 |
+
)
|
188 |
+
except Exception as e:
|
189 |
+
# If we get the flash_attn error, print detailed information
|
190 |
+
if "flash_attn.__spec__ is not set" in str(e):
|
191 |
+
print("\n❌ Flash attention error detected!")
|
192 |
+
|
193 |
+
# See if our mock is still in place
|
194 |
+
if 'flash_attn' in sys.modules:
|
195 |
+
mock = sys.modules['flash_attn']
|
196 |
+
print(f"Flash mock exists: {mock}")
|
197 |
+
print(f"Flash mock __spec__: {getattr(mock, '__spec__', 'NOT SET')}")
|
198 |
+
else:
|
199 |
+
print("flash_attn module was removed from sys.modules")
|
200 |
+
|
201 |
+
# Diagnostic info
|
202 |
+
print("\nCurrent state of sys.meta_path:")
|
203 |
+
for i, finder in enumerate(sys.meta_path):
|
204 |
+
print(f" {i}: {finder.__class__.__name__}")
|
205 |
+
|
206 |
+
# Re-raise the exception
|
207 |
+
raise
|
208 |
|
209 |
# Load tokenizer
|
210 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
329 |
with gr.Column(scale=1):
|
330 |
output = gr.Textbox(label="Analysis Results", lines=15)
|
331 |
|
332 |
+
# Example images - Using stable URLs from GitHub repositories
|
333 |
gr.Examples(
|
334 |
examples=[
|
335 |
+
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/cheetah1.jpg", "What's in this image?"],
|
336 |
+
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/lion.jpg", "Describe this animal."],
|
337 |
],
|
338 |
inputs=[input_image, custom_prompt],
|
339 |
)
|