Spaces:
Sleeping
Sleeping
File size: 8,162 Bytes
bec3e61 3e65d60 bec3e61 3e65d60 bec3e61 3e65d60 bec3e61 3e65d60 bec3e61 ff9fa0c bec3e61 3e65d60 ff9fa0c c3783d2 ff9fa0c 3e65d60 ff9fa0c 25871af ff9fa0c 25871af ff9fa0c 3e72199 3e65d60 07f8442 58835ab 07f8442 bec3e61 07f8442 bec3e61 3e65d60 bec3e61 b76fcf2 bec3e61 b76fcf2 bec3e61 b76fcf2 bec3e61 b76fcf2 bec3e61 3e65d60 bec3e61 3e65d60 bec3e61 3e65d60 bec3e61 b76fcf2 bec3e61 3e65d60 bec3e61 2336cff cd7b378 b76fcf2 bec3e61 2336cff bec3e61 2336cff bec3e61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
from unittest.mock import patch
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch
from PIL import Image, ImageDraw
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Define colormap
colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan']
# Workaround to fix import issues for Florence-2 model
def workaround_fixed_get_imports(filename):
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn") # Remove 'flash_attn' if it's causing issues
return imports
def initialize_model():
# Check if CUDA (GPU) is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Patch the get_imports function and load the model and processor
with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
try:
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
print("Model and processor loaded successfully.")
return model, processor, device
except Exception as e:
print(f"An error occurred while loading the model or processor: {e}")
return None, None, device
# Initialize the model and processor
model, processor, device = initialize_model()
# def run_example(task_prompt, image, text_input=None):
# prompt = task_prompt if text_input is None else task_prompt + text_input
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# with torch.inference_mode():
# generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
def run_example(task_prompt, image, text_input=None):
if text_input is None:
prompt = task_prompt
else:
prompt = task_prompt + text_input
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def fig_to_pil(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
def plot_bbox_img(image, data):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
if 'bboxes' in data and 'labels' in data:
bboxes, labels = data['bboxes'], data['labels']
else:
return fig_to_pil(fig)
for bbox, label in zip(bboxes, labels):
x1, y1, x2, y2 = bbox
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none')
ax.add_patch(rect)
plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8))
ax.axis('off')
return fig_to_pil(fig)
def draw_polygons(image, prediction, fill_mask=False):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
color = random.choice(colormap)
for polygon in polygons:
if isinstance(polygon[0], (int, float)):
polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2)
ax.add_patch(poly)
if polygon:
plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
ax.axis('off')
return fig_to_pil(fig)
def process_image(image, task, text):
task_mapping = {
"Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
"Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
"More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result.get('<MORE_DETAILED_CAPTION>', 'Failed to generate detailed caption'), image)),
"Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox_img(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
"Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox_img(image, result['<OD>']))),
"Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
"Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
"OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
}
if task in task_mapping:
prompt, process_func = task_mapping[task]
print(f"Task: {task}, Prompt: {prompt}") # Debugging statement
result = run_example(prompt, image, text)
print(f"Result: {result}") # Debugging statement
return process_func(result)
else:
return "", image
image_path_1 = "Fiat-500-9-scaled.jpg"
image_path_2 = "OCR_2.png"
with gr.Blocks() as demo:
gr.HTML("<h1><center>Florence-2 Vision</center></h1>")
with gr.Tab(label="Image"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture", type="pil")
task_dropdown = gr.Dropdown(
choices=["Caption", "Detailed Caption", "More Detailed Caption", "Object Detection", "Caption to Phrase Grounding", "Referring Expression Segmentation", "Region to Segmentation", "OCR"],
label="Task", value="Caption"
)
text_input = gr.Textbox(label="Text Input (Optional)", visible=False)
gr.Examples(
examples=[
[image_path_1, "Detailed Caption", ""],
[image_path_1, "Object Detection", ""],
[image_path_1, "More Detailed Caption", ""],
[image_path_1, "Caption to Phrase Grounding", "A white car parked on the street."],
[image_path_1, "Region to Segmentation", ""],
[image_path_2, "OCR", ""]
],
inputs=[input_img, task_dropdown, text_input],
cache_examples=False # Set this to False if caching is not needed
)
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Results")
output_image = gr.Image(label="Image", type="pil")
def update_text_input(task):
return gr.update(visible=task in ["Region to Segmentation"])
task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
demo.launch()
|