gizemsarsinlar's picture
Update app.py
58835ab verified
import os
from unittest.mock import patch
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch
from PIL import Image, ImageDraw
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Define colormap
colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan']
# Workaround to fix import issues for Florence-2 model
def workaround_fixed_get_imports(filename):
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn") # Remove 'flash_attn' if it's causing issues
return imports
def initialize_model():
# Check if CUDA (GPU) is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Patch the get_imports function and load the model and processor
with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
try:
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
print("Model and processor loaded successfully.")
return model, processor, device
except Exception as e:
print(f"An error occurred while loading the model or processor: {e}")
return None, None, device
# Initialize the model and processor
model, processor, device = initialize_model()
# def run_example(task_prompt, image, text_input=None):
# prompt = task_prompt if text_input is None else task_prompt + text_input
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# with torch.inference_mode():
# generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
def run_example(task_prompt, image, text_input=None):
if text_input is None:
prompt = task_prompt
else:
prompt = task_prompt + text_input
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def fig_to_pil(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
def plot_bbox_img(image, data):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
if 'bboxes' in data and 'labels' in data:
bboxes, labels = data['bboxes'], data['labels']
else:
return fig_to_pil(fig)
for bbox, label in zip(bboxes, labels):
x1, y1, x2, y2 = bbox
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none')
ax.add_patch(rect)
plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8))
ax.axis('off')
return fig_to_pil(fig)
def draw_polygons(image, prediction, fill_mask=False):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
color = random.choice(colormap)
for polygon in polygons:
if isinstance(polygon[0], (int, float)):
polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2)
ax.add_patch(poly)
if polygon:
plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
ax.axis('off')
return fig_to_pil(fig)
def process_image(image, task, text):
task_mapping = {
"Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
"Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
"More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result.get('<MORE_DETAILED_CAPTION>', 'Failed to generate detailed caption'), image)),
"Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox_img(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
"Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox_img(image, result['<OD>']))),
"Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
"Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
"OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
}
if task in task_mapping:
prompt, process_func = task_mapping[task]
print(f"Task: {task}, Prompt: {prompt}") # Debugging statement
result = run_example(prompt, image, text)
print(f"Result: {result}") # Debugging statement
return process_func(result)
else:
return "", image
image_path_1 = "Fiat-500-9-scaled.jpg"
image_path_2 = "OCR_2.png"
with gr.Blocks() as demo:
gr.HTML("<h1><center>Florence-2 Vision</center></h1>")
with gr.Tab(label="Image"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture", type="pil")
task_dropdown = gr.Dropdown(
choices=["Caption", "Detailed Caption", "More Detailed Caption", "Object Detection", "Caption to Phrase Grounding", "Referring Expression Segmentation", "Region to Segmentation", "OCR"],
label="Task", value="Caption"
)
text_input = gr.Textbox(label="Text Input (Optional)", visible=False)
gr.Examples(
examples=[
[image_path_1, "Detailed Caption", ""],
[image_path_1, "Object Detection", ""],
[image_path_1, "More Detailed Caption", ""],
[image_path_1, "Caption to Phrase Grounding", "A white car parked on the street."],
[image_path_1, "Region to Segmentation", ""],
[image_path_2, "OCR", ""]
],
inputs=[input_img, task_dropdown, text_input],
cache_examples=False # Set this to False if caching is not needed
)
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Results")
output_image = gr.Image(label="Image", type="pil")
def update_text_input(task):
return gr.update(visible=task in ["Region to Segmentation"])
task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
demo.launch()