hyesulim commited on
Commit
bf85231
·
verified ·
1 Parent(s): c574085

fix: revert to base

Browse files
Files changed (1) hide show
  1. app.py +335 -696
app.py CHANGED
@@ -2,10 +2,7 @@ import gzip
2
  import os
3
  import pickle
4
  from glob import glob
5
- from functools import lru_cache
6
- import concurrent.futures
7
- import threading
8
- import time
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -14,259 +11,47 @@ import torch
14
  from PIL import Image, ImageDraw
15
  from plotly.subplots import make_subplots
16
 
17
- # Constants
18
  IMAGE_SIZE = 400
19
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
20
  GRID_NUM = 14
21
  pkl_root = "./data/out"
22
-
23
- # Global cache for preloaded data
24
  preloaded_data = {}
25
- data_dict = {}
26
- sae_data_dict = {}
27
- activation_cache = {}
28
- segmask_cache = {}
29
- top_images_cache = {}
30
 
31
- # Thread lock for thread-safe operations
32
- data_lock = threading.Lock()
33
 
34
- # Load data more efficiently
35
- def load_all_data(image_root, pkl_root):
36
- """Load all necessary data with optimized caching"""
37
- # Load image data
38
- image_files = glob(f"{image_root}/*")
39
- data_dict = {}
40
-
41
- # Use thread pool for parallel image loading
42
- def load_image_data(image_file):
43
- image_name = os.path.basename(image_file).split(".")[0]
44
- # Only load thumbnail for initial display, load full image on demand
45
- thumbnail = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
46
- return image_name, {
47
- "image": thumbnail,
48
- "image_path": image_file,
49
- }
50
-
51
- # Load images in parallel
52
- with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
53
- results = executor.map(load_image_data, image_files)
54
- for image_name, data in results:
55
- data_dict[image_name] = data
56
-
57
- # Load SAE data with minimal processing
58
- sae_data_dict = {}
59
-
60
- # Load mean acts only once
61
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
62
- sae_data_dict["mean_acts"] = pickle.load(f)
63
-
64
- # Update all components when radio selection changes
65
- radio_choices.change(
66
- fn=update_all,
67
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
68
- outputs=[
69
- seg_mask_display,
70
- seg_mask_display_maple,
71
- top_image_1,
72
- top_image_2,
73
- top_image_3,
74
- act_value_1,
75
- act_value_2,
76
- act_value_3,
77
- markdown_display,
78
- markdown_display_2,
79
- ],
80
- _js="""
81
- function(img, radio, toggle, model) {
82
- // Add a small delay to prevent rapid UI updates
83
- clearTimeout(window._radioTimeout);
84
- return new Promise((resolve) => {
85
- window._radioTimeout = setTimeout(() => {
86
- resolve([img, radio, toggle, model]);
87
- }, 100);
88
- });
89
- }
90
- """
91
- )
92
-
93
- # Update components when toggle button changes
94
- toggle_btn.change(
95
- fn=show_activation_heatmap_clip,
96
- inputs=[image_selector, radio_choices, toggle_btn],
97
- outputs=[
98
- seg_mask_display,
99
- top_image_1,
100
- top_image_2,
101
- top_image_3,
102
- act_value_1,
103
- act_value_2,
104
- act_value_3,
105
- ],
106
- _js="""
107
- function(img, radio, toggle) {
108
- // Add a small delay to prevent rapid UI updates
109
- clearTimeout(window._toggleTimeout);
110
- return new Promise((resolve) => {
111
- window._toggleTimeout = setTimeout(() => {
112
- resolve([img, radio, toggle]);
113
- }, 100);
114
- });
115
- }
116
- """
117
- )
118
 
119
- # Initialize UI with default values
120
- default_options = get_init_radio_options(default_image_name, model_options[0])
121
- if default_options:
122
- default_option = default_options[0]
123
-
124
- # Set initial values to avoid blank UI at start
125
- gr.on(
126
- gr.Blocks.load,
127
- fn=lambda: update_all(
128
- default_image_name,
129
- default_option,
130
- False,
131
- model_options[0]
132
- ),
133
- outputs=[
134
- seg_mask_display,
135
- seg_mask_display_maple,
136
- top_image_1,
137
- top_image_2,
138
- top_image_3,
139
- act_value_1,
140
- act_value_2,
141
- act_value_3,
142
- markdown_display,
143
- markdown_display_2,
144
- ],
145
- )
146
 
147
- # Add a status indicator to show processing state
148
- status_indicator = gr.Markdown("Status: Ready")
149
-
150
- # Add a refresh button to manually reload data if needed
151
- refresh_btn = gr.Button("Refresh Data")
152
-
153
- def reload_data():
154
- global data_dict, sae_data_dict
155
-
156
- # Update status
157
- yield "Status: Reloading data..."
158
-
159
- # Reload data
160
- try:
161
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
162
- yield "Status: Data reloaded successfully!"
163
- except Exception as e:
164
- yield f"Status: Error reloading data - {str(e)}"
165
-
166
- refresh_btn.click(
167
- fn=reload_data,
168
- inputs=[],
169
- outputs=[status_indicator],
170
- queue=False
171
- )
172
-
173
- # Launch app with optimized settings
174
- demo.queue(concurrency_count=3, max_size=10) # Balanced concurrency for better performance
175
-
176
- # Add startup message
177
- print("Starting visualization application...")
178
- print(f"Loaded {len(data_dict)} images and {len(sae_data_dict)} datasets")
179
-
180
- # Launch with proper error handling
181
- demo.launch(
182
- share=False, # Don't share publicly
183
- debug=False, # Disable debug mode for production
184
- show_error=True, # Show errors for debugging
185
- quiet=False, # Show startup messages
186
- favicon_path=None, # Default favicon
187
- server_port=None, # Use default port
188
- server_name=None, # Bind to all interfaces
189
- height=None, # Use default height
190
- width=None, # Use default width
191
- enable_queue=True, # Enable queue for better performance
192
- ) dictionary for dataset values
193
- sae_data_dict["mean_act_values"] = {}
194
-
195
- # Load dataset values in parallel
196
- def load_dataset_values(dataset):
197
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
198
- return dataset, pickle.load(f)
199
-
200
- with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
201
- futures = [
202
- executor.submit(load_dataset_values, dataset)
203
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
204
- ]
205
- for future in concurrent.futures.as_completed(futures):
206
- dataset, data = future.result()
207
- sae_data_dict["mean_act_values"][dataset] = data
208
-
209
- return data_dict, sae_data_dict
210
 
