hyesulim commited on
Commit
c08c98f
·
verified ·
1 Parent(s): fdf6b95

test: add memory monitor

Browse files
Files changed (1) hide show
  1. app.py +1068 -138
app.py CHANGED
@@ -609,148 +609,1078 @@ def load_all_data(image_root, pkl_root):
609
  default_image_name = "christmas-imagenet"
610
 
611
 
612
- def safe_operation(func, *args, timeout=10, default=None):
613
- """Execute function with timeout and return default value on failure."""
614
- try:
615
- with concurrent.futures.ThreadPoolExecutor() as executor:
616
- future = executor.submit(func, *args)
617
- return future.result(timeout=timeout)
618
- except concurrent.futures.TimeoutError:
619
- print(f"Operation timed out: {func.__name__}")
620
- return default
621
- except Exception as e:
622
- print(f"Operation failed: {func.__name__}, Error: {e}")
623
- return default
624
-
625
- def create_interface():
626
- with gr.Blocks(
627
- theme=gr.themes.Citrus(),
628
- css="""
629
- .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
630
- .image-row img { width: auto; height: 50px; }
631
- """,
632
- ) as demo:
633
- # State management for preventing duplicate operations
634
- state = gr.State({
635
- 'last_update': 0,
636
- 'processing': False
637
- })
638
-
639
- with gr.Row():
640
- with gr.Column():
641
- gr.Markdown("## Select input image and patch on the image")
642
- image_selector = gr.Dropdown(
643
- choices=list(_CACHE['data_dict'].keys()),
644
- value=default_image_name,
645
- label="Select Image",
646
- )
647
- image_display = gr.Image(
648
- value=_CACHE['data_dict'][default_image_name]["image"],
649
- type="pil",
650
- interactive=True,
651
- )
652
-
653
- def safe_update_image(img_name):
654
- return safe_operation(
655
- lambda: _CACHE['data_dict'][img_name]["image"],
656
- timeout=5,
657
- default=Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), 'gray')
658
- )
659
-
660
- image_selector.change(
661
- fn=safe_update_image,
662
- inputs=image_selector,
663
- outputs=image_display,
664
- )
665
-
666
- with gr.Column():
667
- gr.Markdown("## SAE latent activations of CLIP and MaPLE")
668
- model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
669
- model_selector = gr.Dropdown(
670
- choices=model_options,
671
- value=model_options[0],
672
- label="Select adapted model (MaPLe)",
673
- )
674
-
675
- def safe_plot_activation(evt, selected_image, model_name):
676
- try:
677
- return safe_operation(
678
- plot_activation_distribution,
679
- evt, selected_image, model_name,
680
- timeout=10,
681
- default=go.Figure() # Return empty figure on timeout
682
- )
683
- except Exception as e:
684
- print(f"Error in plot_activation: {e}")
685
- return go.Figure()
686
-
687
- neuron_plot = gr.Plot(
688
- label="Neuron Activation",
689
- value=safe_plot_activation(None, default_image_name, model_options[0]),
690
- show_label=False,
691
- )
692
-
693
- def debounced_update(*args, minimum_interval=1.0):
694
- """Prevent updates that are too close together"""
695
- current_time = time.time()
696
- if current_time - state.value['last_update'] < minimum_interval:
697
- raise gr.Error("Please wait a moment before updating again")
698
- state.value['last_update'] = current_time
699
- return safe_plot_activation(*args)
700
-
701
- # Add error handling and debouncing to all event handlers
702
- image_selector.change(
703
- fn=debounced_update,
704
- inputs=[image_selector, model_selector],
705
- outputs=neuron_plot,
706
- )
707
- image_display.select(
708
- fn=debounced_update,
709
- inputs=[image_selector, model_selector],
710
- outputs=neuron_plot,
711
- )
712
- model_selector.change(
713
- fn=debounced_update,
714
- inputs=[image_selector, model_selector],
715
- outputs=neuron_plot,
716
- )
717
-
718
- # Rest of your interface code with similar error handling...
719
-
720
- return demo
 
 
 
 
 
 
 
 
721
 
722
- demo = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
 
 
724
  if __name__ == "__main__":
