sergiopaniego HF Staff commited on
Commit
1a933f0
·
1 Parent(s): 19a9827

Updated app

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -18,22 +18,17 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
18
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
 
21
- def get_figure(in_pil_img, in_results):
22
  plt.figure(figsize=(16, 10))
23
- plt.imshow(in_pil_img)
24
- #pyplot.gcf()
25
  ax = plt.gca()
26
-
27
- for prediction in in_results:
28
- selected_color = choice(COLORS)
29
-
30
- x, y = prediction['box']['xmin'], prediction['box']['ymin'],
31
- w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
32
-
33
- ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
34
- ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
35
-
36
- plt.axis("off")
37
 
38
  return plt.gcf()
39
 
@@ -46,9 +41,14 @@ def detect(image):
46
  with torch.no_grad():
47
  outputs = model(**encoding)
48
 
49
- print(outputs)
 
 
 
 
 
50
 
51
- output_figure = get_figure(image, outputs)
52
 
53
  buf = io.BytesIO()
54
  output_figure.savefig(buf, bbox_inches='tight')
 
18
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
 
21
+ def get_output_figure(pil_img, scores, labels, boxes):
22
  plt.figure(figsize=(16, 10))
23
+ plt.imshow(pil_img)
 
24
  ax = plt.gca()
25
+ colors = COLORS * 100
26
+ for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
27
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
28
+ text = f'{model.config.id2label[label]}: {score:0.2f}'
29
+ ax.text(xmin, ymin, text, fontsize=15,
30
+ bbox=dict(facecolor='yellow', alpha=0.5))
31
+ plt.axis('off')
 
 
 
 
32
 
33
  return plt.gcf()
34
 
 
41
  with torch.no_grad():
42
  outputs = model(**encoding)
43
 
44
+ #print(outputs)
45
+ width, height = image.size
46
+ postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
47
+ results = postprocessed_outputs[0]
48
+
49
+ print(results)
50
 
51
+ output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
52
 
53
  buf = io.BytesIO()
54
  output_figure.savefig(buf, bbox_inches='tight')