Spaces:
Sleeping
Sleeping
File size: 1,999 Bytes
d0f5a61 4317393 0b822c2 4317393 f0585ee 0b822c2 1a933f0 0b822c2 1a933f0 0b822c2 1a933f0 0b822c2 4317393 d0f5a61 f0585ee 4317393 0b822c2 1a933f0 19a9827 1a933f0 0b822c2 d0f5a61 f0585ee d0f5a61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import spaces
import torch
from PIL import Image
import requests
from transformers import DetrImageProcessor
from transformers import DetrForObjectDetection
from random import choice
import matplotlib.pyplot as plt
import io
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def get_output_figure(pil_img, scores, labels, boxes):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
text = f'{model.config.id2label[label]}: {score:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
return plt.gcf()
@spaces.GPU
def detect(image):
encoding = processor(image, return_tensors='pt')
print(encoding.keys())
with torch.no_grad():
outputs = model(**encoding)
#print(outputs)
width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
results = postprocessed_outputs[0]
print(results)
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
buf = io.BytesIO()
output_figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
print(output_pil_img)
return output_pil_img
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
demo.launch() |