211
- # Cache activation data with LRU cache
212
- @lru_cache(maxsize=32)
213
- def preload_activation(image_name, model_name):
214
- """Preload and cache activation data for a specific image and model"""
215
- image_file = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
216
-
217
- try:
218
- with gzip.open(image_file, "rb") as f:
219
- return pickle.load(f)
220
- except Exception as e:
221
- print(f"Error loading {image_file}: {e}")
222
- return None
223
-
224
- # Get activation with caching
225
- def get_data(image_name, model_type):
226
- """Get activation data with caching for better performance"""
227
- cache_key = f"{image_name}_{model_type}"
228
-
229
- with data_lock:
230
- if cache_key not in activation_cache:
231
- activation_cache[cache_key] = preload_activation(image_name, model_type)
232
-
233
- return activation_cache[cache_key]
234
-
235
- def get_activation_distribution(image_name, model_type):
236
- """Get activation distribution with noise filtering"""
237
- activation = get_data(image_name, model_type)
238
-
239
- if activation is None:
240
- # Return empty tensor if data loading failed
241
- return torch.zeros((GRID_NUM * GRID_NUM + 1, 1000))
242
-
243
- activation = activation[0]
244
-
245
- # Filter out noisy features
246
  noisy_features_indices = (
247
  (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
248
  )
249
  activation[:, noisy_features_indices] = 0
250
-
251
  return activation
252
 
 
253
  def get_grid_loc(evt, image):
254
- """Get grid location from click event"""
255
  # Get click coordinates
256
  x, y = evt._data["index"][0], evt._data["index"][1]
257
-
258
  cell_width = image.width // GRID_NUM
259
  cell_height = image.height // GRID_NUM
260
-
261
  grid_x = x // cell_width
262
  grid_y = y // cell_height
263
  return grid_x, grid_y, cell_width, cell_height
264
 
265
- def highlight_grid(evt, image_name):
266
- """Highlight grid cell on click"""
267
  image = data_dict[image_name]["image"]
268
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
269
-
270
  highlighted_image = image.copy()
271
  draw = ImageDraw.Draw(highlighted_image)
272
  box = [
@@ -276,14 +61,16 @@ def highlight_grid(evt, image_name):
276
  (grid_y + 1) * cell_height,
277
  ]
278
  draw.rectangle(box, outline="red", width=3)
279
-
280
  return highlighted_image
281
 
 
282
  def load_image(img_name):
283
- """Load image by name"""
284
- return data_dict[img_name]["image"]
 
 
285
 
286
- # Optimized plotting with less annotations
287
  def plot_activations(
288
  all_activation,
289
  tile_activations=None,
@@ -293,28 +80,19 @@ def plot_activations(
293
  colors=("blue", "cyan"),
294
  model_name="CLIP",
295
  ):
296
- """Plot activations with optimized rendering"""
297
  fig = go.Figure()
298
-
299
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
300
- # Only plot non-zero values to reduce points
301
- non_zero_indices = np.where(np.abs(activations) > 1e-5)[0]
302
- if len(non_zero_indices) == 0:
303
- # If all values are near zero, use full array
304
- non_zero_indices = np.arange(len(activations))
305
-
306
  fig.add_trace(
307
  go.Scatter(
308
- x=non_zero_indices,
309
- y=activations[non_zero_indices],
310
  mode="lines",
311
  name=label,
312
  line=dict(color=color, dash="solid"),
313
  showlegend=True,
314
  )
315
  )
316
-
317
- # Only annotate the top_k activations
318
  top_neurons = np.argsort(activations)[::-1][:top_k]
319
  for idx in top_neurons:
320
  fig.add_annotation(
@@ -329,46 +107,45 @@ def plot_activations(
329
  opacity=0.7,
330
  )
331
  return fig
332
-
333
- label = f"{model_name.split('-')[-1]} Image-level"
334
  fig = _add_scatter_with_annotation(
335
  fig, all_activation, model_name, colors[0], label
336
  )
337
-
338
  if tile_activations is not None:
339
- label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
340
  fig = _add_scatter_with_annotation(
341
  fig, tile_activations, model_name, colors[1], label
342
  )
343
-
344
- # Optimize layout with minimal settings
345
  fig.update_layout(
346
  title="Activation Distribution",
347
  xaxis_title="SAE latent index",
348
  yaxis_title="Activation Value",
349
  template="plotly_white",
350
- legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5),
351
  )
352
-
 
 
 
353
  return fig
354
 
355
- def get_activations(evt, selected_image, model_name, colors):
356
- """Get activations for plotting"""
357
  activation = get_activation_distribution(selected_image, model_name)
358
  all_activation = activation.mean(0)
359
-
360
  tile_activations = None
361
  grid_x = None
362
  grid_y = None
363
-
364
- if evt is not None and evt._data is not None:
365
- image = data_dict[selected_image]["image"]
366
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
367
- token_idx = grid_y * GRID_NUM + grid_x + 1
368
- # Ensure token_idx is within bounds
369
- if token_idx < activation.shape[0]:
370
  tile_activations = activation[token_idx]
371
-
372
  fig = plot_activations(
373
  all_activation,
374
  tile_activations,
@@ -378,291 +155,124 @@ def get_activations(evt, selected_image, model_name, colors):
378
  model_name=model_name,
379
  colors=colors,
380
  )
381
-
382
  return fig
383
 
384
- # Cache plot results
385
- @lru_cache(maxsize=16)
386
- def plot_activation_distribution(evt_data, selected_image, model_name):
387
- """Plot activation distribution with caching"""
388
- # Convert event data to hashable format for caching
389
- if evt_data is not None:
390
- evt = type('obj', (object,), {'_data': evt_data})
391
- else:
392
- evt = None
393
-
394
  fig = make_subplots(
395
  rows=2,
396
  cols=1,
397
  shared_xaxes=True,
398
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
399
  )
400
-
401
  fig_clip = get_activations(
402
  evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
403
  )
404
  fig_maple = get_activations(
405
  evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
406
  )
407
-
408
  def _attach_fig(fig, sub_fig, row, col, yref):
409
  for trace in sub_fig.data:
410
  fig.add_trace(trace, row=row, col=col)
411
-
412
  for annotation in sub_fig.layout.annotations:
413
  annotation.update(yref=yref)
414
  fig.add_annotation(annotation)
415
  return fig
416
-
417
  fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
418
  fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
419
-
420
- # Optimize layout with minimal settings
421
  fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
422
  fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
423
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
424
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
425
  fig.update_layout(
 
 
426
  template="plotly_white",
427
  showlegend=True,
428
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
429
  margin=dict(l=20, r=20, t=40, b=20),
430
  )
431
-
432
  return fig
433
 
434
- # Cache segmentation masks
435
- @lru_cache(maxsize=32)
436
  def get_segmask(selected_image, slider_value, model_type):
437
- """Generate segmentation mask with caching"""
 
 
438
  try:
439
- # Check if image exists
440
- if selected_image not in data_dict:
441
- print(f"Image {selected_image} not found in data dictionary")
442
- # Return blank mask with IMAGE_SIZE dimensions
443
- return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
444
-
445
- # Use cache if available
446
- cache_key = f"{selected_image}_{slider_value}_{model_type}"
447
- with data_lock:
448
- if cache_key in segmask_cache:
449
- return segmask_cache[cache_key]
450
-
451
- # Get image
452
- image = data_dict[selected_image]["image"]
453
-
454
- # Get activation data
455
- sae_act = get_data(selected_image, model_type)
456
-
457
- if sae_act is None:
458
- # Return blank mask if data loading failed
459
- return np.zeros((image.height, image.width, 4), dtype=np.uint8)
460
-
461
- # Handle array shape issues
462
- try:
463
- # Check array shape and dimensions
464
- if isinstance(sae_act, tuple) and len(sae_act) > 0:
465
- # First element of tuple
466
- act_data = sae_act[0]
467
- else:
468
- # Direct array
469
- act_data = sae_act
470
-
471
- # Check if slider_value is within bounds
472
- if slider_value >= act_data.shape[1]:
473
- print(f"Slider value {slider_value} out of bounds for activation shape {act_data.shape}")
474
- return np.zeros((image.height, image.width, 4), dtype=np.uint8)
475
-
476
- # Get activation for specific latent
477
- temp = act_data[:, slider_value]
478
-
479
- # Skip first token (CLS token) and reshape to grid
480
- if len(temp) > 1: # Ensure we have enough tokens
481
- mask = torch.Tensor(temp[1:].reshape(GRID_NUM, GRID_NUM)).view(1, 1, GRID_NUM, GRID_NUM)
482
-
483
- # Upsample to image dimensions
484
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
485
-
486
- # Normalize mask values between 0 and 1
487
- mask_min, mask_max = mask.min(), mask.max()
488
- if mask_max > mask_min: # Avoid division by zero
489
- mask = (mask - mask_min) / (mask_max - mask_min)
490
- else:
491
- mask = np.zeros_like(mask)
492
- else:
493
- # Not enough tokens
494
- print(f"Not enough tokens in activation data: {len(temp)}")
495
- return np.zeros((image.height, image.width, 4), dtype=np.uint8)
496
-
497
- except Exception as e:
498
- print(f"Error processing activation data: {e}")
499
- print(f"Shape info - sae_act: {type(sae_act)}, slider_value: {slider_value}")
500
- return np.zeros((image.height, image.width, 4), dtype=np.uint8)
501
-
502
- # Create RGBA overlay
503
- try:
504
- # Set base opacity for darkened areas
505
- base_opacity = 30
506
-
507
- # Convert image to numpy array
508
- image_array = np.array(image)
509
-
510
- # Handle grayscale images
511
- if len(image_array.shape) == 2:
512
- # Convert grayscale to RGB
513
- image_array = np.stack([image_array] * 3, axis=-1)
514
- elif image_array.shape[2] == 4:
515
- # Use only RGB channels
516
- image_array = image_array[..., :3]
517
-
518
- # Create overlay
519
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
520
- rgba_overlay[..., :3] = image_array
521
-
522
- # Use vectorized operations for better performance
523
- darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
524
-
525
- # Create mask for darkened areas
526
- mask_threshold = 0.1 # Adjust threshold if needed
527
- mask_zero = mask < mask_threshold
528
-
529
- # Apply darkening only to low-activation areas
530
- rgba_overlay[mask_zero, :3] = darkened_image[mask_zero]
531
-
532
- # Set alpha channel
533
- rgba_overlay[..., 3] = 255 # Fully opaque
534
-
535
- # Cache result for future use
536
- with data_lock:
537
- segmask_cache[cache_key] = rgba_overlay
538
-
539
- return rgba_overlay
540
-
541
- except Exception as e:
542
- print(f"Error creating overlay: {e}")
543
- return np.zeros((image.height, image.width, 4), dtype=np.uint8)
544
-
545
  except Exception as e:
546
- print(f"Unexpected error in get_segmask: {e}")
547
- # Return a blank image of standard size
548
- return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
- # Cache top images
551
- @lru_cache(maxsize=32)
552
  def get_top_images(slider_value, toggle_btn):
553
- """Get top images with caching"""
554
- cache_key = f"{slider_value}_{toggle_btn}"
555
-
556
- if cache_key in top_images_cache:
557
- return top_images_cache[cache_key]
558
-
559
  def _get_images(dataset_path):
560
  top_image_paths = [
561
  os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
562
  os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
563
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
564
  ]
565
-
566
- top_images = []
567
- for path in top_image_paths:
568
- if os.path.exists(path):
569
- top_images.append(Image.open(path))
570
- else:
571
- top_images.append(Image.new("RGB", (256, 256), (255, 255, 255)))
572
-
573
  return top_images
574
-
575
  if toggle_btn:
576
  top_images = _get_images("./data/top_images_masked")
577
  else:
578
  top_images = _get_images("./data/top_images")
579
-
580
- # Cache result
581
- top_images_cache[cache_key] = top_images
582
-
583
  return top_images
584
 
 
585
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
586
- """Show activation heatmap with optimized processing"""
587
- try:
588
- # Parse slider value safely
589
- if not slider_value:
590
- # Fallback to the first option if no slider value
591
- radio_options = get_init_radio_options(selected_image, model_type)
592
- if not radio_options:
593
- # Create placeholder data if no options available
594
- return (
595
- np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
596
- [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
597
- ["#### Activation values: No data available"] * 3
598
- )
599
- slider_value = radio_options[0]
600
-
601
- # Extract the integer value
602
- try:
603
- slider_value_int = int(slider_value.split("-")[-1])
604
- except (ValueError, IndexError):
605
- print(f"Error parsing slider value: {slider_value}")
606
- slider_value_int = 0
607
-
608
- # Process in parallel with thread pool and add timeout
609
- results = []
610
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
611
- # Start both tasks
612
- segmask_future = executor.submit(get_segmask, selected_image, slider_value_int, model_type)
613
- top_images_future = executor.submit(get_top_images, slider_value_int, toggle_btn)
614
-
615
- # Get results with timeout to prevent hanging
616
- try:
617
- rgba_overlay = segmask_future.result(timeout=5)
618
- except (concurrent.futures.TimeoutError, Exception) as e:
619
- print(f"Error or timeout generating segmentation mask: {e}")
620
- rgba_overlay = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
621
-
622
- try:
623
- top_images = top_images_future.result(timeout=5)
624
- except (concurrent.futures.TimeoutError, Exception) as e:
625
- print(f"Error or timeout getting top images: {e}")
626
- top_images = [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)]
627
-
628
- # Prepare activation values with error handling
629
- act_values = []
630
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
631
- try:
632
- if dataset in sae_data_dict["mean_act_values"]:
633
- values = sae_data_dict["mean_act_values"][dataset]
634
- if slider_value_int < values.shape[0]:
635
- act_value = values[slider_value_int, :5]
636
- act_value = [str(round(value, 3)) for value in act_value]
637
- act_value = " | ".join(act_value)
638
- out = f"#### Activation values: {act_value}"
639
- else:
640
- out = f"#### Activation values: Index out of range"
641
- else:
642
- out = f"#### Activation values: Dataset not available"
643
- except Exception as e:
644
- print(f"Error getting activation values for {dataset}: {e}")
645
- out = f"#### Activation values: Error retrieving data"
646
-
647
- act_values.append(out)
648
-
649
- return rgba_overlay, top_images, act_values
650
-
651
- except Exception as e:
652
- print(f"Error in show_activation_heatmap: {e}")
653
- # Return placeholder data in case of error
654
- return (
655
- np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
656
- [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
657
- ["#### Activation values: Error occurred"] * 3
658
- )
659
 
660
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
661
- """Show CLIP activation heatmap"""
662
  rgba_overlay, top_images, act_values = show_activation_heatmap(
663
  selected_image, slider_value, "CLIP", toggle_btn
664
  )
665
-
666
  return (
667
  rgba_overlay,
668
  top_images[0],
@@ -673,19 +283,18 @@ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
673
  act_values[2],
674
  )
675
 
 
676
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
677
- """Show MaPLE activation heatmap"""
678
- slider_value_int = int(slider_value.split("-")[-1])
679
- rgba_overlay = get_segmask(selected_image, slider_value_int, model_name)
680
-
681
  return rgba_overlay
682
 
683
- # Optimize radio options generation
684
  def get_init_radio_options(selected_image, model_name):
685
- """Get initial radio options with optimized processing"""
686
  clip_neuron_dict = {}
687
  maple_neuron_dict = {}
688
-
689
  def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
690
  activations = get_activation_distribution(selected_image, model_name).mean(0)
691
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
@@ -695,138 +304,127 @@ def get_init_radio_options(selected_image, model_name):
695
  sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
696
  )
697
  return sorted_dict
698
-
699
- # Process in parallel
700
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
701
- future_clip = executor.submit(_get_top_actvation, selected_image, "CLIP", {})
702
- future_maple = executor.submit(_get_top_actvation, selected_image, model_name, {})
703
-
704
- clip_neuron_dict = future_clip.result()
705
- maple_neuron_dict = future_maple.result()
706
-
707
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
708
-
709
  return radio_choices
710
 
 
711
  def get_radio_names(clip_neuron_dict, maple_neuron_dict):
712
- """Get radio button names based on neuron activations"""
713
  clip_keys = list(clip_neuron_dict.keys())
714
  maple_keys = list(maple_neuron_dict.keys())
715
-
716
- # Use set operations for better performance
717
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
718
- clip_only_keys = list(set(clip_keys) - set(maple_keys))
719
- maple_only_keys = list(set(maple_keys) - set(clip_keys))
720
-
721
- # Sort keys by activation values
722
  common_keys.sort(
723
- key=lambda x: max(clip_neuron_dict.get(x, 0), maple_neuron_dict.get(x, 0)),
724
- reverse=True
725
  )
726
- clip_only_keys.sort(key=lambda x: clip_neuron_dict.get(x, 0), reverse=True)
727
- maple_only_keys.sort(key=lambda x: maple_neuron_dict.get(x, 0), reverse=True)
728
-
729
- # Limit number of choices to improve performance
730
  out = []
731
  out.extend([f"common-{i}" for i in common_keys[:5]])
732
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
733
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
734
-
735
  return out
736
 
737
- def update_radio_options(evt, selected_image, model_name):
738
- """Update radio options based on user interaction"""
739
- def _get_top_actvation(evt, selected_image, model_name):
740
- neuron_dict = {}
 
 
 
 
741
  all_activation = get_activation_distribution(selected_image, model_name)
742
  image_activation = all_activation.mean(0)
743
-
744
- # Get top activations from image-level
745
- top_neurons = list(np.argsort(image_activation)[::-1][:5])
746
- for top_neuron in top_neurons:
747
- neuron_dict[top_neuron] = image_activation[top_neuron]
748
-
749
- # Get top activations from tile-level if available
750
- if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
751
- image = data_dict[selected_image]["image"]
752
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
753
- token_idx = grid_y * GRID_NUM + grid_x + 1
754
-
755
- # Ensure token_idx is within bounds
756
- if token_idx < all_activation.shape[0]:
757
  tile_activations = all_activation[token_idx]
758
- top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
759
- for top_neuron in top_tile_neurons:
760
- neuron_dict[top_neuron] = max(
761
- neuron_dict.get(top_neuron, 0),
762
- tile_activations[top_neuron]
763
- )
764
-
765
- # Sort by activation value
766
- return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
767
-
768
- # Process in parallel
769
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
770
- future_clip = executor.submit(_get_top_actvation, evt, selected_image, "CLIP")
771
- future_maple = executor.submit(_get_top_actvation, evt, selected_image, model_name)
772
-
773
- clip_neuron_dict = future_clip.result()
774
- maple_neuron_dict = future_maple.result()
775
-
776
- # Get radio choices
777
- radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
778
-
779
- # Create radio component
780
- radio = gr.Radio(
781
- choices=radio_choices,
782
- label="Top activating SAE latent",
783
- value=radio_choices[0] if radio_choices else None
 
 
 
 
 
 
 
 
784
  )
785
-
786
- return radio
 
787
 
788
  def update_markdown(option_value):
789
- """Update markdown text"""
790
  latent_idx = int(option_value.split("-")[-1])
791
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
792
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
793
  return out_1, out_2
794
 
 
 
 
 
 
 
 
 
 
 
 
795
  def update_all(selected_image, slider_value, toggle_btn, model_name):
796
- """Update all UI components in optimized way"""
797
- # Use a thread pool to parallelize operations
798
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
799
- # Start both tasks
800
- clip_future = executor.submit(
801
- show_activation_heatmap_clip,
802
- selected_image,
803
- slider_value,
804
- toggle_btn
805
- )
806
-
807
- maple_future = executor.submit(
808
- show_activation_heatmap_maple,
809
- selected_image,
810
- slider_value,
811
- model_name
812
- )
813
-
814
- # Get results
815
- (
816
- seg_mask_display,
817
- top_image_1,
818
- top_image_2,
819
- top_image_3,
820
- act_value_1,
821
- act_value_2,
822
- act_value_3,
823
- ) = clip_future.result()
824
-
825
- seg_mask_display_maple = maple_future.result()
826
-
827
- # Update markdown
828
  markdown_display, markdown_display_2 = update_markdown(slider_value)
829
-
830
  return (
831
  seg_mask_display,
832
  seg_mask_display_maple,
@@ -840,17 +438,42 @@ def update_all(selected_image, slider_value, toggle_btn, model_name):
840
  markdown_display_2,
841
  )
842
 
843
- # Initialize data - load at startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
845
  default_image_name = "christmas-imagenet"
846
 
847
- # Define UI with lazy loading
848
  with gr.Blocks(
849
  theme=gr.themes.Citrus(),
850
  css="""
851
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
852
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
853
- """,
854
  ) as demo:
855
  with gr.Row():
856
  with gr.Column():
@@ -862,36 +485,21 @@ with gr.Blocks(
862
  label="Select Image",
863
  )
864
  image_display = gr.Image(
865
- value=load_image(default_image_name),
866
  type="pil",
867
  interactive=True,
868
  )
869
-
870
- # Update image display when a new image is selected (with debounce)
871
  image_selector.change(
872
- fn=load_image,
873
  inputs=image_selector,
874
  outputs=image_display,
875
- _js="""
876
- function(img_name) {
877
- // Simple debounce
878
- clearTimeout(window._imageSelectTimeout);
879
- return new Promise((resolve) => {
880
- window._imageSelectTimeout = setTimeout(() => {
881
- resolve(img_name);
882
- }, 100);
883
- });
884
- }
885
- """
886
  )
887
-
888
- # Handle grid highlighting
889
  image_display.select(
890
- fn=highlight_grid,
891
- inputs=[image_selector],
892
- outputs=[image_display]
893
  )
894
-
895
  with gr.Column():
896
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
897
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
@@ -900,108 +508,139 @@ with gr.Blocks(
900
  value=model_options[0],
901
  label="Select adapted model (MaPLe)",
902
  )
903
-
904
- # Initialize with a placeholder plot to avoid delays
 
905
  neuron_plot = gr.Plot(
906
- label="Neuron Activation",
907
- show_label=False
908
  )
909
-
910
- # Add event handlers with proper data flow
911
- def update_plot(evt, selected_image, model_name):
912
- if hasattr(evt, '_data') and evt._data is not None:
913
- return plot_activation_distribution(
914
- tuple(map(tuple, evt._data.get('index', []))),
915
- selected_image,
916
- model_name
917
- )
918
- return plot_activation_distribution(None, selected_image, model_name)
919
-
920
- # Load initial plot after UI is rendered
921
- gr.on(
922
- [image_selector.change, model_selector.change],
923
- fn=lambda img, model: plot_activation_distribution(None, img, model),
924
  inputs=[image_selector, model_selector],
925
  outputs=neuron_plot,
926
  )
927
-
928
- # Update plot on image click
929
  image_display.select(
930
- fn=update_plot,
 
 
 
 
 
 
 
 
931
  inputs=[image_selector, model_selector],
932
  outputs=neuron_plot,
933
  )
934
 
935
  with gr.Row():
936
  with gr.Column():
937
- # Initialize radio options
938
- radio_names = gr.State(value=get_init_radio_options(default_image_name, model_options[0]))
939
-
940
- # Initialize markdown displays
941
- markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent")
942
-
943
- # Initialize segmentation displays
 
 
 
944
  gr.Markdown("### Localize SAE latent activation using CLIP")
945
- seg_mask_display = gr.Image(type="pil", show_label=False)
946
-
 
 
947
  gr.Markdown("### Localize SAE latent activation using MaPLE")
948
- seg_mask_display_maple = gr.Image(type="pil", show_label=False)
949
-
 
 
950
  with gr.Column():
951
  gr.Markdown("## Top activating SAE latent index")
952
-
953
- # Initialize radio component
954
  radio_choices = gr.Radio(
 
955
  label="Top activating SAE latent",
956
  interactive=True,
 
957
  )
958
-
959
- # Initialize as soon as UI loads
960
- gr.on(
961
- gr.Blocks.load,
962
- fn=lambda: gr.Radio.update(
963
- choices=get_init_radio_options(default_image_name, model_options[0]),
964
- value=get_init_radio_options(default_image_name, model_options[0])[0]
965
- ),
966
- outputs=radio_choices
967
- )
968
-
969
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
970
-
971
- markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent")
972
-
973
- # Initialize image displays
 
974
  gr.Markdown("### ImageNet")
975
- top_image_1 = gr.Image(type="pil", label="ImageNet", show_label=False)
976
- act_value_1 = gr.Markdown()
977
-
 
 
978
  gr.Markdown("### ImageNet-Sketch")
979
- top_image_2 = gr.Image(type="pil", label="ImageNet-Sketch", show_label=False)
980
- act_value_2 = gr.Markdown()
981
-
 
 
 
 
 
982
  gr.Markdown("### Caltech101")
983
- top_image_3 = gr.Image(type="pil", label="Caltech101", show_label=False)
984
- act_value_3 = gr.Markdown()
985
-
986
- # Update radio options on image interaction
 
987
  image_display.select(
988
  fn=update_radio_options,
989
  inputs=[image_selector, model_selector],
990
- outputs=radio_choices,
991
  )
992
-
993
- # Update radio options on model change
994
  model_selector.change(
995
  fn=update_radio_options,
996
  inputs=[image_selector, model_selector],
997
- outputs=radio_choices,
998
  )
999
-
1000
- # Update radio options on image selection
1001
- image_selector.change(
1002
  fn=update_radio_options,
1003
  inputs=[image_selector, model_selector],
1004
- outputs=radio_choices,
1005
  )
1006
-
1007
- # Initialize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import pickle
4
  from glob import glob
5
+ from time import sleep
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
11
  from PIL import Image, ImageDraw
12
  from plotly.subplots import make_subplots
13
 
 
14
  IMAGE_SIZE = 400
15
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
16
  GRID_NUM = 14
17
  pkl_root = "./data/out"
 
 
18
  preloaded_data = {}
 
 
 
 
 
19
 
 
 
20
 
21
+ def preload_activation(image_name):
22
+ for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
23
+ image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
24
+ with gzip.open(image_file, "rb") as f:
25
+ preloaded_data[model] = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def get_activation_distribution(image_name: str, model_type: str):
29
+ activation = get_data(image_name, model_type)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  noisy_features_indices = (
32
  (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
  )
34
  activation[:, noisy_features_indices] = 0
35
+
36
  return activation
37
 
38
+
39
  def get_grid_loc(evt, image):
 
40
  # Get click coordinates
41
  x, y = evt._data["index"][0], evt._data["index"][1]
42
+
43
  cell_width = image.width // GRID_NUM
44
  cell_height = image.height // GRID_NUM
45
+
46
  grid_x = x // cell_width
47
  grid_y = y // cell_height
48
  return grid_x, grid_y, cell_width, cell_height
49
 
50
+
51
+ def highlight_grid(evt: gr.EventData, image_name):
52
  image = data_dict[image_name]["image"]
53
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
54
+
55
  highlighted_image = image.copy()
56
  draw = ImageDraw.Draw(highlighted_image)
57
  box = [
 
61
  (grid_y + 1) * cell_height,
62
  ]
63
  draw.rectangle(box, outline="red", width=3)
64
+
65
  return highlighted_image
66
 
67
+
68
  def load_image(img_name):
69
+ return Image.open(data_dict[img_name]["image_path"]).resize(
70
+ (IMAGE_SIZE, IMAGE_SIZE)
71
+ )
72
+
73
 
 
74
  def plot_activations(
75
  all_activation,
76
  tile_activations=None,
 
80
  colors=("blue", "cyan"),
81
  model_name="CLIP",
82
  ):
 
83
  fig = go.Figure()
84
+
85
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
 
 
 
 
 
 
86
  fig.add_trace(
87
  go.Scatter(
88
+ x=np.arange(len(activations)),
89
+ y=activations,
90
  mode="lines",
91
  name=label,
92
  line=dict(color=color, dash="solid"),
93
  showlegend=True,
94
  )
95
  )
 
 
96
  top_neurons = np.argsort(activations)[::-1][:top_k]
97
  for idx in top_neurons:
98
  fig.add_annotation(
 
107
  opacity=0.7,
108
  )
109
  return fig
110
+
111
+ label = f"{model_name.split('-')[-0]} Image-level"
112
  fig = _add_scatter_with_annotation(
113
  fig, all_activation, model_name, colors[0], label
114
  )
 
115
  if tile_activations is not None:
116
+ label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
117
  fig = _add_scatter_with_annotation(
118
  fig, tile_activations, model_name, colors[1], label
119
  )
120
+
 
121
  fig.update_layout(
122
  title="Activation Distribution",
123
  xaxis_title="SAE latent index",
124
  yaxis_title="Activation Value",
125
  template="plotly_white",
 
126
  )
127
+ fig.update_layout(
128
+ legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
129
+ )
130
+
131
  return fig
132
 
133
+
134
+ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
135
  activation = get_activation_distribution(selected_image, model_name)
136
  all_activation = activation.mean(0)
137
+
138
  tile_activations = None
139
  grid_x = None
140
  grid_y = None
141
+
142
+ if evt is not None:
143
+ if evt._data is not None:
144
+ image = data_dict[selected_image]["image"]
145
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
146
+ token_idx = grid_y * GRID_NUM + grid_x + 1
 
147
  tile_activations = activation[token_idx]
148
+
149
  fig = plot_activations(
150
  all_activation,
151
  tile_activations,
 
155
  model_name=model_name,
156
  colors=colors,
157
  )
 
158
  return fig
159
 
160
+
161
+ def plot_activation_distribution(
162
+ evt: gr.EventData, selected_image: str, model_name: str
163
+ ):
 
 
 
 
 
 
164
  fig = make_subplots(
165
  rows=2,
166
  cols=1,
167
  shared_xaxes=True,
168
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
169
  )
170
+
171
  fig_clip = get_activations(
172
  evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
173
  )
174
  fig_maple = get_activations(
175
  evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
176
  )
177
+
178
  def _attach_fig(fig, sub_fig, row, col, yref):
179
  for trace in sub_fig.data:
180
  fig.add_trace(trace, row=row, col=col)
181
+
182
  for annotation in sub_fig.layout.annotations:
183
  annotation.update(yref=yref)
184
  fig.add_annotation(annotation)
185
  return fig
186
+
187
  fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
188
  fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
189
+
 
190
  fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
191
  fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
192
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
193
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
194
  fig.update_layout(
195
+ # height=500,
196
+ # title="Activation Distributions",
197
  template="plotly_white",
198
  showlegend=True,
199
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
200
  margin=dict(l=20, r=20, t=40, b=20),
201
  )
202
+
203
  return fig
204
 
205
+
 
206
  def get_segmask(selected_image, slider_value, model_type):
207
+ image = data_dict[selected_image]["image"]
208
+ sae_act = get_data(selected_image, model_type)[0]
209
+ temp = sae_act[:, slider_value]
210
  try:
211
+ mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
+ print(sae_act.shape, slider_value)
214
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
+ 0
216
+ ].numpy()
217
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
+
219
+ base_opacity = 30
220
+ image_array = np.array(image)[..., :3]
221
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
222
+ rgba_overlay[..., :3] = image_array[..., :3]
223
+
224
+ darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
225
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
226
+ rgba_overlay[..., 3] = 255 # Fully opaque
227
+
228
+ return rgba_overlay
229
+
230
 
 
 
231
  def get_top_images(slider_value, toggle_btn):
 
 
 
 
 
 
232
  def _get_images(dataset_path):
233
  top_image_paths = [
234
  os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
235
  os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
236
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
  ]
238
+ top_images = [
239
+ (
240
+ Image.open(path)
241
+ if os.path.exists(path)
242
+ else Image.new("RGB", (256, 256), (255, 255, 255))
243
+ )
244
+ for path in top_image_paths
245
+ ]
246
  return top_images
247
+
248
  if toggle_btn:
249
  top_images = _get_images("./data/top_images_masked")
250
  else:
251
  top_images = _get_images("./data/top_images")
 
 
 
 
252
  return top_images
253
 
254
+
255
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
256
+ slider_value = int(slider_value.split("-")[-1])
257
+ rgba_overlay = get_segmask(selected_image, slider_value, model_type)
258
+ top_images = get_top_images(slider_value, toggle_btn)
259
+
260
+ act_values = []
261
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
262
+ act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
263
+ act_value = [str(round(value, 3)) for value in act_value]
264
+ act_value = " | ".join(act_value)
265
+ out = f"#### Activation values: {act_value}"
266
+ act_values.append(out)
267
+
268
+ return rgba_overlay, top_images, act_values
269
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
 
272
  rgba_overlay, top_images, act_values = show_activation_heatmap(
273
  selected_image, slider_value, "CLIP", toggle_btn
274
  )
275
+ sleep(0.1)
276
  return (
277
  rgba_overlay,
278
  top_images[0],
 
283
  act_values[2],
284
  )
285
 
286
+
287
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
288
+ slider_value = int(slider_value.split("-")[-1])
289
+ rgba_overlay = get_segmask(selected_image, slider_value, model_name)
290
+ sleep(0.1)
 
291
  return rgba_overlay
292
 
293
+
294
  def get_init_radio_options(selected_image, model_name):
 
295
  clip_neuron_dict = {}
296
  maple_neuron_dict = {}
297
+
298
  def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
299
  activations = get_activation_distribution(selected_image, model_name).mean(0)
300
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
 
304
  sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
305
  )
306
  return sorted_dict
307
+
308
+ clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
309
+ maple_neuron_dict = _get_top_actvation(
310
+ selected_image, model_name, maple_neuron_dict
311
+ )
312
+
 
 
 
313
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
314
+
315
  return radio_choices
316
 
317
+
318
  def get_radio_names(clip_neuron_dict, maple_neuron_dict):
 
319
  clip_keys = list(clip_neuron_dict.keys())
320
  maple_keys = list(maple_neuron_dict.keys())
321
+
 
322
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
323
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
324
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
325
+
 
326
  common_keys.sort(
327
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
 
328
  )
329
+ clip_only_keys.sort(reverse=True)
330
+ maple_only_keys.sort(reverse=True)
331
+
 
332
  out = []
333
  out.extend([f"common-{i}" for i in common_keys[:5]])
334
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
335
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
336
+
337
  return out
338
 
339
+
340
+ def update_radio_options(evt: gr.EventData, selected_image, model_name):
341
+ def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
342
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
343
+ for top_neuron in top_neurons:
344
+ neuron_dict[top_neuron] = activations[top_neuron]
345
+
346
+ def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
347
  all_activation = get_activation_distribution(selected_image, model_name)
348
  image_activation = all_activation.mean(0)
349
+ _sort_and_save_top_k(image_activation, neuron_dict)
350
+
351
+ if evt is not None:
352
+ if evt._data is not None and isinstance(evt._data["index"], list):
353
+ image = data_dict[selected_image]["image"]
354
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
355
+ token_idx = grid_y * GRID_NUM + grid_x + 1
 
 
 
 
 
 
 
356
  tile_activations = all_activation[token_idx]
357
+ _sort_and_save_top_k(tile_activations, neuron_dict)
358
+
359
+ sorted_dict = dict(
360
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
361
+ )
362
+ return sorted_dict
363
+
364
+ clip_neuron_dict = {}
365
+ maple_neuron_dict = {}
366
+ clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
367
+ maple_neuron_dict = _get_top_actvation(
368
+ evt, selected_image, model_name, maple_neuron_dict
369
+ )
370
+
371
+ clip_keys = list(clip_neuron_dict.keys())
372
+ maple_keys = list(maple_neuron_dict.keys())
373
+
374
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
375
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
376
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
377
+
378
+ common_keys.sort(
379
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
380
+ )
381
+ clip_only_keys.sort(reverse=True)
382
+ maple_only_keys.sort(reverse=True)
383
+
384
+ out = []
385
+ out.extend([f"common-{i}" for i in common_keys[:5]])
386
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
387
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
388
+
389
+ radio_choices = gr.Radio(
390
+ choices=out, label="Top activating SAE latent", value=out[0]
391
  )
392
+ sleep(0.1)
393
+ return radio_choices
394
+
395
 
396
  def update_markdown(option_value):
 
397
  latent_idx = int(option_value.split("-")[-1])
398
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
399
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
400
  return out_1, out_2
401
 
402
+
403
+ def get_data(image_name, model_name):
404
+ pkl_root = "./data/out"
405
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
406
+ with gzip.open(data_dir, "rb") as f:
407
+ data = pickle.load(f)
408
+ out = data
409
+
410
+ return out
411
+
412
+
413
  def update_all(selected_image, slider_value, toggle_btn, model_name):
414
+ (
415
+ seg_mask_display,
416
+ top_image_1,
417
+ top_image_2,
418
+ top_image_3,
419
+ act_value_1,
420
+ act_value_2,
421
+ act_value_3,
422
+ ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
423
+ seg_mask_display_maple = show_activation_heatmap_maple(
424
+ selected_image, slider_value, model_name
425
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  markdown_display, markdown_display_2 = update_markdown(slider_value)
427
+
428
  return (
429
  seg_mask_display,
430
  seg_mask_display_maple,
 
438
  markdown_display_2,
439
  )
440
 
441
+
442
+ def load_all_data(image_root, pkl_root):
443
+ image_files = glob(f"{image_root}/*")
444
+ data_dict = {}
445
+ for image_file in image_files:
446
+ image_name = os.path.basename(image_file).split(".")[0]
447
+ if image_file not in data_dict:
448
+ data_dict[image_name] = {
449
+ "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
450
+ "image_path": image_file,
451
+ }
452
+
453
+ sae_data_dict = {}
454
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
455
+ data = pickle.load(f)
456
+ sae_data_dict["mean_acts"] = data
457
+
458
+ sae_data_dict["mean_act_values"] = {}
459
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
460
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
461
+ data = pickle.load(f)
462
+ sae_data_dict["mean_act_values"][dataset] = data
463
+
464
+ return data_dict, sae_data_dict
465
+
466
+
467
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
468
  default_image_name = "christmas-imagenet"
469
 
470
+
471
  with gr.Blocks(
472
  theme=gr.themes.Citrus(),
473
  css="""
474
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
475
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
476
+ """,
477
  ) as demo:
478
  with gr.Row():
479
  with gr.Column():
 
485
  label="Select Image",
486
  )
487
  image_display = gr.Image(
488
+ value=data_dict[default_image_name]["image"],
489
  type="pil",
490
  interactive=True,
491
  )
492
+
493
+ # Update image display when a new image is selected
494
  image_selector.change(
495
+ fn=lambda img_name: data_dict[img_name]["image"],
496
  inputs=image_selector,
497
  outputs=image_display,
 
 
 
 
 
 
 
 
 
 
 
498
  )
 
 
499
  image_display.select(
500
+ fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
 
 
501
  )
502
+
503
  with gr.Column():
504
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
505
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
 
508
  value=model_options[0],
509
  label="Select adapted model (MaPLe)",
510
  )
511
+ init_plot = plot_activation_distribution(
512
+ None, default_image_name, model_options[0]
513
+ )
514
  neuron_plot = gr.Plot(
515
+ label="Neuron Activation", value=init_plot, show_label=False
 
516
  )
517
+
518
+ image_selector.change(
519
+ fn=plot_activation_distribution,
 
 
 
 
 
 
 
 
 
 
 
 
520
  inputs=[image_selector, model_selector],
521
  outputs=neuron_plot,
522
  )
 
 
523
  image_display.select(
524
+ fn=plot_activation_distribution,
525
+ inputs=[image_selector, model_selector],
526
+ outputs=neuron_plot,
527
+ )
528
+ model_selector.change(
529
+ fn=load_image, inputs=[image_selector], outputs=image_display
530
+ )
531
+ model_selector.change(
532
+ fn=plot_activation_distribution,
533
  inputs=[image_selector, model_selector],
534
  outputs=neuron_plot,
535
  )
536
 
537
  with gr.Row():
538
  with gr.Column():
539
+ radio_names = get_init_radio_options(default_image_name, model_options[0])
540
+
541
+ feautre_idx = radio_names[0].split("-")[-1]
542
+ markdown_display = gr.Markdown(
543
+ f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
544
+ )
545
+ init_seg, init_tops, init_values = show_activation_heatmap(
546
+ default_image_name, radio_names[0], "CLIP"
547
+ )
548
+
549
  gr.Markdown("### Localize SAE latent activation using CLIP")
550
+ seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
551
+ init_seg_maple, _, _ = show_activation_heatmap(
552
+ default_image_name, radio_names[0], model_options[0]
553
+ )
554
  gr.Markdown("### Localize SAE latent activation using MaPLE")
555
+ seg_mask_display_maple = gr.Image(
556
+ value=init_seg_maple, type="pil", show_label=False
557
+ )
558
+
559
  with gr.Column():
560
  gr.Markdown("## Top activating SAE latent index")
561
+
 
562
  radio_choices = gr.Radio(
563
+ choices=radio_names,
564
  label="Top activating SAE latent",
565
  interactive=True,
566
+ value=radio_names[0],
567
  )
 
 
 
 
 
 
 
 
 
 
 
568
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
569
+
570
+ markdown_display_2 = gr.Markdown(
571
+ f"## Top reference images for the selected SAE latent - {feautre_idx}"
572
+ )
573
+
574
  gr.Markdown("### ImageNet")
575
+ top_image_1 = gr.Image(
576
+ value=init_tops[0], type="pil", label="ImageNet", show_label=False
577
+ )
578
+ act_value_1 = gr.Markdown(init_values[0])
579
+
580
  gr.Markdown("### ImageNet-Sketch")
581
+ top_image_2 = gr.Image(
582
+ value=init_tops[1],
583
+ type="pil",
584
+ label="ImageNet-Sketch",
585
+ show_label=False,
586
+ )
587
+ act_value_2 = gr.Markdown(init_values[1])
588
+
589
  gr.Markdown("### Caltech101")
590
+ top_image_3 = gr.Image(
591
+ value=init_tops[2], type="pil", label="Caltech101", show_label=False
592
+ )
593
+ act_value_3 = gr.Markdown(init_values[2])
594
+
595
  image_display.select(
596
  fn=update_radio_options,
597
  inputs=[image_selector, model_selector],
598
+ outputs=[radio_choices],
599
  )
600
+
 
601
  model_selector.change(
602
  fn=update_radio_options,
603
  inputs=[image_selector, model_selector],
604
+ outputs=[radio_choices],
605
  )
606
+
607
+ image_selector.select(
 
608
  fn=update_radio_options,
609
  inputs=[image_selector, model_selector],
610
+ outputs=[radio_choices],
611
  )
612
+
613
+ radio_choices.change(
614
+ fn=update_all,
615
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
616
+ outputs=[
617
+ seg_mask_display,
618
+ seg_mask_display_maple,
619
+ top_image_1,
620
+ top_image_2,
621
+ top_image_3,
622
+ act_value_1,
623
+ act_value_2,
624
+ act_value_3,
625
+ markdown_display,
626
+ markdown_display_2,
627
+ ],
628
+ )
629
+
630
+ toggle_btn.change(
631
+ fn=show_activation_heatmap_clip,
632
+ inputs=[image_selector, radio_choices, toggle_btn],
633
+ outputs=[
634
+ seg_mask_display,
635
+ top_image_1,
636
+ top_image_2,
637
+ top_image_3,
638
+ act_value_1,
639
+ act_value_2,
640
+ act_value_3,
641
+ ],
642
+ )
643
+
644
+ # Launch the app
645
+ # demo.queue()
646
+ demo.launch()