mknolan's picture
Upload app.py with huggingface_hub
920f22f verified
raw
history blame
7 kB
import torch
import os
import sys
import gradio as gr
from PIL import Image
import traceback
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
print("=" * 50)
print("InternVL2-8B IMAGE & TEXT ANALYSIS")
print("=" * 50)
# System information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
# Memory info
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
else:
print("CUDA is not available. This application requires GPU acceleration.")
# Create a function to load the model
def load_model():
try:
print("\nLoading InternVL2-8B model...")
# Create a fake flash_attn module to avoid dependency errors
import sys
import types
if "flash_attn" not in sys.modules:
flash_attn_module = types.ModuleType("flash_attn")
flash_attn_module.__version__ = "0.0.0-disabled"
sys.modules["flash_attn"] = flash_attn_module
print("Created dummy flash_attn module to avoid dependency error")
# Load the model and tokenizer
model_path = "OpenGVLab/InternVL2-8B"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# Define generation config
generation_config = GenerationConfig(
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.8,
repetition_penalty=1.0
)
print("✓ Model and tokenizer loaded successfully!")
return model, tokenizer, generation_config
except Exception as e:
print(f"\n❌ ERROR loading model: {str(e)}")
traceback.print_exc()
return None, None, None
# Helper function to load and process an image
def load_image(image_path, processor=None):
"""Load an image and prepare it for the model."""
if isinstance(image_path, str):
image = Image.open(image_path).convert('RGB')
else:
image = image_path
# The model handles image processing internally
return image
# Function to analyze an image with text
def analyze_image(model, tokenizer, image, prompt, generation_config):
try:
# Process the conversation
messages = [
{"role": "user", "content": f"{prompt}", "image": image}
]
# Generate a response
response = model.chat(tokenizer, messages=messages, generation_config=generation_config)
return response
except Exception as e:
error_msg = f"Error analyzing image: {str(e)}"
traceback.print_exc()
return error_msg
# Create the Gradio interface
def create_interface():
# Load model at startup
model, tokenizer, generation_config = load_model()
if model is None:
# If model loading failed, create a simple error interface
with gr.Blocks(title="InternVL2 Chat - Error") as demo:
gr.Markdown("# ❌ Error: Failed to load models")
gr.Markdown("Please check the console for error details.")
return demo
# Predefined prompts for analysis
prompts = [
"Describe this image in detail.",
"What text appears in this image? Please read and transcribe it accurately.",
"Analyze the content of this image, including any text, pictures, and their relationships.",
"What is the main subject of this image?",
"Is there any text in this image? If so, what does it say?",
"Describe the layout and visual elements of this document.",
"Summarize the key information presented in this image."
]
# Create the full interface
with gr.Blocks(title="InternVL2 Image Analysis") as demo:
gr.Markdown("# 🖼️ InternVL2-8B Image & Text Analyzer")
gr.Markdown("### Upload an image to analyze its visual content and text")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Dropdown(
choices=prompts,
value=prompts[0],
label="Select a prompt or enter your own below",
allow_custom_value=True
)
custom_prompt = gr.Textbox(label="Custom prompt", placeholder="Enter your custom prompt here...")
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Analysis Results", lines=15)
# Example images
gr.Examples(
examples=[
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/blip-image-demo.png", "What's in this image?"],
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/assets/130_vision_language_pretraining/fig_vision_language.jpg", "Describe this diagram in detail."],
],
inputs=[input_image, custom_prompt],
)
# When prompt dropdown changes, update custom prompt
prompt_input.change(fn=lambda x: x, inputs=prompt_input, outputs=custom_prompt)
# Set up the click event for analysis
def on_analyze_click(image, prompt_text):
if image is None:
return "Please upload an image first."
# Use either the dropdown selection or custom prompt
final_prompt = prompt_text if prompt_text.strip() else prompt_input
result = analyze_image(model, tokenizer, image, final_prompt, generation_config)
return result
analyze_btn.click(
fn=on_analyze_click,
inputs=[input_image, custom_prompt],
outputs=output
)
return demo
# Main function
if __name__ == "__main__":
# Set environment variable for better GPU memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Create and launch the interface
demo = create_interface()
demo.launch(share=False, server_name="0.0.0.0")