Spaces:
Sleeping
Sleeping
File size: 4,459 Bytes
d0f5a61 4317393 0b822c2 4317393 f0585ee 0b822c2 1a933f0 0b822c2 1a933f0 0b822c2 1a933f0 047a580 1a933f0 0b822c2 4317393 487ba13 153ddd7 d0f5a61 f0585ee 4317393 0b822c2 1a933f0 19a9827 1a933f0 0b822c2 487ba13 153ddd7 bfd905c 047a580 e2f5aeb 047a580 f34d06d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import spaces
import torch
from PIL import Image
import requests
from transformers import DetrImageProcessor
from transformers import DetrForObjectDetection
from random import choice
import matplotlib.pyplot as plt
import io
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def get_output_figure(pil_img, scores, labels, boxes):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
text = f'{model.config.id2label[label]}: {score:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
return plt.gcf()
def get_output_attn_figure(image, encoding, results, outputs):
# keep only predictions of queries with +0.9 condifence (excluding no-object class)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
bboxes_scaled = results['boxes']
# use lists to store the outputs vis up-values
conv_features = []
hooks = [
model.model.backbone.conv_encoder.register_forward_hook(
lambda self, input, output: conv_features.append(output)
)
]
# propagate through the model
outputs = model(**encoding, output_attentions=True)
for hook in hooks:
hook.remove()
# don't need the list anymore
conv_features = conv_features[0]
# get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height)
dec_attn_weights = outputs.cross_attentions[-1]
#average them over the 8 heads and detach from graph
dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach()
# get the feature map shape
h, w = conv_features[-1][0].shape[-2:]
fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
colors = COLORS * 100
for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled):
xmin, ymin, xmax, ymax = box.detach().numpy()
ax = ax_i[0]
ax.imshow(dec_attn_weights[0, idx].view(h, w))
ax.axis('off')
ax.set_title(f'query id: {idx.item()}')
ax = ax_i[1]
ax.imshow(image)
ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False,
color='blue', linewidth=3))
ax.axis('off')
ax.set_title(model.config.id2label[probas[idx].argmax().item()])
fig.tight_layout()
return plt.gcf()
@spaces.GPU
def detect(image):
encoding = processor(image, return_tensors='pt')
print(encoding.keys())
with torch.no_grad():
outputs = model(**encoding)
width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
results = postprocessed_outputs[0]
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
buf = io.BytesIO()
output_figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
output_figure_attn = get_output_attn_figure(image, encoding, results, outputs)
buf = io.BytesIO()
output_figure_attn.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img_attn = Image.open(buf)
return output_pil_img, output_pil_img_attn
with gr.Blocks() as demo:
gr.Markdown("# Object detection with DETR")
gr.Markdown(
"""
This applciation uses DETR (DEtection TRansformers) to detect objects on images.
You can load an image and see the predictions for the objects detected along with the attention weights.
"""
)
gr.Interface(
fn=detect,
inputs=gr.Image(label="Input image", type="pil"),
outputs=[
gr.Image(label="Output prediction", type="pil"),
gr.Image(label="Attention weights", type="pil")
]
)#.launch()
demo.launch(show_error=True)
|