Spaces:
Sleeping
Sleeping
Commit
·
1a933f0
1
Parent(s):
19a9827
Updated app
Browse files
app.py
CHANGED
@@ -18,22 +18,17 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
|
18 |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
19 |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
20 |
|
21 |
-
def
|
22 |
plt.figure(figsize=(16, 10))
|
23 |
-
plt.imshow(
|
24 |
-
#pyplot.gcf()
|
25 |
ax = plt.gca()
|
26 |
-
|
27 |
-
for
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
|
34 |
-
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
|
35 |
-
|
36 |
-
plt.axis("off")
|
37 |
|
38 |
return plt.gcf()
|
39 |
|
@@ -46,9 +41,14 @@ def detect(image):
|
|
46 |
with torch.no_grad():
|
47 |
outputs = model(**encoding)
|
48 |
|
49 |
-
print(outputs)
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
output_figure =
|
52 |
|
53 |
buf = io.BytesIO()
|
54 |
output_figure.savefig(buf, bbox_inches='tight')
|
|
|
18 |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
19 |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
20 |
|
21 |
+
def get_output_figure(pil_img, scores, labels, boxes):
|
22 |
plt.figure(figsize=(16, 10))
|
23 |
+
plt.imshow(pil_img)
|
|
|
24 |
ax = plt.gca()
|
25 |
+
colors = COLORS * 100
|
26 |
+
for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
|
27 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
|
28 |
+
text = f'{model.config.id2label[label]}: {score:0.2f}'
|
29 |
+
ax.text(xmin, ymin, text, fontsize=15,
|
30 |
+
bbox=dict(facecolor='yellow', alpha=0.5))
|
31 |
+
plt.axis('off')
|
|
|
|
|
|
|
|
|
32 |
|
33 |
return plt.gcf()
|
34 |
|
|
|
41 |
with torch.no_grad():
|
42 |
outputs = model(**encoding)
|
43 |
|
44 |
+
#print(outputs)
|
45 |
+
width, height = image.size
|
46 |
+
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
|
47 |
+
results = postprocessed_outputs[0]
|
48 |
+
|
49 |
+
print(results)
|
50 |
|
51 |
+
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
|
52 |
|
53 |
buf = io.BytesIO()
|
54 |
output_figure.savefig(buf, bbox_inches='tight')
|