sergiopaniego's picture
Updated app
36b4542
raw
history blame
2.01 kB
import gradio as gr
import spaces
import torch
from PIL import Image
import requests
from transformers import DetrImageProcessor
from transformers import DetrForObjectDetection
from random import choice
import matplotlib.pyplot as plt
import io
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def get_output_figure(pil_img, scores, labels, boxes):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
text = f'{model.config.id2label[label]}: {score:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
return plt.gcf()
@spaces.GPU
def detect(image):
encoding = processor(image, return_tensors='pt')
print(encoding.keys())
with torch.no_grad():
outputs = model(**encoding)
#print(outputs)
width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
results = postprocessed_outputs[0]
print(results)
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
buf = io.BytesIO()
output_figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
print(output_pil_img)
return output_pil_img
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image predictions", type="pil"))
demo.launch()