Spaces:
Sleeping
Sleeping
Commit
·
047a580
1
Parent(s):
9aacea4
Updated app
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ def get_output_figure(pil_img, scores, labels, boxes):
|
|
23 |
plt.imshow(pil_img)
|
24 |
ax = plt.gca()
|
25 |
colors = COLORS * 100
|
26 |
-
for score, label, (xmin, ymin, xmax, ymax), c in zip
|
27 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
|
28 |
text = f'{model.config.id2label[label]}: {score:0.2f}'
|
29 |
ax.text(xmin, ymin, text, fontsize=15,
|
@@ -89,12 +89,10 @@ def detect(image):
|
|
89 |
with torch.no_grad():
|
90 |
outputs = model(**encoding)
|
91 |
|
92 |
-
#print(outputs)
|
93 |
width, height = image.size
|
94 |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
|
95 |
results = postprocessed_outputs[0]
|
96 |
|
97 |
-
#print(results)
|
98 |
|
99 |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
|
100 |
|
@@ -110,15 +108,24 @@ def detect(image):
|
|
110 |
buf.seek(0)
|
111 |
output_pil_img_attn = Image.open(buf)
|
112 |
|
113 |
-
#print(output_pil_img)
|
114 |
-
|
115 |
return output_pil_img, output_pil_img_attn
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
plt.imshow(pil_img)
|
24 |
ax = plt.gca()
|
25 |
colors = COLORS * 100
|
26 |
+
for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
|
27 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
|
28 |
text = f'{model.config.id2label[label]}: {score:0.2f}'
|
29 |
ax.text(xmin, ymin, text, fontsize=15,
|
|
|
89 |
with torch.no_grad():
|
90 |
outputs = model(**encoding)
|
91 |
|
|
|
92 |
width, height = image.size
|
93 |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
|
94 |
results = postprocessed_outputs[0]
|
95 |
|
|
|
96 |
|
97 |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
|
98 |
|
|
|
108 |
buf.seek(0)
|
109 |
output_pil_img_attn = Image.open(buf)
|
110 |
|
|
|
|
|
111 |
return output_pil_img, output_pil_img_attn
|
112 |
|
113 |
+
with gr.Blocks() as demo:
|
114 |
+
gr.Markdown("# Object detection with DETR")
|
115 |
+
gr.Markdown(
|
116 |
+
"""
|
117 |
+
This applciation uses DETR (DEtection TRansformers) to detect objects on images.
|
118 |
+
You can load an image and see the predictions for the objects detected along with the attention weights.
|
119 |
+
"""
|
120 |
+
)
|
121 |
+
|
122 |
+
gr.Interface(
|
123 |
+
fn=detect,
|
124 |
+
inputs=gr.Image(label="Input image", type="pil"),
|
125 |
+
outputs=[
|
126 |
+
gr.Image(label="Output prediction", type="pil"),
|
127 |
+
gr.Image(label="Attention weights", type="pil")
|
128 |
+
]
|
129 |
+
).launch()
|
130 |
+
|
131 |
+
demo.launch()
|