Spaces:
Sleeping
Sleeping
Commit
·
0b822c2
1
Parent(s):
f0585ee
Updated app
Browse files
app.py
CHANGED
@@ -5,18 +5,55 @@ import torch
|
|
5 |
from PIL import Image
|
6 |
import requests
|
7 |
from transformers import DetrImageProcessor
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
zero = torch.Tensor([0]).cuda()
|
12 |
-
print(zero.device) # <-- 'cpu' 🤔
|
13 |
|
14 |
@spaces.GPU
|
15 |
def detect(image):
|
16 |
encoding = processor(image, return_tensors='pt')
|
17 |
print(encoding.keys())
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
|
22 |
demo.launch()
|
|
|
5 |
from PIL import Image
|
6 |
import requests
|
7 |
from transformers import DetrImageProcessor
|
8 |
+
from transformers import DetrForObjectDetection
|
9 |
+
from random import choice
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import io
|
12 |
+
|
13 |
|
14 |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
15 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
16 |
+
|
17 |
+
|
18 |
+
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
19 |
+
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
20 |
+
|
21 |
+
def generate_output_figure(in_pil_img, in_results):
|
22 |
+
plt.figure(figsize=(16, 10))
|
23 |
+
ax = plt.gca()
|
24 |
+
|
25 |
+
for prediction in in_results:
|
26 |
+
selected_color = choice(COLORS)
|
27 |
+
|
28 |
+
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
|
29 |
+
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
|
30 |
+
|
31 |
+
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
|
32 |
+
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
|
33 |
+
|
34 |
+
plt.axis("off")
|
35 |
+
|
36 |
+
return plt.gcf()
|
37 |
|
|
|
|
|
38 |
|
39 |
@spaces.GPU
|
40 |
def detect(image):
|
41 |
encoding = processor(image, return_tensors='pt')
|
42 |
print(encoding.keys())
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
outputs = model(**encoding)
|
46 |
+
|
47 |
+
output_figure = generate_output_figure(image, outputs)
|
48 |
+
|
49 |
+
buf = io.BytesIO()
|
50 |
+
output_figure.savefig(buf, bbox_inches='tight')
|
51 |
+
buf.seek(0)
|
52 |
+
output_pil_img = Image.open(buf)
|
53 |
+
|
54 |
+
print(output_pil_img)
|
55 |
+
|
56 |
+
return output_pil_img
|
57 |
|
58 |
demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
|
59 |
demo.launch()
|