hyesulim commited on
Commit
bdd4e6b
·
verified ·
1 Parent(s): c6fcf0b

test: add timeout handling

Browse files
Files changed (1) hide show
  1. app.py +141 -184
app.py CHANGED
@@ -609,191 +609,148 @@ def load_all_data(image_root, pkl_root):
609
  default_image_name = "christmas-imagenet"
610
 
611
 
612
- with gr.Blocks(
613
- theme=gr.themes.Citrus(),
614
- css="""
615
- .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
616
- .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
617
- """,
618
- ) as demo:
619
- with gr.Row():
620
- with gr.Column():
621
- # Left View: Image selection and click handling
622
- gr.Markdown("## Select input image and patch on the image")
623
- image_selector = gr.Dropdown(
624
- choices=list(data_dict.keys()),
625
- value=default_image_name,
626
- label="Select Image",
627
- )
628
- image_display = gr.Image(
629
- value=data_dict[default_image_name]["image"],
630
- type="pil",
631
- interactive=True,
632
- )
633
-
634
- # Update image display when a new image is selected
635
- image_selector.change(
636
- fn=lambda img_name: data_dict[img_name]["image"],
637
- inputs=image_selector,
638
- outputs=image_display,
639
- )
640
- image_display.select(
641
- fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
642
- )
643
-
644
- with gr.Column():
645
- gr.Markdown("## SAE latent activations of CLIP and MaPLE")
646
- model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
647
- model_selector = gr.Dropdown(
648
- choices=model_options,
649
- value=model_options[0],
650
- label="Select adapted model (MaPLe)",
651
- )
652
- init_plot = plot_activation_distribution(
653
- None, default_image_name, model_options[0]
654
- )
655
- neuron_plot = gr.Plot(
656
- label="Neuron Activation", value=init_plot, show_label=False
657
- )
658
-
659
- image_selector.change(
660
- fn=plot_activation_distribution,
661
- inputs=[image_selector, model_selector],
662
- outputs=neuron_plot,
663
- )
664
- image_display.select(
665
- fn=plot_activation_distribution,
666
- inputs=[image_selector, model_selector],
667
- outputs=neuron_plot,
668
- )
669
- model_selector.change(
670
- fn=load_image, inputs=[image_selector], outputs=image_display
671
- )
672
- model_selector.change(
673
- fn=plot_activation_distribution,
674
- inputs=[image_selector, model_selector],
675
- outputs=neuron_plot,
676
- )
677
-
678
- with gr.Row():
679
- with gr.Column():
680
- radio_names = get_init_radio_options(default_image_name, model_options[0])
681
-
682
- feautre_idx = radio_names[0].split("-")[-1]
683
- markdown_display = gr.Markdown(
684
- f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
685
- )
686
- init_seg, init_tops, init_values = show_activation_heatmap(
687
- default_image_name, radio_names[0], "CLIP"
688
- )
689
-
690
- gr.Markdown("### Localize SAE latent activation using CLIP")
691
- seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
692
- init_seg_maple, _, _ = show_activation_heatmap(
693
- default_image_name, radio_names[0], model_options[0]
694
- )
695
- gr.Markdown("### Localize SAE latent activation using MaPLE")
696
- seg_mask_display_maple = gr.Image(
697
- value=init_seg_maple, type="pil", show_label=False
698
- )
699
-
700
- with gr.Column():
701
- gr.Markdown("## Top activating SAE latent index")
702
-
703
- radio_choices = gr.Radio(
704
- choices=radio_names,
705
- label="Top activating SAE latent",
706
- interactive=True,
707
- value=radio_names[0],
708
- )
709
- toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
710
-
711
- markdown_display_2 = gr.Markdown(
712
- f"## Top reference images for the selected SAE latent - {feautre_idx}"
713
- )
714
-
715
- gr.Markdown("### ImageNet")
716
- top_image_1 = gr.Image(
717
- value=init_tops[0], type="pil", label="ImageNet", show_label=False
718
- )
719
- act_value_1 = gr.Markdown(init_values[0])
720
-
721
- gr.Markdown("### ImageNet-Sketch")
722
- top_image_2 = gr.Image(
723
- value=init_tops[1],
724
- type="pil",
725
- label="ImageNet-Sketch",
726
- show_label=False,
727
- )
728
- act_value_2 = gr.Markdown(init_values[1])
729
-
730
- gr.Markdown("### Caltech101")
731
- top_image_3 = gr.Image(
732
- value=init_tops[2], type="pil", label="Caltech101", show_label=False
733
- )
734
- act_value_3 = gr.Markdown(init_values[2])
735
-
736
- image_display.select(
737
- fn=update_radio_options,
738
- inputs=[image_selector, model_selector],
739
- outputs=[radio_choices],
740
- )
741
-
742
- model_selector.change(
743
- fn=update_radio_options,
744
- inputs=[image_selector, model_selector],
745
- outputs=[radio_choices],
746
- )
747
-
748
- image_selector.select(
749
- fn=update_radio_options,
750
- inputs=[image_selector, model_selector],
751
- outputs=[radio_choices],
752
- )
753
-
754
- radio_choices.change(
755
- fn=update_all,
756
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
757
- outputs=[
758
- seg_mask_display,
759
- seg_mask_display_maple,
760
- top_image_1,
761
- top_image_2,
762
- top_image_3,
763
- act_value_1,
764
- act_value_2,
765
- act_value_3,
766
- markdown_display,
767
- markdown_display_2,
768
- ],
769
- )
770
-
771
- toggle_btn.change(
772
- fn=show_activation_heatmap_clip,
773
- inputs=[image_selector, radio_choices, toggle_btn],
774
- outputs=[
775
- seg_mask_display,
776
- top_image_1,
777
- top_image_2,
778
- top_image_3,
779
- act_value_1,
780
- act_value_2,
781
- act_value_3,
782
- ],
783
- )
784
 
