import gradio as gr from .compute import run_model from .logger import setup_logger, read_logs from .utils import load_ct_to_numpy # setup logging LOGGER = setup_logger() class WebUI: def __init__( self, model_name: str = None, cwd: str = "/home/user/app/", share: int = 1, ): # global states self.fixed_images = [] self.moving_images = [] self.pred_images = [] # @TODO: This should be dynamically set based on chosen volume size self.nb_slider_items = 128 self.model_name = model_name self.cwd = cwd self.share = share self.class_names = { "Brain": "B", "Liver": "L" } self.class_name = "Brain" self.fixed_image_path = None self.moving_image_path = None self.fixed_seg_path = None self.moving_seg_path = None # define widgets not to be rendered immediantly, but later on self.slider = gr.Slider( 1, self.nb_slider_items, value=1, step=1, label="Which 2D slice to show", ) self.run_btn = gr.Button("Run analysis", show_progress="full", elem_id="button").style( full_width=False, size="lg" ) def set_class_name(self, value): LOGGER.info(f"Changed task to: {value}") self.class_name = value def upload_file(self, files): return [f.name for f in files] def update_fixed(self, cfile): self.fixed_image_path = cfile.name return self.fixed_image_path def update_moving(self, cfile): self.moving_image_path = cfile.name return self.moving_image_path def update_fixed_seg(self, cfile): self.fixed_seg_path = cfile.name return self.fixed_seg_path def update_moving_seg(self, cfile): self.moving_seg_path = cfile.name return self.moving_seg_path def process(self): if (self.fixed_image_path is None) or (self.moving_image_path is None): raise ValueError("Please, select both a fixed and moving image before running inference.") output_path = self.cwd run_model(self.fixed_image_path, self.moving_image_path, self.fixed_seg_path, self.moving_seg_path, output_path, self.class_names[self.class_name]) # reset - to avoid using these segmentations again for new images self.fixed_seg_path = None self.moving_seg_path = None self.fixed_images = load_ct_to_numpy(self.fixed_image_path) self.moving_images = load_ct_to_numpy(self.moving_image_path) self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz") return self.pred_images[0] def get_fixed_image(self, k): k = int(k) - 1 out = [gr.Image.update(visible=False)] * self.nb_slider_items out[k] = gr.Image.update( self.fixed_images[k], visible=True, ) return out def get_moving_image(self, k): k = int(k) - 1 out = [gr.Image.update(visible=False)] * self.nb_slider_items out[k] = gr.Image.update( self.moving_images[k], visible=True, ) return out def get_pred_image(self, k): k = int(k) - 1 out = [gr.Image.update(visible=False)] * self.nb_slider_items out[k] = gr.Image.update( self.pred_images[k], visible=True, ) return out def run(self): css = """ #model-2d { height: 512px; margin: auto; } #upload { height: 120px; } #button { height: 120px; } #dropdown { height: 120px; } """ with gr.Blocks(css=css) as demo: with gr.Row(): with gr.Column(): file_fixed = gr.File(file_count="single", elem_id="upload", label="Select Fixed Image", show_label=True) file_fixed.upload(self.update_fixed, file_fixed, file_fixed) file_moving = gr.File(file_count="single", elem_id="upload", label="Select Moving Image", show_label=True) file_moving.upload(self.update_moving, file_moving, file_moving) #with gr.Group(): with gr.Column(): file_fixed_seg = gr.File(file_count="single", elem_id="upload", label="Select Fixed Seg Image", show_label=True) file_fixed_seg.upload(self.update_fixed_seg, file_fixed_seg, file_fixed_seg) file_moving_seg = gr.File(file_count="single", elem_id="upload", label="Select Moving Seg Image", show_label=True) file_moving_seg.upload(self.update_moving_seg, file_moving_seg, file_moving_seg) with gr.Column(): model_selector = gr.Dropdown( list(self.class_names.keys()), label="Task", info="Which task to perform image-to-registration on", multiselect=False, size="sm", default="Brain", elem_id="dropdown", ) model_selector.input( fn=lambda x: self.set_class_name(x), inputs=model_selector, outputs=None, ) self.run_btn.render() logs = gr.Textbox(label="Logs", info="Verbose from inference will be displayed below.", max_lines=8, autoscroll=True) demo.load(read_logs, None, logs, every=1) with gr.Row(): with gr.Box(): with gr.Column(): with gr.Row(): fixed_images = [] for i in range(self.nb_slider_items): visibility = True if i == 1 else False t = gr.Image( visible=visibility, elem_id="model-2d", label="fixed image", show_label=True, ).style( height=512, width=512, ) fixed_images.append(t) moving_images = [] for i in range(self.nb_slider_items): visibility = True if i == 1 else False t = gr.Image( visible=visibility, elem_id="model-2d", label="moving image", show_label=True, ).style( height=512, width=512, ) moving_images.append(t) pred_images = [] for i in range(self.nb_slider_items): if i == 0: first_pred_component = t visibility = True if i == 1 else False t = gr.Image( visible=visibility, elem_id="model-2d", label="predicted fixed image", show_label=True, ).style( height=512, width=512, ) pred_images.append(t) self.run_btn.click( fn=self.process, inputs=None, outputs=first_pred_component, ) self.slider.input( self.get_fixed_image, self.slider, fixed_images ) self.slider.input( self.get_moving_image, self.slider, moving_images ) self.slider.input( self.get_pred_image, self.slider, pred_images ) self.slider.render() # sharing app publicly -> share=True: # https://gradio.app/sharing-your-app/ # inference times > 60 seconds -> need queue(): # https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062 demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=self.share )