725
- # Configure logging
726
- import logging
727
- logging.basicConfig(
728
- level=logging.INFO,
729
- format='%(asctime)s - %(levelname)s - %(message)s'
 
 
 
730
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
- # Initialize the interface with error handling
 
733
  try:
734
- demo = create_interface()
735
- demo.queue() # Limit concurrent operations
736
- demo.launch(
737
- server_name="0.0.0.0",
738
- server_port=7860,
739
- share=False,
740
- show_error=True,
741
- max_threads=4, # Reduced for better stability
742
- prevent_thread_lock=True,
743
- # Add health check endpoint
744
- root_path="/healthz",
745
- # Add automatic cleanup
746
- _teardown=lambda: (
747
- clear_caches(),
748
- logging.info("Caches cleared during teardown")
749
- ),
750
- # Configure timeouts
751
- api_open_timeout=60,
752
- api_call_timeout=300,
753
- )
754
  except Exception as e:
755
- logging.error(f"Failed to launch interface: {e}")
756
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  default_image_name = "christmas-imagenet"
610
 
611
 
612
+ with gr.Blocks(
613
+ theme=gr.themes.Citrus(),
614
+ css="""
615
+ .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
616
+ .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
617
+ """,
618
+ ) as demo:
619
+ with gr.Row():
620
+ with gr.Column():
621
+ # Left View: Image selection and click handling
622
+ gr.Markdown("## Select input image and patch on the image")
623
+ image_selector = gr.Dropdown(
624
+ choices=list(data_dict.keys()),
625
+ value=default_image_name,
626
+ label="Select Image",
627
+ )
628
+ image_display = gr.Image(
629
+ value=data_dict[default_image_name]["image"],
630
+ type="pil",
631
+ interactive=True,
632
+ )
633
+
634
+ # Update image display when a new image is selected
635
+ image_selector.change(
636
+ fn=lambda img_name: data_dict[img_name]["image"],
637
+ inputs=image_selector,
638
+ outputs=image_display,
639
+ )
640
+ image_display.select(
641
+ fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
642
+ )
643
+
644
+ with gr.Column():
645
+ gr.Markdown("## SAE latent activations of CLIP and MaPLE")
646
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
647
+ model_selector = gr.Dropdown(
648
+ choices=model_options,
649
+ value=model_options[0],
650
+ label="Select adapted model (MaPLe)",
651
+ )
652
+ init_plot = plot_activation_distribution(
653
+ None, default_image_name, model_options[0]
654
+ )
655
+ neuron_plot = gr.Plot(
656
+ label="Neuron Activation", value=init_plot, show_label=False
657
+ )
658
+
659
+ image_selector.change(
660
+ fn=plot_activation_distribution,
661
+ inputs=[image_selector, model_selector],
662
+ outputs=neuron_plot,
663
+ )
664
+ image_display.select(
665
+ fn=plot_activation_distribution,
666
+ inputs=[image_selector, model_selector],
667
+ outputs=neuron_plot,
668
+ )
669
+ model_selector.change(
670
+ fn=load_image, inputs=[image_selector], outputs=image_display
671
+ )
672
+ model_selector.change(
673
+ fn=plot_activation_distribution,
674
+ inputs=[image_selector, model_selector],
675
+ outputs=neuron_plot,
676
+ )
677
+
678
+ with gr.Row():
679
+ with gr.Column():
680
+ radio_names = get_init_radio_options(default_image_name, model_options[0])
681
+
682
+ feautre_idx = radio_names[0].split("-")[-1]
683
+ markdown_display = gr.Markdown(
684
+ f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
685
+ )
686
+ init_seg, init_tops, init_values = show_activation_heatmap(
687
+ default_image_name, radio_names[0], "CLIP"
688
+ )
689
+
690
+ gr.Markdown("### Localize SAE latent activation using CLIP")
691
+ seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
692
+ init_seg_maple, _, _ = show_activation_heatmap(
693
+ default_image_name, radio_names[0], model_options[0]
694
+ )
695
+ gr.Markdown("### Localize SAE latent activation using MaPLE")
696
+ seg_mask_display_maple = gr.Image(
697
+ value=init_seg_maple, type="pil", show_label=False
698
+ )
699
+
700
+ with gr.Column():
701
+ gr.Markdown("## Top activating SAE latent index")
702
+
703
+ radio_choices = gr.Radio(
704
+ choices=radio_names,
705
+ label="Top activating SAE latent",
706
+ interactive=True,
707
+ value=radio_names[0],
708
+ )
709
+ toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
710
+
711
+ markdown_display_2 = gr.Markdown(
712
+ f"## Top reference images for the selected SAE latent - {feautre_idx}"
713
+ )
714
+
715
+ gr.Markdown("### ImageNet")
716
+ top_image_1 = gr.Image(
717
+ value=init_tops[0], type="pil", label="ImageNet", show_label=False
718
+ )
719
+ act_value_1 = gr.Markdown(init_values[0])
720
+
721
+ gr.Markdown("### ImageNet-Sketch")
722
+ top_image_2 = gr.Image(
723
+ value=init_tops[1],
724
+ type="pil",
725
+ label="ImageNet-Sketch",
726
+ show_label=False,
727
+ )
728
+ act_value_2 = gr.Markdown(init_values[1])
729
 
730
+ gr.Markdown("### Caltech101")
731
+ top_image_3 = gr.Image(
732
+ value=init_tops[2], type="pil", label="Caltech101", show_label=False
733
+ )
734
+ act_value_3 = gr.Markdown(init_values[2])
735
+
736
+ image_display.select(
737
+ fn=update_radio_options,
738
+ inputs=[image_selector, model_selector],
739
+ outputs=[radio_choices],
740
+ )
741
+
742
+ model_selector.change(
743
+ fn=update_radio_options,
744
+ inputs=[image_selector, model_selector],
745
+ outputs=[radio_choices],
746
+ )
747
+
748
+ image_selector.select(
749
+ fn=update_radio_options,
750
+ inputs=[image_selector, model_selector],
751
+ outputs=[radio_choices],
752
+ )
753
+
754
+ radio_choices.change(
755
+ fn=update_all,
756
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
757
+ outputs=[
758
+ seg_mask_display,
759
+ seg_mask_display_maple,
760
+ top_image_1,
761
+ top_image_2,
762
+ top_image_3,
763
+ act_value_1,
764
+ act_value_2,
765
+ act_value_3,
766
+ markdown_display,
767
+ markdown_display_2,
768
+ ],
769
+ )
770
+
771
+ toggle_btn.change(
772
+ fn=show_activation_heatmap_clip,
773
+ inputs=[image_selector, radio_choices, toggle_btn],
774
+ outputs=[
775
+ seg_mask_display,
776
+ top_image_1,
777
+ top_image_2,
778
+ top_image_3,
779
+ act_value_1,
780
+ act_value_2,
781
+ act_value_3,
782
+ ],
783
+ )
784
+
785
+ # Launch the app
786
+ # demo.queue()
787
+ # demo.launch()
788
 
789
+
790
  if __name__ == "__main__":
791
+ demo.queue() # Enable queuing for better handling of concurrent users
792
+ demo.launch(
793
+ server_name="0.0.0.0", # Allow external access
794
+ server_port=7860,
795
+ share=False, # Set to True if you want to create a public URL
796
+ show_error=True,
797
+ # Optimize concurrency
798
+ max_threads=8, # Adjust based on your CPU cores
799
  )
800
+ import gzip
801
+ import os
802
+ import pickle
803
+ from glob import glob
804
+ from time import sleep
805
+
806
+ from functools import lru_cache
807
+ import concurrent.futures
808
+ from typing import Dict, Tuple, List
809
+
810
+ import gradio as gr
811
+ import numpy as np
812
+ import plotly.graph_objects as go
813
+ import torch
814
+ from PIL import Image, ImageDraw
815
+ from plotly.subplots import make_subplots
816
+
817
+ IMAGE_SIZE = 400
818
+ DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
819
+ GRID_NUM = 14
820
+ pkl_root = "./data/out"
821
+ preloaded_data = {}
822
+
823
+
824
+ # Global cache for data
825
+ _CACHE = {
826
+ 'data_dict': {},
827
+ 'sae_data_dict': {},
828
+ 'model_data': {},
829
+ 'segmasks': {},
830
+ 'top_images': {}
831
+ }
832
+
833
+ def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
834
+ """Load all data with optimized parallel processing."""
835
+ # Load images in parallel
836
+ with concurrent.futures.ThreadPoolExecutor() as executor:
837
+ image_files = glob(f"{image_root}/*")
838
+ future_to_file = {
839
+ executor.submit(_load_image_file, image_file): image_file
840
+ for image_file in image_files
841
+ }
842
+
843
+ for future in concurrent.futures.as_completed(future_to_file):
844
+ image_file = future_to_file[future]
845
+ image_name = os.path.basename(image_file).split(".")[0]
846
+ result = future.result()
847
+ if result is not None:
848
+ _CACHE['data_dict'][image_name] = result
849
+
850
+ # Load SAE data
851
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
852
+ _CACHE['sae_data_dict']["mean_acts"] = pickle.load(f)
853
+
854
+ # Load mean act values in parallel
855
+ datasets = ["imagenet", "imagenet-sketch", "caltech101"]
856
+ _CACHE['sae_data_dict']["mean_act_values"] = {}
857
+
858
+ with concurrent.futures.ThreadPoolExecutor() as executor:
859
+ future_to_dataset = {
860
+ executor.submit(_load_mean_act_values, dataset): dataset
861
+ for dataset in datasets
862
+ }
863
+
864
+ for future in concurrent.futures.as_completed(future_to_dataset):
865
+ dataset = future_to_dataset[future]
866
+ result = future.result()
867
+ if result is not None:
868
+ _CACHE['sae_data_dict']["mean_act_values"][dataset] = result
869
+
870
+ return _CACHE['data_dict'], _CACHE['sae_data_dict']
871
 
872
+ def _load_image_file(image_file: str) -> Dict:
873
+ """Helper function to load a single image file."""
874
  try:
875
+ image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
876
+ return {
877
+ "image": image,
878
+ "image_path": image_file,
879
+ }
880
+ except Exception as e:
881
+ print(f"Error loading {image_file}: {e}")
882
+ return None
883
+
884
+ def _load_mean_act_values(dataset: str) -> np.ndarray:
885
+ """Helper function to load mean act values for a dataset."""
886
+ try:
887
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
888
+ return pickle.load(f)
 
 
 
 
 
 
889
  except Exception as e:
890
+ print(f"Error loading mean act values for {dataset}: {e}")
891
+ return None
892
+
893
+ @lru_cache(maxsize=1024)
894
+ def get_data(image_name: str, model_name: str) -> np.ndarray:
895
+ """Cached function to get model data."""
896
+ cache_key = f"{model_name}_{image_name}"
897
+ if cache_key not in _CACHE['model_data']:
898
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
899
+ with gzip.open(data_dir, "rb") as f:
900
+ _CACHE['model_data'][cache_key] = pickle.load(f)
901
+ return _CACHE['model_data'][cache_key]
902
+
903
+ @lru_cache(maxsize=1024)
904
+ def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
905
+ """Cached function to get activation distribution."""
906
+ activation = get_data(image_name, model_type)[0]
907
+ noisy_features_indices = (
908
+ (_CACHE['sae_data_dict']["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
909
+ )
910
+ activation[:, noisy_features_indices] = 0
911
+ return activation
912
+
913
+ @lru_cache(maxsize=1024)
914
+ def get_segmask(selected_image: str, slider_value: int, model_type: str) -> np.ndarray:
915
+ """Cached function to get segmentation mask."""
916
+ cache_key = f"{selected_image}_{slider_value}_{model_type}"
917
+ if cache_key not in _CACHE['segmasks']:
918
+ image = _CACHE['data_dict'][selected_image]["image"]
919
+ sae_act = get_data(selected_image, model_type)[0]
920
+ temp = sae_act[:, slider_value]
921
+
922
+ mask = torch.Tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
923
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
924
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
925
+
926
+ base_opacity = 30
927
+ image_array = np.array(image)[..., :3]
928
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
929
+ rgba_overlay[..., :3] = image_array[..., :3]
930
+
931
+ darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
932
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
933
+ rgba_overlay[..., 3] = 255
934
+
935
+ _CACHE['segmasks'][cache_key] = rgba_overlay
936
+
937
+ return _CACHE['segmasks'][cache_key]
938
+
939
+ @lru_cache(maxsize=1024)
940
+ def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
941
+ """Cached function to get top images."""
942
+ cache_key = f"{slider_value}_{toggle_btn}"
943
+ if cache_key not in _CACHE['top_images']:
944
+ dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
945
+ paths = [
946
+ os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
947
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
948
+ ]
949
+
950
+ _CACHE['top_images'][cache_key] = [
951
+ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
952
+ for path in paths
953
+ ]
954
+
955
+ return _CACHE['top_images'][cache_key]
956
+
957
+
958
+ # def preload_activation(image_name):
959
+ # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
960
+ # image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
961
+ # with gzip.open(image_file, "rb") as f:
962
+ # preloaded_data[model] = pickle.load(f)
963
+
964
+
965
+ # def get_activation_distribution(image_name: str, model_type: str):
966
+ # activation = get_data(image_name, model_type)[0]
967
+
968
+ # noisy_features_indices = (
969
+ # (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
970
+ # )
971
+ # activation[:, noisy_features_indices] = 0
972
+
973
+ # return activation
974
+
975
+
976
+ def get_grid_loc(evt, image):
977
+ # Get click coordinates
978
+ x, y = evt._data["index"][0], evt._data["index"][1]
979
+
980
+ cell_width = image.width // GRID_NUM
981
+ cell_height = image.height // GRID_NUM
982
+
983
+ grid_x = x // cell_width
984
+ grid_y = y // cell_height
985
+ return grid_x, grid_y, cell_width, cell_height
986
+
987
+
988
+ def highlight_grid(evt: gr.EventData, image_name):
989
+ image = data_dict[image_name]["image"]
990
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
991
+
992
+ highlighted_image = image.copy()
993
+ draw = ImageDraw.Draw(highlighted_image)
994
+ box = [
995
+ grid_x * cell_width,
996
+ grid_y * cell_height,
997
+ (grid_x + 1) * cell_width,
998
+ (grid_y + 1) * cell_height,
999
+ ]
1000
+ draw.rectangle(box, outline="red", width=3)
1001
+
1002
+ return highlighted_image
1003
+
1004
+
1005
+ def load_image(img_name):
1006
+ return Image.open(data_dict[img_name]["image_path"]).resize(
1007
+ (IMAGE_SIZE, IMAGE_SIZE)
1008
+ )
1009
+
1010
+
1011
+ def plot_activations(
1012
+ all_activation,
1013
+ tile_activations=None,
1014
+ grid_x=None,
1015
+ grid_y=None,
1016
+ top_k=5,
1017
+ colors=("blue", "cyan"),
1018
+ model_name="CLIP",
1019
+ ):
1020
+ fig = go.Figure()
1021
+
1022
+ def _add_scatter_with_annotation(fig, activations, model_name, color, label):
1023
+ fig.add_trace(
1024
+ go.Scatter(
1025
+ x=np.arange(len(activations)),
1026
+ y=activations,
1027
+ mode="lines",
1028
+ name=label,
1029
+ line=dict(color=color, dash="solid"),
1030
+ showlegend=True,
1031
+ )
1032
+ )
1033
+ top_neurons = np.argsort(activations)[::-1][:top_k]
1034
+ for idx in top_neurons:
1035
+ fig.add_annotation(
1036
+ x=idx,
1037
+ y=activations[idx],
1038
+ text=str(idx),
1039
+ showarrow=True,
1040
+ arrowhead=2,
1041
+ ax=0,
1042
+ ay=-15,
1043
+ arrowcolor=color,
1044
+ opacity=0.7,
1045
+ )
1046
+ return fig
1047
+
1048
+ label = f"{model_name.split('-')[-0]} Image-level"
1049
+ fig = _add_scatter_with_annotation(
1050
+ fig, all_activation, model_name, colors[0], label
1051
+ )
1052
+ if tile_activations is not None:
1053
+ label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
1054
+ fig = _add_scatter_with_annotation(
1055
+ fig, tile_activations, model_name, colors[1], label
1056
+ )
1057
+
1058
+ fig.update_layout(
1059
+ title="Activation Distribution",
1060
+ xaxis_title="SAE latent index",
1061
+ yaxis_title="Activation Value",
1062
+ template="plotly_white",
1063
+ )
1064
+ fig.update_layout(
1065
+ legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
1066
+ )
1067
+
1068
+ return fig
1069
+
1070
+
1071
+ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
1072
+ activation = get_activation_distribution(selected_image, model_name)
1073
+ all_activation = activation.mean(0)
1074
+
1075
+ tile_activations = None
1076
+ grid_x = None
1077
+ grid_y = None
1078
+
1079
+ if evt is not None:
1080
+ if evt._data is not None:
1081
+ image = data_dict[selected_image]["image"]
1082
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
1083
+ token_idx = grid_y * GRID_NUM + grid_x + 1
1084
+ tile_activations = activation[token_idx]
1085
+
1086
+ fig = plot_activations(
1087
+ all_activation,
1088
+ tile_activations,
1089
+ grid_x,
1090
+ grid_y,
1091
+ top_k=5,
1092
+ model_name=model_name,
1093
+ colors=colors,
1094
+ )
1095
+ return fig
1096
+
1097
+
1098
+ def plot_activation_distribution(
1099
+ evt: gr.EventData, selected_image: str, model_name: str
1100
+ ):
1101
+ fig = make_subplots(
1102
+ rows=2,
1103
+ cols=1,
1104
+ shared_xaxes=True,
1105
+ subplot_titles=["CLIP Activation", f"{model_name} Activation"],
1106
+ )
1107
+
1108
+ fig_clip = get_activations(
1109
+ evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
1110
+ )
1111
+ fig_maple = get_activations(
1112
+ evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
1113
+ )
1114
+
1115
+ def _attach_fig(fig, sub_fig, row, col, yref):
1116
+ for trace in sub_fig.data:
1117
+ fig.add_trace(trace, row=row, col=col)
1118
+
1119
+ for annotation in sub_fig.layout.annotations:
1120
+ annotation.update(yref=yref)
1121
+ fig.add_annotation(annotation)
1122
+ return fig
1123
+
1124
+ fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
1125
+ fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
1126
+
1127
+ fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
1128
+ fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
1129
+ fig.update_yaxes(title_text="Activation Value", row=1, col=1)
1130
+ fig.update_yaxes(title_text="Activation Value", row=2, col=1)
1131
+ fig.update_layout(
1132
+ # height=500,
1133
+ # title="Activation Distributions",
1134
+ template="plotly_white",
1135
+ showlegend=True,
1136
+ legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
1137
+ margin=dict(l=20, r=20, t=40, b=20),
1138
+ )
1139
+
1140
+ return fig
1141
+
1142
+
1143
+ # def get_segmask(selected_image, slider_value, model_type):
1144
+ # image = data_dict[selected_image]["image"]
1145
+ # sae_act = get_data(selected_image, model_type)[0]
1146
+ # temp = sae_act[:, slider_value]
1147
+ # try:
1148
+ # mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
1149
+ # except Exception as e:
1150
+ # print(sae_act.shape, slider_value)
1151
+ # mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
1152
+ # 0
1153
+ # ].numpy()
1154
+ # mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
1155
+
1156
+ # base_opacity = 30
1157
+ # image_array = np.array(image)[..., :3]
1158
+ # rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
1159
+ # rgba_overlay[..., :3] = image_array[..., :3]
1160
+
1161
+ # darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
1162
+ # rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
1163
+ # rgba_overlay[..., 3] = 255 # Fully opaque
1164
+
1165
+ # return rgba_overlay
1166
+
1167
+
1168
+ # def get_top_images(slider_value, toggle_btn):
1169
+ # def _get_images(dataset_path):
1170
+ # top_image_paths = [
1171
+ # os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
1172
+ # os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
1173
+ # os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
1174
+ # ]
1175
+ # top_images = [
1176
+ # (
1177
+ # Image.open(path)
1178
+ # if os.path.exists(path)
1179
+ # else Image.new("RGB", (256, 256), (255, 255, 255))
1180
+ # )
1181
+ # for path in top_image_paths
1182
+ # ]
1183
+ # return top_images
1184
+
1185
+ # if toggle_btn:
1186
+ # top_images = _get_images("./data/top_images_masked")
1187
+ # else:
1188
+ # top_images = _get_images("./data/top_images")
1189
+ # return top_images
1190
+
1191
+
1192
+ def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
1193
+ slider_value = int(slider_value.split("-")[-1])
1194
+ rgba_overlay = get_segmask(selected_image, slider_value, model_type)
1195
+ top_images = get_top_images(slider_value, toggle_btn)
1196
+
1197
+ act_values = []
1198
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
1199
+ act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
1200
+ act_value = [str(round(value, 3)) for value in act_value]
1201
+ act_value = " | ".join(act_value)
1202
+ out = f"#### Activation values: {act_value}"
1203
+ act_values.append(out)
1204
+
1205
+ return rgba_overlay, top_images, act_values
1206
+
1207
+
1208
+ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
1209
+ rgba_overlay, top_images, act_values = show_activation_heatmap(
1210
+ selected_image, slider_value, "CLIP", toggle_btn
1211
+ )
1212
+ sleep(0.1)
1213
+ return (
1214
+ rgba_overlay,
1215
+ top_images[0],
1216
+ top_images[1],
1217
+ top_images[2],
1218
+ act_values[0],
1219
+ act_values[1],
1220
+ act_values[2],
1221
+ )
1222
+
1223
+
1224
+ def show_activation_heatmap_maple(selected_image, slider_value, model_name):
1225
+ slider_value = int(slider_value.split("-")[-1])
1226
+ rgba_overlay = get_segmask(selected_image, slider_value, model_name)
1227
+ sleep(0.1)
1228
+ return rgba_overlay
1229
+
1230
+
1231
+ def get_init_radio_options(selected_image, model_name):
1232
+ clip_neuron_dict = {}
1233
+ maple_neuron_dict = {}
1234
+
1235
+ def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
1236
+ activations = get_activation_distribution(selected_image, model_name).mean(0)
1237
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
1238
+ for top_neuron in top_neurons:
1239
+ neuron_dict[top_neuron] = activations[top_neuron]
1240
+ sorted_dict = dict(
1241
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
1242
+ )
1243
+ return sorted_dict
1244
+
1245
+ clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
1246
+ maple_neuron_dict = _get_top_actvation(
1247
+ selected_image, model_name, maple_neuron_dict
1248
+ )
1249
+
1250
+ radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
1251
+
1252
+ return radio_choices
1253
+
1254
+
1255
+ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
1256
+ clip_keys = list(clip_neuron_dict.keys())
1257
+ maple_keys = list(maple_neuron_dict.keys())
1258
+
1259
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
1260
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
1261
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
1262
+
1263
+ common_keys.sort(
1264
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
1265
+ )
1266
+ clip_only_keys.sort(reverse=True)
1267
+ maple_only_keys.sort(reverse=True)
1268
+
1269
+ out = []
1270
+ out.extend([f"common-{i}" for i in common_keys[:5]])
1271
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
1272
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
1273
+
1274
+ return out
1275
+
1276
+
1277
+ def update_radio_options(evt: gr.EventData, selected_image, model_name):
1278
+ def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
1279
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
1280
+ for top_neuron in top_neurons:
1281
+ neuron_dict[top_neuron] = activations[top_neuron]
1282
+
1283
+ def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
1284
+ all_activation = get_activation_distribution(selected_image, model_name)
1285
+ image_activation = all_activation.mean(0)
1286
+ _sort_and_save_top_k(image_activation, neuron_dict)
1287
+
1288
+ if evt is not None:
1289
+ if evt._data is not None and isinstance(evt._data["index"], list):
1290
+ image = data_dict[selected_image]["image"]
1291
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
1292
+ token_idx = grid_y * GRID_NUM + grid_x + 1
1293
+ tile_activations = all_activation[token_idx]
1294
+ _sort_and_save_top_k(tile_activations, neuron_dict)
1295
+
1296
+ sorted_dict = dict(
1297
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
1298
+ )
1299
+ return sorted_dict
1300
+
1301
+ clip_neuron_dict = {}
1302
+ maple_neuron_dict = {}
1303
+ clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
1304
+ maple_neuron_dict = _get_top_actvation(
1305
+ evt, selected_image, model_name, maple_neuron_dict
1306
+ )
1307
+
1308
+ clip_keys = list(clip_neuron_dict.keys())
1309
+ maple_keys = list(maple_neuron_dict.keys())
1310
+
1311
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
1312
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
1313
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
1314
+
1315
+ common_keys.sort(
1316
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
1317
+ )
1318
+ clip_only_keys.sort(reverse=True)
1319
+ maple_only_keys.sort(reverse=True)
1320
+
1321
+ out = []
1322
+ out.extend([f"common-{i}" for i in common_keys[:5]])
1323
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
1324
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
1325
+
1326
+ radio_choices = gr.Radio(
1327
+ choices=out, label="Top activating SAE latent", value=out[0]
1328
+ )
1329
+ sleep(0.1)
1330
+ return radio_choices
1331
+
1332
+
1333
+ def update_markdown(option_value):
1334
+ latent_idx = int(option_value.split("-")[-1])
1335
+ out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
1336
+ out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
1337
+ return out_1, out_2
1338
+
1339
+
1340
+ def get_data(image_name, model_name):
1341
+ pkl_root = "./data/out"
1342
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
1343
+ with gzip.open(data_dir, "rb") as f:
1344
+ data = pickle.load(f)
1345
+ out = data
1346
+
1347
+ return out
1348
+
1349
+
1350
+ def update_all(selected_image, slider_value, toggle_btn, model_name):
1351
+ (
1352
+ seg_mask_display,
1353
+ top_image_1,
1354
+ top_image_2,
1355
+ top_image_3,
1356
+ act_value_1,
1357
+ act_value_2,
1358
+ act_value_3,
1359
+ ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
1360
+ seg_mask_display_maple = show_activation_heatmap_maple(
1361
+ selected_image, slider_value, model_name
1362
+ )
1363
+ markdown_display, markdown_display_2 = update_markdown(slider_value)
1364
+
1365
+ return (
1366
+ seg_mask_display,
1367
+ seg_mask_display_maple,
1368
+ top_image_1,
1369
+ top_image_2,
1370
+ top_image_3,
1371
+ act_value_1,
1372
+ act_value_2,
1373
+ act_value_3,
1374
+ markdown_display,
1375
+ markdown_display_2,
1376
+ )
1377
+
1378
+
1379
+ def load_all_data(image_root, pkl_root):
1380
+ image_files = glob(f"{image_root}/*")
1381
+ data_dict = {}
1382
+ for image_file in image_files:
1383
+ image_name = os.path.basename(image_file).split(".")[0]
1384
+ if image_file not in data_dict:
1385
+ data_dict[image_name] = {
1386
+ "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
1387
+ "image_path": image_file,
1388
+ }
1389
+
1390
+ sae_data_dict = {}
1391
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
1392
+ data = pickle.load(f)
1393
+ sae_data_dict["mean_acts"] = data
1394
+
1395
+ sae_data_dict["mean_act_values"] = {}
1396
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
1397
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
1398
+ data = pickle.load(f)
1399
+ sae_data_dict["mean_act_values"][dataset] = data
1400
+
1401
+ return data_dict, sae_data_dict
1402
+
1403
+
1404
+ def preload_all_model_data():
1405
+ """Preload all model data into memory at startup"""
1406
+ print("Preloading model data...")
1407
+ for image_name in data_dict.keys():
1408
+ for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
1409
+ try:
1410
+ data = get_data(image_name, model_name)
1411
+ cache_key = f"{model_name}_{image_name}"
1412
+ _CACHE['model_data'][cache_key] = data
1413
+ except Exception as e:
1414
+ print(f"Error preloading {cache_key}: {e}")
1415
+
1416
+ # Add to initialization
1417
+ preload_all_model_data()
1418
+
1419
+ def precompute_activations():
1420
+ """Precompute and cache common activation patterns"""
1421
+ print("Precomputing activations...")
1422
+ for image_name in data_dict.keys():
1423
+ for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
1424
+ activation = get_activation_distribution(image_name, model_name)
1425
+ cache_key = f"activation_{model_name}_{image_name}"
1426
+ _CACHE['precomputed_activations'][cache_key] = activation.mean(0)
1427
+
1428
+ # Add to _CACHE initialization
1429
+ _CACHE['precomputed_activations'] = {}
1430
+
1431
+ # Add to initialization
1432
+ precompute_activations()
1433
+
1434
+ def precompute_segmasks():
1435
+ """Precompute common segmentation masks"""
1436
+ print("Precomputing segmentation masks...")
1437
+ for image_name in data_dict.keys():
1438
+ for model_type in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
1439
+ for slider_value in range(0, 100): # Adjust range as needed
1440
+ try:
1441
+ mask = get_segmask(image_name, slider_value, model_type)
1442
+ cache_key = f"{image_name}_{slider_value}_{model_type}"
1443
+ _CACHE['segmasks'][cache_key] = mask
1444
+ except Exception as e:
1445
+ print(f"Error precomputing mask {cache_key}: {e}")
1446
+
1447
+ # Add to initialization
1448
+ precompute_segmasks()
1449
+
1450
+
1451
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
1452
+ default_image_name = "christmas-imagenet"
1453
+
1454
+
1455
+ with gr.Blocks(
1456
+ theme=gr.themes.Citrus(),
1457
+ css="""
1458
+ .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
1459
+ .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
1460
+ """,
1461
+ ) as demo:
1462
+ with gr.Row():
1463
+ with gr.Column():
1464
+ # Left View: Image selection and click handling
1465
+ gr.Markdown("## Select input image and patch on the image")
1466
+ image_selector = gr.Dropdown(
1467
+ choices=list(data_dict.keys()),
1468
+ value=default_image_name,
1469
+ label="Select Image",
1470
+ )
1471
+ image_display = gr.Image(
1472
+ value=data_dict[default_image_name]["image"],
1473
+ type="pil",
1474
+ interactive=True,
1475
+ )
1476
+
1477
+ # Update image display when a new image is selected
1478
+ image_selector.change(
1479
+ fn=lambda img_name: data_dict[img_name]["image"],
1480
+ inputs=image_selector,
1481
+ outputs=image_display,
1482
+ )
1483
+ image_display.select(
1484
+ fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
1485
+ )
1486
+
1487
+ with gr.Column():
1488
+ gr.Markdown("## SAE latent activations of CLIP and MaPLE")
1489
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
1490
+ model_selector = gr.Dropdown(
1491
+ choices=model_options,
1492
+ value=model_options[0],
1493
+ label="Select adapted model (MaPLe)",
1494
+ )
1495
+ init_plot = plot_activation_distribution(
1496
+ None, default_image_name, model_options[0]
1497
+ )
1498
+ neuron_plot = gr.Plot(
1499
+ label="Neuron Activation", value=init_plot, show_label=False
1500
+ )
1501
+
1502
+ image_selector.change(
1503
+ fn=plot_activation_distribution,
1504
+ inputs=[image_selector, model_selector],
1505
+ outputs=neuron_plot,
1506
+ )
1507
+ image_display.select(
1508
+ fn=plot_activation_distribution,
1509
+ inputs=[image_selector, model_selector],
1510
+ outputs=neuron_plot,
1511
+ )
1512
+ model_selector.change(
1513
+ fn=load_image, inputs=[image_selector], outputs=image_display
1514
+ )
1515
+ model_selector.change(
1516
+ fn=plot_activation_distribution,
1517
+ inputs=[image_selector, model_selector],
1518
+ outputs=neuron_plot,
1519
+ )
1520
+
1521
+ with gr.Row():
1522
+ with gr.Column():
1523
+ radio_names = get_init_radio_options(default_image_name, model_options[0])
1524
+
1525
+ feautre_idx = radio_names[0].split("-")[-1]
1526
+ markdown_display = gr.Markdown(
1527
+ f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
1528
+ )
1529
+ init_seg, init_tops, init_values = show_activation_heatmap(
1530
+ default_image_name, radio_names[0], "CLIP"
1531
+ )
1532
+
1533
+ gr.Markdown("### Localize SAE latent activation using CLIP")
1534
+ seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
1535
+ init_seg_maple, _, _ = show_activation_heatmap(
1536
+ default_image_name, radio_names[0], model_options[0]
1537
+ )
1538
+ gr.Markdown("### Localize SAE latent activation using MaPLE")
1539
+ seg_mask_display_maple = gr.Image(
1540
+ value=init_seg_maple, type="pil", show_label=False
1541
+ )
1542
+
1543
+ with gr.Column():
1544
+ gr.Markdown("## Top activating SAE latent index")
1545
+
1546
+ radio_choices = gr.Radio(
1547
+ choices=radio_names,
1548
+ label="Top activating SAE latent",
1549
+ interactive=True,
1550
+ value=radio_names[0],
1551
+ )
1552
+ toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
1553
+
1554
+ markdown_display_2 = gr.Markdown(
1555
+ f"## Top reference images for the selected SAE latent - {feautre_idx}"
1556
+ )
1557
+
1558
+ gr.Markdown("### ImageNet")
1559
+ top_image_1 = gr.Image(
1560
+ value=init_tops[0], type="pil", label="ImageNet", show_label=False
1561
+ )
1562
+ act_value_1 = gr.Markdown(init_values[0])
1563
+
1564
+ gr.Markdown("### ImageNet-Sketch")
1565
+ top_image_2 = gr.Image(
1566
+ value=init_tops[1],
1567
+ type="pil",
1568
+ label="ImageNet-Sketch",
1569
+ show_label=False,
1570
+ )
1571
+ act_value_2 = gr.Markdown(init_values[1])
1572
+
1573
+ gr.Markdown("### Caltech101")
1574
+ top_image_3 = gr.Image(
1575
+ value=init_tops[2], type="pil", label="Caltech101", show_label=False
1576
+ )
1577
+ act_value_3 = gr.Markdown(init_values[2])
1578
+
1579
+ image_display.select(
1580
+ fn=update_radio_options,
1581
+ inputs=[image_selector, model_selector],
1582
+ outputs=[radio_choices],
1583
+ )
1584
+
1585
+ model_selector.change(
1586
+ fn=update_radio_options,
1587
+ inputs=[image_selector, model_selector],
1588
+ outputs=[radio_choices],
1589
+ )
1590
+
1591
+ image_selector.select(
1592
+ fn=update_radio_options,
1593
+ inputs=[image_selector, model_selector],
1594
+ outputs=[radio_choices],
1595
+ )
1596
+
1597
+ radio_choices.change(
1598
+ fn=update_all,
1599
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
1600
+ outputs=[
1601
+ seg_mask_display,
1602
+ seg_mask_display_maple,
1603
+ top_image_1,
1604
+ top_image_2,
1605
+ top_image_3,
1606
+ act_value_1,
1607
+ act_value_2,
1608
+ act_value_3,
1609
+ markdown_display,
1610
+ markdown_display_2,
1611
+ ],
1612
+ )
1613
+
1614
+ toggle_btn.change(
1615
+ fn=show_activation_heatmap_clip,
1616
+ inputs=[image_selector, radio_choices, toggle_btn],
1617
+ outputs=[
1618
+ seg_mask_display,
1619
+ top_image_1,
1620
+ top_image_2,
1621
+ top_image_3,
1622
+ act_value_1,
1623
+ act_value_2,
1624
+ act_value_3,
1625
+ ],
1626
+ )
1627
+
1628
+ # Launch the app
1629
+ # demo.queue()
1630
+ # demo.launch()
1631
+
1632
+
1633
+ # if __name__ == "__main__":
1634
+ # demo.queue() # Enable queuing for better handling of concurrent users
1635
+ # demo.launch(
1636
+ # server_name="0.0.0.0", # Allow external access
1637
+ # server_port=7860,
1638
+ # share=False, # Set to True if you want to create a public URL
1639
+ # show_error=True,
1640
+ # # Optimize concurrency
1641
+ # max_threads=8, # Adjust based on your CPU cores
1642
+ # )
1643
+
1644
+ if __name__ == "__main__":
1645
+ import psutil
1646
+
1647
+ # Get system memory info
1648
+ mem = psutil.virtual_memory()
1649
+ total_ram_gb = mem.total / (1024**3)
1650
+
1651
+ # Configure cache sizes based on available RAM
1652
+ cache_size = int(total_ram_gb * 100) # Rough estimate: 100 entries per GB
1653
+
1654
+ # Memory monitoring function
1655
+ def monitor_memory_usage():
1656
+ """Monitor and log memory usage"""
1657
+ process = psutil.Process()
1658
+ mem_info = process.memory_info()
1659
+ print(f"""
1660
+ Memory Usage:
1661
+ - RSS: {mem_info.rss / (1024**2):.2f} MB
1662
+ - VMS: {mem_info.vms / (1024**2):.2f} MB
1663
+ - Cache Size: {len(_CACHE['model_data'])} entries
1664
+ """)
1665
+
1666
+ # Start periodic monitoring
1667
+ def start_memory_monitor():
1668
+ threading.Timer(300.0, start_memory_monitor).start() # Every 5 minutes
1669
+ monitor_memory_usage()
1670
+
1671
+ # Start the monitoring
1672
+ import threading
1673
+ start_memory_monitor()
1674
+
1675
+ # Launch the app with memory-optimized settings
1676
+ demo.queue(max_size=min(20, int(total_ram_gb))) # Scale queue with RAM
1677
+ demo.launch(
1678
+ server_name="0.0.0.0",
1679
+ server_port=7860,
1680
+ share=False,
1681
+ show_error=True,
1682
+ max_threads=min(16, psutil.cpu_count()), # Scale threads with CPU
1683
+ websocket_ping_timeout=60,
1684
+ preventive_refresh=True,
1685
+ memory_limit_mb=int(total_ram_gb * 1024 * 0.8) # Use up to 80% of RAM
1686
+ )