sayakpaul's picture
sayakpaul HF Staff
Update app.py
2a6b8e9
raw
history blame
3.64 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F
cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
transform = timm.data.create_transform(
**timm.data.resolve_data_config(cait_model.pretrained_cfg)
)
patch_size = 16
def create_attn_extractor(model, block_id=0):
"""Creates a model that produces the softmax attention scores.
References:
https://github.com/huggingface/pytorch-image-models/discussions/926
"""
feature_extractor = create_feature_extractor(
cait_model,
return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
tracer_kwargs={"leaf_modules": [PatchEmbed]},
)
return feature_extractor
def get_cls_attention_map(
image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
"""Prepares attention maps so that they can be visualized."""
w_featmap = image.shape[3] // patch_size
h_featmap = image.shape[2] // patch_size
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
print(attentions.shape)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
print(attentions.shape)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = F.resize(
attentions,
size=(h_featmap * patch_size, w_featmap * patch_size),
interpolation=3,
)
print(attentions.shape)
return attentions
def generate_plot(processed_map):
"""Generates a class attention map plot."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(processed_map.shape[0]):
if img_count < processed_map.shape[0]:
axes[i].imshow(processed_map[img_count].numpy())
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def generate_class_attn_map(image, block_id=0):
"""Collates the above utilities together for generating
a class attention map."""
image_tensor = transform(image).unsqueeze(0)
feature_extractor = create_attn_extractor(cait_model, block_id)
with torch.no_grad():
out = feature_extractor(image_tensor)
block_key = f"blocks_token_only.{block_id}.attn.softmax"
processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
return generate_plot(processed_cls_attn_map)
title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."
iface = gr.Interface(
generate_class_attn_map,
inputs=[
gr.inputs.Image(type="pil", label="Input Image"),
gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
],
outputs=[gr.Plot(type="auto").style()],
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[["./bird.png", 0]],
)
iface.launch()