mknolan's picture
Fix model name from InternVL2-8B to InternVL2
a954ee1 verified
import torch
import os
import sys
import gradio as gr
from PIL import Image
import traceback
import types
import importlib.util
import importlib.machinery
import importlib.abc
print("=" * 50)
print("InternVL2 IMAGE & TEXT ANALYSIS")
print("=" * 50)
# System information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
# Memory info
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
else:
print("CUDA is not available. This application requires GPU acceleration.")
# In-memory mock implementation
def create_in_memory_flash_attn_mock():
"""Create a completely in-memory flash_attn mock with all required attributes"""
print("Setting up in-memory flash_attn mock...")
# Create a dummy module finder and loader for the mock
class DummyFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
if fullname == 'flash_attn' or fullname.startswith('flash_attn.'):
return self.create_spec(fullname)
elif fullname == 'flash_attn_2_cuda':
return self.create_spec(fullname)
return None
def create_spec(self, fullname):
# Create a spec
loader = DummyLoader(fullname)
spec = importlib.machinery.ModuleSpec(
name=fullname,
loader=loader,
is_package=fullname.count('.') == 0 or fullname.split('.')[-1] == ''
)
return spec
class DummyLoader(importlib.abc.Loader):
def __init__(self, fullname):
self.fullname = fullname
def create_module(self, spec):
module = types.ModuleType(spec.name)
# Set default attributes for any module
module.__spec__ = spec
module.__loader__ = self
module.__file__ = f"<{spec.name}>"
module.__path__ = []
module.__package__ = spec.name.rpartition('.')[0] if '.' in spec.name else ''
if spec.name == 'flash_attn':
# Add flash_attn-specific attributes
module.__version__ = "0.0.0-mocked"
# Add flash_attn functions
module.flash_attn_func = lambda *args, **kwargs: None
module.flash_attn_kvpacked_func = lambda *args, **kwargs: None
module.flash_attn_qkvpacked_func = lambda *args, **kwargs: None
return module
def exec_module(self, module):
# Nothing to execute
pass
# Remove any existing modules to avoid conflicts
for name in list(sys.modules.keys()):
if name == 'flash_attn' or name.startswith('flash_attn.') or name == 'flash_attn_2_cuda':
del sys.modules[name]
# Register our finder at the beginning of meta_path
sys.meta_path.insert(0, DummyFinder())
# Pre-create and configure the flash_attn module
spec = importlib.machinery.ModuleSpec(
name='flash_attn',
loader=DummyLoader('flash_attn'),
is_package=True
)
flash_attn = importlib.util.module_from_spec(spec)
sys.modules['flash_attn'] = flash_attn
# Add attributes used by transformers checks
flash_attn.__version__ = "0.0.0-mocked"
# Create common submodules - without 'parent' parameter
for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']:
parts = submodule.split('.')
parent_name = '.'.join(parts[:-1])
child_name = parts[-1]
parent = sys.modules[parent_name]
# Create submodule spec - removed 'parent' parameter
subspec = importlib.machinery.ModuleSpec(
name=submodule,
loader=DummyLoader(submodule),
is_package=False
)
# Create and register submodule
module = importlib.util.module_from_spec(subspec)
setattr(parent, child_name, module)
sys.modules[submodule] = module
# Create flash_attn_2_cuda module
cuda_spec = importlib.machinery.ModuleSpec(
name='flash_attn_2_cuda',
loader=DummyLoader('flash_attn_2_cuda'),
is_package=False
)
cuda_module = importlib.util.module_from_spec(cuda_spec)
sys.modules['flash_attn_2_cuda'] = cuda_module
# Set environment variables to disable flash attention
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1" # Avoid online checks
# Verify the mock was created successfully
try:
import flash_attn
print(f"✓ Mock flash_attn loaded successfully: {flash_attn.__version__}")
print(f"✓ flash_attn.__spec__ exists: {flash_attn.__spec__ is not None}")
# Let's explicitly check for __spec__ in importlib.util.find_spec
spec = importlib.util.find_spec("flash_attn")
print(f"✓ importlib.util.find_spec returns: {spec is not None}")
# Check that parent/child relationships work
import flash_attn.flash_attn_interface
print("✓ flash_attn.flash_attn_interface loaded")
# Check CUDA module
import flash_attn_2_cuda
print("✓ flash_attn_2_cuda loaded")
except Exception as e:
print(f"WARNING: Error verifying flash_attn mock: {e}")
traceback.print_exc()
# Now set up the mock BEFORE importing transformers
create_in_memory_flash_attn_mock()
# Import transformers AFTER setting up mock
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
# Create a function to load the model
def load_model():
try:
print("\nLoading InternVL2 model...")
# Load the model and tokenizer
# FIXED: Corrected model name from InternVL2-8B to InternVL2
model_path = "OpenGVLab/InternVL2"
# Print downloading status
print("Downloading model shards. This may take some time...")
# Load the model - with careful error handling
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
except Exception as e:
# If we get the flash_attn error, print detailed information
if "flash_attn.__spec__ is not set" in str(e):
print("\n❌ Flash attention error detected!")
# See if our mock is still in place
if 'flash_attn' in sys.modules:
mock = sys.modules['flash_attn']
print(f"Flash mock exists: {mock}")
print(f"Flash mock __spec__: {getattr(mock, '__spec__', 'NOT SET')}")
else:
print("flash_attn module was removed from sys.modules")
# Diagnostic info
print("\nCurrent state of sys.meta_path:")
for i, finder in enumerate(sys.meta_path):
print(f" {i}: {finder.__class__.__name__}")
# Re-raise the exception
raise
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
trust_remote_code=True
)
# Set generation config
generation_config = GenerationConfig.from_pretrained(
model_path,
trust_remote_code=True
)
print("✓ Model and tokenizer loaded successfully!")
return model, tokenizer, generation_config
except Exception as e:
print(f"\n❌ ERROR loading model: {str(e)}")
traceback.print_exc()
return None, None, None
# Helper function to load and process an image
def load_image(image_path, processor=None):
"""Load an image and prepare it for the model."""
if isinstance(image_path, str):
if image_path.startswith('http'):
import requests
from io import BytesIO
try:
response = requests.get(image_path, timeout=10)
image = Image.open(BytesIO(response.content)).convert('RGB')
except Exception as e:
print(f"Error loading image from URL: {e}")
# Return a default image or raise an error
image = Image.new('RGB', (224, 224), color='gray')
else:
image = Image.open(image_path).convert('RGB')
else:
image = image_path
# No need to process, the model handles that internally
return image
# Function to analyze an image with text
def analyze_image(model, tokenizer, image, prompt, generation_config):
try:
# Prepare inputs
text_prompt = f"USER: <image>\n{prompt}\nASSISTANT:"
# Convert inputs for the model
inputs = tokenizer([text_prompt], return_tensors="pt")
# Move inputs to the right device
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Add image to the inputs
inputs["images"] = [image]
# Generate a response
with torch.no_grad():
outputs = model.generate(
**inputs,
generation_config=generation_config,
max_new_tokens=512,
)
# Decode the outputs
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
assistant_response = generated_text.split("ASSISTANT:")[-1].strip()
return assistant_response
except Exception as e:
error_msg = f"Error analyzing image: {str(e)}"
traceback.print_exc()
return error_msg
# Create the Gradio interface
def create_interface():
# Load model at startup
model, tokenizer, generation_config = load_model()
if model is None:
# If model loading failed, create a simple error interface
with gr.Blocks(title="InternVL2 Image Analysis - Error") as demo:
gr.Markdown("# ❌ Error: Failed to load models")
gr.Markdown("Please check the console for error details.")
return demo
# Predefined prompts for analysis
prompts = [
"Describe this image in detail.",
"What can you tell me about this image?",
"Is there any text in this image? If so, can you read it?",
"What is the main subject of this image?",
"What emotions or feelings does this image convey?",
"Describe the composition and visual elements of this image.",
"Summarize what you see in this image in one paragraph."
]
# Create the full interface
with gr.Blocks(title="InternVL2 Image Analysis") as demo:
gr.Markdown("# 🖼️ InternVL2 Image & Text Analyzer")
gr.Markdown("### Upload an image and ask questions about it")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Dropdown(
choices=prompts,
value=prompts[0],
label="Select a prompt or enter your own below",
allow_custom_value=True
)
custom_prompt = gr.Textbox(label="Custom prompt", placeholder="Enter your custom prompt here...")
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Analysis Results", lines=15)
# Example images - Using stable URLs from GitHub repositories
gr.Examples(
examples=[
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/cheetah1.jpg", "What's in this image?"],
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/lion.jpg", "Describe this animal."],
],
inputs=[input_image, custom_prompt],
)
# When prompt dropdown changes, update custom prompt
prompt_input.change(fn=lambda x: x, inputs=prompt_input, outputs=custom_prompt)
# Set up the click event for analysis
def on_analyze_click(image, prompt_text):
if image is None:
return "Please upload an image first."
# Use either the dropdown selection or custom prompt
final_prompt = prompt_text if prompt_text.strip() else prompt_input
result = analyze_image(model, tokenizer, image, final_prompt, generation_config)
return result
analyze_btn.click(
fn=on_analyze_click,
inputs=[input_image, custom_prompt],
outputs=output
)
return demo
# Main function
if __name__ == "__main__":
# Set environment variable for better GPU memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Create and launch the interface
demo = create_interface()
demo.launch(share=False, server_name="0.0.0.0")