|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from skimage import filters |
|
import cv2 |
|
import torch.nn.functional as F |
|
from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu |
|
import numpy as np |
|
from visualization_utils import show_tensors |
|
import matplotlib.pyplot as plt |
|
|
|
def text_to_tokens(text, tokenizer): |
|
return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]] |
|
|
|
def flatten_list(l): |
|
return [item for sublist in l for item in sublist] |
|
|
|
def gaussian_blur(heatmap, kernel_size=7, sigma=0): |
|
|
|
heatmap = heatmap.cpu().numpy() |
|
heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma) |
|
heatmap = torch.tensor(heatmap) |
|
|
|
return heatmap |
|
|
|
def min_max_norm(x): |
|
return (x - x.min()) / (x.max() - x.min()) |
|
|
|
class AttentionStore: |
|
def __init__(self, prompts, tokenizer, |
|
subject_token=None, record_attention_steps=[], |
|
is_cache_attn_ratio=False, attn_ratios_steps=[5]): |
|
|
|
self.text2image_store = {} |
|
self.image2text_store = {} |
|
self.count_per_layer = {} |
|
|
|
self.record_attention_steps = record_attention_steps |
|
self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"] |
|
|
|
self.attention_ratios = {} |
|
self._is_cache_attn_ratio = is_cache_attn_ratio |
|
self.attn_ratios_steps = attn_ratios_steps |
|
self.ratio_source = 'text' |
|
|
|
self.max_tokens_to_record = 10 |
|
|
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
batch_size = 1 |
|
else: |
|
batch_size = len(prompts) |
|
|
|
tokens_per_prompt = [] |
|
|
|
for prompt in prompts: |
|
tokens = text_to_tokens(prompt, tokenizer) |
|
tokens_per_prompt.append(tokens) |
|
|
|
self.tokens_to_record = [] |
|
self.token_idxs_to_record = [] |
|
|
|
if len(record_attention_steps) > 0: |
|
self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]]) |
|
self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens] |
|
self.add_token_idx = self.subject_tokens_idx[-1] |
|
|
|
def is_record_attention(self, layer_name, step_index): |
|
is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers) |
|
|
|
record_attention = (step_index in self.record_attention_steps) and (is_correct_layer) |
|
|
|
return record_attention |
|
|
|
def store_attention(self, attention_probs, layer_name, batch_size, num_heads): |
|
text_len = 512 |
|
timesteps = len(self.record_attention_steps) |
|
|
|
|
|
attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:]) |
|
|
|
|
|
attention_probs = attention_probs.mean(dim=1) |
|
|
|
|
|
attention_probs_text2image = attention_probs[:, :text_len, text_len:] |
|
attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]] |
|
|
|
|
|
attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2) |
|
attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]] |
|
|
|
if layer_name not in self.text2image_store: |
|
self.text2image_store[layer_name] = [x for x in attention_probs_text2image] |
|
self.image2text_store[layer_name] = [x for x in attention_probs_image2text] |
|
else: |
|
self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)] |
|
self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)] |
|
|
|
def is_cache_attn_ratio(self, step_index): |
|
return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps) |
|
|
|
def store_attention_ratios(self, attention_probs, step_index, layer_name): |
|
layer_prefix = layer_name.split(".")[0] |
|
|
|
if self.ratio_source == 'pixels': |
|
extended_attention_probs = attention_probs.mean(dim=0)[512:, :] |
|
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu() |
|
extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu() |
|
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu() |
|
token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu() |
|
|
|
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1) |
|
elif self.ratio_source == 'text': |
|
extended_attention_probs = attention_probs.mean(dim=0)[:512, :] |
|
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu() |
|
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu() |
|
|
|
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1) |
|
|
|
if step_index not in self.attention_ratios: |
|
self.attention_ratios[step_index] = {} |
|
|
|
if layer_prefix not in self.attention_ratios[step_index]: |
|
self.attention_ratios[step_index][layer_prefix] = [] |
|
|
|
self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios) |
|
|
|
def get_attention_ratios(self, step_indices=None, display_imgs=False): |
|
ratios = [] |
|
|
|
if step_indices is None: |
|
step_indices = list(self.attention_ratios.keys()) |
|
|
|
if len(step_indices) == 1: |
|
steps = f"Step: {step_indices[0]}" |
|
else: |
|
steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]" |
|
|
|
layer_prefixes = list(self.attention_ratios[step_indices[0]].keys()) |
|
scores_per_layer = {} |
|
|
|
for layer_prefix in layer_prefixes: |
|
ratios = [] |
|
|
|
for step_index in step_indices: |
|
if layer_prefix in self.attention_ratios[step_index]: |
|
step_ratios = self.attention_ratios[step_index][layer_prefix] |
|
step_ratios = torch.stack(step_ratios).mean(dim=0) |
|
ratios.append(step_ratios) |
|
|
|
|
|
ratios = torch.stack(ratios).mean(dim=0) |
|
|
|
if self.ratio_source == 'pixels': |
|
source, text, target, token = torch.split(ratios, 64, dim=1) |
|
title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}" |
|
ratios = min_max_norm(torch.cat([source, text, target], dim=1)) |
|
token = min_max_norm(token) |
|
ratios = torch.cat([ratios, token], dim=1) |
|
elif self.ratio_source == 'text': |
|
source, target = torch.split(ratios, 64, dim=1) |
|
source_sum = source.sum().item() |
|
target_sum = target.sum().item() |
|
text_sum = 512 - (source_sum + target_sum) |
|
|
|
title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}" |
|
ratios = min_max_norm(torch.cat([source, target], dim=1)) |
|
|
|
if display_imgs: |
|
print(f"Layer: {layer_prefix}") |
|
show_tensors([ratios], [title]) |
|
|
|
scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum) |
|
|
|
return scores_per_layer |
|
|
|
def plot_attention_ratios(self, step_indices=None): |
|
steps = list(self.attention_ratios.keys()) |
|
score_per_layer = { |
|
'transformer_blocks': {}, |
|
'single_transformer_blocks': {} |
|
} |
|
|
|
for i in steps: |
|
scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False) |
|
|
|
for layer in self.attention_ratios[i]: |
|
source, text, target = scores_per_layer[layer] |
|
score_per_layer[layer][i] = (source, text, target) |
|
|
|
for layer_type in score_per_layer: |
|
x = list(score_per_layer[layer_type].keys()) |
|
source_sums = [x[0] for x in score_per_layer[layer_type].values()] |
|
text_sums = [x[1] for x in score_per_layer[layer_type].values()] |
|
target_sums = [x[2] for x in score_per_layer[layer_type].values()] |
|
|
|
|
|
total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
indices = np.arange(len(x)) |
|
|
|
|
|
ax.bar(indices, source_sums, label='Source', color='#6A2C70') |
|
|
|
|
|
ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums) |
|
|
|
|
|
target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))] |
|
ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom) |
|
|
|
|
|
for j, index in enumerate(indices): |
|
|
|
font_size = 12 |
|
|
|
|
|
source_percentage = 100 * source_sums[j] / total_sums[j] |
|
ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%', |
|
ha='center', va='center', rotation=90, color='white', |
|
fontsize=font_size, fontweight='bold') |
|
|
|
|
|
text_percentage = 100 * text_sums[j] / total_sums[j] |
|
ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%', |
|
ha='center', va='center', rotation=90, color='white', |
|
fontsize=font_size, fontweight='bold') |
|
|
|
|
|
target_percentage = 100 * target_sums[j] / total_sums[j] |
|
ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%', |
|
ha='center', va='center', rotation=90, color='white', |
|
fontsize=font_size, fontweight='bold') |
|
|
|
|
|
ax.set_xlabel('Step Index') |
|
ax.set_ylabel('Attention Ratio') |
|
ax.set_title(f'Attention Ratios for {layer_type}') |
|
ax.set_xticks(indices) |
|
ax.set_xticklabels(x) |
|
|
|
plt.legend() |
|
plt.show() |
|
|
|
def aggregate_attention(self, store, target_layers=None, resolution=None, |
|
gaussian_kernel=3, thr_type='otsu', thr_number=0.5): |
|
if target_layers is None: |
|
store_vals = list(store.values()) |
|
elif isinstance(target_layers, list): |
|
store_vals = [store[x] for x in target_layers] |
|
else: |
|
raise ValueError("target_layers must be a list of layer names or None.") |
|
|
|
|
|
batch_size = len(store_vals[0]) |
|
|
|
attention_maps = [] |
|
attention_masks = [] |
|
|
|
for i in range(batch_size): |
|
|
|
agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0) |
|
|
|
if resolution is None: |
|
size = int(agg_vals.shape[-1] ** 0.5) |
|
resolution = (size, size) |
|
|
|
agg_vals = agg_vals.view(agg_vals.shape[0], *resolution) |
|
|
|
if gaussian_kernel > 0: |
|
agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype) |
|
|
|
mask_vals = agg_vals.clone() |
|
|
|
for j in range(mask_vals.shape[0]): |
|
mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min()) |
|
np_vals = mask_vals[j].float().cpu().numpy() |
|
|
|
otsu_thr = filters.threshold_otsu(np_vals) |
|
li_thr = threshold_li(np_vals, initial_guess=otsu_thr) |
|
yen_thr = threshold_yen(np_vals) |
|
|
|
if thr_type == 'otsu': |
|
thr = otsu_thr |
|
elif thr_type == 'yen': |
|
thr = yen_thr |
|
elif thr_type == 'li': |
|
thr = li_thr |
|
elif thr_type == 'number': |
|
thr = thr_number |
|
elif thr_type == 'multiotsu': |
|
thrs = threshold_multiotsu(np_vals, classes=3) |
|
|
|
if thrs[1] > thrs[0] * 3.5: |
|
thr = thrs[1] |
|
else: |
|
thr = thrs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype) |
|
|
|
attention_maps.append(agg_vals) |
|
attention_masks.append(mask_vals) |
|
|
|
return attention_maps, attention_masks, self.tokens_to_record |
|
|