|
import torch |
|
import os |
|
import sys |
|
import gradio as gr |
|
from PIL import Image |
|
import traceback |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
|
|
print("=" * 50) |
|
print("InternVL2-8B IMAGE & TEXT ANALYSIS") |
|
print("=" * 50) |
|
|
|
|
|
print(f"Python version: {sys.version}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA version: {torch.version.cuda}") |
|
print(f"GPU count: {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
print(f"GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
|
|
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
|
print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB") |
|
else: |
|
print("CUDA is not available. This application requires GPU acceleration.") |
|
|
|
|
|
def load_model(): |
|
try: |
|
print("\nLoading InternVL2-8B model...") |
|
|
|
|
|
import sys |
|
import types |
|
if "flash_attn" not in sys.modules: |
|
flash_attn_module = types.ModuleType("flash_attn") |
|
flash_attn_module.__version__ = "0.0.0-disabled" |
|
sys.modules["flash_attn"] = flash_attn_module |
|
print("Created dummy flash_attn module to avoid dependency error") |
|
|
|
|
|
model_path = "OpenGVLab/InternVL2-8B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.8, |
|
repetition_penalty=1.0 |
|
) |
|
|
|
print("✓ Model and tokenizer loaded successfully!") |
|
return model, tokenizer, generation_config |
|
|
|
except Exception as e: |
|
print(f"\n❌ ERROR loading model: {str(e)}") |
|
traceback.print_exc() |
|
return None, None, None |
|
|
|
|
|
def load_image(image_path, processor=None): |
|
"""Load an image and prepare it for the model.""" |
|
if isinstance(image_path, str): |
|
image = Image.open(image_path).convert('RGB') |
|
else: |
|
image = image_path |
|
|
|
|
|
return image |
|
|
|
|
|
def analyze_image(model, tokenizer, image, prompt, generation_config): |
|
try: |
|
|
|
messages = [ |
|
{"role": "user", "content": f"{prompt}", "image": image} |
|
] |
|
|
|
|
|
response = model.chat(tokenizer, messages=messages, generation_config=generation_config) |
|
return response |
|
|
|
except Exception as e: |
|
error_msg = f"Error analyzing image: {str(e)}" |
|
traceback.print_exc() |
|
return error_msg |
|
|
|
|
|
def create_interface(): |
|
|
|
model, tokenizer, generation_config = load_model() |
|
|
|
if model is None: |
|
|
|
with gr.Blocks(title="InternVL2 Chat - Error") as demo: |
|
gr.Markdown("# ❌ Error: Failed to load models") |
|
gr.Markdown("Please check the console for error details.") |
|
return demo |
|
|
|
|
|
prompts = [ |
|
"Describe this image in detail.", |
|
"What text appears in this image? Please read and transcribe it accurately.", |
|
"Analyze the content of this image, including any text, pictures, and their relationships.", |
|
"What is the main subject of this image?", |
|
"Is there any text in this image? If so, what does it say?", |
|
"Describe the layout and visual elements of this document.", |
|
"Summarize the key information presented in this image." |
|
] |
|
|
|
|
|
with gr.Blocks(title="InternVL2 Image Analysis") as demo: |
|
gr.Markdown("# 🖼️ InternVL2-8B Image & Text Analyzer") |
|
gr.Markdown("### Upload an image to analyze its visual content and text") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
prompt_input = gr.Dropdown( |
|
choices=prompts, |
|
value=prompts[0], |
|
label="Select a prompt or enter your own below", |
|
allow_custom_value=True |
|
) |
|
custom_prompt = gr.Textbox(label="Custom prompt", placeholder="Enter your custom prompt here...") |
|
analyze_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output = gr.Textbox(label="Analysis Results", lines=15) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/blip-image-demo.png", "What's in this image?"], |
|
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/assets/130_vision_language_pretraining/fig_vision_language.jpg", "Describe this diagram in detail."], |
|
], |
|
inputs=[input_image, custom_prompt], |
|
) |
|
|
|
|
|
prompt_input.change(fn=lambda x: x, inputs=prompt_input, outputs=custom_prompt) |
|
|
|
|
|
def on_analyze_click(image, prompt_text): |
|
if image is None: |
|
return "Please upload an image first." |
|
|
|
|
|
final_prompt = prompt_text if prompt_text.strip() else prompt_input |
|
|
|
result = analyze_image(model, tokenizer, image, final_prompt, generation_config) |
|
return result |
|
|
|
analyze_btn.click( |
|
fn=on_analyze_click, |
|
inputs=[input_image, custom_prompt], |
|
outputs=output |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
demo = create_interface() |
|
demo.launch(share=False, server_name="0.0.0.0") |