File size: 8,162 Bytes
bec3e61
 
3e65d60
bec3e61
 
3e65d60
bec3e61
 
 
 
 
 
3e65d60
bec3e61
 
3e65d60
bec3e61
 
 
 
 
ff9fa0c
 
bec3e61
3e65d60
ff9fa0c
 
c3783d2
ff9fa0c
3e65d60
ff9fa0c
 
 
 
 
 
25871af
ff9fa0c
 
25871af
ff9fa0c
 
3e72199
3e65d60
07f8442
 
 
 
 
 
 
 
58835ab
07f8442
 
 
 
 
 
 
 
 
 
 
 
 
bec3e61
07f8442
 
 
 
 
 
 
bec3e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e65d60
bec3e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b76fcf2
bec3e61
 
 
 
 
 
 
 
 
b76fcf2
bec3e61
b76fcf2
bec3e61
 
 
 
b76fcf2
bec3e61
 
 
 
 
3e65d60
bec3e61
3e65d60
bec3e61
 
 
 
 
3e65d60
bec3e61
 
 
 
 
b76fcf2
bec3e61
 
 
 
 
 
3e65d60
bec3e61
 
 
 
 
2336cff
cd7b378
b76fcf2
bec3e61
2336cff
bec3e61
2336cff
bec3e61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
from unittest.mock import patch
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch
from PIL import Image, ImageDraw
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io

# Define colormap
colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan']

# Workaround to fix import issues for Florence-2 model
def workaround_fixed_get_imports(filename):
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    if "flash_attn" in imports:
        imports.remove("flash_attn")  # Remove 'flash_attn' if it's causing issues
    return imports

def initialize_model():
    # Check if CUDA (GPU) is available and set the device accordingly
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Patch the get_imports function and load the model and processor
    with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
        try:
            model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
            processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
            print("Model and processor loaded successfully.")
            return model, processor, device
        except Exception as e:
            print(f"An error occurred while loading the model or processor: {e}")
            return None, None, device

# Initialize the model and processor
model, processor, device = initialize_model()

# def run_example(task_prompt, image, text_input=None):
#     prompt = task_prompt if text_input is None else task_prompt + text_input
#     inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
#     with torch.inference_mode():
#         generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
#     generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
#     return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))

def run_example(task_prompt, image, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )

    return parsed_answer

def fig_to_pil(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
    buf.seek(0)
    return Image.open(buf)

def plot_bbox_img(image, data):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)

    if 'bboxes' in data and 'labels' in data:
        bboxes, labels = data['bboxes'], data['labels']
    else:
        return fig_to_pil(fig)

    for bbox, label in zip(bboxes, labels):
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none')
        ax.add_patch(rect)
        plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8))
    
    ax.axis('off')
    return fig_to_pil(fig)

def draw_polygons(image, prediction, fill_mask=False):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)
    for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
        color = random.choice(colormap)
        for polygon in polygons:
            if isinstance(polygon[0], (int, float)):
                polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
            poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2)
            ax.add_patch(poly)
        if polygon:
            plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
    ax.axis('off')
    return fig_to_pil(fig)

def process_image(image, task, text):
    task_mapping = {
        "Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
        "Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
        "More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result.get('<MORE_DETAILED_CAPTION>', 'Failed to generate detailed caption'), image)),
        "Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox_img(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
        "Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox_img(image, result['<OD>']))),
        "Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
        "Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
        "OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
    }

    if task in task_mapping:
        prompt, process_func = task_mapping[task]
        print(f"Task: {task}, Prompt: {prompt}")  # Debugging statement
        result = run_example(prompt, image, text)
        print(f"Result: {result}")  # Debugging statement
        return process_func(result)
    else:
        return "", image

        
image_path_1 = "Fiat-500-9-scaled.jpg"
image_path_2 = "OCR_2.png"

with gr.Blocks() as demo:
    gr.HTML("<h1><center>Florence-2 Vision</center></h1>")
    
    with gr.Tab(label="Image"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Picture", type="pil")
                task_dropdown = gr.Dropdown(
                    choices=["Caption", "Detailed Caption", "More Detailed Caption", "Object Detection", "Caption to Phrase Grounding", "Referring Expression Segmentation", "Region to Segmentation", "OCR"],
                    label="Task", value="Caption"
                )
                text_input = gr.Textbox(label="Text Input (Optional)", visible=False)
                gr.Examples(
                    examples=[
                        [image_path_1, "Detailed Caption", ""],
                        [image_path_1, "Object Detection", ""],
                        [image_path_1, "More Detailed Caption", ""],
                        [image_path_1, "Caption to Phrase Grounding", "A white car parked on the street."],
                        [image_path_1, "Region to Segmentation", ""],
                        [image_path_2, "OCR", ""]
                    ],
                    inputs=[input_img, task_dropdown, text_input],
                    cache_examples=False  # Set this to False if caching is not needed
                )
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.Textbox(label="Results")
                output_image = gr.Image(label="Image", type="pil")

    def update_text_input(task):
     return gr.update(visible=task in ["Region to Segmentation"])


    task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)

    submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])

demo.launch()