AustingDong commited on
Commit
369c141
·
1 Parent(s): e788822

renew app, formatted visualization

Browse files
Files changed (3) hide show
  1. app-old.py +501 -0
  2. app.py +9 -9
  3. demo/visualization.py +524 -0
app-old.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from janus.utils.io import load_pil_images
6
+ from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip, AttentionGuidedCAMChartGemma, AttentionGuidedCAMLLaVA
7
+ from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Utils, add_title_to_image
8
+
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import gc
12
+ import os
13
+ import spaces
14
+ from PIL import Image
15
+
16
+ def set_seed(model_seed = 42):
17
+ torch.manual_seed(model_seed)
18
+ np.random.seed(model_seed)
19
+ torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
20
+
21
+ set_seed()
22
+ clip_utils = Clip_Utils()
23
+ clip_utils.init_Clip()
24
+ model_utils, vl_gpt, tokenizer = None, None, None
25
+ model_name = "Clip"
26
+ language_model_max_layer = 24
27
+ language_model_best_layer = 8
28
+ vision_model_best_layer = 24
29
+
30
+ def clean():
31
+ global model_utils, vl_gpt, tokenizer, clip_utils
32
+ # Move models to CPU first (prevents CUDA references)
33
+ if 'vl_gpt' in globals() and vl_gpt is not None:
34
+ vl_gpt.to("cpu")
35
+ if 'clip_utils' in globals() and clip_utils is not None:
36
+ del clip_utils
37
+
38
+ # Delete all references
39
+ del model_utils, vl_gpt, tokenizer
40
+ model_utils, vl_gpt, tokenizer, clip_utils = None, None, None, None
41
+ gc.collect()
42
+
43
+ # Empty CUDA cache
44
+ if torch.cuda.is_available():
45
+ torch.cuda.empty_cache()
46
+ torch.cuda.ipc_collect() # Frees inter-process CUDA memory
47
+
48
+ # Empty MacOS Metal backend (if using Apple Silicon)
49
+ if torch.backends.mps.is_available():
50
+ torch.mps.empty_cache()
51
+
52
+ # Multimodal Understanding function
53
+ @spaces.GPU(duration=120)
54
+ def multimodal_understanding(model_type,
55
+ activation_map_method,
56
+ visual_pooling_method,
57
+ image, question, seed, top_p, temperature, target_token_idx,
58
+ visualization_layer_min, visualization_layer_max, focus, response_type, chart_type):
59
+ # Clear CUDA cache before generating
60
+ gc.collect()
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
+ torch.cuda.ipc_collect()
64
+
65
+ # set seed
66
+ torch.manual_seed(seed)
67
+ np.random.seed(seed)
68
+ torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
69
+
70
+ input_text_decoded = ""
71
+ answer = ""
72
+ if model_name == "Clip":
73
+
74
+ inputs = clip_utils.prepare_inputs([question], image)
75
+
76
+
77
+ if activation_map_method == "GradCAM":
78
+ # Generate Grad-CAM
79
+ all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
80
+
81
+ if visualization_layer_min != visualization_layer_max:
82
+ target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
83
+ else:
84
+ target_layers = [all_layers[visualization_layer_min-1]]
85
+ grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
86
+ cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
87
+ cam = cam.to("cpu")
88
+ cam = [generate_gradcam(cam, image, size=(224, 224))]
89
+ grad_cam.remove_hooks()
90
+ target_token_decoded = ""
91
+
92
+
93
+
94
+ else:
95
+
96
+ for param in vl_gpt.parameters():
97
+ param.requires_grad = True
98
+
99
+
100
+ prepare_inputs = model_utils.prepare_inputs(question, image)
101
+
102
+ if response_type == "answer + visualization":
103
+ if model_name.split('-')[0] == "Janus":
104
+ inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
105
+ outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
106
+ else:
107
+ outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
108
+
109
+ sequences = outputs.sequences.cpu().tolist()
110
+ answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
111
+ attention_raw = outputs.attentions
112
+ print("answer generated")
113
+
114
+ input_ids = prepare_inputs.input_ids[0].cpu().tolist()
115
+ input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
116
+
117
+ if activation_map_method == "GradCAM":
118
+ # target_layers = vl_gpt.vision_model.vision_tower.blocks
119
+ if focus == "Visual Encoder":
120
+ if model_name.split('-')[0] == "Janus":
121
+ all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
122
+ else:
123
+ all_layers = [block.layer_norm1 for block in vl_gpt.vision_tower.vision_model.encoder.layers]
124
+ else:
125
+ all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
126
+
127
+ print("layer values:", visualization_layer_min, visualization_layer_max)
128
+ if visualization_layer_min != visualization_layer_max:
129
+ print("multi layers")
130
+ target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max]
131
+ else:
132
+ print("single layer")
133
+ target_layers = [all_layers[visualization_layer_min-1]]
134
+
135
+
136
+ if model_name.split('-')[0] == "Janus":
137
+ gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
138
+ elif model_name.split('-')[0] == "LLaVA":
139
+ gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
140
+ elif model_name.split('-')[0] == "ChartGemma":
141
+ gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
142
+
143
+ start = 0
144
+ cam = []
145
+ if focus == "Visual Encoder":
146
+ if target_token_idx != -1:
147
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
148
+ cam_grid = cam_tensors.reshape(grid_size, grid_size)
149
+ cam_i = generate_gradcam(cam_grid, image)
150
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + target_token_idx])
151
+ cam = [cam_i]
152
+ else:
153
+ i = 0
154
+ cam = []
155
+ while start + i < len(input_ids_decoded):
156
+ if model_name.split('-')[0] == "Janus":
157
+ gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
158
+ elif model_name.split('-')[0] == "LLaVA":
159
+ gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
160
+ elif model_name.split('-')[0] == "ChartGemma":
161
+ gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
162
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_pooling_method, focus)
163
+ cam_grid = cam_tensors.reshape(grid_size, grid_size)
164
+ cam_i = generate_gradcam(cam_grid, image)
165
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
166
+ cam.append(cam_i)
167
+ gradcam.remove_hooks()
168
+ i += 1
169
+ else:
170
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
171
+ if target_token_idx != -1:
172
+ input_text_decoded = input_ids_decoded[start + target_token_idx]
173
+ for i, cam_tensor in enumerate(cam_tensors):
174
+ if i == target_token_idx:
175
+ cam_grid = cam_tensor.reshape(grid_size, grid_size)
176
+ cam_i = generate_gradcam(cam_grid, image)
177
+ cam = [add_title_to_image(cam_i, input_text_decoded)]
178
+ break
179
+ else:
180
+ cam = []
181
+ for i, cam_tensor in enumerate(cam_tensors):
182
+ cam_grid = cam_tensor.reshape(grid_size, grid_size)
183
+ cam_i = generate_gradcam(cam_grid, image)
184
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
185
+
186
+ cam.append(cam_i)
187
+
188
+ gradcam.remove_hooks()
189
+
190
+
191
+ # Collect Results
192
+ RESULTS_ROOT = "./results"
193
+ FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}"
194
+ os.makedirs(FILES_ROOT, exist_ok=True)
195
+ if focus == "Visual Encoder":
196
+ cam[0].save(f"{FILES_ROOT}/{visual_pooling_method}.png")
197
+ else:
198
+ for i, cam_p in enumerate(cam):
199
+ cam_p.save(f"{FILES_ROOT}/{i}.png")
200
+
201
+ with open(f"{FILES_ROOT}/input_text_decoded.txt", "w") as f:
202
+ f.write(input_text_decoded)
203
+ f.close()
204
+
205
+ with open(f"{FILES_ROOT}/answer.txt", "w") as f:
206
+ f.write(answer)
207
+ f.close()
208
+
209
+
210
+
211
+ return answer, cam, input_text_decoded
212
+
213
+
214
+
215
+
216
+ # Gradio interface
217
+
218
+ def model_slider_change(model_type):
219
+ global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer, vision_model_best_layer
220
+ model_name = model_type
221
+ if model_type == "Clip":
222
+ clean()
223
+ set_seed()
224
+ clip_utils = Clip_Utils()
225
+ clip_utils.init_Clip()
226
+ res = (
227
+ gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type"),
228
+ gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
229
+ gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
230
+ gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
231
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
232
+ )
233
+ return res
234
+ elif model_type.split('-')[0] == "Janus":
235
+
236
+ clean()
237
+ set_seed()
238
+ model_utils = Janus_Utils()
239
+ vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
240
+ language_model_max_layer = 24
241
+ language_model_best_layer = 8
242
+
243
+ res = (
244
+ gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
245
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
246
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
247
+ gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus"),
248
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
249
+ )
250
+ return res
251
+
252
+ elif model_type.split('-')[0] == "LLaVA":
253
+
254
+ clean()
255
+ set_seed()
256
+ model_utils = LLaVA_Utils()
257
+ version = model_type.split('-')[1]
258
+ vl_gpt, tokenizer = model_utils.init_LLaVA(version=version)
259
+ language_model_max_layer = 32 if version == "1.5" else 28
260
+ language_model_best_layer = 10
261
+
262
+ res = (
263
+ gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
264
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
265
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
266
+ gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
267
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
268
+ )
269
+ return res
270
+
271
+ elif model_type.split('-')[0] == "ChartGemma":
272
+ clean()
273
+ set_seed()
274
+ model_utils = ChartGemma_Utils()
275
+ vl_gpt, tokenizer = model_utils.init_ChartGemma()
276
+ language_model_max_layer = 18
277
+ vision_model_best_layer = 19
278
+ language_model_best_layer = 15
279
+
280
+ res = (
281
+ gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
282
+ gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
283
+ gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
284
+ gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Language Model", label="focus"),
285
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
286
+ )
287
+ return res
288
+
289
+
290
+
291
+
292
+ def focus_change(focus):
293
+ global model_name, language_model_max_layer
294
+ if model_name == "Clip":
295
+ res = (
296
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
297
+ gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
298
+ gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
299
+ )
300
+ return res
301
+
302
+ if focus == "Language Model":
303
+ if response_type.value == "answer + visualization":
304
+ res = (
305
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
306
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
307
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max")
308
+ )
309
+ return res
310
+ else:
311
+ res = (
312
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
313
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
314
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max")
315
+ )
316
+ return res
317
+
318
+ else:
319
+ if model_name.split('-')[0] == "ChartGemma":
320
+ res = (
321
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
322
+ gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers min"),
323
+ gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers max")
324
+ )
325
+ return res
326
+
327
+ else:
328
+ res = (
329
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
330
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
331
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
332
+ )
333
+ return res
334
+
335
+
336
+
337
+
338
+
339
+ with gr.Blocks() as demo:
340
+ gr.Markdown(value="# Multimodal Understanding")
341
+
342
+ with gr.Row():
343
+ image_input = gr.Image(height=500, label="Image")
344
+ activation_map_output = gr.Gallery(label="Visualization", height=500, columns=1, preview=True)
345
+
346
+ with gr.Row():
347
+ chart_type = gr.Textbox(label="Chart Type")
348
+ understanding_output = gr.Textbox(label="Answer")
349
+
350
+ with gr.Row():
351
+
352
+ with gr.Column():
353
+ model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
354
+ question_input = gr.Textbox(label="Input Prompt")
355
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
356
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
357
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
358
+ target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
359
+
360
+
361
+ with gr.Column():
362
+ response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
363
+ focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
364
+ activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="visualization type")
365
+ visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
366
+
367
+
368
+ visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
369
+ visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
370
+
371
+
372
+
373
+
374
+
375
+ model_selector.change(
376
+ fn=model_slider_change,
377
+ inputs=model_selector,
378
+ outputs=[
379
+ response_type,
380
+ visualization_layers_min,
381
+ visualization_layers_max,
382
+ focus,
383
+ activation_map_method
384
+ ]
385
+ )
386
+
387
+ focus.change(
388
+ fn = focus_change,
389
+ inputs = focus,
390
+ outputs=[
391
+ activation_map_method,
392
+ visualization_layers_min,
393
+ visualization_layers_max,
394
+ ]
395
+ )
396
+
397
+ # response_type.change(
398
+ # fn = response_type_change,
399
+ # inputs = response_type,
400
+ # outputs = [activation_map_method]
401
+ # )
402
+
403
+
404
+
405
+ understanding_button = gr.Button("Submit")
406
+
407
+ understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
408
+
409
+
410
+ examples_inpainting = gr.Examples(
411
+ label="Multimodal Understanding examples",
412
+ examples=[
413
+
414
+ [
415
+ "LineChart",
416
+ "What was the price of a barrel of oil in February 2020?",
417
+ "images/LineChart.png"
418
+ ],
419
+
420
+ [
421
+ "BarChart",
422
+ "What is the average internet speed in Japan?",
423
+ "images/BarChart.png"
424
+ ],
425
+
426
+ [
427
+ "StackedBar",
428
+ "What is the cost of peanuts in Seoul?",
429
+ "images/StackedBar.png"
430
+ ],
431
+
432
+ [
433
+ "100%StackedBar",
434
+ "Which country has the lowest proportion of Gold medals?",
435
+ "images/Stacked100.png"
436
+ ],
437
+
438
+ [
439
+ "PieChart",
440
+ "What is the approximate global smartphone market share of Samsung?",
441
+ "images/PieChart.png"
442
+ ],
443
+
444
+ [
445
+ "Histogram",
446
+ "What distance have customers traveled in the taxi the most?",
447
+ "images/Histogram.png"
448
+ ],
449
+
450
+ [
451
+ "Scatterplot",
452
+ "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
453
+ "images/Scatterplot.png"
454
+ ],
455
+
456
+ [
457
+ "AreaChart",
458
+ "What was the average price of pount of coffee beans in October 2019?",
459
+ "images/AreaChart.png"
460
+ ],
461
+
462
+ [
463
+ "StackedArea",
464
+ "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
465
+ "images/StackedArea.png"
466
+ ],
467
+
468
+ [
469
+ "BubbleChart",
470
+ "Which city's metro system has the largest number of stations?",
471
+ "images/BubbleChart.png"
472
+ ],
473
+
474
+ [
475
+ "Choropleth",
476
+ "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
477
+ "images/Choropleth_New.png"
478
+ ],
479
+
480
+ [
481
+ "TreeMap",
482
+ "True/False: eBay is nested in the Software category.",
483
+ "images/TreeMap.png"
484
+ ]
485
+
486
+ ],
487
+ inputs=[chart_type, question_input, image_input],
488
+ )
489
+
490
+
491
+
492
+
493
+ understanding_button.click(
494
+ multimodal_understanding,
495
+ inputs=[model_selector, activation_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
496
+ visualization_layers_min, visualization_layers_max, focus, response_type, chart_type],
497
+ outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
498
+ )
499
+
500
+ demo.launch(share=True)
501
+ # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
- from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip, AttentionGuidedCAMChartGemma, AttentionGuidedCAMLLaVA
7
  from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Utils, add_title_to_image
