File size: 1,999 Bytes
d0f5a61
 
 
 
4317393
 
 
0b822c2
 
 
 
 
4317393
f0585ee
0b822c2
 
 
 
 
 
1a933f0
0b822c2
1a933f0
0b822c2
1a933f0
 
 
 
 
 
 
0b822c2
 
4317393
d0f5a61
 
f0585ee
4317393
 
0b822c2
 
 
 
1a933f0
 
 
 
 
 
19a9827
1a933f0
0b822c2
 
 
 
 
 
 
 
 
d0f5a61
f0585ee
d0f5a61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import spaces
import torch

from PIL import Image
import requests
from transformers import DetrImageProcessor
from transformers import DetrForObjectDetection
from random import choice
import matplotlib.pyplot as plt
import io


processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")


COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def get_output_figure(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        text = f'{model.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15, 
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')

    return plt.gcf()


@spaces.GPU
def detect(image):
    encoding = processor(image, return_tensors='pt')
    print(encoding.keys())
    
    with torch.no_grad():
        outputs = model(**encoding)

    #print(outputs)
    width, height = image.size
    postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
    results = postprocessed_outputs[0]

    print(results)

    output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])

    buf = io.BytesIO()
    output_figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    print(output_pil_img)

    return output_pil_img

demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
demo.launch()