785
- # Launch the app
786
- # demo.queue()
787
- # demo.launch()
788
 
789
-
790
  if __name__ == "__main__":
791
- demo.queue() # Enable queuing for better handling of concurrent users
792
- demo.launch(
793
- server_name="0.0.0.0", # Allow external access
794
- server_port=7860,
795
- share=False, # Set to True if you want to create a public URL
796
- show_error=True,
797
- # Optimize concurrency
798
- max_threads=8, # Adjust based on your CPU cores
799
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  default_image_name = "christmas-imagenet"
610
 
611
 
612
+ def safe_operation(func, *args, timeout=10, default=None):
613
+ """Execute function with timeout and return default value on failure."""
614
+ try:
615
+ with concurrent.futures.ThreadPoolExecutor() as executor:
616
+ future = executor.submit(func, *args)
617
+ return future.result(timeout=timeout)
618
+ except concurrent.futures.TimeoutError:
619
+ print(f"Operation timed out: {func.__name__}")
620
+ return default
621
+ except Exception as e:
622
+ print(f"Operation failed: {func.__name__}, Error: {e}")
623
+ return default
624
+
625
+ def create_interface():
626
+ with gr.Blocks(
627
+ theme=gr.themes.Citrus(),
628
+ css="""
629
+ .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
630
+ .image-row img { width: auto; height: 50px; }
631
+ """,
632
+ ) as demo:
633
+ # State management for preventing duplicate operations
634
+ state = gr.State({
635
+ 'last_update': 0,
636
+ 'processing': False
637
+ })
638
+
639
+ with gr.Row():
640
+ with gr.Column():
641
+ gr.Markdown("## Select input image and patch on the image")
642
+ image_selector = gr.Dropdown(
643
+ choices=list(_CACHE['data_dict'].keys()),
644
+ value=default_image_name,
645
+ label="Select Image",
646
+ )
647
+ image_display = gr.Image(
648
+ value=_CACHE['data_dict'][default_image_name]["image"],
649
+ type="pil",
650
+ interactive=True,
651
+ )
652
+
653
+ def safe_update_image(img_name):
654
+ return safe_operation(
655
+ lambda: _CACHE['data_dict'][img_name]["image"],
656
+ timeout=5,
657
+ default=Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), 'gray')
658
+ )
659
+
660
+ image_selector.change(
661
+ fn=safe_update_image,
662
+ inputs=image_selector,
663
+ outputs=image_display,
664
+ )
665
+
666
+ with gr.Column():
667
+ gr.Markdown("## SAE latent activations of CLIP and MaPLE")
668
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
669
+ model_selector = gr.Dropdown(
670
+ choices=model_options,
671
+ value=model_options[0],
672
+ label="Select adapted model (MaPLe)",
673
+ )
674
+
675
+ def safe_plot_activation(evt, selected_image, model_name):
676
+ try:
677
+ return safe_operation(
678
+ plot_activation_distribution,
679
+ evt, selected_image, model_name,
680
+ timeout=10,
681
+ default=go.Figure() # Return empty figure on timeout
682
+ )
683
+ except Exception as e:
684
+ print(f"Error in plot_activation: {e}")
685
+ return go.Figure()
686
+
687
+ neuron_plot = gr.Plot(
688
+ label="Neuron Activation",
689
+ value=safe_plot_activation(None, default_image_name, model_options[0]),
690
+ show_label=False,
691
+ )
692
+
693
+ def debounced_update(*args, minimum_interval=1.0):
694
+ """Prevent updates that are too close together"""
695
+ current_time = time.time()
696
+ if current_time - state.value['last_update'] < minimum_interval:
697
+ raise gr.Error("Please wait a moment before updating again")
698
+ state.value['last_update'] = current_time
699
+ return safe_plot_activation(*args)
700
+
701
+ # Add error handling and debouncing to all event handlers
702
+ image_selector.change(
703
+ fn=debounced_update,
704
+ inputs=[image_selector, model_selector],
705
+ outputs=neuron_plot,
706
+ )
707
+ image_display.select(
708
+ fn=debounced_update,
709
+ inputs=[image_selector, model_selector],
710
+ outputs=neuron_plot,
711
+ )
712
+ model_selector.change(
713
+ fn=debounced_update,
714
+ inputs=[image_selector, model_selector],
715
+ outputs=neuron_plot,
716
+ )
717
+
718
+ # Rest of your interface code with similar error handling...
719
+
720
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
 
722
+ demo = create_interface()
 
 
723
 
 
724
  if __name__ == "__main__":
725
+ # Configure logging
726
+ import logging
727
+ logging.basicConfig(
728
+ level=logging.INFO,
729
+ format='%(asctime)s - %(levelname)s - %(message)s'
 
 
 
730
  )
731
+
732
+ # Initialize the interface with error handling
733
+ try:
734
+ demo = create_interface()
735
+ demo.queue(concurrency_count=2) # Limit concurrent operations
736
+ demo.launch(
737
+ server_name="0.0.0.0",
738
+ server_port=7860,
739
+ share=False,
740
+ show_error=True,
741
+ max_threads=4, # Reduced for better stability
742
+ prevent_thread_lock=True,
743
+ # Add health check endpoint
744
+ root_path="/healthz",
745
+ # Add automatic cleanup
746
+ _teardown=lambda: (
747
+ clear_caches(),
748
+ logging.info("Caches cleared during teardown")
749
+ ),
750
+ # Configure timeouts
751
+ api_open_timeout=60,
752
+ api_call_timeout=300,
753
+ )
754
+ except Exception as e:
755
+ logging.error(f"Failed to launch interface: {e}")
756
+ raise