addit / addit_attention_store.py
YoadTew's picture
Add application file
504c7e8
raw
history blame
13.8 kB
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.
import torch
from skimage import filters
import cv2
import torch.nn.functional as F
from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu
import numpy as np
from visualization_utils import show_tensors
import matplotlib.pyplot as plt
def text_to_tokens(text, tokenizer):
return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]]
def flatten_list(l):
return [item for sublist in l for item in sublist]
def gaussian_blur(heatmap, kernel_size=7, sigma=0):
# Shape of heatmap: (H, W)
heatmap = heatmap.cpu().numpy()
heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma)
heatmap = torch.tensor(heatmap)
return heatmap
def min_max_norm(x):
return (x - x.min()) / (x.max() - x.min())
class AttentionStore:
def __init__(self, prompts, tokenizer,
subject_token=None, record_attention_steps=[],
is_cache_attn_ratio=False, attn_ratios_steps=[5]):
self.text2image_store = {}
self.image2text_store = {}
self.count_per_layer = {}
self.record_attention_steps = record_attention_steps
self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"]
self.attention_ratios = {}
self._is_cache_attn_ratio = is_cache_attn_ratio
self.attn_ratios_steps = attn_ratios_steps
self.ratio_source = 'text'
self.max_tokens_to_record = 10
if isinstance(prompts, str):
prompts = [prompts]
batch_size = 1
else:
batch_size = len(prompts)
tokens_per_prompt = []
for prompt in prompts:
tokens = text_to_tokens(prompt, tokenizer)
tokens_per_prompt.append(tokens)
self.tokens_to_record = []
self.token_idxs_to_record = []
if len(record_attention_steps) > 0:
self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]])
self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens]
self.add_token_idx = self.subject_tokens_idx[-1]
def is_record_attention(self, layer_name, step_index):
is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers)
record_attention = (step_index in self.record_attention_steps) and (is_correct_layer)
return record_attention
def store_attention(self, attention_probs, layer_name, batch_size, num_heads):
text_len = 512
timesteps = len(self.record_attention_steps)
# Split batch and heads
attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:])
# Mean over the heads
attention_probs = attention_probs.mean(dim=1)
# Attention: text -> image
attention_probs_text2image = attention_probs[:, :text_len, text_len:]
attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]]
# Attention: image -> text
attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2)
attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]]
if layer_name not in self.text2image_store:
self.text2image_store[layer_name] = [x for x in attention_probs_text2image]
self.image2text_store[layer_name] = [x for x in attention_probs_image2text]
else:
self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)]
self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)]
def is_cache_attn_ratio(self, step_index):
return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps)
def store_attention_ratios(self, attention_probs, step_index, layer_name):
layer_prefix = layer_name.split(".")[0]
if self.ratio_source == 'pixels':
extended_attention_probs = attention_probs.mean(dim=0)[512:, :]
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu()
extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu()
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu()
token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu()
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1)
elif self.ratio_source == 'text':
extended_attention_probs = attention_probs.mean(dim=0)[:512, :]
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu()
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu()
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1)
if step_index not in self.attention_ratios:
self.attention_ratios[step_index] = {}
if layer_prefix not in self.attention_ratios[step_index]:
self.attention_ratios[step_index][layer_prefix] = []
self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios)
def get_attention_ratios(self, step_indices=None, display_imgs=False):
ratios = []
if step_indices is None:
step_indices = list(self.attention_ratios.keys())
if len(step_indices) == 1:
steps = f"Step: {step_indices[0]}"
else:
steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]"
layer_prefixes = list(self.attention_ratios[step_indices[0]].keys())
scores_per_layer = {}
for layer_prefix in layer_prefixes:
ratios = []
for step_index in step_indices:
if layer_prefix in self.attention_ratios[step_index]:
step_ratios = self.attention_ratios[step_index][layer_prefix]
step_ratios = torch.stack(step_ratios).mean(dim=0)
ratios.append(step_ratios)
# Mean over the steps
ratios = torch.stack(ratios).mean(dim=0)
if self.ratio_source == 'pixels':
source, text, target, token = torch.split(ratios, 64, dim=1)
title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}"
ratios = min_max_norm(torch.cat([source, text, target], dim=1))
token = min_max_norm(token)
ratios = torch.cat([ratios, token], dim=1)
elif self.ratio_source == 'text':
source, target = torch.split(ratios, 64, dim=1)
source_sum = source.sum().item()
target_sum = target.sum().item()
text_sum = 512 - (source_sum + target_sum)
title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}"
ratios = min_max_norm(torch.cat([source, target], dim=1))
if display_imgs:
print(f"Layer: {layer_prefix}")
show_tensors([ratios], [title])
scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum)
return scores_per_layer
def plot_attention_ratios(self, step_indices=None):
steps = list(self.attention_ratios.keys())
score_per_layer = {
'transformer_blocks': {},
'single_transformer_blocks': {}
}
for i in steps:
scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False)
for layer in self.attention_ratios[i]:
source, text, target = scores_per_layer[layer]
score_per_layer[layer][i] = (source, text, target)
for layer_type in score_per_layer:
x = list(score_per_layer[layer_type].keys())
source_sums = [x[0] for x in score_per_layer[layer_type].values()]
text_sums = [x[1] for x in score_per_layer[layer_type].values()]
target_sums = [x[2] for x in score_per_layer[layer_type].values()]
# Calculate the total sums for each stack (source + text + target)
total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))]
# Create stacked bar plots
fig, ax = plt.subplots(figsize=(10, 6))
indices = np.arange(len(x))
# Plot source at the bottom
ax.bar(indices, source_sums, label='Source', color='#6A2C70')
# Plot text stacked on source
ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums)
# Plot target stacked on text + source
target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))]
ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom)
# Annotate bars with percentage values
for j, index in enumerate(indices):
font_size = 12
# Source percentage
source_percentage = 100 * source_sums[j] / total_sums[j]
ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%',
ha='center', va='center', rotation=90, color='white',
fontsize=font_size, fontweight='bold')
# Text percentage
text_percentage = 100 * text_sums[j] / total_sums[j]
ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%',
ha='center', va='center', rotation=90, color='white',
fontsize=font_size, fontweight='bold')
# Target percentage
target_percentage = 100 * target_sums[j] / total_sums[j]
ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%',
ha='center', va='center', rotation=90, color='white',
fontsize=font_size, fontweight='bold')
ax.set_xlabel('Step Index')
ax.set_ylabel('Attention Ratio')
ax.set_title(f'Attention Ratios for {layer_type}')
ax.set_xticks(indices)
ax.set_xticklabels(x)
plt.legend()
plt.show()
def aggregate_attention(self, store, target_layers=None, resolution=None,
gaussian_kernel=3, thr_type='otsu', thr_number=0.5):
if target_layers is None:
store_vals = list(store.values())
elif isinstance(target_layers, list):
store_vals = [store[x] for x in target_layers]
else:
raise ValueError("target_layers must be a list of layer names or None.")
# store vals = List[layers] of Tensor[batch_size, text_tokens, image_tokens]
batch_size = len(store_vals[0])
attention_maps = []
attention_masks = []
for i in range(batch_size):
# Average over the layers
agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0)
if resolution is None:
size = int(agg_vals.shape[-1] ** 0.5)
resolution = (size, size)
agg_vals = agg_vals.view(agg_vals.shape[0], *resolution)
if gaussian_kernel > 0:
agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype)
mask_vals = agg_vals.clone()
for j in range(mask_vals.shape[0]):
mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min())
np_vals = mask_vals[j].float().cpu().numpy()
otsu_thr = filters.threshold_otsu(np_vals)
li_thr = threshold_li(np_vals, initial_guess=otsu_thr)
yen_thr = threshold_yen(np_vals)
if thr_type == 'otsu':
thr = otsu_thr
elif thr_type == 'yen':
thr = yen_thr
elif thr_type == 'li':
thr = li_thr
elif thr_type == 'number':
thr = thr_number
elif thr_type == 'multiotsu':
thrs = threshold_multiotsu(np_vals, classes=3)
if thrs[1] > thrs[0] * 3.5:
thr = thrs[1]
else:
thr = thrs[0]
# Take the closest threshold to otsu_thr
# thr = thrs[np.argmin(np.abs(thrs - otsu_thr))]
# alpha = 0.8
# thr = (alpha * thr + (1-alpha) * mask_vals[j].max())
mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype)
attention_maps.append(agg_vals)
attention_masks.append(mask_vals)
return attention_maps, attention_masks, self.tokens_to_record