limhyesu98
		
	commited on
		
		
					Commit 
							
							·
						
						4cf80d2
	
1
								Parent(s):
							
							5e15f55
								
init
Browse files- .gitattributes +21 -0
 - README.md +12 -3
 - app.py +503 -0
 - requirements.txt +3 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -1,3 +1,4 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
             
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
             
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         @@ -33,3 +34,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <<<<<<< HEAD
         
     | 
| 2 | 
         
             
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
             
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
             
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 34 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            =======
         
     | 
| 38 | 
         
            +
            chrismas-imagnet.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 39 | 
         
            +
            dog-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 40 | 
         
            +
            dog-mmvp.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 41 | 
         
            +
            golden_gate_bridge.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 42 | 
         
            +
            hen-imagenet-r.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 43 | 
         
            +
            hen-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 44 | 
         
            +
            kayaking-ucf.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 45 | 
         
            +
            owl-imagenet-sketch.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 46 | 
         
            +
            owl-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 47 | 
         
            +
            paphiopedilum-micranthum.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 48 | 
         
            +
            phalaenopsis-aphrodite.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 49 | 
         
            +
            text-1.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 50 | 
         
            +
            text-2.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 51 | 
         
            +
            text-3.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 52 | 
         
            +
            vegetation-land-eurosat.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 53 | 
         
            +
            data/sae_data/mean_act_values_caltech101.pkl.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 54 | 
         
            +
            data/sae_data/mean_act_values_imagenet-sketch.pkl.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 55 | 
         
            +
            data/sae_data/mean_act_values_imagenet.pkl.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 56 | 
         
            +
            >>>>>>> master
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,10 +1,19 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 
         | 
|
| 2 | 
         
             
            title: Patchsae Demo
         
     | 
| 3 | 
         
            -
            emoji:  
     | 
| 4 | 
         
            -
            colorFrom:  
     | 
| 5 | 
         
            -
            colorTo:  
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
             
            sdk_version: 5.8.0
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            ---
         
     | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
            +
            <<<<<<< HEAD
         
     | 
| 3 | 
         
             
            title: Patchsae Demo
         
     | 
| 4 | 
         
            +
            emoji: 😻
         
     | 
| 5 | 
         
            +
            colorFrom: red
         
     | 
| 6 | 
         
            +
            colorTo: gray
         
     | 
| 7 | 
         
             
            sdk: gradio
         
     | 
| 8 | 
         
             
            sdk_version: 5.8.0
         
     | 
| 9 | 
         
            +
            =======
         
     | 
| 10 | 
         
            +
            title: Paper14240
         
     | 
| 11 | 
         
            +
            emoji: 📈
         
     | 
| 12 | 
         
            +
            colorFrom: blue
         
     | 
| 13 | 
         
            +
            colorTo: pink
         
     | 
| 14 | 
         
            +
            sdk: gradio
         
     | 
| 15 | 
         
            +
            sdk_version: 5.5.0
         
     | 
| 16 | 
         
            +
            >>>>>>> master
         
     | 
| 17 | 
         
             
            app_file: app.py
         
     | 
| 18 | 
         
             
            pinned: false
         
     | 
| 19 | 
         
             
            ---
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,503 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gzip
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import pickle
         
     | 
| 4 | 
         
            +
            from glob import glob
         
     | 
| 5 | 
         
            +
            from time import sleep
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import gradio as gr
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import plotly.graph_objects as go
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from PIL import Image, ImageDraw
         
     | 
| 12 | 
         
            +
            from plotly.subplots import make_subplots
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            IMAGE_SIZE = 400
         
     | 
| 15 | 
         
            +
            DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
         
     | 
| 16 | 
         
            +
            GRID_NUM = 14
         
     | 
| 17 | 
         
            +
            pkl_root = "./data/out"
         
     | 
| 18 | 
         
            +
            preloaded_data = {}
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def preload_activation(image_name):
         
     | 
| 22 | 
         
            +
                for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
         
     | 
| 23 | 
         
            +
                    image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
         
     | 
| 24 | 
         
            +
                    with gzip.open(image_file, "rb") as f:
         
     | 
| 25 | 
         
            +
                        preloaded_data[model] = pickle.load(f)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def get_activation_distribution(image_name: str, model_type: str):
         
     | 
