import numpy as np import gradio as gr from huggingface_hub import hf_hub_download import matplotlib.pyplot as plt import SimpleITK as sitk # noqa: N813 import torch from monai.transforms import Compose, ScaleIntensityd, SpatialPadd from cinema import ConvUNetR from pathlib import Path import spaces # cache directories cache_dir = Path("/tmp/.cinema") cache_dir.mkdir(parents=True, exist_ok=True) @spaces.GPU def inferece( images: torch.Tensor, view: str, transform: Compose, model: ConvUNetR, progress=gr.Progress(), ) -> np.ndarray: # set device and dtype dtype, device = torch.float32, torch.device("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() device = torch.device("cuda") if torch.cuda.is_bf16_supported(): dtype = torch.bfloat16 # inference model.to(device) n_slices, n_frames = images.shape[-2:] labels_list = [] for t in range(0, n_frames): progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") batch = transform({view: torch.from_numpy(images[None, ..., t])}) batch = { k: v[None, ...].to(device=device, dtype=torch.float32) for k, v in batch.items() } with ( torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), ): logits = model(batch)[view] labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices]) labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy() return labels def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()): # Fixed parameters view = "sax" split = "train" if image_id <= 100 else "test" trained_dataset = { "ACDC": "acdc", "M&MS": "mnms", "M&MS2": "mnms2", }[str(trained_dataset)] # Download and load model progress(0, desc="Downloading model and data...") image_path = hf_hub_download( repo_id="mathpluscode/ACDC", repo_type="dataset", filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz", cache_dir=cache_dir, ) model = ConvUNetR.from_finetuned( repo_id="mathpluscode/CineMA", model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", cache_dir=cache_dir, ) # Load and process data transform = Compose( [ ScaleIntensityd(keys=view), SpatialPadd( keys=view, spatial_size=(192, 192, 16), method="end", lazy=True, allow_missing_keys=True, ), ] ) images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path))) images = images[..., ::t_step] labels = inferece(images, view, transform, model, progress) progress(1, desc="Plotting results...") # Create segmentation visualization n_slices, n_frames = labels.shape[-2:] fig1, axs = plt.subplots(n_frames, n_slices, figsize=(n_slices, n_frames), dpi=300) for t in range(n_frames): for z in range(n_slices): axs[t, z].imshow(images[..., z, t], cmap="gray") axs[t, z].imshow( (labels[..., z, t, None] == 1) * np.array([108 / 255, 142 / 255, 191 / 255, 0.6]) ) axs[t, z].imshow( (labels[..., z, t, None] == 2) * np.array([214 / 255, 182 / 255, 86 / 255, 0.6]) ) axs[t, z].imshow( (labels[..., z, t, None] == 3) * np.array([130 / 255, 179 / 255, 102 / 255, 0.6]) ) axs[t, z].set_xticks([]) axs[t, z].set_yticks([]) if z == 0: axs[t, z].set_ylabel(f"t = {t * t_step}") fig1.suptitle(f"Subject {image_id} in {split} split") axs[0, n_slices // 2].set_title("SAX Slices") fig1.tight_layout() plt.subplots_adjust(wspace=0, hspace=0) # Create volume plot xs = np.arange(n_frames) * t_step rv_volumes = np.sum(labels == 1, axis=(0, 1, 2)) * 10 / 1000 myo_volumes = np.sum(labels == 2, axis=(0, 1, 2)) * 10 / 1000 lv_volumes = np.sum(labels == 3, axis=(0, 1, 2)) * 10 / 1000 lvef = (max(lv_volumes) - min(lv_volumes)) / max(lv_volumes) * 100 rvef = (max(rv_volumes) - min(rv_volumes)) / max(rv_volumes) * 100 fig2, ax = plt.subplots(figsize=(4, 4), dpi=120) ax.plot(xs, rv_volumes, color="#6C8EBF", label="RV") ax.plot(xs, myo_volumes, color="#D6B656", label="MYO") ax.plot(xs, lv_volumes, color="#82B366", label="LV") ax.set_xlabel("Frame") ax.set_ylabel("Volume (ml)") ax.set_title(f"LVEF = {lvef:.2f}%, RVEF = {rvef:.2f}%") ax.legend(loc="lower right") fig2.tight_layout() return fig1, fig2 # Create the Gradio interface theme = gr.themes.Ocean( primary_hue="red", secondary_hue="purple", ) with gr.Blocks( theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI" ) as demo: gr.Markdown( """ # CineMA: A Foundation Model for Cine Cardiac MRI 🎥🫀 Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA). """ ) with gr.Row(): with gr.Column(scale=0.4): gr.Markdown("## Description") gr.Markdown(""" Please adjust the settings on the right panels and click the button to run the inference. ### Data The available data is from ACDC. All images have been resampled to 1 mm × 1 mm × 10 mm and centre-cropped to 192 mm × 192 mm for each SAX slice. Image 1 - 100 are from the training set, and image 101 - 150 are from the test set. ### Model The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0. ### Visualization The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices. The right panel plots the ventricle and mycoardium volumes across all inference time frames. """) with gr.Column(scale=0.3): gr.Markdown("## Data Settings") image_id = gr.Slider( minimum=1, maximum=150, step=1, label="Choose an ACDC image, ID is between 1 and 150", value=1, ) t_step = gr.Slider( minimum=1, maximum=10, step=1, label="Choose the gap between time frames", value=2, ) with gr.Column(scale=0.3): gr.Markdown("## Model Setting") trained_dataset = gr.Dropdown( choices=["ACDC", "M&MS", "M&MS2"], label="Choose which dataset the segmentation model was finetuned on", value="ACDC", ) seed = gr.Slider( minimum=0, maximum=2, step=1, label="Choose which seed the finetuning used", value=0, ) run_button = gr.Button("Run segmentation inference", variant="primary") with gr.Row(): segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation") volume_plot = gr.Plot(label="Ejection Fraction Prediction") run_button.click( fn=run_inference, inputs=[trained_dataset, seed, image_id, t_step], outputs=[segmentation_plot, volume_plot], ) demo.launch()