sergiopaniego HF Staff commited on
Commit
047a580
·
1 Parent(s): 9aacea4

Updated app

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -23,7 +23,7 @@ def get_output_figure(pil_img, scores, labels, boxes):
23
  plt.imshow(pil_img)
24
  ax = plt.gca()
25
  colors = COLORS * 100
26
- for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
27
  ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
28
  text = f'{model.config.id2label[label]}: {score:0.2f}'
29
  ax.text(xmin, ymin, text, fontsize=15,
@@ -89,12 +89,10 @@ def detect(image):
89
  with torch.no_grad():
90
  outputs = model(**encoding)
91
 
92
- #print(outputs)
93
  width, height = image.size
94
  postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
95
  results = postprocessed_outputs[0]
96
 
97
- #print(results)
98
 
99
  output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
100
 
@@ -110,15 +108,24 @@ def detect(image):
110
  buf.seek(0)
111
  output_pil_img_attn = Image.open(buf)
112
 
113
- #print(output_pil_img)
114
-
115
  return output_pil_img, output_pil_img_attn
116
 
117
- demo = gr.Interface(
118
- fn=detect,
119
- inputs=gr.Image(label="Input image", type="pil"),
120
- outputs=[
121
- gr.Image(label="Output image predictions", type="pil"),
122
- gr.Image(label="Output attention weights", type="pil")
123
- ])
124
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
23
  plt.imshow(pil_img)
24
  ax = plt.gca()
25
  colors = COLORS * 100
26
+ for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
27
  ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
28
  text = f'{model.config.id2label[label]}: {score:0.2f}'
29
  ax.text(xmin, ymin, text, fontsize=15,
 
89
  with torch.no_grad():
90
  outputs = model(**encoding)
91
 
 
92
  width, height = image.size
93
  postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
94
  results = postprocessed_outputs[0]
95
 
 
96
 
97
  output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
98
 
 
108
  buf.seek(0)
109
  output_pil_img_attn = Image.open(buf)
110
 
 
 
111
  return output_pil_img, output_pil_img_attn
112
 
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# Object detection with DETR")
115
+ gr.Markdown(
116
+ """
117
+ This applciation uses DETR (DEtection TRansformers) to detect objects on images.
118
+ You can load an image and see the predictions for the objects detected along with the attention weights.
119
+ """
120
+ )
121
+
122
+ gr.Interface(
123
+ fn=detect,
124
+ inputs=gr.Image(label="Input image", type="pil"),
125
+ outputs=[
126
+ gr.Image(label="Output prediction", type="pil"),
127
+ gr.Image(label="Attention weights", type="pil")
128
+ ]
129
+ ).launch()
130
+
131
+ demo.launch()