Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
import requests | |
from transformers import DetrImageProcessor | |
from transformers import DetrForObjectDetection | |
from random import choice | |
import matplotlib.pyplot as plt | |
import io | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
def get_output_figure(pil_img, scores, labels, boxes): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) | |
text = f'{model.config.id2label[label]}: {score:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
return plt.gcf() | |
def detect(image): | |
encoding = processor(image, return_tensors='pt') | |
print(encoding.keys()) | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
#print(outputs) | |
width, height = image.size | |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9) | |
results = postprocessed_outputs[0] | |
print(results) | |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes']) | |
buf = io.BytesIO() | |
output_figure.savefig(buf, bbox_inches='tight') | |
buf.seek(0) | |
output_pil_img = Image.open(buf) | |
print(output_pil_img) | |
return output_pil_img | |
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image predictions", type="pil")) | |
demo.launch() |