File size: 13,827 Bytes
504c7e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright (C) 2025 NVIDIA Corporation.  All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.

import torch
from skimage import filters
import cv2
import torch.nn.functional as F
from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu
import numpy as np
from visualization_utils import show_tensors
import matplotlib.pyplot as plt

def text_to_tokens(text, tokenizer):
    return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]]

def flatten_list(l):
    return [item for sublist in l for item in sublist]

def gaussian_blur(heatmap, kernel_size=7, sigma=0):
    # Shape of heatmap: (H, W)
    heatmap = heatmap.cpu().numpy()
    heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma)
    heatmap = torch.tensor(heatmap)
    
    return heatmap

def min_max_norm(x):
    return (x - x.min()) / (x.max() - x.min())

class AttentionStore:
    def __init__(self, prompts, tokenizer, 
                 subject_token=None, record_attention_steps=[],
                 is_cache_attn_ratio=False, attn_ratios_steps=[5]):
        
        self.text2image_store = {}
        self.image2text_store = {}
        self.count_per_layer = {}

        self.record_attention_steps = record_attention_steps
        self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"]

        self.attention_ratios = {}
        self._is_cache_attn_ratio = is_cache_attn_ratio
        self.attn_ratios_steps = attn_ratios_steps
        self.ratio_source = 'text'

        self.max_tokens_to_record = 10

        if isinstance(prompts, str):
            prompts = [prompts]
            batch_size = 1
        else:
            batch_size = len(prompts)

        tokens_per_prompt = []

        for prompt in prompts:
            tokens = text_to_tokens(prompt, tokenizer)
            tokens_per_prompt.append(tokens)

        self.tokens_to_record = []
        self.token_idxs_to_record = []

        if len(record_attention_steps) > 0:
            self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]])
            self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens]
            self.add_token_idx = self.subject_tokens_idx[-1]

    def is_record_attention(self, layer_name, step_index):
        is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers)

        record_attention =  (step_index in self.record_attention_steps) and (is_correct_layer)

        return record_attention

    def store_attention(self, attention_probs, layer_name, batch_size, num_heads):
        text_len = 512
        timesteps = len(self.record_attention_steps)
        
        # Split batch and heads
        attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:])

        # Mean over the heads
        attention_probs = attention_probs.mean(dim=1)

        # Attention: text -> image
        attention_probs_text2image = attention_probs[:, :text_len, text_len:]
        attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]]

        # Attention: image -> text
        attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2)
        attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]]

        if layer_name not in self.text2image_store:
            self.text2image_store[layer_name] = [x for x in attention_probs_text2image]
            self.image2text_store[layer_name] = [x for x in attention_probs_image2text]
        else:
            self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)]
            self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)]
    
    def is_cache_attn_ratio(self, step_index):
        return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps)
    
    def store_attention_ratios(self, attention_probs, step_index, layer_name):
        layer_prefix = layer_name.split(".")[0]
        
        if self.ratio_source == 'pixels':
            extended_attention_probs = attention_probs.mean(dim=0)[512:, :]
            extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu()
            extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu()
            extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu()
            token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu()

            stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1)
        elif self.ratio_source == 'text':
            extended_attention_probs = attention_probs.mean(dim=0)[:512, :]
            extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu()
            extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu()

            stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1)

        if step_index not in self.attention_ratios:
            self.attention_ratios[step_index] = {}

        if layer_prefix not in self.attention_ratios[step_index]:
            self.attention_ratios[step_index][layer_prefix] = []

        self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios)

    def get_attention_ratios(self, step_indices=None, display_imgs=False):
        ratios = []

        if step_indices is None:
            step_indices = list(self.attention_ratios.keys())

        if len(step_indices) == 1:
            steps = f"Step: {step_indices[0]}"
        else:
            steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]"

        layer_prefixes = list(self.attention_ratios[step_indices[0]].keys())
        scores_per_layer = {}
        
        for layer_prefix in layer_prefixes:
            ratios = []

            for step_index in step_indices:
                if layer_prefix in self.attention_ratios[step_index]:
                    step_ratios = self.attention_ratios[step_index][layer_prefix]
                    step_ratios = torch.stack(step_ratios).mean(dim=0)
                    ratios.append(step_ratios)
            
            # Mean over the steps
            ratios = torch.stack(ratios).mean(dim=0)

            if self.ratio_source == 'pixels':
                source, text, target, token = torch.split(ratios, 64, dim=1)
                title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}"
                ratios = min_max_norm(torch.cat([source, text, target], dim=1))
                token = min_max_norm(token)
                ratios = torch.cat([ratios, token], dim=1)
            elif self.ratio_source == 'text':
                source, target = torch.split(ratios, 64, dim=1)
                source_sum = source.sum().item()
                target_sum = target.sum().item()
                text_sum = 512 - (source_sum + target_sum)

                title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}"
                ratios = min_max_norm(torch.cat([source, target], dim=1))
            
            if display_imgs:
                print(f"Layer: {layer_prefix}")
                show_tensors([ratios], [title])

            scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum)

        return scores_per_layer

    def plot_attention_ratios(self, step_indices=None):
        steps = list(self.attention_ratios.keys())
        score_per_layer = {
            'transformer_blocks': {},
            'single_transformer_blocks': {}
        }

        for i in steps:
            scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False)

            for layer in self.attention_ratios[i]:
                source, text, target = scores_per_layer[layer]
                score_per_layer[layer][i] = (source, text, target)

        for layer_type in score_per_layer:
            x = list(score_per_layer[layer_type].keys())
            source_sums = [x[0] for x in score_per_layer[layer_type].values()]
            text_sums = [x[1] for x in score_per_layer[layer_type].values()]
            target_sums = [x[2] for x in score_per_layer[layer_type].values()]

            # Calculate the total sums for each stack (source + text + target)
            total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))]

            # Create stacked bar plots
            fig, ax = plt.subplots(figsize=(10, 6))
            indices = np.arange(len(x))

            # Plot source at the bottom
            ax.bar(indices, source_sums, label='Source', color='#6A2C70')

            # Plot text stacked on source
            ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums)

            # Plot target stacked on text + source
            target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))]
            ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom)

            # Annotate bars with percentage values
            for j, index in enumerate(indices):

                font_size = 12

                # Source percentage
                source_percentage = 100 * source_sums[j] / total_sums[j]
                ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%', 
                        ha='center', va='center', rotation=90, color='white', 
                        fontsize=font_size, fontweight='bold')

                # Text percentage
                text_percentage = 100 * text_sums[j] / total_sums[j]
                ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%', 
                        ha='center', va='center', rotation=90, color='white', 
                        fontsize=font_size, fontweight='bold')

                # Target percentage
                target_percentage = 100 * target_sums[j] / total_sums[j]
                ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%', 
                        ha='center', va='center', rotation=90, color='white', 
                        fontsize=font_size, fontweight='bold')


            ax.set_xlabel('Step Index')
            ax.set_ylabel('Attention Ratio')
            ax.set_title(f'Attention Ratios for {layer_type}')
            ax.set_xticks(indices)
            ax.set_xticklabels(x)

            plt.legend()
            plt.show()

    def aggregate_attention(self, store, target_layers=None, resolution=None,
                            gaussian_kernel=3, thr_type='otsu', thr_number=0.5):
        if target_layers is None:
            store_vals = list(store.values())
        elif isinstance(target_layers, list):
            store_vals = [store[x] for x in target_layers]
        else:
            raise ValueError("target_layers must be a list of layer names or None.")

        # store vals = List[layers] of Tensor[batch_size, text_tokens, image_tokens]
        batch_size = len(store_vals[0])
        
        attention_maps = []
        attention_masks = []

        for i in range(batch_size):
            # Average over the layers
            agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0)

            if resolution is None:
                size = int(agg_vals.shape[-1] ** 0.5)
                resolution = (size, size)
            
            agg_vals = agg_vals.view(agg_vals.shape[0], *resolution)

            if gaussian_kernel > 0:
                agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype)

            mask_vals = agg_vals.clone()

            for j in range(mask_vals.shape[0]):
                mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min())
                np_vals = mask_vals[j].float().cpu().numpy()

                otsu_thr = filters.threshold_otsu(np_vals)
                li_thr = threshold_li(np_vals, initial_guess=otsu_thr)
                yen_thr = threshold_yen(np_vals)

                if thr_type == 'otsu':
                    thr = otsu_thr
                elif thr_type == 'yen':
                    thr = yen_thr
                elif thr_type == 'li':
                    thr = li_thr
                elif thr_type == 'number':
                    thr = thr_number
                elif thr_type == 'multiotsu':
                    thrs = threshold_multiotsu(np_vals, classes=3)

                    if thrs[1] > thrs[0] * 3.5:
                        thr = thrs[1]
                    else:
                        thr = thrs[0]

                    # Take the closest threshold to otsu_thr
                    # thr = thrs[np.argmin(np.abs(thrs - otsu_thr))]
                
                # alpha = 0.8
                # thr  = (alpha * thr + (1-alpha) * mask_vals[j].max())
                
                mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype)

            attention_maps.append(agg_vals)
            attention_masks.append(mask_vals)

        return attention_maps, attention_masks, self.tokens_to_record