sergiopaniego HF Staff commited on
Commit
153ddd7
·
1 Parent(s): bfd905c

Added attention

Browse files
Files changed (1) hide show
  1. app.py +58 -3
app.py CHANGED
@@ -32,6 +32,54 @@ def get_output_figure(pil_img, scores, labels, boxes):
32
 
33
  return plt.gcf()
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  @spaces.GPU
37
  def detect(image):
@@ -46,7 +94,7 @@ def detect(image):
46
  postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
47
  results = postprocessed_outputs[0]
48
 
49
- print(results)
50
 
51
  output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
52
 
@@ -55,9 +103,16 @@ def detect(image):
55
  buf.seek(0)
56
  output_pil_img = Image.open(buf)
57
 
58
- print(output_pil_img)
 
 
 
 
 
 
 
59
 
60
- return output_pil_img, output_pil_img
61
 
62
  demo = gr.Interface(
63
  fn=detect,
 
32
 
33
  return plt.gcf()
34
 
35
+ def get_output_attn_figure(image, encoding, results):
36
+ # keep only predictions of queries with +0.9 condifence (excluding no-object class)
37
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
38
+ keep = probas.max(-1).values > 0.9
39
+
40
+ bboxes_scaled = results['boxes']
41
+ # use lists to store the outputs vis up-values
42
+ conv_features = []
43
+
44
+ hooks = [
45
+ model.model.backbone.conv_encoder.register_forward_hook(
46
+ lambda self, input, output: conv_features.append(output)
47
+ )
48
+ ]
49
+
50
+ # propagate through the model
51
+ outputs = model(**encoding, output_attentions=True)
52
+
53
+ for hook in hooks:
54
+ hook.remove()
55
+
56
+ # don't need the list anymore
57
+ conv_features = conv_features[0]
58
+ # get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height)
59
+ dec_attn_weights = outputs.cross_attentions[-1]
60
+ #average them over the 8 heads and detach from graph
61
+ dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach()
62
+
63
+ # get the feature map shape
64
+ h, w = conv_features[-1][0].shape[-2:]
65
+
66
+ fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
67
+ colors = COLORS * 100
68
+ for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled):
69
+ xmin, ymin, xmax, ymax = box.detach().numpy()
70
+ ax = ax_i[0]
71
+ ax.imshow(dec_attn_weights[0, idx].view(h, w))
72
+ ax.axis('off')
73
+ ax.set_title(f'query id: {idx.item()}')
74
+ ax = ax_i[1]
75
+ ax.imshow(image)
76
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False,
77
+ color='blue', linewidth=3))
78
+ ax.axis('off')
79
+ ax.set_title(model.config.id2label[probas[idx].argmax().item()])
80
+ fig.tight_layout()
81
+ return plt.gcf()
82
+
83
 
84
  @spaces.GPU
85
  def detect(image):
 
94
  postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
95
  results = postprocessed_outputs[0]
96
 
97
+ #print(results)
98
 
99
  output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
100
 
 
103
  buf.seek(0)
104
  output_pil_img = Image.open(buf)
105
 
106
+ output_figure_attn = get_output_attn_figure(image, encoding, results)
107
+
108
+ buf = io.BytesIO()
109
+ output_figure_attn.savefig(buf, bbox_inches='tight')
110
+ buf.seek(0)
111
+ output_pil_img_attn = Image.open(buf)
112
+
113
+ #print(output_pil_img)
114
 
115
+ return output_pil_img, output_pil_img_attn
116
 
117
  demo = gr.Interface(
118
  fn=detect,