Spaces:
Sleeping
Sleeping
Commit
·
153ddd7
1
Parent(s):
bfd905c
Added attention
Browse files
app.py
CHANGED
@@ -32,6 +32,54 @@ def get_output_figure(pil_img, scores, labels, boxes):
|
|
32 |
|
33 |
return plt.gcf()
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
@spaces.GPU
|
37 |
def detect(image):
|
@@ -46,7 +94,7 @@ def detect(image):
|
|
46 |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
|
47 |
results = postprocessed_outputs[0]
|
48 |
|
49 |
-
print(results)
|
50 |
|
51 |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
|
52 |
|
@@ -55,9 +103,16 @@ def detect(image):
|
|
55 |
buf.seek(0)
|
56 |
output_pil_img = Image.open(buf)
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
return output_pil_img,
|
61 |
|
62 |
demo = gr.Interface(
|
63 |
fn=detect,
|
|
|
32 |
|
33 |
return plt.gcf()
|
34 |
|
35 |
+
def get_output_attn_figure(image, encoding, results):
|
36 |
+
# keep only predictions of queries with +0.9 condifence (excluding no-object class)
|
37 |
+
probas = outputs.logits.softmax(-1)[0, :, :-1]
|
38 |
+
keep = probas.max(-1).values > 0.9
|
39 |
+
|
40 |
+
bboxes_scaled = results['boxes']
|
41 |
+
# use lists to store the outputs vis up-values
|
42 |
+
conv_features = []
|
43 |
+
|
44 |
+
hooks = [
|
45 |
+
model.model.backbone.conv_encoder.register_forward_hook(
|
46 |
+
lambda self, input, output: conv_features.append(output)
|
47 |
+
)
|
48 |
+
]
|
49 |
+
|
50 |
+
# propagate through the model
|
51 |
+
outputs = model(**encoding, output_attentions=True)
|
52 |
+
|
53 |
+
for hook in hooks:
|
54 |
+
hook.remove()
|
55 |
+
|
56 |
+
# don't need the list anymore
|
57 |
+
conv_features = conv_features[0]
|
58 |
+
# get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height)
|
59 |
+
dec_attn_weights = outputs.cross_attentions[-1]
|
60 |
+
#average them over the 8 heads and detach from graph
|
61 |
+
dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach()
|
62 |
+
|
63 |
+
# get the feature map shape
|
64 |
+
h, w = conv_features[-1][0].shape[-2:]
|
65 |
+
|
66 |
+
fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
|
67 |
+
colors = COLORS * 100
|
68 |
+
for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled):
|
69 |
+
xmin, ymin, xmax, ymax = box.detach().numpy()
|
70 |
+
ax = ax_i[0]
|
71 |
+
ax.imshow(dec_attn_weights[0, idx].view(h, w))
|
72 |
+
ax.axis('off')
|
73 |
+
ax.set_title(f'query id: {idx.item()}')
|
74 |
+
ax = ax_i[1]
|
75 |
+
ax.imshow(image)
|
76 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False,
|
77 |
+
color='blue', linewidth=3))
|
78 |
+
ax.axis('off')
|
79 |
+
ax.set_title(model.config.id2label[probas[idx].argmax().item()])
|
80 |
+
fig.tight_layout()
|
81 |
+
return plt.gcf()
|
82 |
+
|
83 |
|
84 |
@spaces.GPU
|
85 |
def detect(image):
|
|
|
94 |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
|
95 |
results = postprocessed_outputs[0]
|
96 |
|
97 |
+
#print(results)
|
98 |
|
99 |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
|
100 |
|
|
|
103 |
buf.seek(0)
|
104 |
output_pil_img = Image.open(buf)
|
105 |
|
106 |
+
output_figure_attn = get_output_attn_figure(image, encoding, results)
|
107 |
+
|
108 |
+
buf = io.BytesIO()
|
109 |
+
output_figure_attn.savefig(buf, bbox_inches='tight')
|
110 |
+
buf.seek(0)
|
111 |
+
output_pil_img_attn = Image.open(buf)
|
112 |
+
|
113 |
+
#print(output_pil_img)
|
114 |
|
115 |
+
return output_pil_img, output_pil_img_attn
|
116 |
|
117 |
demo = gr.Interface(
|
118 |
fn=detect,
|