|
import torch |
|
import os |
|
import sys |
|
import gradio as gr |
|
from PIL import Image |
|
import traceback |
|
import types |
|
import importlib.util |
|
import importlib.machinery |
|
import importlib.abc |
|
|
|
print("=" * 50) |
|
print("InternVL2 IMAGE & TEXT ANALYSIS") |
|
print("=" * 50) |
|
|
|
|
|
print(f"Python version: {sys.version}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA version: {torch.version.cuda}") |
|
print(f"GPU count: {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
print(f"GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
|
|
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
|
print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB") |
|
else: |
|
print("CUDA is not available. This application requires GPU acceleration.") |
|
|
|
|
|
def create_in_memory_flash_attn_mock(): |
|
"""Create a completely in-memory flash_attn mock with all required attributes""" |
|
print("Setting up in-memory flash_attn mock...") |
|
|
|
|
|
class DummyFinder(importlib.abc.MetaPathFinder): |
|
def find_spec(self, fullname, path, target=None): |
|
if fullname == 'flash_attn' or fullname.startswith('flash_attn.'): |
|
return self.create_spec(fullname) |
|
elif fullname == 'flash_attn_2_cuda': |
|
return self.create_spec(fullname) |
|
return None |
|
|
|
def create_spec(self, fullname): |
|
|
|
loader = DummyLoader(fullname) |
|
spec = importlib.machinery.ModuleSpec( |
|
name=fullname, |
|
loader=loader, |
|
is_package=fullname.count('.') == 0 or fullname.split('.')[-1] == '' |
|
) |
|
return spec |
|
|
|
class DummyLoader(importlib.abc.Loader): |
|
def __init__(self, fullname): |
|
self.fullname = fullname |
|
|
|
def create_module(self, spec): |
|
module = types.ModuleType(spec.name) |
|
|
|
|
|
module.__spec__ = spec |
|
module.__loader__ = self |
|
module.__file__ = f"<{spec.name}>" |
|
module.__path__ = [] |
|
module.__package__ = spec.name.rpartition('.')[0] if '.' in spec.name else '' |
|
|
|
if spec.name == 'flash_attn': |
|
|
|
module.__version__ = "0.0.0-mocked" |
|
|
|
|
|
module.flash_attn_func = lambda *args, **kwargs: None |
|
module.flash_attn_kvpacked_func = lambda *args, **kwargs: None |
|
module.flash_attn_qkvpacked_func = lambda *args, **kwargs: None |
|
|
|
return module |
|
|
|
def exec_module(self, module): |
|
|
|
pass |
|
|
|
|
|
for name in list(sys.modules.keys()): |
|
if name == 'flash_attn' or name.startswith('flash_attn.') or name == 'flash_attn_2_cuda': |
|
del sys.modules[name] |
|
|
|
|
|
sys.meta_path.insert(0, DummyFinder()) |
|
|
|
|
|
spec = importlib.machinery.ModuleSpec( |
|
name='flash_attn', |
|
loader=DummyLoader('flash_attn'), |
|
is_package=True |
|
) |
|
flash_attn = importlib.util.module_from_spec(spec) |
|
sys.modules['flash_attn'] = flash_attn |
|
|
|
|
|
flash_attn.__version__ = "0.0.0-mocked" |
|
|
|
|
|
for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']: |
|
parts = submodule.split('.') |
|
parent_name = '.'.join(parts[:-1]) |
|
child_name = parts[-1] |
|
parent = sys.modules[parent_name] |
|
|
|
|
|
subspec = importlib.machinery.ModuleSpec( |
|
name=submodule, |
|
loader=DummyLoader(submodule), |
|
is_package=False |
|
) |
|
|
|
|
|
module = importlib.util.module_from_spec(subspec) |
|
setattr(parent, child_name, module) |
|
sys.modules[submodule] = module |
|
|
|
|
|
cuda_spec = importlib.machinery.ModuleSpec( |
|
name='flash_attn_2_cuda', |
|
loader=DummyLoader('flash_attn_2_cuda'), |
|
is_package=False |
|
) |
|
cuda_module = importlib.util.module_from_spec(cuda_spec) |
|
sys.modules['flash_attn_2_cuda'] = cuda_module |
|
|
|
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" |
|
os.environ["TRANSFORMERS_OFFLINE"] = "1" |
|
|
|
|
|
try: |
|
import flash_attn |
|
print(f"✓ Mock flash_attn loaded successfully: {flash_attn.__version__}") |
|
print(f"✓ flash_attn.__spec__ exists: {flash_attn.__spec__ is not None}") |
|
|
|
|
|
spec = importlib.util.find_spec("flash_attn") |
|
print(f"✓ importlib.util.find_spec returns: {spec is not None}") |
|
|
|
|
|
import flash_attn.flash_attn_interface |
|
print("✓ flash_attn.flash_attn_interface loaded") |
|
|
|
|
|
import flash_attn_2_cuda |
|
print("✓ flash_attn_2_cuda loaded") |
|
except Exception as e: |
|
print(f"WARNING: Error verifying flash_attn mock: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
create_in_memory_flash_attn_mock() |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
|
|
|
|
def load_model(): |
|
try: |
|
print("\nLoading InternVL2 model...") |
|
|
|
|
|
|
|
model_path = "OpenGVLab/InternVL2" |
|
|
|
|
|
print("Downloading model shards. This may take some time...") |
|
|
|
|
|
try: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True |
|
) |
|
except Exception as e: |
|
|
|
if "flash_attn.__spec__ is not set" in str(e): |
|
print("\n❌ Flash attention error detected!") |
|
|
|
|
|
if 'flash_attn' in sys.modules: |
|
mock = sys.modules['flash_attn'] |
|
print(f"Flash mock exists: {mock}") |
|
print(f"Flash mock __spec__: {getattr(mock, '__spec__', 'NOT SET')}") |
|
else: |
|
print("flash_attn module was removed from sys.modules") |
|
|
|
|
|
print("\nCurrent state of sys.meta_path:") |
|
for i, finder in enumerate(sys.meta_path): |
|
print(f" {i}: {finder.__class__.__name__}") |
|
|
|
|
|
raise |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
use_fast=False, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
generation_config = GenerationConfig.from_pretrained( |
|
model_path, |
|
trust_remote_code=True |
|
) |
|
|
|
print("✓ Model and tokenizer loaded successfully!") |
|
return model, tokenizer, generation_config |
|
|
|
except Exception as e: |
|
print(f"\n❌ ERROR loading model: {str(e)}") |
|
traceback.print_exc() |
|
return None, None, None |
|
|
|
|
|
def load_image(image_path, processor=None): |
|
"""Load an image and prepare it for the model.""" |
|
if isinstance(image_path, str): |
|
if image_path.startswith('http'): |
|
import requests |
|
from io import BytesIO |
|
try: |
|
response = requests.get(image_path, timeout=10) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
except Exception as e: |
|
print(f"Error loading image from URL: {e}") |
|
|
|
image = Image.new('RGB', (224, 224), color='gray') |
|
else: |
|
image = Image.open(image_path).convert('RGB') |
|
else: |
|
image = image_path |
|
|
|
|
|
return image |
|
|
|
|
|
def analyze_image(model, tokenizer, image, prompt, generation_config): |
|
try: |
|
|
|
text_prompt = f"USER: <image>\n{prompt}\nASSISTANT:" |
|
|
|
|
|
inputs = tokenizer([text_prompt], return_tensors="pt") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
inputs["images"] = [image] |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
generation_config=generation_config, |
|
max_new_tokens=512, |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
assistant_response = generated_text.split("ASSISTANT:")[-1].strip() |
|
|
|
return assistant_response |
|
|
|
except Exception as e: |
|
error_msg = f"Error analyzing image: {str(e)}" |
|
traceback.print_exc() |
|
return error_msg |
|
|
|
|
|
def create_interface(): |
|
|
|
model, tokenizer, generation_config = load_model() |
|
|
|
if model is None: |
|
|
|
with gr.Blocks(title="InternVL2 Image Analysis - Error") as demo: |
|
gr.Markdown("# ❌ Error: Failed to load models") |
|
gr.Markdown("Please check the console for error details.") |
|
return demo |
|
|
|
|
|
prompts = [ |
|
"Describe this image in detail.", |
|
"What can you tell me about this image?", |
|
"Is there any text in this image? If so, can you read it?", |
|
"What is the main subject of this image?", |
|
"What emotions or feelings does this image convey?", |
|
"Describe the composition and visual elements of this image.", |
|
"Summarize what you see in this image in one paragraph." |
|
] |
|
|
|
|
|
with gr.Blocks(title="InternVL2 Image Analysis") as demo: |
|
gr.Markdown("# 🖼️ InternVL2 Image & Text Analyzer") |
|
gr.Markdown("### Upload an image and ask questions about it") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
prompt_input = gr.Dropdown( |
|
choices=prompts, |
|
value=prompts[0], |
|
label="Select a prompt or enter your own below", |
|
allow_custom_value=True |
|
) |
|
custom_prompt = gr.Textbox(label="Custom prompt", placeholder="Enter your custom prompt here...") |
|
analyze_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output = gr.Textbox(label="Analysis Results", lines=15) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/cheetah1.jpg", "What's in this image?"], |
|
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen_sink/files/lion.jpg", "Describe this animal."], |
|
], |
|
inputs=[input_image, custom_prompt], |
|
) |
|
|
|
|
|
prompt_input.change(fn=lambda x: x, inputs=prompt_input, outputs=custom_prompt) |
|
|
|
|
|
def on_analyze_click(image, prompt_text): |
|
if image is None: |
|
return "Please upload an image first." |
|
|
|
|
|
final_prompt = prompt_text if prompt_text.strip() else prompt_input |
|
|
|
result = analyze_image(model, tokenizer, image, final_prompt, generation_config) |
|
return result |
|
|
|
analyze_btn.click( |
|
fn=on_analyze_click, |
|
inputs=[input_image, custom_prompt], |
|
outputs=output |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
demo = create_interface() |
|
demo.launch(share=False, server_name="0.0.0.0") |
|
|