8
 
9
  import numpy as np
@@ -82,8 +82,8 @@ def multimodal_understanding(model_type,
82
  target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
83
  else:
84
  target_layers = [all_layers[visualization_layer_min-1]]
85
- grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
86
- cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
87
  cam = cam.to("cpu")
88
  cam = [generate_gradcam(cam, image, size=(224, 224))]
89
  grad_cam.remove_hooks()
@@ -134,11 +134,11 @@ def multimodal_understanding(model_type,
134
 
135
 
136
  if model_name.split('-')[0] == "Janus":
137
- gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
138
  elif model_name.split('-')[0] == "LLaVA":
139
- gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
140
  elif model_name.split('-')[0] == "ChartGemma":
141
- gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
142
 
143
  start = 0
144
  cam = []
@@ -154,11 +154,11 @@ def multimodal_understanding(model_type,
154
  cam = []
155
  while start + i < len(input_ids_decoded):
156
  if model_name.split('-')[0] == "Janus":
157
- gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
158
  elif model_name.split('-')[0] == "LLaVA":
159
- gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
160
  elif model_name.split('-')[0] == "ChartGemma":
161
- gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
162
  cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_pooling_method, focus)
163
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
164
  cam_i = generate_gradcam(cam_grid, image)
 
3
  from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
+ from demo.visualization import generate_gradcam, VisualizationJanus, VisualizationClip, VisualizationChartGemma, VisualizationLLaVA
7
  from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Utils, add_title_to_image
8
 
9
  import numpy as np
 
82
  target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
83
  else:
84
  target_layers = [all_layers[visualization_layer_min-1]]
85
+ grad_cam = VisualizationClip(clip_utils.model, target_layers)
86
+ cam, outputs, grid_size = grad_cam.generate_cam(inputs, target_token_idx=0, visual_pooling_method=visual_pooling_method)
87
  cam = cam.to("cpu")
88
  cam = [generate_gradcam(cam, image, size=(224, 224))]
89
  grad_cam.remove_hooks()
 
134
 
135
 
136
  if model_name.split('-')[0] == "Janus":
137
+ gradcam = VisualizationJanus(vl_gpt, target_layers)
138
  elif model_name.split('-')[0] == "LLaVA":
139
+ gradcam = VisualizationLLaVA(vl_gpt, target_layers)
140
  elif model_name.split('-')[0] == "ChartGemma":
141
+ gradcam = VisualizationChartGemma(vl_gpt, target_layers)
142
 
143
  start = 0
144
  cam = []
 
154
  cam = []
155
  while start + i < len(input_ids_decoded):
156
  if model_name.split('-')[0] == "Janus":
157
+ gradcam = VisualizationJanus(vl_gpt, target_layers)
158
  elif model_name.split('-')[0] == "LLaVA":
159
+ gradcam = VisualizationLLaVA(vl_gpt, target_layers)
160
  elif model_name.split('-')[0] == "ChartGemma":
161
+ gradcam = VisualizationChartGemma(vl_gpt, target_layers)
162
  cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_pooling_method, focus)
163
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
164
  cam_i = generate_gradcam(cam_grid, image)
demo/visualization.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import types
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from torch import nn
9
+ import spaces
10
+ from demo.modify_llama import *
11
+
12
+
13
+ class Visualization:
14
+ def __init__(self, model, register=True):
15
+ self.model = model
16
+ self.gradients = []
17
+ self.activations = []
18
+ self.hooks = []
19
+ if register:
20
+ self._register_hooks()
21
+
22
+ def _register_hooks(self):
23
+ for layer in self.target_layers:
24
+ self.hooks.append(layer.register_forward_hook(self._forward_hook))
25
+ self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
+
27
+ def _forward_hook(self, module, input, output):
28
+ self.activations.append(output)
29
+
30
+ def _backward_hook(self, module, grad_in, grad_out):
31
+ self.gradients.append(grad_out[0])
32
+
33
+ def _modify_layers(self):
34
+ for layer in self.target_layers:
35
+ setattr(layer, "attn_gradients", None)
36
+ setattr(layer, "attention_map", None)
37
+
38
+ layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
39
+ layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
40
+ layer.save_attn_map = types.MethodType(save_attn_map, layer)
41
+ layer.get_attn_map = types.MethodType(get_attn_map, layer)
42
+
43
+ def _forward_activate_hooks(self, module, input, output):
44
+ attn_output, attn_weights = output # Unpack outputs
45
+ print("attn_output shape:", attn_output.shape)
46
+ print("attn_weights shape:", attn_weights.shape)
47
+ module.save_attn_map(attn_weights)
48
+ attn_weights.register_hook(module.save_attn_gradients)
49
+
50
+ def _register_hooks_activations(self):
51
+ for layer in self.target_layers:
52
+ if hasattr(layer, "q_proj"): # is an attention layer
53
+ self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
54
+
55
+
56
+ def remove_hooks(self):
57
+ for hook in self.hooks:
58
+ hook.remove()
59
+
60
+ def setup_grads(self):
61
+ torch.autograd.set_detect_anomaly(True)
62
+ for param in self.model.parameters():
63
+ param.requires_grad = False
64
+
65
+ for layer in self.target_layers:
66
+ for param in layer.parameters():
67
+ param.requires_grad = True
68
+
69
+ def forward_backward(self):
70
+ raise NotImplementedError
71
+
72
+ def grad_cam_vis(self):
73
+ self.model.zero_grad()
74
+ cam_sum = None
75
+ for act, grad in zip(self.activations, self.gradients):
76
+
77
+ act = F.relu(act[0])
78
+
79
+ grad_weights = grad.mean(dim=-1, keepdim=True)
80
+
81
+ print("act shape", act.shape)
82
+ print("grad_weights shape", grad_weights.shape)
83
+
84
+ # cam = (act * grad_weights).sum(dim=-1)
85
+ cam, _ = (act * grad_weights).max(dim=-1)
86
+
87
+ print("cam_shape: ", cam.shape)
88
+
89
+ # Sum across all layers
90
+ if cam_sum is None:
91
+ cam_sum = cam
92
+ else:
93
+ cam_sum += cam
94
+
95
+ cam_sum = F.relu(cam_sum)
96
+ return cam_sum
97
+
98
+
99
+
100
+ def grad_cam_llm(self, mean_inside=False):
101
+
102
+ cam_sum = None
103
+ for act, grad in zip(self.activations, self.gradients):
104
+
105
+ if mean_inside:
106
+ act = act.mean(dim=1)
107
+ grad = F.relu(grad.mean(dim=1))
108
+ cam = act * grad
109
+ else:
110
+ cam = act * grad
111
+ cam = act * grad.sum(dim=1)
112
+
113
+ print(cam.shape)
114
+
115
+ # Sum across all layers
116
+ if cam_sum is None:
117
+ cam_sum = cam
118
+ else:
119
+ cam_sum += cam
120
+
121
+ cam_sum = F.relu(cam_sum)
122
+ return cam_sum
123
+
124
+ def attention_map(self):
125
+ raise NotImplementedError
126
+
127
+ def attn_guided_cam(self):
128
+
129
+ cams = []
130
+ for act, grad in zip(self.activations, self.gradients):
131
+ print("act shape", act.shape)
132
+ print("grad shape", grad.shape)
133
+
134
+ grad = F.relu(grad)
135
+
136
+ # cam = grad
137
+ cam = act * grad # shape: [1, heads, seq_len, seq_len]
138
+ cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
139
+ cam = cam.to(torch.float32).detach().cpu()
140
+ cams.append(cam)
141
+ return cams
142
+
143
+
144
+ def process(self, cam_sum, thresholding=True, remove_cls=True, normalize=True):
145
+
146
+ cam_sum = cam_sum.to(torch.float32)
147
+
148
+ # thresholding
149
+ if thresholding:
150
+ percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
151
+ cam_sum[cam_sum < percentile] = 0
152
+
153
+ # Remove CLS
154
+ if remove_cls:
155
+ cam_sum = cam_sum[0, 1:]
156
+
157
+ num_patches = cam_sum.shape[-1] # Last dimension of CAM output
158
+ grid_size = int(num_patches ** 0.5)
159
+ print(f"Detected grid size: {grid_size}x{grid_size}")
160
+ cam_sum = cam_sum.view(grid_size, grid_size).detach()
161
+
162
+ # Normalize
163
+ if normalize:
164
+ cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
165
+
166
+ return cam_sum, grid_size
167
+
168
+ def process_multiple(self, cam_sum, start_idx, images_seq_mask, thresholding=True, normalize=True):
169
+ cam_sum = cam_sum.to(torch.float32)
170
+ # thresholding
171
+ if thresholding:
172
+ percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
173
+ cam_sum[cam_sum < percentile] = 0
174
+
175
+
176
+ # cam_sum shape: [1, seq_len, seq_len]
177
+ cam_sum_lst = []
178
+ cam_sum_raw = cam_sum
179
+ start = start_idx
180
+ for i in range(start, cam_sum_raw.shape[1]):
181
+ cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
182
+ cam_sum = cam_sum[images_seq_mask].unsqueeze(0) # shape: [1, img_seq_len]
183
+ print("cam_sum shape: ", cam_sum.shape)
184
+ num_patches = cam_sum.shape[-1] # Last dimension of CAM output
185
+ grid_size = int(num_patches ** 0.5)
186
+ print(f"Detected grid size: {grid_size}x{grid_size}")
187
+
188
+ cam_sum = cam_sum.view(grid_size, grid_size)
189
+ if normalize:
190
+ cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
191
+ cam_sum = cam_sum.detach().to("cpu")
192
+ cam_sum_lst.append(cam_sum)
193
+ return cam_sum_lst, grid_size
194
+
195
+ def process_multiple_withsum(self, cams, start_idx, images_seq_mask, normalize=False):
196
+ cam_sum_lst = []
197
+ for i in range(start_idx, cams[0].shape[1]):
198
+ cam_sum = None
199
+ for layer, cam_l in enumerate(cams):
200
+ cam_l_i = cam_l[0, i, :] # shape: [1: seq_len]
201
+
202
+ cam_l_i = cam_l_i[images_seq_mask].unsqueeze(0) # shape: [1, img_seq_len]
203
+
204
+ num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
205
+ grid_size = int(num_patches ** 0.5)
206
+ # print(f"Detected grid size: {grid_size}x{grid_size}")
207
+
208
+ # Fix the reshaping step dynamically
209
+ cam_reshaped = cam_l_i.view(grid_size, grid_size)
210
+
211
+ if normalize:
212
+ cam_reshaped = (cam_reshaped - cam_reshaped.min()) / (cam_reshaped.max() - cam_reshaped.min())
213
+ if cam_sum == None:
214
+ cam_sum = cam_reshaped
215
+ else:
216
+ cam_sum += cam_reshaped
217
+
218
+ cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
219
+ cam_sum_lst.append(cam_sum)
220
+ return cam_sum_lst, grid_size
221
+
222
+ def generate_cam(self, input_tensor, target_token_idx=None):
223
+ raise NotImplementedError
224
+
225
+
226
+
227
+
228
+ class VisualizationClip(Visualization):
229
+ def __init__(self, model, target_layers):
230
+ self.target_layers = target_layers
231
+ super().__init__(model)
232
+
233
+ @spaces.GPU(duration=120)
234
+ def forward_backward(self, input_tensor, visual_pooling_method, target_token_idx):
235
+ output_full = self.model(**input_tensor)
236
+
237
+ if target_token_idx is None:
238
+ target_token_idx = torch.argmax(output_full.logits, dim=1).item()
239
+
240
+ if visual_pooling_method == "CLS":
241
+ output = output_full.image_embeds
242
+ elif visual_pooling_method == "avg":
243
+ output = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).mean(dim=1)
244
+ else:
245
+ output, _ = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).max(dim=1)
246
+
247
+
248
+ output.backward(output_full.text_embeds[target_token_idx:target_token_idx+1], retain_graph=True)
249
+ return output_full
250
+
251
+
252
+ @spaces.GPU(duration=120)
253
+ def generate_cam(self, input_tensor, target_token_idx=None, visual_pooling_method="CLS"):
254
+ """ Generates Grad-CAM heatmap for ViT. """
255
+ self.setup_grads()
256
+ # Forward Backward pass
257
+ output_full = self.forward_backward(input_tensor, visual_pooling_method, target_token_idx)
258
+
259
+ cam_sum = self.grad_cam_vis()
260
+ cam_sum, grid_size = self.process(cam_sum)
261
+
262
+ return cam_sum, output_full, grid_size
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
284
+
285
+
286
+
287
+ class VisualizationJanus(Visualization):
288
+ def __init__(self, model, target_layers):
289
+ self.target_layers = target_layers
290
+ super().__init__(model)
291
+ self._modify_layers()
292
+ self._register_hooks_activations()
293
+
294
+ def forward_backward(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
295
+ # Forward
296
+ image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
297
+ input_ids = input_tensor.input_ids
298
+
299
+ if focus == "Visual Encoder":
300
+
301
+ start_idx = 620
302
+ self.model.zero_grad()
303
+
304
+ loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
305
+ loss.backward()
306
+
307
+ elif focus == "Language Model":
308
+ self.model.zero_grad()
309
+ loss = outputs.logits.max(dim=-1).values.sum()
310
+ loss.backward()
311
+
312
+ self.activations = [layer.get_attn_map() for layer in self.target_layers]
313
+ self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
314
+
315
+ @spaces.GPU(duration=120)
316
+ def generate_cam(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
317
+
318
+ self.setup_grads()
319
+
320
+ # Forward Backward pass
321
+ self.forward_backward(input_tensor, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
322
+
323
+ start_idx = 620
324
+ if focus == "Visual Encoder":
325
+
326
+ cam_sum = self.grad_cam_vis()
327
+ cam_sum, grid_size = self.process(cam_sum)
328
+ return cam_sum, grid_size, start_idx
329
+
330
+ elif focus == "Language Model":
331
+
332
+ cam_sum = self.grad_cam_llm(mean_inside=True)
333
+
334
+ images_seq_mask = input_tensor.images_seq_mask
335
+
336
+ cam_sum_lst, grid_size = self.process_multiple(cam_sum, start_idx, images_seq_mask)
337
+
338
+ return cam_sum_lst, grid_size, start_idx
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+ class VisualizationLLaVA(Visualization):
349
+ def __init__(self, model, target_layers):
350
+ self.target_layers = target_layers
351
+ super().__init__(model, register=False)
352
+ self._modify_layers()
353
+ self._register_hooks_activations()
354
+
355
+ def forward_backward(self, inputs):
356
+ # Forward pass
357
+ outputs_raw = self.model(**inputs)
358
+
359
+ self.model.zero_grad()
360
+ print("outputs_raw", outputs_raw)
361
+
362
+ loss = outputs_raw.logits.max(dim=-1).values.sum()
363
+ loss.backward()
364
+ self.activations = [layer.get_attn_map() for layer in self.target_layers]
365
+ self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
366
+
367
+ @spaces.GPU(duration=120)
368
+ def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
369
+
370
+ self.setup_grads()
371
+ self.forward_backward(inputs)
372
+
373
+ # get image masks
374
+ images_seq_mask = []
375
+ last = 0
376
+ for i in range(inputs["input_ids"].shape[1]):
377
+ decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
378
+ if (decoded_token == "<image>"):
379
+ images_seq_mask.append(True)
380
+ last = i
381
+ else:
382
+ images_seq_mask.append(False)
383
+
384
+
385
+ # Aggregate activations and gradients from ALL layers
386
+ start_idx = last + 1
387
+ cams = self.attn_guided_cam()
388
+ cam_sum_lst, grid_size = self.process_multiple_withsum(cams, start_idx, images_seq_mask)
389
+
390
+ return cam_sum_lst, grid_size, start_idx
391
+
392
+
393
+
394
+
395
+
396
+
397
+ class VisualizationChartGemma(Visualization):
398
+ def __init__(self, model, target_layers):
399
+ self.target_layers = target_layers
400
+ super().__init__(model, register=True)
401
+ self._modify_layers()
402
+ self._register_hooks_activations()
403
+
404
+ def forward_backward(self, inputs, focus, start_idx, target_token_idx):
405
+ outputs_raw = self.model(**inputs, output_hidden_states=True)
406
+ if focus == "Visual Encoder":
407
+
408
+ self.model.zero_grad()
409
+
410
+ loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
411
+ loss.backward()
412
+
413
+ elif focus == "Language Model":
414
+ self.model.zero_grad()
415
+ if target_token_idx == -1:
416
+ loss = outputs_raw.logits.max(dim=-1).values.sum()
417
+ else:
418
+ loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
419
+ loss.backward()
420
+ self.activations = [layer.get_attn_map() for layer in self.target_layers]
421
+ self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
422
+
423
+ @spaces.GPU(duration=120)
424
+ def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
425
+
426
+ # Forward pass
427
+ self.setup_grads()
428
+
429
+ # get image masks
430
+ images_seq_mask = []
431
+ last = 0
432
+ for i in range(inputs["input_ids"].shape[1]):
433
+ decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
434
+ if (decoded_token == "<image>"):
435
+ images_seq_mask.append(True)
436
+ last = i
437
+ else:
438
+ images_seq_mask.append(False)
439
+ start_idx = last + 1
440
+
441
+
442
+ self.forward_backward(inputs, focus, start_idx, target_token_idx)
443
+ if focus == "Visual Encoder":
444
+
445
+ cam_sum = self.grad_cam_vis()
446
+ cam_sum, grid_size = self.process(cam_sum, remove_cls=False)
447
+
448
+ return cam_sum, grid_size, start_idx
449
+
450
+ elif focus == "Language Model":
451
+
452
+ cams = self.attn_guided_cam()
453
+ cam_sum_lst, grid_size = self.process_multiple_withsum(cams, start_idx, images_seq_mask)
454
+
455
+ # cams shape: [layers, 1, seq_len, seq_len]
456
+
457
+
458
+
459
+ return cam_sum_lst, grid_size, start_idx
460
+
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+ def generate_gradcam(
471
+ cam,
472
+ image,
473
+ size = (384, 384),
474
+ alpha=0.5,
475
+ colormap=cv2.COLORMAP_JET,
476
+ aggregation='mean',
477
+ normalize=False
478
+ ):
479
+ """
480
+ Generates a Grad-CAM heatmap overlay on top of the input image.
481
+
482
+ Parameters:
483
+ attributions (torch.Tensor): A tensor of shape (C, H, W) representing the
484
+ intermediate activations or gradients at the target layer.
485
+ image (PIL.Image): The original image.
486
+ alpha (float): The blending factor for the heatmap overlay (default 0.5).
487
+ colormap (int): OpenCV colormap to apply (default cv2.COLORMAP_JET).
488
+ aggregation (str): How to aggregate across channels; either 'mean' or 'sum'.
489
+
490
+ Returns:
491
+ PIL.Image: The image overlaid with the Grad-CAM heatmap.
492
+ """
493
+ # print("Generating Grad-CAM with shape:", cam.shape)
494
+
495
+ if normalize:
496
+ cam_min, cam_max = cam.min(), cam.max()
497
+ cam = cam - cam_min
498
+ cam = cam / (cam_max - cam_min)
499
+ # Convert tensor to numpy array
500
+ cam = torch.nn.functional.interpolate(cam.unsqueeze(0).unsqueeze(0), size=size, mode='bilinear').squeeze()
501
+ cam_np = cam.squeeze().detach().cpu().numpy()
502
+
503
+ # Apply Gaussian blur for smoother heatmaps
504
+ cam_np = cv2.GaussianBlur(cam_np, (5,5), sigmaX=0.8)
505
+
506
+ # Resize the cam to match the image size
507
+ width, height = size
508
+ cam_resized = cv2.resize(cam_np, (width, height))
509
+
510
+ # Convert the normalized map to a heatmap (0-255 uint8)
511
+ heatmap = np.uint8(255 * cam_resized)
512
+ heatmap = cv2.applyColorMap(heatmap, colormap)
513
+ # OpenCV produces heatmaps in BGR, so convert to RGB for consistency
514
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
515
+
516
+ # Convert original image to a numpy array
517
+ image_np = np.array(image)
518
+ image_np = cv2.resize(image_np, (width, height))
519
+
520
+ # Blend the heatmap with the original image
521
+ overlay = cv2.addWeighted(image_np, 1 - alpha, heatmap, alpha, 0)
522
+
523
+ return Image.fromarray(overlay)
524
+