mknolan commited on
Commit
79f6e49
·
verified ·
1 Parent(s): 6131f9b

Implement pure in-memory flash_attn mock to fix __spec__ error

Browse files
Files changed (1) hide show
  1. app.py +160 -52
app.py CHANGED
@@ -4,8 +4,10 @@ import sys
4
  import gradio as gr
5
  from PIL import Image
6
  import traceback
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- from transformers.generation import GenerationConfig
 
 
9
 
10
  print("=" * 50)
11
  print("InternVL2 IMAGE & TEXT ANALYSIS")
@@ -29,74 +31,180 @@ if torch.cuda.is_available():
29
  else:
30
  print("CUDA is not available. This application requires GPU acceleration.")
31
 
32
- # Create a mock function for flash_attn modules
33
- def setup_flash_attn_mock():
34
- # Disable flash attention in transformers
35
- os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
36
 
37
- # First, check if flash_attn is already imported
38
- if "flash_attn" in sys.modules:
39
- print("flash_attn module already imported - no mocking needed")
40
- return
41
-
42
- # If we should mock the module
43
- print("Setting up flash_attn mock...")
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Create a proper mock that has the necessary attributes
46
- class FlashAttnMock:
47
- __version__ = "0.0.0-disabled-mock"
48
-
49
- def __init__(self):
50
- pass
51
-
52
- def flash_attn_func(self, *args, **kwargs):
53
- raise NotImplementedError("This is a mock flash_attn implementation")
54
 
55
- def flash_attn_kvpacked_func(self, *args, **kwargs):
56
- raise NotImplementedError("This is a mock flash_attn implementation")
 
 
 
 
57
 
