patchsae-demo / app.py
hyesulim's picture
perf: improve delay
c574085 verified
raw
history blame
37 kB
import gzip
import os
import pickle
from glob import glob
from functools import lru_cache
import concurrent.futures
import threading
import time
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from PIL import Image, ImageDraw
from plotly.subplots import make_subplots
# Constants
IMAGE_SIZE = 400
DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
GRID_NUM = 14
pkl_root = "./data/out"
# Global cache for preloaded data
preloaded_data = {}
data_dict = {}
sae_data_dict = {}
activation_cache = {}
segmask_cache = {}
top_images_cache = {}
# Thread lock for thread-safe operations
data_lock = threading.Lock()
# Load data more efficiently
def load_all_data(image_root, pkl_root):
"""Load all necessary data with optimized caching"""
# Load image data
image_files = glob(f"{image_root}/*")
data_dict = {}
# Use thread pool for parallel image loading
def load_image_data(image_file):
image_name = os.path.basename(image_file).split(".")[0]
# Only load thumbnail for initial display, load full image on demand
thumbnail = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
return image_name, {
"image": thumbnail,
"image_path": image_file,
}
# Load images in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
results = executor.map(load_image_data, image_files)
for image_name, data in results:
data_dict[image_name] = data
# Load SAE data with minimal processing
sae_data_dict = {}
# Load mean acts only once
with open("./data/sae_data/mean_acts.pkl", "rb") as f:
sae_data_dict["mean_acts"] = pickle.load(f)
# Update all components when radio selection changes
radio_choices.change(
fn=update_all,
inputs=[image_selector, radio_choices, toggle_btn, model_selector],
outputs=[
seg_mask_display,
seg_mask_display_maple,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
markdown_display,
markdown_display_2,
],
_js="""
function(img, radio, toggle, model) {
// Add a small delay to prevent rapid UI updates
clearTimeout(window._radioTimeout);
return new Promise((resolve) => {
window._radioTimeout = setTimeout(() => {
resolve([img, radio, toggle, model]);
}, 100);
});
}
"""
)
# Update components when toggle button changes
toggle_btn.change(
fn=show_activation_heatmap_clip,
inputs=[image_selector, radio_choices, toggle_btn],
outputs=[
seg_mask_display,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
],
_js="""
function(img, radio, toggle) {
// Add a small delay to prevent rapid UI updates
clearTimeout(window._toggleTimeout);
return new Promise((resolve) => {
window._toggleTimeout = setTimeout(() => {
resolve([img, radio, toggle]);
}, 100);
});
}
"""
)
# Initialize UI with default values
default_options = get_init_radio_options(default_image_name, model_options[0])
if default_options:
default_option = default_options[0]
# Set initial values to avoid blank UI at start
gr.on(
gr.Blocks.load,
fn=lambda: update_all(
default_image_name,
default_option,
False,
model_options[0]
),
outputs=[
seg_mask_display,
seg_mask_display_maple,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
markdown_display,
markdown_display_2,
],
)
# Add a status indicator to show processing state
status_indicator = gr.Markdown("Status: Ready")
# Add a refresh button to manually reload data if needed
refresh_btn = gr.Button("Refresh Data")
def reload_data():
global data_dict, sae_data_dict
# Update status
yield "Status: Reloading data..."
# Reload data
try:
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
yield "Status: Data reloaded successfully!"
except Exception as e:
yield f"Status: Error reloading data - {str(e)}"
refresh_btn.click(
fn=reload_data,
inputs=[],
outputs=[status_indicator],
queue=False
)
# Launch app with optimized settings
demo.queue(concurrency_count=3, max_size=10) # Balanced concurrency for better performance
# Add startup message
print("Starting visualization application...")
print(f"Loaded {len(data_dict)} images and {len(sae_data_dict)} datasets")
# Launch with proper error handling
demo.launch(
share=False, # Don't share publicly
debug=False, # Disable debug mode for production
show_error=True, # Show errors for debugging
quiet=False, # Show startup messages
favicon_path=None, # Default favicon
server_port=None, # Use default port
server_name=None, # Bind to all interfaces
height=None, # Use default height
width=None, # Use default width
enable_queue=True, # Enable queue for better performance
) dictionary for dataset values
sae_data_dict["mean_act_values"] = {}
# Load dataset values in parallel
def load_dataset_values(dataset):
with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
return dataset, pickle.load(f)
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(load_dataset_values, dataset)
for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
]
for future in concurrent.futures.as_completed(futures):
dataset, data = future.result()
sae_data_dict["mean_act_values"][dataset] = data
return data_dict, sae_data_dict
# Cache activation data with LRU cache
@lru_cache(maxsize=32)
def preload_activation(image_name, model_name):
"""Preload and cache activation data for a specific image and model"""
image_file = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
try:
with gzip.open(image_file, "rb") as f:
return pickle.load(f)
except Exception as e:
print(f"Error loading {image_file}: {e}")
return None
# Get activation with caching
def get_data(image_name, model_type):
"""Get activation data with caching for better performance"""
cache_key = f"{image_name}_{model_type}"
with data_lock:
if cache_key not in activation_cache:
activation_cache[cache_key] = preload_activation(image_name, model_type)
return activation_cache[cache_key]
def get_activation_distribution(image_name, model_type):
"""Get activation distribution with noise filtering"""
activation = get_data(image_name, model_type)
if activation is None:
# Return empty tensor if data loading failed
return torch.zeros((GRID_NUM * GRID_NUM + 1, 1000))
activation = activation[0]
# Filter out noisy features
noisy_features_indices = (
(sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
)
activation[:, noisy_features_indices] = 0
return activation
def get_grid_loc(evt, image):
"""Get grid location from click event"""
# Get click coordinates
x, y = evt._data["index"][0], evt._data["index"][1]
cell_width = image.width // GRID_NUM
cell_height = image.height // GRID_NUM
grid_x = x // cell_width
grid_y = y // cell_height
return grid_x, grid_y, cell_width, cell_height
def highlight_grid(evt, image_name):
"""Highlight grid cell on click"""
image = data_dict[image_name]["image"]
grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
highlighted_image = image.copy()
draw = ImageDraw.Draw(highlighted_image)
box = [
grid_x * cell_width,
grid_y * cell_height,
(grid_x + 1) * cell_width,
(grid_y + 1) * cell_height,
]
draw.rectangle(box, outline="red", width=3)
return highlighted_image
def load_image(img_name):
"""Load image by name"""
return data_dict[img_name]["image"]
# Optimized plotting with less annotations
def plot_activations(
all_activation,
tile_activations=None,
grid_x=None,
grid_y=None,
top_k=5,
colors=("blue", "cyan"),
model_name="CLIP",
):
"""Plot activations with optimized rendering"""
fig = go.Figure()
def _add_scatter_with_annotation(fig, activations, model_name, color, label):
# Only plot non-zero values to reduce points
non_zero_indices = np.where(np.abs(activations) > 1e-5)[0]
if len(non_zero_indices) == 0:
# If all values are near zero, use full array
non_zero_indices = np.arange(len(activations))
fig.add_trace(
go.Scatter(
x=non_zero_indices,
y=activations[non_zero_indices],
mode="lines",
name=label,
line=dict(color=color, dash="solid"),
showlegend=True,
)
)
# Only annotate the top_k activations
top_neurons = np.argsort(activations)[::-1][:top_k]
for idx in top_neurons:
fig.add_annotation(
x=idx,
y=activations[idx],
text=str(idx),
showarrow=True,
arrowhead=2,
ax=0,
ay=-15,
arrowcolor=color,
opacity=0.7,
)
return fig
label = f"{model_name.split('-')[-1]} Image-level"
fig = _add_scatter_with_annotation(
fig, all_activation, model_name, colors[0], label
)
if tile_activations is not None:
label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
fig = _add_scatter_with_annotation(
fig, tile_activations, model_name, colors[1], label
)
# Optimize layout with minimal settings
fig.update_layout(
title="Activation Distribution",
xaxis_title="SAE latent index",
yaxis_title="Activation Value",
template="plotly_white",
legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5),
)
return fig
def get_activations(evt, selected_image, model_name, colors):
"""Get activations for plotting"""
activation = get_activation_distribution(selected_image, model_name)
all_activation = activation.mean(0)
tile_activations = None
grid_x = None
grid_y = None
if evt is not None and evt._data is not None:
image = data_dict[selected_image]["image"]
grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
token_idx = grid_y * GRID_NUM + grid_x + 1
# Ensure token_idx is within bounds
if token_idx < activation.shape[0]:
tile_activations = activation[token_idx]
fig = plot_activations(
all_activation,
tile_activations,
grid_x,
grid_y,
top_k=5,
model_name=model_name,
colors=colors,
)
return fig
# Cache plot results
@lru_cache(maxsize=16)
def plot_activation_distribution(evt_data, selected_image, model_name):
"""Plot activation distribution with caching"""
# Convert event data to hashable format for caching
if evt_data is not None:
evt = type('obj', (object,), {'_data': evt_data})
else:
evt = None
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
subplot_titles=["CLIP Activation", f"{model_name} Activation"],
)
fig_clip = get_activations(
evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
)
fig_maple = get_activations(
evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
)
def _attach_fig(fig, sub_fig, row, col, yref):
for trace in sub_fig.data:
fig.add_trace(trace, row=row, col=col)
for annotation in sub_fig.layout.annotations:
annotation.update(yref=yref)
fig.add_annotation(annotation)
return fig
fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
# Optimize layout with minimal settings
fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
fig.update_yaxes(title_text="Activation Value", row=1, col=1)
fig.update_yaxes(title_text="Activation Value", row=2, col=1)
fig.update_layout(
template="plotly_white",
showlegend=True,
legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
margin=dict(l=20, r=20, t=40, b=20),
)
return fig
# Cache segmentation masks
@lru_cache(maxsize=32)
def get_segmask(selected_image, slider_value, model_type):
"""Generate segmentation mask with caching"""
try:
# Check if image exists
if selected_image not in data_dict:
print(f"Image {selected_image} not found in data dictionary")
# Return blank mask with IMAGE_SIZE dimensions
return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
# Use cache if available
cache_key = f"{selected_image}_{slider_value}_{model_type}"
with data_lock:
if cache_key in segmask_cache:
return segmask_cache[cache_key]
# Get image
image = data_dict[selected_image]["image"]
# Get activation data
sae_act = get_data(selected_image, model_type)
if sae_act is None:
# Return blank mask if data loading failed
return np.zeros((image.height, image.width, 4), dtype=np.uint8)
# Handle array shape issues
try:
# Check array shape and dimensions
if isinstance(sae_act, tuple) and len(sae_act) > 0:
# First element of tuple
act_data = sae_act[0]
else:
# Direct array
act_data = sae_act
# Check if slider_value is within bounds
if slider_value >= act_data.shape[1]:
print(f"Slider value {slider_value} out of bounds for activation shape {act_data.shape}")
return np.zeros((image.height, image.width, 4), dtype=np.uint8)
# Get activation for specific latent
temp = act_data[:, slider_value]
# Skip first token (CLS token) and reshape to grid
if len(temp) > 1: # Ensure we have enough tokens
mask = torch.Tensor(temp[1:].reshape(GRID_NUM, GRID_NUM)).view(1, 1, GRID_NUM, GRID_NUM)
# Upsample to image dimensions
mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
# Normalize mask values between 0 and 1
mask_min, mask_max = mask.min(), mask.max()
if mask_max > mask_min: # Avoid division by zero
mask = (mask - mask_min) / (mask_max - mask_min)
else:
mask = np.zeros_like(mask)
else:
# Not enough tokens
print(f"Not enough tokens in activation data: {len(temp)}")
return np.zeros((image.height, image.width, 4), dtype=np.uint8)
except Exception as e:
print(f"Error processing activation data: {e}")
print(f"Shape info - sae_act: {type(sae_act)}, slider_value: {slider_value}")
return np.zeros((image.height, image.width, 4), dtype=np.uint8)
# Create RGBA overlay
try:
# Set base opacity for darkened areas
base_opacity = 30
# Convert image to numpy array
image_array = np.array(image)
# Handle grayscale images
if len(image_array.shape) == 2:
# Convert grayscale to RGB
image_array = np.stack([image_array] * 3, axis=-1)
elif image_array.shape[2] == 4:
# Use only RGB channels
image_array = image_array[..., :3]
# Create overlay
rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
rgba_overlay[..., :3] = image_array
# Use vectorized operations for better performance
darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
# Create mask for darkened areas
mask_threshold = 0.1 # Adjust threshold if needed
mask_zero = mask < mask_threshold
# Apply darkening only to low-activation areas
rgba_overlay[mask_zero, :3] = darkened_image[mask_zero]
# Set alpha channel
rgba_overlay[..., 3] = 255 # Fully opaque
# Cache result for future use
with data_lock:
segmask_cache[cache_key] = rgba_overlay
return rgba_overlay
except Exception as e:
print(f"Error creating overlay: {e}")
return np.zeros((image.height, image.width, 4), dtype=np.uint8)
except Exception as e:
print(f"Unexpected error in get_segmask: {e}")
# Return a blank image of standard size
return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
# Cache top images
@lru_cache(maxsize=32)
def get_top_images(slider_value, toggle_btn):
"""Get top images with caching"""
cache_key = f"{slider_value}_{toggle_btn}"
if cache_key in top_images_cache:
return top_images_cache[cache_key]
def _get_images(dataset_path):
top_image_paths = [
os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
]
top_images = []
for path in top_image_paths:
if os.path.exists(path):
top_images.append(Image.open(path))
else:
top_images.append(Image.new("RGB", (256, 256), (255, 255, 255)))
return top_images
if toggle_btn:
top_images = _get_images("./data/top_images_masked")
else:
top_images = _get_images("./data/top_images")
# Cache result
top_images_cache[cache_key] = top_images
return top_images
def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
"""Show activation heatmap with optimized processing"""
try:
# Parse slider value safely
if not slider_value:
# Fallback to the first option if no slider value
radio_options = get_init_radio_options(selected_image, model_type)
if not radio_options:
# Create placeholder data if no options available
return (
np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
[Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
["#### Activation values: No data available"] * 3
)
slider_value = radio_options[0]
# Extract the integer value
try:
slider_value_int = int(slider_value.split("-")[-1])
except (ValueError, IndexError):
print(f"Error parsing slider value: {slider_value}")
slider_value_int = 0
# Process in parallel with thread pool and add timeout
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# Start both tasks
segmask_future = executor.submit(get_segmask, selected_image, slider_value_int, model_type)
top_images_future = executor.submit(get_top_images, slider_value_int, toggle_btn)
# Get results with timeout to prevent hanging
try:
rgba_overlay = segmask_future.result(timeout=5)
except (concurrent.futures.TimeoutError, Exception) as e:
print(f"Error or timeout generating segmentation mask: {e}")
rgba_overlay = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
try:
top_images = top_images_future.result(timeout=5)
except (concurrent.futures.TimeoutError, Exception) as e:
print(f"Error or timeout getting top images: {e}")
top_images = [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)]
# Prepare activation values with error handling
act_values = []
for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
try:
if dataset in sae_data_dict["mean_act_values"]:
values = sae_data_dict["mean_act_values"][dataset]
if slider_value_int < values.shape[0]:
act_value = values[slider_value_int, :5]
act_value = [str(round(value, 3)) for value in act_value]
act_value = " | ".join(act_value)
out = f"#### Activation values: {act_value}"
else:
out = f"#### Activation values: Index out of range"
else:
out = f"#### Activation values: Dataset not available"
except Exception as e:
print(f"Error getting activation values for {dataset}: {e}")
out = f"#### Activation values: Error retrieving data"
act_values.append(out)
return rgba_overlay, top_images, act_values
except Exception as e:
print(f"Error in show_activation_heatmap: {e}")
# Return placeholder data in case of error
return (
np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
[Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
["#### Activation values: Error occurred"] * 3
)
def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
"""Show CLIP activation heatmap"""
rgba_overlay, top_images, act_values = show_activation_heatmap(
selected_image, slider_value, "CLIP", toggle_btn
)
return (
rgba_overlay,
top_images[0],
top_images[1],
top_images[2],
act_values[0],
act_values[1],
act_values[2],
)
def show_activation_heatmap_maple(selected_image, slider_value, model_name):
"""Show MaPLE activation heatmap"""
slider_value_int = int(slider_value.split("-")[-1])
rgba_overlay = get_segmask(selected_image, slider_value_int, model_name)
return rgba_overlay
# Optimize radio options generation
def get_init_radio_options(selected_image, model_name):
"""Get initial radio options with optimized processing"""
clip_neuron_dict = {}
maple_neuron_dict = {}
def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
activations = get_activation_distribution(selected_image, model_name).mean(0)
top_neurons = list(np.argsort(activations)[::-1][:top_k])
for top_neuron in top_neurons:
neuron_dict[top_neuron] = activations[top_neuron]
sorted_dict = dict(
sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
)
return sorted_dict
# Process in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_clip = executor.submit(_get_top_actvation, selected_image, "CLIP", {})
future_maple = executor.submit(_get_top_actvation, selected_image, model_name, {})
clip_neuron_dict = future_clip.result()
maple_neuron_dict = future_maple.result()
radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
return radio_choices
def get_radio_names(clip_neuron_dict, maple_neuron_dict):
"""Get radio button names based on neuron activations"""
clip_keys = list(clip_neuron_dict.keys())
maple_keys = list(maple_neuron_dict.keys())
# Use set operations for better performance
common_keys = list(set(clip_keys).intersection(set(maple_keys)))
clip_only_keys = list(set(clip_keys) - set(maple_keys))
maple_only_keys = list(set(maple_keys) - set(clip_keys))
# Sort keys by activation values
common_keys.sort(
key=lambda x: max(clip_neuron_dict.get(x, 0), maple_neuron_dict.get(x, 0)),
reverse=True
)
clip_only_keys.sort(key=lambda x: clip_neuron_dict.get(x, 0), reverse=True)
maple_only_keys.sort(key=lambda x: maple_neuron_dict.get(x, 0), reverse=True)
# Limit number of choices to improve performance
out = []
out.extend([f"common-{i}" for i in common_keys[:5]])
out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
return out
def update_radio_options(evt, selected_image, model_name):
"""Update radio options based on user interaction"""
def _get_top_actvation(evt, selected_image, model_name):
neuron_dict = {}
all_activation = get_activation_distribution(selected_image, model_name)
image_activation = all_activation.mean(0)
# Get top activations from image-level
top_neurons = list(np.argsort(image_activation)[::-1][:5])
for top_neuron in top_neurons:
neuron_dict[top_neuron] = image_activation[top_neuron]
# Get top activations from tile-level if available
if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
image = data_dict[selected_image]["image"]
grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
token_idx = grid_y * GRID_NUM + grid_x + 1
# Ensure token_idx is within bounds
if token_idx < all_activation.shape[0]:
tile_activations = all_activation[token_idx]
top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
for top_neuron in top_tile_neurons:
neuron_dict[top_neuron] = max(
neuron_dict.get(top_neuron, 0),
tile_activations[top_neuron]
)
# Sort by activation value
return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
# Process in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_clip = executor.submit(_get_top_actvation, evt, selected_image, "CLIP")
future_maple = executor.submit(_get_top_actvation, evt, selected_image, model_name)
clip_neuron_dict = future_clip.result()
maple_neuron_dict = future_maple.result()
# Get radio choices
radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
# Create radio component
radio = gr.Radio(
choices=radio_choices,
label="Top activating SAE latent",
value=radio_choices[0] if radio_choices else None
)
return radio
def update_markdown(option_value):
"""Update markdown text"""
latent_idx = int(option_value.split("-")[-1])
out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
return out_1, out_2
def update_all(selected_image, slider_value, toggle_btn, model_name):
"""Update all UI components in optimized way"""
# Use a thread pool to parallelize operations
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# Start both tasks
clip_future = executor.submit(
show_activation_heatmap_clip,
selected_image,
slider_value,
toggle_btn
)
maple_future = executor.submit(
show_activation_heatmap_maple,
selected_image,
slider_value,
model_name
)
# Get results
(
seg_mask_display,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
) = clip_future.result()
seg_mask_display_maple = maple_future.result()
# Update markdown
markdown_display, markdown_display_2 = update_markdown(slider_value)
return (
seg_mask_display,
seg_mask_display_maple,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
markdown_display,
markdown_display_2,
)
# Initialize data - load at startup
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
default_image_name = "christmas-imagenet"
# Define UI with lazy loading
with gr.Blocks(
theme=gr.themes.Citrus(),
css="""
.image-row .gr-image { margin: 0 !important; padding: 0 !important; }
.image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
""",
) as demo:
with gr.Row():
with gr.Column():
# Left View: Image selection and click handling
gr.Markdown("## Select input image and patch on the image")
image_selector = gr.Dropdown(
choices=list(data_dict.keys()),
value=default_image_name,
label="Select Image",
)
image_display = gr.Image(
value=load_image(default_image_name),
type="pil",
interactive=True,
)
# Update image display when a new image is selected (with debounce)
image_selector.change(
fn=load_image,
inputs=image_selector,
outputs=image_display,
_js="""
function(img_name) {
// Simple debounce
clearTimeout(window._imageSelectTimeout);
return new Promise((resolve) => {
window._imageSelectTimeout = setTimeout(() => {
resolve(img_name);
}, 100);
});
}
"""
)
# Handle grid highlighting
image_display.select(
fn=highlight_grid,
inputs=[image_selector],
outputs=[image_display]
)
with gr.Column():
gr.Markdown("## SAE latent activations of CLIP and MaPLE")
model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
model_selector = gr.Dropdown(
choices=model_options,
value=model_options[0],
label="Select adapted model (MaPLe)",
)
# Initialize with a placeholder plot to avoid delays
neuron_plot = gr.Plot(
label="Neuron Activation",
show_label=False
)
# Add event handlers with proper data flow
def update_plot(evt, selected_image, model_name):
if hasattr(evt, '_data') and evt._data is not None:
return plot_activation_distribution(
tuple(map(tuple, evt._data.get('index', []))),
selected_image,
model_name
)
return plot_activation_distribution(None, selected_image, model_name)
# Load initial plot after UI is rendered
gr.on(
[image_selector.change, model_selector.change],
fn=lambda img, model: plot_activation_distribution(None, img, model),
inputs=[image_selector, model_selector],
outputs=neuron_plot,
)
# Update plot on image click
image_display.select(
fn=update_plot,
inputs=[image_selector, model_selector],
outputs=neuron_plot,
)
with gr.Row():
with gr.Column():
# Initialize radio options
radio_names = gr.State(value=get_init_radio_options(default_image_name, model_options[0]))
# Initialize markdown displays
markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent")
# Initialize segmentation displays
gr.Markdown("### Localize SAE latent activation using CLIP")
seg_mask_display = gr.Image(type="pil", show_label=False)
gr.Markdown("### Localize SAE latent activation using MaPLE")
seg_mask_display_maple = gr.Image(type="pil", show_label=False)
with gr.Column():
gr.Markdown("## Top activating SAE latent index")
# Initialize radio component
radio_choices = gr.Radio(
label="Top activating SAE latent",
interactive=True,
)
# Initialize as soon as UI loads
gr.on(
gr.Blocks.load,
fn=lambda: gr.Radio.update(
choices=get_init_radio_options(default_image_name, model_options[0]),
value=get_init_radio_options(default_image_name, model_options[0])[0]
),
outputs=radio_choices
)
toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent")
# Initialize image displays
gr.Markdown("### ImageNet")
top_image_1 = gr.Image(type="pil", label="ImageNet", show_label=False)
act_value_1 = gr.Markdown()
gr.Markdown("### ImageNet-Sketch")
top_image_2 = gr.Image(type="pil", label="ImageNet-Sketch", show_label=False)
act_value_2 = gr.Markdown()
gr.Markdown("### Caltech101")
top_image_3 = gr.Image(type="pil", label="Caltech101", show_label=False)
act_value_3 = gr.Markdown()
# Update radio options on image interaction
image_display.select(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=radio_choices,
)
# Update radio options on model change
model_selector.change(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=radio_choices,
)
# Update radio options on image selection
image_selector.change(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=radio_choices,
)
# Initialize