mknolan commited on
Commit
f952993
·
verified ·
1 Parent(s): 06d7f3d

Properly mock flash_attn module with __spec__ attribute

Browse files
Files changed (1) hide show
  1. app.py +62 -9
app.py CHANGED
@@ -6,6 +6,8 @@ from PIL import Image
6
  import traceback
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from transformers.generation import GenerationConfig
 
 
9
 
10
  print("=" * 50)
11
  print("InternVL2-8B IMAGE & TEXT ANALYSIS")
@@ -29,28 +31,79 @@ if torch.cuda.is_available():
29
  else:
30
  print("CUDA is not available. This application requires GPU acceleration.")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Create a function to load the model
33
  def load_model():
34
  try:
35
  print("\nLoading InternVL2-8B model...")
36
 
37
- # Create a fake flash_attn module to avoid dependency errors
38
- import sys
39
- import types
40
- if "flash_attn" not in sys.modules:
41
- flash_attn_module = types.ModuleType("flash_attn")
42
- flash_attn_module.__version__ = "0.0.0-disabled"
43
- sys.modules["flash_attn"] = flash_attn_module
44
- print("Created dummy flash_attn module to avoid dependency error")
45
 
46
  # Load the model and tokenizer
47
  model_path = "OpenGVLab/InternVL2-8B"
 
48
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
  model_path,
51
  torch_dtype=torch.bfloat16,
52
  device_map="auto",
53
- trust_remote_code=True
 
 
54
  )
55
 
56
  # Define generation config
 
6
  import traceback
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from transformers.generation import GenerationConfig
9
+ import importlib.util
10
+ import importlib.machinery
11
 
12
  print("=" * 50)
13
  print("InternVL2-8B IMAGE & TEXT ANALYSIS")
 
31
  else:
32
  print("CUDA is not available. This application requires GPU acceleration.")
33
 
34
+ # Create a proper flash_attn mock module before loading the model
35
+ def setup_flash_attn_mock():
36
+ # Create a more complete mock for flash_attn
37
+ print("Setting up a proper flash_attn mock...")
38
+
39
+ # First, remove any existing flash_attn module if it exists
40
+ if "flash_attn" in sys.modules:
41
+ del sys.modules["flash_attn"]
42
+
43
+ # Create a simple Python file with flash_attn mock code
44
+ flash_attn_path = os.path.join(os.getcwd(), "flash_attn.py")
45
+ with open(flash_attn_path, "w") as f:
46
+ f.write("""
47
+ # Mock flash_attn module
48
+ __version__ = "0.0.0-disabled"
49
+
50
+ def flash_attn_func(*args, **kwargs):
51
+ raise NotImplementedError("This is a mock flash_attn implementation")
52
+
53
+ def flash_attn_kvpacked_func(*args, **kwargs):
54
+ raise NotImplementedError("This is a mock flash_attn implementation")
55
+
56
+ def flash_attn_qkvpacked_func(*args, **kwargs):
57
+ raise NotImplementedError("This is a mock flash_attn implementation")
58
+
59
+ # Add any other functions that might be needed
60
+ """)
61
+
62
+ # Load the mock module properly with spec
63
+ spec = importlib.util.spec_from_file_location("flash_attn", flash_attn_path)
64
+ flash_attn_module = importlib.util.module_from_spec(spec)
65
+ sys.modules["flash_attn"] = flash_attn_module
66
+ spec.loader.exec_module(flash_attn_module)
67
+
68
+ # Now also create the flash_attn_2_cuda if needed
69
+ if "flash_attn_2_cuda" not in sys.modules:
70
+ flash_attn_2_path = os.path.join(os.getcwd(), "flash_attn_2_cuda.py")
71
+ with open(flash_attn_2_path, "w") as f:
72
+ f.write("# Mock flash_attn_2_cuda module\n")
73
+
74
+ spec_cuda = importlib.util.spec_from_file_location("flash_attn_2_cuda", flash_attn_2_path)
75
+ flash_attn_2_cuda_module = importlib.util.module_from_spec(spec_cuda)
76
+ sys.modules["flash_attn_2_cuda"] = flash_attn_2_cuda_module
77
+ spec_cuda.loader.exec_module(flash_attn_2_cuda_module)
78
+
79
+ print("Flash-attention mock modules set up successfully")
80
+
81
  # Create a function to load the model
82
  def load_model():
83
  try:
84
  print("\nLoading InternVL2-8B model...")
85
 
86
+ # Set up proper mock modules for flash_attn
87
+ setup_flash_attn_mock()
88
+
89
+ # Disable flash attention in transformers by patching environment vars
90
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
91
+ os.environ["TRANSFORMERS_OFFLINE"] = "1" # Avoid online checks for flash_attn
 
 
92
 
93
  # Load the model and tokenizer
94
  model_path = "OpenGVLab/InternVL2-8B"
95
+ print("Loading tokenizer...")
96
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
97
+
98
+ print("Loading model (this may take a while)...")
99
+ # Add specific flags to avoid flash_attn usage
100
  model = AutoModelForCausalLM.from_pretrained(
101
  model_path,
102
  torch_dtype=torch.bfloat16,
103
  device_map="auto",
104
+ trust_remote_code=True,
105
+ use_flash_attention_2=False, # Explicitly disable flash attention
106
+ attn_implementation="eager" # Use eager implementation instead
107
  )
108
 
109
  # Define generation config