| 29 | 
         
            +
                activation = get_data(image_name, model_type)[0]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                noisy_features_indices = (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
         
     | 
| 32 | 
         
            +
                activation[:, noisy_features_indices] = 0
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                return activation
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def get_grid_loc(evt, image):
         
     | 
| 38 | 
         
            +
                # Get click coordinates
         
     | 
| 39 | 
         
            +
                x, y = evt._data["index"][0], evt._data["index"][1]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                cell_width = image.width // GRID_NUM
         
     | 
| 42 | 
         
            +
                cell_height = image.height // GRID_NUM
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                grid_x = x // cell_width
         
     | 
| 45 | 
         
            +
                grid_y = y // cell_height
         
     | 
| 46 | 
         
            +
                return grid_x, grid_y, cell_width, cell_height
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            def highlight_grid(evt: gr.EventData, image_name):
         
     | 
| 50 | 
         
            +
                image = data_dict[image_name]["image"]
         
     | 
| 51 | 
         
            +
                grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                highlighted_image = image.copy()
         
     | 
| 54 | 
         
            +
                draw = ImageDraw.Draw(highlighted_image)
         
     | 
| 55 | 
         
            +
                box = [grid_x * cell_width, grid_y * cell_height, (grid_x + 1) * cell_width, (grid_y + 1) * cell_height]
         
     | 
| 56 | 
         
            +
                draw.rectangle(box, outline="red", width=3)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                return highlighted_image
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def load_image(img_name):
         
     | 
| 62 | 
         
            +
                return Image.open(data_dict[img_name]["image_path"]).resize((IMAGE_SIZE, IMAGE_SIZE))
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            def plot_activations(
         
     | 
| 66 | 
         
            +
                all_activation, tile_activations=None, grid_x=None, grid_y=None, top_k=5, colors=("blue", "cyan"), model_name="CLIP"
         
     | 
| 67 | 
         
            +
            ):
         
     | 
| 68 | 
         
            +
                fig = go.Figure()
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def _add_scatter_with_annotation(fig, activations, model_name, color, label):
         
     | 
| 71 | 
         
            +
                    fig.add_trace(
         
     | 
| 72 | 
         
            +
                        go.Scatter(
         
     | 
| 73 | 
         
            +
                            x=np.arange(len(activations)),
         
     | 
| 74 | 
         
            +
                            y=activations,
         
     | 
| 75 | 
         
            +
                            mode="lines",
         
     | 
| 76 | 
         
            +
                            name=label,
         
     | 
| 77 | 
         
            +
                            line=dict(color=color, dash="solid"),
         
     | 
| 78 | 
         
            +
                            showlegend=True,
         
     | 
| 79 | 
         
            +
                        )
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
                    top_neurons = np.argsort(activations)[::-1][:top_k]
         
     | 
| 82 | 
         
            +
                    for idx in top_neurons:
         
     | 
| 83 | 
         
            +
                        fig.add_annotation(
         
     | 
| 84 | 
         
            +
                            x=idx,
         
     | 
| 85 | 
         
            +
                            y=activations[idx],
         
     | 
| 86 | 
         
            +
                            text=str(idx),
         
     | 
| 87 | 
         
            +
                            showarrow=True,
         
     | 
| 88 | 
         
            +
                            arrowhead=2,
         
     | 
| 89 | 
         
            +
                            ax=0,
         
     | 
| 90 | 
         
            +
                            ay=-15,
         
     | 
| 91 | 
         
            +
                            arrowcolor=color,
         
     | 
| 92 | 
         
            +
                            opacity=0.7,
         
     | 
| 93 | 
         
            +
                        )
         
     | 
| 94 | 
         
            +
                    return fig
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                label = f"{model_name.split('-')[-0]} Image-level"
         
     | 
| 97 | 
         
            +
                fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
         
     | 
| 98 | 
         
            +
                if tile_activations is not None:
         
     | 
| 99 | 
         
            +
                    label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
         
     | 
| 100 | 
         
            +
                    fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                fig.update_layout(
         
     | 
| 103 | 
         
            +
                    title="Activation Distribution",
         
     | 
| 104 | 
         
            +
                    xaxis_title="SAE latent index",
         
     | 
| 105 | 
         
            +
                    yaxis_title="Activation Value",
         
     | 
| 106 | 
         
            +
                    template="plotly_white",
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
                fig.update_layout(legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5))
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                return fig
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
         
     | 
| 114 | 
         
            +
                activation = get_activation_distribution(selected_image, model_name)
         
     | 
| 115 | 
         
            +
                all_activation = activation.mean(0)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                tile_activations = None
         
     | 
| 118 | 
         
            +
                grid_x = None
         
     | 
| 119 | 
         
            +
                grid_y = None
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                if evt is not None:
         
     | 
| 122 | 
         
            +
                    if evt._data is not None:
         
     | 
| 123 | 
         
            +
                        image = data_dict[selected_image]["image"]
         
     | 
| 124 | 
         
            +
                        grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
         
     | 
| 125 | 
         
            +
                        token_idx = grid_y * GRID_NUM + grid_x + 1
         
     | 
| 126 | 
         
            +
                        tile_activations = activation[token_idx]
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                fig = plot_activations(
         
     | 
| 129 | 
         
            +
                    all_activation, tile_activations, grid_x, grid_y, top_k=5, model_name=model_name, colors=colors
         
     | 
| 130 | 
         
            +
                )
         
     | 
| 131 | 
         
            +
                return fig
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            def plot_activation_distribution(evt: gr.EventData, selected_image: str, model_name: str):
         
     | 
| 135 | 
         
            +
                fig = make_subplots(
         
     | 
| 136 | 
         
            +
                    rows=2,
         
     | 
| 137 | 
         
            +
                    cols=1,
         
     | 
| 138 | 
         
            +
                    shared_xaxes=True,
         
     | 
| 139 | 
         
            +
                    subplot_titles=["CLIP Activation", f"{model_name} Activation"],
         
     | 
| 140 | 
         
            +
                )
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
         
     | 
| 143 | 
         
            +
                fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def _attach_fig(fig, sub_fig, row, col, yref):
         
     | 
| 146 | 
         
            +
                    for trace in sub_fig.data:
         
     | 
| 147 | 
         
            +
                        fig.add_trace(trace, row=row, col=col)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    for annotation in sub_fig.layout.annotations:
         
     | 
| 150 | 
         
            +
                        annotation.update(yref=yref)
         
     | 
| 151 | 
         
            +
                        fig.add_annotation(annotation)
         
     | 
| 152 | 
         
            +
                    return fig
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
         
     | 
| 155 | 
         
            +
                fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
         
     | 
| 158 | 
         
            +
                fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
         
     | 
| 159 | 
         
            +
                fig.update_yaxes(title_text="Activation Value", row=1, col=1)
         
     | 
| 160 | 
         
            +
                fig.update_yaxes(title_text="Activation Value", row=2, col=1)
         
     | 
| 161 | 
         
            +
                fig.update_layout(
         
     | 
| 162 | 
         
            +
                    # height=500,
         
     | 
| 163 | 
         
            +
                    # title="Activation Distributions",
         
     | 
| 164 | 
         
            +
                    template="plotly_white",
         
     | 
| 165 | 
         
            +
                    showlegend=True,
         
     | 
| 166 | 
         
            +
                    legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
         
     | 
| 167 | 
         
            +
                    margin=dict(l=20, r=20, t=40, b=20),
         
     | 
| 168 | 
         
            +
                )
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                return fig
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            def get_segmask(selected_image, slider_value, model_type):
         
     | 
| 174 | 
         
            +
                image = data_dict[selected_image]["image"]
         
     | 
| 175 | 
         
            +
                sae_act = get_data(selected_image, model_type)[0]
         
     | 
| 176 | 
         
            +
                temp = sae_act[:, slider_value]
         
     | 
| 177 | 
         
            +
                try:
         
     | 
| 178 | 
         
            +
                    mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
         
     | 
| 179 | 
         
            +
                except Exception as e:
         
     | 
| 180 | 
         
            +
                    print(sae_act.shape, slider_value)
         
     | 
| 181 | 
         
            +
                mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
         
     | 
| 182 | 
         
            +
                mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                base_opacity = 30
         
     | 
| 185 | 
         
            +
                image_array = np.array(image)[..., :3]
         
     | 
| 186 | 
         
            +
                rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
         
     | 
| 187 | 
         
            +
                rgba_overlay[..., :3] = image_array[..., :3]
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
         
     | 
| 190 | 
         
            +
                rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
         
     | 
| 191 | 
         
            +
                rgba_overlay[..., 3] = 255  # Fully opaque
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                return rgba_overlay
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
            def get_top_images(slider_value, toggle_btn):
         
     | 
| 197 | 
         
            +
                def _get_images(dataset_path):
         
     | 
| 198 | 
         
            +
                    top_image_paths = [
         
     | 
| 199 | 
         
            +
                        os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
         
     | 
| 200 | 
         
            +
                        os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
         
     | 
| 201 | 
         
            +
                        os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
         
     | 
| 202 | 
         
            +
                    ]
         
     | 
| 203 | 
         
            +
                    top_images = [
         
     | 
| 204 | 
         
            +
                        Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
         
     | 
| 205 | 
         
            +
                        for path in top_image_paths
         
     | 
| 206 | 
         
            +
                    ]
         
     | 
| 207 | 
         
            +
                    return top_images
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                if toggle_btn:
         
     | 
| 210 | 
         
            +
                    top_images = _get_images("./data/top_images_masked")
         
     | 
| 211 | 
         
            +
                else:
         
     | 
| 212 | 
         
            +
                    top_images = _get_images("./data/top_images")
         
     | 
| 213 | 
         
            +
                return top_images
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
         
     | 
| 217 | 
         
            +
                slider_value = int(slider_value.split("-")[-1])
         
     | 
| 218 | 
         
            +
                rgba_overlay = get_segmask(selected_image, slider_value, model_type)
         
     | 
| 219 | 
         
            +
                top_images = get_top_images(slider_value, toggle_btn)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                act_values = []
         
     | 
| 222 | 
         
            +
                for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
         
     | 
| 223 | 
         
            +
                    act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
         
     | 
| 224 | 
         
            +
                    act_value = [str(round(value, 3)) for value in act_value]
         
     | 
| 225 | 
         
            +
                    act_value = " | ".join(act_value)
         
     | 
| 226 | 
         
            +
                    out = f"#### Activation values: {act_value}"
         
     | 
| 227 | 
         
            +
                    act_values.append(out)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                return rgba_overlay, top_images, act_values
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
         
     | 
| 233 | 
         
            +
                rgba_overlay, top_images, act_values = show_activation_heatmap(selected_image, slider_value, "CLIP", toggle_btn)
         
     | 
| 234 | 
         
            +
                sleep(0.1)
         
     | 
| 235 | 
         
            +
                return (rgba_overlay, top_images[0], top_images[1], top_images[2], act_values[0], act_values[1], act_values[2])
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
            def show_activation_heatmap_maple(selected_image, slider_value, model_name):
         
     | 
| 239 | 
         
            +
                slider_value = int(slider_value.split("-")[-1])
         
     | 
| 240 | 
         
            +
                rgba_overlay = get_segmask(selected_image, slider_value, model_name)
         
     | 
| 241 | 
         
            +
                sleep(0.1)
         
     | 
| 242 | 
         
            +
                return rgba_overlay
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
            def get_init_radio_options(selected_image, model_name):
         
     | 
| 246 | 
         
            +
                clip_neuron_dict = {}
         
     | 
| 247 | 
         
            +
                maple_neuron_dict = {}
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
         
     | 
| 250 | 
         
            +
                    activations = get_activation_distribution(selected_image, model_name).mean(0)
         
     | 
| 251 | 
         
            +
                    top_neurons = list(np.argsort(activations)[::-1][:top_k])
         
     | 
| 252 | 
         
            +
                    for top_neuron in top_neurons:
         
     | 
| 253 | 
         
            +
                        neuron_dict[top_neuron] = activations[top_neuron]
         
     | 
| 254 | 
         
            +
                    sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
         
     | 
| 255 | 
         
            +
                    return sorted_dict
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
         
     | 
| 258 | 
         
            +
                maple_neuron_dict = _get_top_actvation(selected_image, model_name, maple_neuron_dict)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                return radio_choices
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
            def get_radio_names(clip_neuron_dict, maple_neuron_dict):
         
     | 
| 266 | 
         
            +
                clip_keys = list(clip_neuron_dict.keys())
         
     | 
| 267 | 
         
            +
                maple_keys = list(maple_neuron_dict.keys())
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                common_keys = list(set(clip_keys).intersection(set(maple_keys)))
         
     | 
| 270 | 
         
            +
                clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
         
     | 
| 271 | 
         
            +
                maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
         
     | 
| 274 | 
         
            +
                clip_only_keys.sort(reverse=True)
         
     | 
| 275 | 
         
            +
                maple_only_keys.sort(reverse=True)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                out = []
         
     | 
| 278 | 
         
            +
                out.extend([f"common-{i}" for i in common_keys[:5]])
         
     | 
| 279 | 
         
            +
                out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
         
     | 
| 280 | 
         
            +
                out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                return out
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
            def update_radio_options(evt: gr.EventData, selected_image, model_name):
         
     | 
| 286 | 
         
            +
                def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
         
     | 
| 287 | 
         
            +
                    top_neurons = list(np.argsort(activations)[::-1][:top_k])
         
     | 
| 288 | 
         
            +
                    for top_neuron in top_neurons:
         
     | 
| 289 | 
         
            +
                        neuron_dict[top_neuron] = activations[top_neuron]
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
         
     | 
| 292 | 
         
            +
                    all_activation = get_activation_distribution(selected_image, model_name)
         
     | 
| 293 | 
         
            +
                    image_activation = all_activation.mean(0)
         
     | 
| 294 | 
         
            +
                    _sort_and_save_top_k(image_activation, neuron_dict)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    if evt is not None:
         
     | 
| 297 | 
         
            +
                        if evt._data is not None and isinstance(evt._data["index"], list):
         
     | 
| 298 | 
         
            +
                            image = data_dict[selected_image]["image"]
         
     | 
| 299 | 
         
            +
                            grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
         
     | 
| 300 | 
         
            +
                            token_idx = grid_y * GRID_NUM + grid_x + 1
         
     | 
| 301 | 
         
            +
                            tile_activations = all_activation[token_idx]
         
     | 
| 302 | 
         
            +
                            _sort_and_save_top_k(tile_activations, neuron_dict)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
         
     | 
| 305 | 
         
            +
                    return sorted_dict
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                clip_neuron_dict = {}
         
     | 
| 308 | 
         
            +
                maple_neuron_dict = {}
         
     | 
| 309 | 
         
            +
                clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
         
     | 
| 310 | 
         
            +
                maple_neuron_dict = _get_top_actvation(evt, selected_image, model_name, maple_neuron_dict)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                clip_keys = list(clip_neuron_dict.keys())
         
     | 
| 313 | 
         
            +
                maple_keys = list(maple_neuron_dict.keys())
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                common_keys = list(set(clip_keys).intersection(set(maple_keys)))
         
     | 
| 316 | 
         
            +
                clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
         
     | 
| 317 | 
         
            +
                maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
         
     | 
| 320 | 
         
            +
                clip_only_keys.sort(reverse=True)
         
     | 
| 321 | 
         
            +
                maple_only_keys.sort(reverse=True)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                out = []
         
     | 
| 324 | 
         
            +
                out.extend([f"common-{i}" for i in common_keys[:5]])
         
     | 
| 325 | 
         
            +
                out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
         
     | 
| 326 | 
         
            +
                out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                radio_choices = gr.Radio(choices=out, label="Top activating SAE latent", value=out[0])
         
     | 
| 329 | 
         
            +
                sleep(0.1)
         
     | 
| 330 | 
         
            +
                return radio_choices
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
            def update_markdown(option_value):
         
     | 
| 334 | 
         
            +
                latent_idx = int(option_value.split("-")[-1])
         
     | 
| 335 | 
         
            +
                out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
         
     | 
| 336 | 
         
            +
                out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
         
     | 
| 337 | 
         
            +
                return out_1, out_2
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
            def get_data(image_name, model_name):
         
     | 
| 341 | 
         
            +
                pkl_root = "./data/out"
         
     | 
| 342 | 
         
            +
                data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
         
     | 
| 343 | 
         
            +
                with gzip.open(data_dir, "rb") as f:
         
     | 
| 344 | 
         
            +
                    data = pickle.load(f)
         
     | 
| 345 | 
         
            +
                    out = data
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                return out
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
            def load_all_data(image_root, pkl_root):
         
     | 
| 351 | 
         
            +
                image_files = glob(f"{image_root}/*")
         
     | 
| 352 | 
         
            +
                data_dict = {}
         
     | 
| 353 | 
         
            +
                for image_file in image_files:
         
     | 
| 354 | 
         
            +
                    image_name = os.path.basename(image_file).split(".")[0]
         
     | 
| 355 | 
         
            +
                    if image_file not in data_dict:
         
     | 
| 356 | 
         
            +
                        data_dict[image_name] = {
         
     | 
| 357 | 
         
            +
                            "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
         
     | 
| 358 | 
         
            +
                            "image_path": image_file,
         
     | 
| 359 | 
         
            +
                        }
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                sae_data_dict = {}
         
     | 
| 362 | 
         
            +
                with open("./data/sae_data/mean_acts.pkl", "rb") as f:
         
     | 
| 363 | 
         
            +
                    data = pickle.load(f)
         
     | 
| 364 | 
         
            +
                    sae_data_dict["mean_acts"] = data
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                sae_data_dict["mean_act_values"] = {}
         
     | 
| 367 | 
         
            +
                for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
         
     | 
| 368 | 
         
            +
                    with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
         
     | 
| 369 | 
         
            +
                        data = pickle.load(f)
         
     | 
| 370 | 
         
            +
                        sae_data_dict["mean_act_values"][dataset] = data
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                return data_dict, sae_data_dict
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
            data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
         
     | 
| 376 | 
         
            +
            default_image_name = "christmas-imagenet"
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            with gr.Blocks(
         
     | 
| 380 | 
         
            +
                theme=gr.themes.Citrus(),
         
     | 
| 381 | 
         
            +
                css="""
         
     | 
| 382 | 
         
            +
                .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
         
     | 
| 383 | 
         
            +
                .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
         
     | 
| 384 | 
         
            +
            """,
         
     | 
| 385 | 
         
            +
            ) as demo:
         
     | 
| 386 | 
         
            +
                with gr.Row():
         
     | 
| 387 | 
         
            +
                    with gr.Column():
         
     | 
| 388 | 
         
            +
                        # Left View: Image selection and click handling
         
     | 
| 389 | 
         
            +
                        gr.Markdown("## Select input image and patch on the image")
         
     | 
| 390 | 
         
            +
                        image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
         
     | 
| 391 | 
         
            +
                        image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                        # Update image display when a new image is selected
         
     | 
| 394 | 
         
            +
                        image_selector.change(
         
     | 
| 395 | 
         
            +
                            fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
         
     | 
| 396 | 
         
            +
                        )
         
     | 
| 397 | 
         
            +
                        image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    with gr.Column():
         
     | 
| 400 | 
         
            +
                        gr.Markdown("## SAE latent activations of CLIP and MaPLE")
         
     | 
| 401 | 
         
            +
                        model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
         
     | 
| 402 | 
         
            +
                        model_selector = gr.Dropdown(
         
     | 
| 403 | 
         
            +
                            choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
         
     | 
| 404 | 
         
            +
                        )
         
     | 
| 405 | 
         
            +
                        init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
         
     | 
| 406 | 
         
            +
                        neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                        image_selector.change(
         
     | 
| 409 | 
         
            +
                            fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
         
     | 
| 410 | 
         
            +
                        )
         
     | 
| 411 | 
         
            +
                        image_display.select(
         
     | 
| 412 | 
         
            +
                            fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
         
     | 
| 413 | 
         
            +
                        )
         
     | 
| 414 | 
         
            +
                        model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
         
     | 
| 415 | 
         
            +
                        model_selector.change(
         
     | 
| 416 | 
         
            +
                            fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
         
     | 
| 417 | 
         
            +
                        )
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                with gr.Row():
         
     | 
| 420 | 
         
            +
                    with gr.Column():
         
     | 
| 421 | 
         
            +
                        radio_names = get_init_radio_options(default_image_name, model_options[0])
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                        feautre_idx = radio_names[0].split("-")[-1]
         
     | 
| 424 | 
         
            +
                        markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
         
     | 
| 425 | 
         
            +
                        init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                        gr.Markdown("### Localize SAE latent activation using CLIP")
         
     | 
| 428 | 
         
            +
                        seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
         
     | 
| 429 | 
         
            +
                        init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
         
     | 
| 430 | 
         
            +
                        gr.Markdown("### Localize SAE latent activation using MaPLE")
         
     | 
| 431 | 
         
            +
                        seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    with gr.Column():
         
     | 
| 434 | 
         
            +
                        gr.Markdown("## Top activating SAE latent index")
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                        radio_choices = gr.Radio(
         
     | 
| 437 | 
         
            +
                            choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
         
     | 
| 438 | 
         
            +
                        )
         
     | 
| 439 | 
         
            +
                        toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                        markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                        gr.Markdown("### ImageNet")
         
     | 
| 444 | 
         
            +
                        top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
         
     | 
| 445 | 
         
            +
                        act_value_1 = gr.Markdown(init_values[0])
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                        gr.Markdown("### ImageNet-Sketch")
         
     | 
| 448 | 
         
            +
                        top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
         
     | 
| 449 | 
         
            +
                        act_value_2 = gr.Markdown(init_values[1])
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                        gr.Markdown("### Caltech101")
         
     | 
| 452 | 
         
            +
                        top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
         
     | 
| 453 | 
         
            +
                        act_value_3 = gr.Markdown(init_values[2])
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                        image_display.select(
         
     | 
| 456 | 
         
            +
                            fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
         
     | 
| 457 | 
         
            +
                        )
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                        model_selector.change(
         
     | 
| 460 | 
         
            +
                            fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
         
     | 
| 461 | 
         
            +
                        )
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                        image_selector.select(
         
     | 
| 464 | 
         
            +
                            fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
         
     | 
| 465 | 
         
            +
                        )
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                    radio_choices.change(
         
     | 
| 468 | 
         
            +
                        fn=update_markdown,
         
     | 
| 469 | 
         
            +
                        inputs=[radio_choices],
         
     | 
| 470 | 
         
            +
                        outputs=[markdown_display, markdown_display_2],
         
     | 
| 471 | 
         
            +
                        queue=True,
         
     | 
| 472 | 
         
            +
                    )
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    radio_choices.change(
         
     | 
| 475 | 
         
            +
                        fn=show_activation_heatmap_clip,
         
     | 
| 476 | 
         
            +
                        inputs=[image_selector, radio_choices, toggle_btn],
         
     | 
| 477 | 
         
            +
                        outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
         
     | 
| 478 | 
         
            +
                        queue=True,
         
     | 
| 479 | 
         
            +
                    )
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                    radio_choices.change(
         
     | 
| 482 | 
         
            +
                        fn=show_activation_heatmap_maple,
         
     | 
| 483 | 
         
            +
                        inputs=[image_selector, radio_choices, model_selector],
         
     | 
| 484 | 
         
            +
                        outputs=[seg_mask_display_maple],
         
     | 
| 485 | 
         
            +
                        queue=True,
         
     | 
| 486 | 
         
            +
                    )
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                    # toggle_btn.change(
         
     | 
| 489 | 
         
            +
                    #     fn=get_top_images,
         
     | 
| 490 | 
         
            +
                    #     inputs=[radio_choices, toggle_btn],
         
     | 
| 491 | 
         
            +
                    #     outputs=[top_image_1, top_image_2, top_image_3],
         
     | 
| 492 | 
         
            +
                    #     queue=True,
         
     | 
| 493 | 
         
            +
                    # )
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    toggle_btn.change(
         
     | 
| 496 | 
         
            +
                        fn=show_activation_heatmap_clip,
         
     | 
| 497 | 
         
            +
                        inputs=[image_selector, radio_choices, toggle_btn],
         
     | 
| 498 | 
         
            +
                        outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
         
     | 
| 499 | 
         
            +
                        queue=True,
         
     | 
| 500 | 
         
            +
                    )
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                # Launch the app
         
     | 
| 503 | 
         
            +
                demo.launch()
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch
         
     | 
| 2 | 
         
            +
            matplotlib
         
     | 
| 3 | 
         
            +
            plotly
         
     |