58
- def flash_attn_qkvpacked_func(self, *args, **kwargs):
59
- raise NotImplementedError("This is a mock flash_attn implementation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Create the module with proper spec
62
- import types
63
- flash_attn_mock = FlashAttnMock()
64
- sys.modules["flash_attn"] = flash_attn_mock
65
- print("flash_attn mock set up successfully")
 
 
 
66
 
67
- # Also mock the related modules that might be imported
68
- sys.modules["flash_attn.flash_attn_interface"] = types.ModuleType("flash_attn.flash_attn_interface")
69
- sys.modules["flash_attn.flash_attn_triton"] = types.ModuleType("flash_attn.flash_attn_triton")
70
 
71
- # Check if it worked
72
  try:
73
  import flash_attn
74
- print(f"Mock flash_attn module version: {flash_attn.__version__}")
75
- except:
76
- print("Warning: flash_attn mock failed to load correctly")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Create a function to load the model
79
  def load_model():
80
  try:
81
  print("\nLoading InternVL2 model...")
82
 
83
- # Setup flash_attn mock
84
- setup_flash_attn_mock()
85
-
86
  # Load the model and tokenizer
87
  model_path = "OpenGVLab/InternVL2-8B"
88
 
89
  # Print downloading status
90
  print("Downloading model shards. This may take some time...")
91
 
92
- # Load the model
93
- model = AutoModelForCausalLM.from_pretrained(
94
- model_path,
95
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
96
- low_cpu_mem_usage=True,
97
- device_map="auto" if torch.cuda.is_available() else None,
98
- trust_remote_code=True
99
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Load tokenizer
102
  tokenizer = AutoTokenizer.from_pretrained(
@@ -221,11 +329,11 @@ def create_interface():
221
  with gr.Column(scale=1):
222
  output = gr.Textbox(label="Analysis Results", lines=15)
223
 
224
- # Example images - UPDATED with more reliable image URLs
225
  gr.Examples(
226
  examples=[
227
- ["https://github.com/huggingface/transformers/raw/main/docs/source/en/model_doc/blip-2_files/BobRoss.jpg", "What's in this image?"],
228
- ["https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png", "Describe this diagram in detail."],
229
  ],
230
  inputs=[input_image, custom_prompt],
231
  )
 
4
  import gradio as gr
5
  from PIL import Image
6
  import traceback
7
+ import types
8
+ import importlib.util
9
+ import importlib.machinery
10
+ import importlib.abc
11
 
12
  print("=" * 50)
13
  print("InternVL2 IMAGE & TEXT ANALYSIS")
 
31
  else:
32
  print("CUDA is not available. This application requires GPU acceleration.")
33
 
34
+ # In-memory mock implementation
35
+ def create_in_memory_flash_attn_mock():
36
+ """Create a completely in-memory flash_attn mock with all required attributes"""
37
+ print("Setting up in-memory flash_attn mock...")
38
 
39
+ # Create a dummy module finder and loader for the mock
40
+ class DummyFinder(importlib.abc.MetaPathFinder):
41
+ def find_spec(self, fullname, path, target=None):
42
+ if fullname == 'flash_attn' or fullname.startswith('flash_attn.'):
43
+ return self.create_spec(fullname)
44
+ elif fullname == 'flash_attn_2_cuda':
45
+ return self.create_spec(fullname)
46
+ return None
47
+
48
+ def create_spec(self, fullname):
49
+ # Create a spec
50
+ loader = DummyLoader(fullname)
51
+ spec = importlib.machinery.ModuleSpec(
52
+ name=fullname,
53
+ loader=loader,
54
+ is_package=fullname.count('.') == 0 or fullname.split('.')[-1] == ''
55
+ )
56
+ return spec
57
 
58
+ class DummyLoader(importlib.abc.Loader):
59
+ def __init__(self, fullname):
60
+ self.fullname = fullname
61
+
62
+ def create_module(self, spec):
63
+ module = types.ModuleType(spec.name)
 
 
 
64
 
65
+ # Set default attributes for any module
66
+ module.__spec__ = spec
67
+ module.__loader__ = self
68
+ module.__file__ = f"<{spec.name}>"
69
+ module.__path__ = []
70
+ module.__package__ = spec.name.rpartition('.')[0] if '.' in spec.name else ''
71
 
72
+ if spec.name == 'flash_attn':
73
+ # Add flash_attn-specific attributes
74
+ module.__version__ = "0.0.0-mocked"
75
+
76
+ # Add flash_attn functions
77
+ module.flash_attn_func = lambda *args, **kwargs: None
78
+ module.flash_attn_kvpacked_func = lambda *args, **kwargs: None
79
+ module.flash_attn_qkvpacked_func = lambda *args, **kwargs: None
80
+
81
+ return module
82
+
83
+ def exec_module(self, module):
84
+ # Nothing to execute
85
+ pass
86
+
87
+ # Remove any existing modules to avoid conflicts
88
+ for name in list(sys.modules.keys()):
89
+ if name == 'flash_attn' or name.startswith('flash_attn.') or name == 'flash_attn_2_cuda':
90
+ del sys.modules[name]
91
+
92
+ # Register our finder at the beginning of meta_path
93
+ sys.meta_path.insert(0, DummyFinder())
94
+
95
+ # Pre-create and configure the flash_attn module
96
+ spec = importlib.machinery.ModuleSpec(
97
+ name='flash_attn',
98
+ loader=DummyLoader('flash_attn'),
99
+ is_package=True
100
+ )
101
+ flash_attn = importlib.util.module_from_spec(spec)
102
+ sys.modules['flash_attn'] = flash_attn
103
+
104
+ # Add attributes used by transformers checks
105
+ flash_attn.__version__ = "0.0.0-mocked"
106
+
107
+ # Create common submodules
108
+ for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']:
109
+ parts = submodule.split('.')
110
+ parent_name = '.'.join(parts[:-1])
111
+ child_name = parts[-1]
112
+ parent = sys.modules[parent_name]
113
+
114
+ # Create submodule spec
115
+ subspec = importlib.machinery.ModuleSpec(
116
+ name=submodule,
117
+ loader=DummyLoader(submodule),
118
+ is_package=False,
119
+ parent=parent
120
+ )
121
+
122
+ # Create and register submodule
123
+ module = importlib.util.module_from_spec(subspec)
124
+ setattr(parent, child_name, module)
125
+ sys.modules[submodule] = module
126
 
127
+ # Create flash_attn_2_cuda module
128
+ cuda_spec = importlib.machinery.ModuleSpec(
129
+ name='flash_attn_2_cuda',
130
+ loader=DummyLoader('flash_attn_2_cuda'),
131
+ is_package=False
132
+ )
133
+ cuda_module = importlib.util.module_from_spec(cuda_spec)
134
+ sys.modules['flash_attn_2_cuda'] = cuda_module
135
 
136
+ # Set environment variables to disable flash attention
137
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
138
+ os.environ["TRANSFORMERS_OFFLINE"] = "1" # Avoid online checks
139
 
140
+ # Verify the mock was created successfully
141
  try:
142
  import flash_attn
143
+ print(f"Mock flash_attn loaded successfully: {flash_attn.__version__}")
144
+ print(f"✓ flash_attn.__spec__ exists: {flash_attn.__spec__ is not None}")
145
+
146
+ # Let's explicitly check for __spec__ in importlib.util.find_spec
147
+ spec = importlib.util.find_spec("flash_attn")
148
+ print(f"✓ importlib.util.find_spec returns: {spec is not None}")
149
+
150
+ # Check that parent/child relationships work
151
+ import flash_attn.flash_attn_interface
152
+ print("✓ flash_attn.flash_attn_interface loaded")
153
+
154
+ # Check CUDA module
155
+ import flash_attn_2_cuda
156
+ print("✓ flash_attn_2_cuda loaded")
157
+ except Exception as e:
158
+ print(f"WARNING: Error verifying flash_attn mock: {e}")
159
+ traceback.print_exc()
160
+
161
+ # Now set up the mock BEFORE importing transformers
162
+ create_in_memory_flash_attn_mock()
163
+
164
+ # Import transformers AFTER setting up mock
165
+ from transformers import AutoModelForCausalLM, AutoTokenizer
166
+ from transformers.generation import GenerationConfig
167
 
168
  # Create a function to load the model
169
  def load_model():
170
  try:
171
  print("\nLoading InternVL2 model...")
172
 
 
 
 
173
  # Load the model and tokenizer
174
  model_path = "OpenGVLab/InternVL2-8B"
175
 
176
  # Print downloading status
177
  print("Downloading model shards. This may take some time...")
178
 
179
+ # Load the model - with careful error handling
180
+ try:
181
+ model = AutoModelForCausalLM.from_pretrained(
182
+ model_path,
183
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
184
+ low_cpu_mem_usage=True,
185
+ device_map="auto" if torch.cuda.is_available() else None,
186
+ trust_remote_code=True
187
+ )
188
+ except Exception as e:
189
+ # If we get the flash_attn error, print detailed information
190
+ if "flash_attn.__spec__ is not set" in str(e):
191
+ print("\n❌ Flash attention error detected!")
192
+
193
+ # See if our mock is still in place
194
+ if 'flash_attn' in sys.modules:
195
+ mock = sys.modules['flash_attn']
196
+ print(f"Flash mock exists: {mock}")
197
+ print(f"Flash mock __spec__: {getattr(mock, '__spec__', 'NOT SET')}")
198
+ else:
199
+ print("flash_attn module was removed from sys.modules")
200
+
201
+ # Diagnostic info
202
+ print("\nCurrent state of sys.meta_path:")
203
+ for i, finder in enumerate(sys.meta_path):
204
+ print(f" {i}: {finder.__class__.__name__}")
205
+
206
+ # Re-raise the exception
207
+ raise
208
 
209
  # Load tokenizer
210
  tokenizer = AutoTokenizer.from_pretrained(
 
329
  with gr.Column(scale=1):
330
  output = gr.Textbox(label="Analysis Results", lines=15)
331
 
332
+ # Example images - Using stable URLs from GitHub repositories
333
  gr.Examples(
334
  examples=[
335
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/cheetah1.jpg", "What's in this image?"],
336
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/lion.jpg", "Describe this animal."],
337
  ],
338
  inputs=[input_image, custom_prompt],
339
  )