Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
65411d8
1
Parent(s):
f4635f5
Add segmentation demo
Browse files- .gitignore +4 -0
- .pre-commit-config.yaml +49 -0
- README.md +10 -5
- app.py +225 -0
- requirements.txt +23 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.DS_Store
|
3 |
+
node_modules/
|
4 |
+
src/cache/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3
|
3 |
+
repos:
|
4 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
+
rev: v5.0.0
|
6 |
+
hooks:
|
7 |
+
- id: check-added-large-files
|
8 |
+
args: ["--maxkb=30000"]
|
9 |
+
- id: check-ast
|
10 |
+
- id: check-byte-order-marker
|
11 |
+
- id: check-builtin-literals
|
12 |
+
- id: check-case-conflict
|
13 |
+
- id: check-merge-conflict
|
14 |
+
- id: check-symlinks
|
15 |
+
- id: check-toml
|
16 |
+
- id: check-yaml
|
17 |
+
- id: debug-statements
|
18 |
+
- id: destroyed-symlinks
|
19 |
+
- id: end-of-file-fixer
|
20 |
+
- id: fix-byte-order-marker
|
21 |
+
- id: mixed-line-ending
|
22 |
+
- id: file-contents-sorter
|
23 |
+
files: "envs/requirements*.txt"
|
24 |
+
- id: trailing-whitespace
|
25 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
26 |
+
rev: v0.11.8
|
27 |
+
hooks:
|
28 |
+
# run the linter
|
29 |
+
- id: ruff
|
30 |
+
args: [--fix]
|
31 |
+
# run the formatter
|
32 |
+
- id: ruff-format
|
33 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
34 |
+
rev: v4.0.0-alpha.8
|
35 |
+
hooks:
|
36 |
+
- id: prettier
|
37 |
+
args:
|
38 |
+
- --print-width=120
|
39 |
+
- --prose-wrap=always
|
40 |
+
- --tab-width=2
|
41 |
+
- repo: https://github.com/codespell-project/codespell
|
42 |
+
rev: v2.4.1
|
43 |
+
hooks:
|
44 |
+
- id: codespell
|
45 |
+
name: codespell
|
46 |
+
description: Checks for common misspellings in text files.
|
47 |
+
entry: codespell
|
48 |
+
language: python
|
49 |
+
types: [text]
|
README.md
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
---
|
2 |
title: CineMA
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: CineMA
|
3 |
+
emoji: π«
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.30.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: "3.10.13"
|
11 |
license: mit
|
12 |
+
short_description: A Foundation Model for Cine CMR Images π₯π«
|
13 |
---
|
14 |
|
15 |
+
# CineMA: A Foundation Model for Cine Cardiac MRI
|
16 |
+
|
17 |
+
This is a demo of CineMA, a foundation model for cine cardiac MRI. For more details, checkout our
|
18 |
+
[GitHub](https://github.com/mathpluscode/CineMA).
|
app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import SimpleITK as sitk # noqa: N813
|
6 |
+
import torch
|
7 |
+
from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
|
8 |
+
from cinema import ConvUNetR
|
9 |
+
from pathlib import Path
|
10 |
+
import spaces
|
11 |
+
|
12 |
+
# cache directories
|
13 |
+
cache_dir = Path("/tmp/.cinema")
|
14 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
15 |
+
|
16 |
+
|
17 |
+
@spaces.GPU
|
18 |
+
def inferece(
|
19 |
+
images: torch.Tensor,
|
20 |
+
view: str,
|
21 |
+
transform: Compose,
|
22 |
+
model: ConvUNetR,
|
23 |
+
progress=gr.Progress(),
|
24 |
+
) -> np.ndarray:
|
25 |
+
# set device and dtype
|
26 |
+
dtype, device = torch.float32, torch.device("cpu")
|
27 |
+
if torch.cuda.is_available():
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
device = torch.device("cuda")
|
30 |
+
if torch.cuda.is_bf16_supported():
|
31 |
+
dtype = torch.bfloat16
|
32 |
+
|
33 |
+
# inference
|
34 |
+
model.to(device)
|
35 |
+
n_slices, n_frames = images.shape[-2:]
|
36 |
+
labels_list = []
|
37 |
+
for t in range(0, n_frames):
|
38 |
+
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
|
39 |
+
batch = transform({view: torch.from_numpy(images[None, ..., t])})
|
40 |
+
batch = {
|
41 |
+
k: v[None, ...].to(device=device, dtype=torch.float32)
|
42 |
+
for k, v in batch.items()
|
43 |
+
}
|
44 |
+
with (
|
45 |
+
torch.no_grad(),
|
46 |
+
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
|
47 |
+
):
|
48 |
+
logits = model(batch)[view]
|
49 |
+
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
|
50 |
+
labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy()
|
51 |
+
return labels
|
52 |
+
|
53 |
+
|
54 |
+
def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()):
|
55 |
+
# Fixed parameters
|
56 |
+
view = "sax"
|
57 |
+
split = "train" if image_id <= 100 else "test"
|
58 |
+
trained_dataset = {
|
59 |
+
"ACDC": "acdc",
|
60 |
+
"M&MS": "mnms",
|
61 |
+
"M&MS2": "mnms2",
|
62 |
+
}[str(trained_dataset)]
|
63 |
+
|
64 |
+
# Download and load model
|
65 |
+
progress(0, desc="Downloading model and data...")
|
66 |
+
image_path = hf_hub_download(
|
67 |
+
repo_id="mathpluscode/ACDC",
|
68 |
+
repo_type="dataset",
|
69 |
+
filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz",
|
70 |
+
cache_dir=cache_dir,
|
71 |
+
)
|
72 |
+
|
73 |
+
model = ConvUNetR.from_finetuned(
|
74 |
+
repo_id="mathpluscode/CineMA",
|
75 |
+
model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
|
76 |
+
config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
|
77 |
+
cache_dir=cache_dir,
|
78 |
+
)
|
79 |
+
|
80 |
+
# Load and process data
|
81 |
+
transform = Compose(
|
82 |
+
[
|
83 |
+
ScaleIntensityd(keys=view),
|
84 |
+
SpatialPadd(
|
85 |
+
keys=view,
|
86 |
+
spatial_size=(192, 192, 16),
|
87 |
+
method="end",
|
88 |
+
lazy=True,
|
89 |
+
allow_missing_keys=True,
|
90 |
+
),
|
91 |
+
]
|
92 |
+
)
|
93 |
+
|
94 |
+
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
|
95 |
+
images = images[..., ::t_step]
|
96 |
+
labels = inferece(images, view, transform, model, progress)
|
97 |
+
|
98 |
+
progress(1, desc="Plotting results...")
|
99 |
+
# Create segmentation visualization
|
100 |
+
n_slices, n_frames = labels.shape[-2:]
|
101 |
+
fig1, axs = plt.subplots(n_frames, n_slices, figsize=(n_slices, n_frames), dpi=300)
|
102 |
+
for t in range(n_frames):
|
103 |
+
for z in range(n_slices):
|
104 |
+
axs[t, z].imshow(images[..., z, t], cmap="gray")
|
105 |
+
axs[t, z].imshow(
|
106 |
+
(labels[..., z, t, None] == 1)
|
107 |
+
* np.array([108 / 255, 142 / 255, 191 / 255, 0.6])
|
108 |
+
)
|
109 |
+
axs[t, z].imshow(
|
110 |
+
(labels[..., z, t, None] == 2)
|
111 |
+
* np.array([214 / 255, 182 / 255, 86 / 255, 0.6])
|
112 |
+
)
|
113 |
+
axs[t, z].imshow(
|
114 |
+
(labels[..., z, t, None] == 3)
|
115 |
+
* np.array([130 / 255, 179 / 255, 102 / 255, 0.6])
|
116 |
+
)
|
117 |
+
axs[t, z].set_xticks([])
|
118 |
+
axs[t, z].set_yticks([])
|
119 |
+
if z == 0:
|
120 |
+
axs[t, z].set_ylabel(f"t = {t * t_step}")
|
121 |
+
fig1.suptitle(f"Subject {image_id} in {split} split")
|
122 |
+
axs[0, n_slices // 2].set_title("SAX Slices")
|
123 |
+
fig1.tight_layout()
|
124 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
125 |
+
|
126 |
+
# Create volume plot
|
127 |
+
xs = np.arange(n_frames) * t_step
|
128 |
+
rv_volumes = np.sum(labels == 1, axis=(0, 1, 2)) * 10 / 1000
|
129 |
+
myo_volumes = np.sum(labels == 2, axis=(0, 1, 2)) * 10 / 1000
|
130 |
+
lv_volumes = np.sum(labels == 3, axis=(0, 1, 2)) * 10 / 1000
|
131 |
+
lvef = (max(lv_volumes) - min(lv_volumes)) / max(lv_volumes) * 100
|
132 |
+
rvef = (max(rv_volumes) - min(rv_volumes)) / max(rv_volumes) * 100
|
133 |
+
|
134 |
+
fig2, ax = plt.subplots(figsize=(4, 4), dpi=120)
|
135 |
+
ax.plot(xs, rv_volumes, color="#6C8EBF", label="RV")
|
136 |
+
ax.plot(xs, myo_volumes, color="#D6B656", label="MYO")
|
137 |
+
ax.plot(xs, lv_volumes, color="#82B366", label="LV")
|
138 |
+
ax.set_xlabel("Frame")
|
139 |
+
ax.set_ylabel("Volume (ml)")
|
140 |
+
ax.set_title(f"LVEF = {lvef:.2f}%, RVEF = {rvef:.2f}%")
|
141 |
+
ax.legend(loc="lower right")
|
142 |
+
fig2.tight_layout()
|
143 |
+
|
144 |
+
return fig1, fig2
|
145 |
+
|
146 |
+
|
147 |
+
# Create the Gradio interface
|
148 |
+
theme = gr.themes.Ocean(
|
149 |
+
primary_hue="red",
|
150 |
+
secondary_hue="purple",
|
151 |
+
)
|
152 |
+
with gr.Blocks(
|
153 |
+
theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
|
154 |
+
) as demo:
|
155 |
+
gr.Markdown(
|
156 |
+
"""
|
157 |
+
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π«
|
158 |
+
|
159 |
+
Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
|
160 |
+
"""
|
161 |
+
)
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=0.4):
|
165 |
+
gr.Markdown("## Description")
|
166 |
+
gr.Markdown("""
|
167 |
+
Please adjust the settings on the right panels and click the button to run the inference.
|
168 |
+
|
169 |
+
### Data
|
170 |
+
|
171 |
+
The available data is from ACDC. All images have been resampled to 1 mm Γ 1 mm Γ 10 mm and centre-cropped to 192 mm Γ 192 mm for each SAX slice.
|
172 |
+
Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
|
173 |
+
|
174 |
+
### Model
|
175 |
+
|
176 |
+
The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0.
|
177 |
+
|
178 |
+
### Visualization
|
179 |
+
|
180 |
+
The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
|
181 |
+
The right panel plots the ventricle and mycoardium volumes across all inference time frames.
|
182 |
+
""")
|
183 |
+
with gr.Column(scale=0.3):
|
184 |
+
gr.Markdown("## Data Settings")
|
185 |
+
image_id = gr.Slider(
|
186 |
+
minimum=1,
|
187 |
+
maximum=150,
|
188 |
+
step=1,
|
189 |
+
label="Choose an ACDC image, ID is between 1 and 150",
|
190 |
+
value=1,
|
191 |
+
)
|
192 |
+
t_step = gr.Slider(
|
193 |
+
minimum=1,
|
194 |
+
maximum=10,
|
195 |
+
step=1,
|
196 |
+
label="Choose the gap between time frames",
|
197 |
+
value=2,
|
198 |
+
)
|
199 |
+
with gr.Column(scale=0.3):
|
200 |
+
gr.Markdown("## Model Setting")
|
201 |
+
trained_dataset = gr.Dropdown(
|
202 |
+
choices=["ACDC", "M&MS", "M&MS2"],
|
203 |
+
label="Choose which dataset the segmentation model was finetuned on",
|
204 |
+
value="ACDC",
|
205 |
+
)
|
206 |
+
seed = gr.Slider(
|
207 |
+
minimum=0,
|
208 |
+
maximum=2,
|
209 |
+
step=1,
|
210 |
+
label="Choose which seed the finetuning used",
|
211 |
+
value=0,
|
212 |
+
)
|
213 |
+
run_button = gr.Button("Run segmentation inference", variant="primary")
|
214 |
+
|
215 |
+
with gr.Row():
|
216 |
+
segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation")
|
217 |
+
volume_plot = gr.Plot(label="Ejection Fraction Prediction")
|
218 |
+
|
219 |
+
run_button.click(
|
220 |
+
fn=run_inference,
|
221 |
+
inputs=[trained_dataset, seed, image_id, t_step],
|
222 |
+
outputs=[segmentation_plot, volume_plot],
|
223 |
+
)
|
224 |
+
|
225 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.5.0
|
2 |
+
streamlit==1.45.1
|
3 |
+
SimpleITK==2.5.0
|
4 |
+
einops==0.8.1
|
5 |
+
hydra-core==1.3.2
|
6 |
+
matplotlib==3.10.1
|
7 |
+
monai==1.4.0
|
8 |
+
nbmake==1.5.5
|
9 |
+
nibabel==5.3.2
|
10 |
+
numpy==1.26.4
|
11 |
+
pandas==2.2.3
|
12 |
+
plotly==6.0.1
|
13 |
+
pydicom==3.0.1
|
14 |
+
safetensors==0.5.3
|
15 |
+
scikit-image==0.25.2
|
16 |
+
scikit-learn==1.6.1
|
17 |
+
scipy==1.15.2
|
18 |
+
spaces==0.36.0
|
19 |
+
timm==1.0.15
|
20 |
+
wandb==0.19.11
|
21 |
+
git+https://github.com/mathpluscode/CineMA#egg=cinema
|
22 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
23 |
+
torch==2.5.1
|