Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import List | |
| import deepinv as dinv | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric | |
| DEVICE_STR = 'cuda' | |
| ### Gradio Utils | |
| def generate_imgs(dataset: EvalDataset, idx: int, | |
| model: EvalModel, baseline: BaselineModel, | |
| physics: PhysicsWithGenerator, use_gen: bool, | |
| metrics: List[Metric]): | |
| ### Load 1 image | |
| x = dataset[idx] # shape : (3, 256, 256) | |
| x = x.unsqueeze(0) # shape : (1, 3, 256, 256) | |
| with torch.no_grad(): | |
| ### Compute y | |
| y = physics(x, use_gen) # possible reduction in img shape due to Blurring | |
| ### Compute x_hat | |
| out = model(y=y, physics=physics.physics) | |
| out_baseline = baseline(y=y, physics=physics.physics) | |
| ### Process tensors before metric computation | |
| if "Blur" in physics.name: | |
| w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 | |
| h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 | |
| x = x[..., w_1:w_2, h_1:h_2] | |
| out = out[..., w_1:w_2, h_1:h_2] | |
| if out_baseline.shape != out.shape: | |
| out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] | |
| ### Metrics | |
| metrics_y = "" | |
| metrics_out = "" | |
| metrics_out_baseline = "" | |
| for metric in metrics: | |
| if y.shape == x.shape: | |
| metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n" | |
| metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" | |
| metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" | |
| ### Process y when y shape is different from x shape | |
| if physics.name == "MRI" or "SR" in physics.name: | |
| y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) | |
| else: | |
| y_plot = y.clone() | |
| ### Processing images for plotting : | |
| # - clip value outside of [0,1] | |
| # - shape (1, C, H, W) -> (C, H, W) | |
| # - torch.Tensor object -> Pil object | |
| process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") | |
| to_pil = transforms.ToPILImage() | |
| x = to_pil(process_img(x)[0].to('cpu')) | |
| y = to_pil(process_img(y_plot)[0].to('cpu')) | |
| out = to_pil(process_img(out)[0].to('cpu')) | |
| out_baseline = to_pil(process_img(out_baseline)[0].to('cpu')) | |
| return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline | |
| def update_random_idx_and_generate_imgs(dataset: EvalDataset, | |
| model: EvalModel, | |
| baseline: BaselineModel, | |
| physics: PhysicsWithGenerator, | |
| use_gen: bool, | |
| metrics: List[Metric]): | |
| idx = random.randint(0, len(dataset)-1) | |
| x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset, | |
| idx, | |
| model, | |
| baseline, | |
| physics, | |
| use_gen, | |
| metrics) | |
| return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline | |
| def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator, | |
| model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel, | |
| x: Image.Image, y: Image.Image, | |
| out_a: Image.Image, out_b: Image.Image, | |
| y_metrics_str: str, | |
| out_a_metric_str : str, out_b_metric_str: str) -> None: | |
| ### PROCESSES STR | |
| physics_params_str = "" | |
| for param_name, param_value in physics.saved_params["updatable_params"].items(): | |
| physics_params_str += f"{param_name}_{param_value}-" | |
| physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str | |
| y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-") | |
| y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str | |
| out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-") | |
| out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str | |
| out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-") | |
| out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str | |
| save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png" | |
| titles = [f"{dataset.name}[{idx}]", | |
| f"y = {physics.name}(x)", | |
| f"{model_a.name}", | |
| f"{model_b.name}"] | |
| # Pil object -> torch.Tensor | |
| to_tensor = transforms.ToTensor() | |
| x = to_tensor(x) | |
| y = to_tensor(y) | |
| out_a = to_tensor(out_a) | |
| out_b = to_tensor(out_b) | |
| dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path) | |
| get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR) | |
| get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR) | |
| get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) | |
| get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) | |
| get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) | |
| def get_physics(physics_name): | |
| if physics_name == 'MRI': | |
| baseline = get_baseline_model_on_DEVICE_STR('DPIR_MRI') | |
| elif physics_name == 'CT': | |
| baseline = get_baseline_model_on_DEVICE_STR('DPIR_CT') | |
| else: | |
| baseline = get_baseline_model_on_DEVICE_STR('DPIR') | |
| return get_physics_on_DEVICE_STR(physics_name), baseline | |
| def get_model(model_name, ckpt_pth): | |
| if model_name in BaselineModel.all_baselines: | |
| return get_baseline_model_on_DEVICE_STR(model_name) | |
| else: | |
| return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth) | |
| AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics | |
| def get_dataset(dataset_name): | |
| global AVAILABLE_PHYSICS | |
| if dataset_name = 'MRI': | |
| AVAILABLE_PHYSICS = ['MRI'] | |
| elif dataset_name = 'CT': | |
| AVAILABLE_PHYSICS = ['CT'] | |
| else: | |
| AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'] | |
| return get_dataset_on_DEVICE_STR(dataset_name) | |
| ### Gradio Blocks interface | |
| # Define custom CSS | |
| custom_css = """ | |
| .fixed-textbox textarea { | |
| height: 90px !important; /* Adjust height to fit exactly 4 lines */ | |
| overflow: scroll; /* Add a scroll bar if necessary */ | |
| resize: none; /* User can resize vertically the textbox */ | |
| } | |
| """ | |
| title = "Inverse problem playground" # displayed on gradio tab and in the gradio page | |
| with gr.Blocks(title=title, css=custom_css) as interface: | |
| gr.Markdown("## " + title) | |
| # Loading things | |
| model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State | |
| model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) # lambda expression to instanciate a callable in a gr.State | |
| dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural")) | |
| physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State | |
| metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"])) | |
| def dynamic_layout(dataset, physics, metrics): | |
| ### LAYOUT | |
| dataset_name = dataset.name | |
| physics_name = physics.name | |
| metric_names = [metric.name for metric in metrics] | |
| # Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False) | |
| physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params()) | |
| with gr.Column(): | |
| y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False) | |
| y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],) | |
| choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics, | |
| label="List of PhysicsWithGenerator", | |
| value=physics_name) | |
| with gr.Row(): | |
| key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), | |
| label="Updatable Parameter Key", | |
| scale=2) | |
| value_text = gr.Textbox(label="Update Value", scale=2) | |
| update_button = gr.Button("Manually update parameter value", scale=1) | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_a_out = gr.Image(label="RAM OUTPUT", interactive=False) | |
| out_a_metric = gr.Textbox(label="Metrics(RAM(y, physics), x)", elem_classes=["fixed-textbox"]) | |
| with gr.Column(): | |
| model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False) | |
| out_b_metric = gr.Textbox(label="Metrics(DPIR(y, physics), x)", elem_classes=["fixed-textbox"]) | |
| with gr.Row(): | |
| choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, | |
| label="List of EvalDataset", | |
| value=dataset_name, | |
| scale=2) | |
| idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1) | |
| # Components: Load Metric + Load image Buttons | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics, | |
| value=metric_names, | |
| label="Choose metrics you are interested") | |
| use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1) | |
| with gr.Column(scale=1): | |
| load_button = gr.Button("Load images...") | |
| load_random_button = gr.Button("Load randomly...") | |
| ### Event listeners | |
| choose_dataset.change(fn=get_dataset_on_DEVICE_STR, | |
| inputs=choose_dataset, | |
| outputs=dataset_placeholder) | |
| choose_physics.change(fn=get_physics, | |
| inputs=choose_physics, | |
| outputs=[physics_placeholder, model_b_placeholder]) | |
| update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params) | |
| choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR, | |
| inputs=choose_metrics, | |
| outputs=metrics_placeholder) | |
| load_button.click(fn=generate_imgs, | |
| inputs=[dataset_placeholder, | |
| idx_slider, | |
| model_a_placeholder, | |
| model_b_placeholder, | |
| physics_placeholder, | |
| use_generator_button, | |
| metrics_placeholder], | |
| outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric]) | |
| load_random_button.click(fn=update_random_idx_and_generate_imgs, | |
| inputs=[dataset_placeholder, | |
| model_a_placeholder, | |
| model_b_placeholder, | |
| physics_placeholder, | |
| use_generator_button, | |
| metrics_placeholder], | |
| outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric]) | |
| interface.launch() | |