SkalskiP's picture
make sure the space code runs with cuda
7e2035e
raw
history blame
7.58 kB
import gradio as gr
import supervision as sv
import torch
import spaces
from utils.annotate import annotate_with_boxes
from utils.models import load_models, run_inference, CHECKPOINTS
from utils.tasks import TASK_NAMES, TASKS, OBJECT_DETECTION_TASK_NAME, \
CAPTION_TASK_NAMES, CAPTION_TASK_NAME, DETAILED_CAPTION_TASK_NAME, \
MORE_DETAILED_CAPTION_TASK_NAME, OCR_WITH_REGION_TASK_NAME, OCR_TASK_NAME
MARKDOWN = """
# Better Florence-2 Playground πŸ”₯
<div>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/florence-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://arxiv.org/abs/2311.06242">
<img src="https://img.shields.io/badge/arXiv-2311.06242-b31b1b.svg" alt="arXiv" style="display:inline-block;">
</a>
<a href="https://www.youtube.com/watch?v=i3KjYgxNH6w">
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
</a>
</div>
Florence-2 is a lightweight vision-language model open-sourced by Microsoft under the
MIT license. The model demonstrates strong zero-shot and fine-tuning capabilities
across tasks such as captioning, object detection, grounding, and segmentation.
The model takes images and task prompts as input, generating the desired results in
text format. It uses a DaViT vision encoder to convert images into visual token
embeddings. These are then concatenated with BERT-generated text embeddings and
processed by a transformer-based multi-modal encoder-decoder to generate the response.
"""
OBJECT_DETECTION_EXAMPLES = [
["microsoft/Florence-2-large-ft", OBJECT_DETECTION_TASK_NAME, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"]
]
CAPTION_EXAMPLES = [
["microsoft/Florence-2-large-ft", CAPTION_TASK_NAME, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"],
["microsoft/Florence-2-large-ft", DETAILED_CAPTION_TASK_NAME, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"],
["microsoft/Florence-2-large-ft", MORE_DETAILED_CAPTION_TASK_NAME, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"]
]
OCR_EXAMPLES = [
["microsoft/Florence-2-large-ft", OCR_TASK_NAME, "https://media.roboflow.com/notebooks/examples/handwritten-text.jpg"],
]
OCR_WITH_REGION_EXAMPLES = [
["microsoft/Florence-2-large-ft", OCR_WITH_REGION_TASK_NAME, "https://media.roboflow.com/notebooks/examples/handwritten-text.jpg"],
["microsoft/Florence-2-large-ft", OCR_WITH_REGION_TASK_NAME, "https://media.roboflow.com/inference/license_plate_1.jpg"]
]
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = "cuda"
MODELS, PROCESSORS = load_models(DEVICE)
@spaces.GPU
def process(checkpoint_dropdown, task_dropdown, image_input):
model = MODELS[checkpoint_dropdown]
processor = PROCESSORS[checkpoint_dropdown]
task = TASKS[task_dropdown]
if task_dropdown in [OBJECT_DETECTION_TASK_NAME, OCR_WITH_REGION_TASK_NAME]:
_, response = run_inference(
model, processor, DEVICE, image_input, task)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2, result=response, resolution_wh=image_input.size)
return annotate_with_boxes(image_input, detections)
elif task_dropdown in CAPTION_TASK_NAMES or task_dropdown == OCR_TASK_NAME:
_, response = run_inference(
model, processor, DEVICE, image_input, task)
return response[task]
image_output_component = None
text_output_component = None
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
checkpoint_dropdown_component = gr.Dropdown(
choices=CHECKPOINTS,
value=CHECKPOINTS[0],
label="Model", info="Select a Florence 2 model to use.")
task_dropdown_component = gr.Dropdown(
choices=TASK_NAMES,
value=TASK_NAMES[0],
label="Task", info="Select a task to perform with the model.")
with gr.Row():
with gr.Column():
image_input_component = gr.Image(type='pil', label='Image Input')
submit_button_component = gr.Button(value='Submit', variant='primary')
with gr.Column():
@gr.render(inputs=task_dropdown_component)
def show_output(text):
global image_output_component
global text_output_component
if text in [OBJECT_DETECTION_TASK_NAME, OCR_WITH_REGION_TASK_NAME]:
image_output_component = gr.Image(type='pil', label='Image Output')
submit_button_component.click(
fn=process,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=image_output_component
)
elif text in CAPTION_TASK_NAMES or text == OCR_TASK_NAME:
text_output_component = gr.Textbox(label='Caption Output')
submit_button_component.click(
fn=process,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=text_output_component
)
@gr.render(inputs=task_dropdown_component)
def show_examples(text):
global image_output_component
global text_output_component
if text == OBJECT_DETECTION_TASK_NAME:
gr.Examples(
fn=process,
examples=OBJECT_DETECTION_EXAMPLES,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=image_output_component
)
elif text in CAPTION_TASK_NAMES:
gr.Examples(
fn=process,
examples=CAPTION_EXAMPLES,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=text_output_component
)
elif text == OCR_TASK_NAME:
gr.Examples(
fn=process,
examples=OCR_EXAMPLES,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=text_output_component
)
elif text == OCR_WITH_REGION_TASK_NAME:
gr.Examples(
fn=process,
examples=OCR_WITH_REGION_EXAMPLES,
inputs=[
checkpoint_dropdown_component,
task_dropdown_component,
image_input_component
],
outputs=image_output_component
)
demo.launch(debug=False, show_error=True, max_threads=1)