Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
import requests | |
from transformers import DetrImageProcessor | |
from transformers import DetrForObjectDetection | |
from random import choice | |
import matplotlib.pyplot as plt | |
import io | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
def generate_output_figure(in_pil_img, in_results): | |
plt.figure(figsize=(16, 10)) | |
ax = plt.gca() | |
for prediction in in_results: | |
selected_color = choice(COLORS) | |
x, y = prediction['box']['xmin'], prediction['box']['ymin'], | |
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin'] | |
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3)) | |
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic) | |
plt.axis("off") | |
return plt.gcf() | |
def detect(image): | |
encoding = processor(image, return_tensors='pt') | |
print(encoding.keys()) | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
output_figure = generate_output_figure(image, outputs) | |
buf = io.BytesIO() | |
output_figure.savefig(buf, bbox_inches='tight') | |
buf.seek(0) | |
output_pil_img = Image.open(buf) | |
print(output_pil_img) | |
return output_pil_img | |
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil")) | |
demo.launch() |