import gzip import os import pickle from glob import glob import threading import psutil from functools import lru_cache import concurrent.futures from typing import Dict, Tuple, List, Optional from time import sleep import gradio as gr import numpy as np import torch from PIL import Image, ImageDraw import plotly.graph_objects as go from plotly.subplots import make_subplots # Constants IMAGE_SIZE = 400 DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"] GRID_NUM = 14 PKL_ROOT = "./data/out" # Global cache with better type hints and error handling class Cache: def __init__(self): self.data: Dict[str, Dict] = { 'data_dict': {}, 'sae_data_dict': {}, 'model_data': {}, 'segmasks': {}, 'top_images': {}, 'precomputed_activations': {} } def get(self, category: str, key: str, default=None): try: return self.data[category].get(key, default) except KeyError: return default def set(self, category: str, key: str, value): try: self.data[category][key] = value except KeyError: self.data[category] = {key: value} def clear_category(self, category: str): if category in self.data: self.data[category].clear() _CACHE = Cache() def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]: """Load all data with optimized parallel processing.""" def load_image_file(image_file: str) -> Optional[Dict]: try: image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)) return { "image": image, "image_path": image_file, } except Exception as e: print(f"Error loading image {image_file}: {e}") return None # Load images in parallel with concurrent.futures.ThreadPoolExecutor() as executor: future_to_file = { executor.submit(load_image_file, image_file): image_file for image_file in glob(f"{image_root}/*") } for future in concurrent.futures.as_completed(future_to_file): try: image_file = future_to_file[future] image_name = os.path.basename(image_file).split(".")[0] result = future.result() if result: _CACHE.set('data_dict', image_name, result) except Exception as e: print(f"Error processing image future: {e}") # Load SAE data try: with open("./data/sae_data/mean_acts.pkl", "rb") as f: _CACHE.set('sae_data_dict', "mean_acts", pickle.load(f)) except Exception as e: print(f"Error loading mean_acts.pkl: {e}") # Load mean act values datasets = ["imagenet", "imagenet-sketch", "caltech101"] for dataset in datasets: try: with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f: if "mean_act_values" not in _CACHE.data['sae_data_dict']: _CACHE.set('sae_data_dict', "mean_act_values", {}) _CACHE.data['sae_data_dict']["mean_act_values"][dataset] = pickle.load(f) except Exception as e: print(f"Error loading mean act values for {dataset}: {e}") return _CACHE.data['data_dict'], _CACHE.data['sae_data_dict'] @lru_cache(maxsize=1024) def get_data(image_name: str, model_name: str) -> np.ndarray: """Get model data with caching.""" cache_key = f"{model_name}_{image_name}" if cache_key not in _CACHE.data['model_data']: try: data_dir = f"{PKL_ROOT}/{model_name}/{image_name}.pkl.gz" with gzip.open(data_dir, "rb") as f: _CACHE.data['model_data'][cache_key] = pickle.load(f) except Exception as e: print(f"Error loading model data for {cache_key}: {e}") return np.array([]) return _CACHE.data['model_data'][cache_key] @lru_cache(maxsize=1024) def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray: """Get activation distribution with memory optimization.""" try: data = get_data(image_name, model_type) if isinstance(data, (list, tuple)): activation = data[0] else: activation = data if not isinstance(activation, np.ndarray): activation = np.array(activation) mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([])) if mean_acts.size > 0 and activation.size > 0: noisy_features_indices = np.where(mean_acts > 0.1)[0] if activation.ndim >= 2: activation[:, noisy_features_indices] = 0 return activation except Exception as e: print(f"Error getting activation distribution: {e}") return np.array([]) def get_grid_loc(evt: gr.SelectData, image: Image.Image) -> Tuple[int, int, int, int]: """Get grid location from click event.""" x, y = evt.index[0], evt.index[1] cell_width = image.width // GRID_NUM cell_height = image.height // GRID_NUM grid_x = x // cell_width grid_y = y // cell_height return grid_x, grid_y, cell_width, cell_height def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image: """Highlight selected grid cell.""" image = _CACHE.get('data_dict', image_name, {}).get("image") if not image: return None grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image) highlighted_image = image.copy() draw = ImageDraw.Draw(highlighted_image) box = [ grid_x * cell_width, grid_y * cell_height, (grid_x + 1) * cell_width, (grid_y + 1) * cell_height, ] draw.rectangle(box, outline="red", width=3) return highlighted_image def plot_activations( all_activation: np.ndarray, tile_activations: Optional[np.ndarray] = None, grid_x: Optional[int] = None, grid_y: Optional[int] = None, top_k: int = 5, colors: Tuple[str, str] = ("blue", "cyan"), model_name: str = "CLIP", ) -> go.Figure: """Plot activation distributions.""" fig = go.Figure() def _add_scatter_with_annotation(fig, activations, model_name, color, label): fig.add_trace( go.Scatter( x=np.arange(len(activations)), y=activations, mode="lines", name=label, line=dict(color=color, dash="solid"), showlegend=True, ) ) top_neurons = np.argsort(activations)[::-1][:top_k] for idx in top_neurons: fig.add_annotation( x=idx, y=activations[idx], text=str(idx), showarrow=True, arrowhead=2, ax=0, ay=-15, arrowcolor=color, opacity=0.7, ) return fig label = f"{model_name.split('-')[-1]} Image-level" fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label) if tile_activations is not None: label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})" fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label) fig.update_layout( title="Activation Distribution", xaxis_title="SAE latent index", yaxis_title="Activation Value", template="plotly_white", legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5) ) return fig def get_segmask(selected_image: str, slider_value: int, model_type: str) -> Optional[np.ndarray]: """Get segmentation mask with caching.""" cache_key = f"{selected_image}_{slider_value}_{model_type}" cached_mask = _CACHE.get('segmasks', cache_key) if cached_mask is not None: return cached_mask try: image = _CACHE.get('data_dict', selected_image, {}).get("image") if image is None: return None sae_act = get_data(selected_image, model_type)[0] temp = sae_act[:, slider_value] mask = torch.tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14) mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy() if mask.size == 0: return None mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10) base_opacity = 30 image_array = np.array(image)[..., :3] rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) rgba_overlay[..., :3] = image_array darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8) rgba_overlay[mask == 0, :3] = darkened_image[mask == 0] rgba_overlay[..., 3] = 255 _CACHE.set('segmasks', cache_key, rgba_overlay) return rgba_overlay except Exception as e: print(f"Error generating segmentation mask: {e}") return None def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]: """Get top images with caching.""" cache_key = f"{slider_value}_{toggle_btn}" cached_images = _CACHE.get('top_images', cache_key) if cached_images is not None: return cached_images dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images" paths = [ os.path.join(dataset_path, dataset, f"{slider_value}.jpg") for dataset in ["imagenet", "imagenet-sketch", "caltech101"] ] images = [ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255)) for path in paths ] _CACHE.set('top_images', cache_key, images) return images # UI Event Handlers def plot_activation_distribution( evt: Optional[gr.EventData], selected_image: str, model_name: str ) -> go.Figure: """Plot activation distributions for both models.""" fig = make_subplots( rows=2, cols=1, shared_xaxes=True, subplot_titles=["CLIP Activation", f"{model_name} Activation"], ) def get_activations(evt, selected_image, model_name, colors): activation = get_activation_distribution(selected_image, model_name) all_activation = activation.mean(0) tile_activations = None grid_x = None grid_y = None if evt is not None and evt._data is not None: image = _CACHE.get('data_dict', selected_image, {}).get("image") if image: grid_x, grid_y, _, _ = get_grid_loc(evt, image) token_idx = grid_y * GRID_NUM + grid_x + 1 tile_activations = activation[token_idx] return plot_activations( all_activation, tile_activations, grid_x, grid_y, top_k=5, model_name=model_name, colors=colors, ) fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")) fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")) def _attach_fig(fig, sub_fig, row, col, yref): for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) for annotation in sub_fig.layout.annotations: annotation.update(yref=yref) fig.add_annotation(annotation) return fig fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1") fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2") fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1) fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1) fig.update_yaxes(title_text="Activation Value", row=1, col=1) fig.update_yaxes(title_text="Activation Value", row=2, col=1) fig.update_layout( template="plotly_white", showlegend=True, legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5), margin=dict(l=20, r=20, t=40, b=20), ) return fig def show_activation_heatmap_clip( selected_image: str, slider_value: str, toggle_btn: bool ): """Show activation heatmap for CLIP model.""" rgba_overlay, top_images, act_values = show_activation_heatmap( selected_image, slider_value, "CLIP", toggle_btn ) sleep(0.1) return ( rgba_overlay, top_images[0], top_images[1], top_images[2], act_values[0], act_values[1], act_values[2], ) def show_activation_heatmap( selected_image: str, slider_value: str, model_type: str, toggle_btn: bool = False ) -> Tuple[np.ndarray, List[Image.Image], List[str]]: """Show activation heatmap with segmentation mask and top images.""" slider_value = int(slider_value.split("-")[-1]) rgba_overlay = get_segmask(selected_image, slider_value, model_type) top_images = get_top_images(slider_value, toggle_btn) act_values = [] for dataset in ["imagenet", "imagenet-sketch", "caltech101"]: act_value = _CACHE.get('sae_data_dict', "mean_act_values", {}).get(dataset, np.array([]))[slider_value, :5] act_value = [str(round(value, 3)) for value in act_value] act_value = " | ".join(act_value) out = f"#### Activation values: {act_value}" act_values.append(out) return rgba_overlay, top_images, act_values def show_activation_heatmap_maple( selected_image: str, slider_value: str, model_name: str ) -> np.ndarray: """Show activation heatmap for MaPLE model.""" slider_value = int(slider_value.split("-")[-1]) rgba_overlay = get_segmask(selected_image, slider_value, model_name) sleep(0.1) return rgba_overlay def get_init_radio_options(selected_image: str, model_name: str) -> List[str]: """Get initial radio options for UI.""" clip_neuron_dict = {} maple_neuron_dict = {} def _get_top_activation(selected_image: str, model_name: str, neuron_dict: Dict, top_k: int = 5) -> Dict: activations = get_activation_distribution(selected_image, model_name).mean(0) top_neurons = list(np.argsort(activations)[::-1][:top_k]) for top_neuron in top_neurons: neuron_dict[top_neuron] = activations[top_neuron] return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)) clip_neuron_dict = _get_top_activation(selected_image, "CLIP", clip_neuron_dict) maple_neuron_dict = _get_top_activation(selected_image, model_name, maple_neuron_dict) return get_radio_names(clip_neuron_dict, maple_neuron_dict) def get_radio_names( clip_neuron_dict: Dict[int, float], maple_neuron_dict: Dict[int, float] ) -> List[str]: """Generate radio button names based on neuron activations.""" clip_keys = list(clip_neuron_dict.keys()) maple_keys = list(maple_neuron_dict.keys()) common_keys = list(set(clip_keys).intersection(set(maple_keys))) clip_only_keys = list(set(clip_keys) - set(maple_keys)) maple_only_keys = list(set(maple_keys) - set(clip_keys)) common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True) clip_only_keys.sort(reverse=True) maple_only_keys.sort(reverse=True) out = [] out.extend([f"common-{i}" for i in common_keys[:5]]) out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]]) out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]]) return out def update_radio_options( evt: Optional[gr.EventData], selected_image: str, model_name: str ) -> gr.Radio: """Update radio options based on user interaction.""" clip_neuron_dict = {} maple_neuron_dict = {} def _get_top_activation(evt, selected_image, model_name, neuron_dict): all_activation = get_activation_distribution(selected_image, model_name) image_activation = all_activation.mean(0) top_neurons = list(np.argsort(image_activation)[::-1][:5]) for top_neuron in top_neurons: neuron_dict[top_neuron] = image_activation[top_neuron] if evt is not None and evt._data is not None and isinstance(evt._data["index"], list): image = _CACHE.get('data_dict', selected_image, {}).get("image") if image: grid_x, grid_y, _, _ = get_grid_loc(evt, image) token_idx = grid_y * GRID_NUM + grid_x + 1 tile_activations = all_activation[token_idx] top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5]) for top_neuron in top_tile_neurons: neuron_dict[top_neuron] = tile_activations[top_neuron] return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)) clip_neuron_dict = _get_top_activation(evt, selected_image, "CLIP", clip_neuron_dict) maple_neuron_dict = _get_top_activation(evt, selected_image, model_name, maple_neuron_dict) radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict) return gr.Radio(choices=radio_choices, label="Top activating SAE latent", value=radio_choices[0]) def update_markdown(option_value: str) -> Tuple[str, str]: """Update markdown text based on selected option.""" latent_idx = int(option_value.split("-")[-1]) out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}" out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}" return out_1, out_2 def update_all( selected_image: str, slider_value: str, toggle_btn: bool, model_name: str ) -> Tuple: """Update all UI components.""" ( seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3, ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn) seg_mask_display_maple = show_activation_heatmap_maple( selected_image, slider_value, model_name ) markdown_display, markdown_display_2 = update_markdown(slider_value) return ( seg_mask_display, seg_mask_display_maple, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3, markdown_display, markdown_display_2, ) def monitor_memory_usage(): """Monitor memory usage and clean cache if necessary.""" process = psutil.Process() mem_info = process.memory_info() mem_percent = process.memory_percent() print(f""" Memory Usage: - RSS: {mem_info.rss / (1024**2):.2f} MB - VMS: {mem_info.vms / (1024**2):.2f} MB - Percent: {mem_percent:.1f}% - Cache Sizes: {[len(cache) for cache in _CACHE.data.values()]} """) if mem_percent > 80: print("Memory usage too high, clearing caches...") _CACHE.clear_category('segmasks') _CACHE.clear_category('top_images') _CACHE.clear_category('precomputed_activations') def start_memory_monitor(interval: int = 300): """Start periodic memory monitoring.""" monitor_memory_usage() threading.Timer(interval, start_memory_monitor).start() # Initialize the application data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT) default_image_name = "christmas-imagenet" # Create the Gradio interface with gr.Blocks( theme=gr.themes.Citrus(), css=""" .image-row .gr-image { margin: 0 !important; padding: 0 !important; } .image-row img { width: auto; height: 50px; } """, ) as demo: with gr.Row(): with gr.Column(): gr.Markdown("## Select input image and patch on the image") image_selector = gr.Dropdown( choices=list(_CACHE.data['data_dict'].keys()), value=default_image_name, label="Select Image", ) image_display = gr.Image( value=_CACHE.get('data_dict', default_image_name, {}).get("image"), type="pil", interactive=True, ) image_selector.change( fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"), inputs=image_selector, outputs=image_display, ) image_display.select( fn=highlight_grid, inputs=[image_selector], outputs=[image_display] ) with gr.Column(): gr.Markdown("## SAE latent activations of CLIP and MaPLE") model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST] model_selector = gr.Dropdown( choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)", ) init_plot = plot_activation_distribution(None, default_image_name, model_options[0]) neuron_plot = gr.Plot(value=init_plot, show_label=False) image_selector.change( fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot, ) image_display.select( fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot, ) model_selector.change( fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"), inputs=[image_selector], outputs=image_display, ) model_selector.change( fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot, ) with gr.Row(): with gr.Column(): radio_names = get_init_radio_options(default_image_name, model_options[0]) feature_idx = radio_names[0].split("-")[-1] markdown_display = gr.Markdown( f"## Segmentation mask for the selected SAE latent - {feature_idx}" ) init_seg, init_tops, init_values = show_activation_heatmap( default_image_name, radio_names[0], "CLIP" ) gr.Markdown("### Localize SAE latent activation using CLIP") seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False) init_seg_maple, _, _ = show_activation_heatmap( default_image_name, radio_names[0], model_options[0] ) gr.Markdown("### Localize SAE latent activation using MaPLE") seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False) with gr.Column(): gr.Markdown("## Top activating SAE latent index") radio_choices = gr.Radio( choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0], ) toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False) markdown_display_2 = gr.Markdown( f"## Top reference images for the selected SAE latent - {feature_idx}" ) gr.Markdown("### ImageNet") top_image_1 = gr.Image(value=init_tops[0], type="pil", show_label=False) act_value_1 = gr.Markdown(init_values[0]) gr.Markdown("### ImageNet-Sketch") top_image_2 = gr.Image(value=init_tops[1], type="pil", show_label=False) act_value_2 = gr.Markdown(init_values[1]) gr.Markdown("### Caltech101") top_image_3 = gr.Image(value=init_tops[2], type="pil", show_label=False) act_value_3 = gr.Markdown(init_values[2]) # Event handlers image_display.select( fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], ) model_selector.change( fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], ) image_selector.select( fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], ) radio_choices.change( fn=update_all, inputs=[image_selector, radio_choices, toggle_btn, model_selector], outputs=[ seg_mask_display, seg_mask_display_maple, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3, markdown_display, markdown_display_2, ], ) toggle_btn.change( fn=show_activation_heatmap_clip, inputs=[image_selector, radio_choices, toggle_btn], outputs=[ seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3, ], ) if __name__ == "__main__": # Initialize memory monitoring start_memory_monitor() # Get system memory info mem = psutil.virtual_memory() total_ram_gb = mem.total / (1024**3) try: print("Starting application initialization...") # Precompute common data print("Precomputing activation patterns...") for image_name in _CACHE.data['data_dict'].keys(): for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]: try: activation = get_activation_distribution(image_name, model_name) cache_key = f"activation_{model_name}_{image_name}" _CACHE.set('precomputed_activations', cache_key, activation.mean(0)) except Exception as e: print(f"Error precomputing activation for {image_name}, {model_name}: {e}") print("Starting Gradio interface...") # Launch the app with optimized settings demo.queue(max_size=min(20, int(total_ram_gb))) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, max_threads=min(16, psutil.cpu_count()) ) except Exception as e: print(f"Critical error during startup: {e}") # Attempt to clean up resources _CACHE.data.clear() raise