Spaces:
Sleeping
Sleeping
import os | |
from unittest.mock import patch | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from transformers.dynamic_module_utils import get_imports | |
import torch | |
from PIL import Image, ImageDraw | |
import random | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import io | |
# Define colormap | |
colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan'] | |
# Workaround to fix import issues for Florence-2 model | |
def workaround_fixed_get_imports(filename): | |
if not str(filename).endswith("/modeling_florence2.py"): | |
return get_imports(filename) | |
imports = get_imports(filename) | |
if "flash_attn" in imports: | |
imports.remove("flash_attn") # Remove 'flash_attn' if it's causing issues | |
return imports | |
def initialize_model(): | |
# Check if CUDA (GPU) is available and set the device accordingly | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Patch the get_imports function and load the model and processor | |
with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports): | |
try: | |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval() | |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True) | |
print("Model and processor loaded successfully.") | |
return model, processor, device | |
except Exception as e: | |
print(f"An error occurred while loading the model or processor: {e}") | |
return None, None, device | |
# Initialize the model and processor | |
model, processor, device = initialize_model() | |
# def run_example(task_prompt, image, text_input=None): | |
# prompt = task_prompt if text_input is None else task_prompt + text_input | |
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
# with torch.inference_mode(): | |
# generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3) | |
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
# return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1])) | |
def run_example(task_prompt, image, text_input=None): | |
if text_input is None: | |
prompt = task_prompt | |
else: | |
prompt = task_prompt + text_input | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer | |
def fig_to_pil(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
buf.seek(0) | |
return Image.open(buf) | |
def plot_bbox_img(image, data): | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(image) | |
if 'bboxes' in data and 'labels' in data: | |
bboxes, labels = data['bboxes'], data['labels'] | |
else: | |
return fig_to_pil(fig) | |
for bbox, label in zip(bboxes, labels): | |
x1, y1, x2, y2 = bbox | |
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none') | |
ax.add_patch(rect) | |
plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8)) | |
ax.axis('off') | |
return fig_to_pil(fig) | |
def draw_polygons(image, prediction, fill_mask=False): | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(image) | |
for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])): | |
color = random.choice(colormap) | |
for polygon in polygons: | |
if isinstance(polygon[0], (int, float)): | |
polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)] | |
poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2) | |
ax.add_patch(poly) | |
if polygon: | |
plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8)) | |
ax.axis('off') | |
return fig_to_pil(fig) | |
def process_image(image, task, text): | |
task_mapping = { | |
"Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)), | |
"Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)), | |
"More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result.get('<MORE_DETAILED_CAPTION>', 'Failed to generate detailed caption'), image)), | |
"Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox_img(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))), | |
"Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox_img(image, result['<OD>']))), | |
"Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))), | |
"Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))), | |
"OCR": ("<OCR>", lambda result: (result['<OCR>'], image)), | |
} | |
if task in task_mapping: | |
prompt, process_func = task_mapping[task] | |
print(f"Task: {task}, Prompt: {prompt}") # Debugging statement | |
result = run_example(prompt, image, text) | |
print(f"Result: {result}") # Debugging statement | |
return process_func(result) | |
else: | |
return "", image | |
image_path_1 = "Fiat-500-9-scaled.jpg" | |
image_path_2 = "OCR_2.png" | |
with gr.Blocks() as demo: | |
gr.HTML("<h1><center>Florence-2 Vision</center></h1>") | |
with gr.Tab(label="Image"): | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input Picture", type="pil") | |
task_dropdown = gr.Dropdown( | |
choices=["Caption", "Detailed Caption", "More Detailed Caption", "Object Detection", "Caption to Phrase Grounding", "Referring Expression Segmentation", "Region to Segmentation", "OCR"], | |
label="Task", value="Caption" | |
) | |
text_input = gr.Textbox(label="Text Input (Optional)", visible=False) | |
gr.Examples( | |
examples=[ | |
[image_path_1, "Detailed Caption", ""], | |
[image_path_1, "Object Detection", ""], | |
[image_path_1, "More Detailed Caption", ""], | |
[image_path_1, "Caption to Phrase Grounding", "A white car parked on the street."], | |
[image_path_1, "Region to Segmentation", ""], | |
[image_path_2, "OCR", ""] | |
], | |
inputs=[input_img, task_dropdown, text_input], | |
cache_examples=False # Set this to False if caching is not needed | |
) | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Results") | |
output_image = gr.Image(label="Image", type="pil") | |
def update_text_input(task): | |
return gr.update(visible=task in ["Region to Segmentation"]) | |
task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input) | |
submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image]) | |
demo.launch() | |