mathpluscode commited on
Commit
65411d8
Β·
1 Parent(s): f4635f5

Add segmentation demo

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. .pre-commit-config.yaml +49 -0
  3. README.md +10 -5
  4. app.py +225 -0
  5. requirements.txt +23 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea/
2
+ .DS_Store
3
+ node_modules/
4
+ src/cache/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v5.0.0
6
+ hooks:
7
+ - id: check-added-large-files
8
+ args: ["--maxkb=30000"]
9
+ - id: check-ast
10
+ - id: check-byte-order-marker
11
+ - id: check-builtin-literals
12
+ - id: check-case-conflict
13
+ - id: check-merge-conflict
14
+ - id: check-symlinks
15
+ - id: check-toml
16
+ - id: check-yaml
17
+ - id: debug-statements
18
+ - id: destroyed-symlinks
19
+ - id: end-of-file-fixer
20
+ - id: fix-byte-order-marker
21
+ - id: mixed-line-ending
22
+ - id: file-contents-sorter
23
+ files: "envs/requirements*.txt"
24
+ - id: trailing-whitespace
25
+ - repo: https://github.com/astral-sh/ruff-pre-commit
26
+ rev: v0.11.8
27
+ hooks:
28
+ # run the linter
29
+ - id: ruff
30
+ args: [--fix]
31
+ # run the formatter
32
+ - id: ruff-format
33
+ - repo: https://github.com/pre-commit/mirrors-prettier
34
+ rev: v4.0.0-alpha.8
35
+ hooks:
36
+ - id: prettier
37
+ args:
38
+ - --print-width=120
39
+ - --prose-wrap=always
40
+ - --tab-width=2
41
+ - repo: https://github.com/codespell-project/codespell
42
+ rev: v2.4.1
43
+ hooks:
44
+ - id: codespell
45
+ name: codespell
46
+ description: Checks for common misspellings in text files.
47
+ entry: codespell
48
+ language: python
49
+ types: [text]
README.md CHANGED
@@ -1,13 +1,18 @@
1
  ---
2
  title: CineMA
3
- emoji: πŸ‘
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.29.1
8
  app_file: app.py
9
  pinned: false
 
10
  license: mit
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
  title: CineMA
3
+ emoji: πŸ«€
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.30.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: "3.10.13"
11
  license: mit
12
+ short_description: A Foundation Model for Cine CMR Images πŸŽ₯πŸ«€
13
  ---
14
 
15
+ # CineMA: A Foundation Model for Cine Cardiac MRI
16
+
17
+ This is a demo of CineMA, a foundation model for cine cardiac MRI. For more details, checkout our
18
+ [GitHub](https://github.com/mathpluscode/CineMA).
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ import matplotlib.pyplot as plt
5
+ import SimpleITK as sitk # noqa: N813
6
+ import torch
7
+ from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
8
+ from cinema import ConvUNetR
9
+ from pathlib import Path
10
+ import spaces
11
+
12
+ # cache directories
13
+ cache_dir = Path("/tmp/.cinema")
14
+ cache_dir.mkdir(parents=True, exist_ok=True)
15
+
16
+
17
+ @spaces.GPU
18
+ def inferece(
19
+ images: torch.Tensor,
20
+ view: str,
21
+ transform: Compose,
22
+ model: ConvUNetR,
23
+ progress=gr.Progress(),
24
+ ) -> np.ndarray:
25
+ # set device and dtype
26
+ dtype, device = torch.float32, torch.device("cpu")
27
+ if torch.cuda.is_available():
28
+ torch.cuda.empty_cache()
29
+ device = torch.device("cuda")
30
+ if torch.cuda.is_bf16_supported():
31
+ dtype = torch.bfloat16
32
+
33
+ # inference
34
+ model.to(device)
35
+ n_slices, n_frames = images.shape[-2:]
36
+ labels_list = []
37
+ for t in range(0, n_frames):
38
+ progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
39
+ batch = transform({view: torch.from_numpy(images[None, ..., t])})
40
+ batch = {
41
+ k: v[None, ...].to(device=device, dtype=torch.float32)
42
+ for k, v in batch.items()
43
+ }
44
+ with (
45
+ torch.no_grad(),
46
+ torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
47
+ ):
48
+ logits = model(batch)[view]
49
+ labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
50
+ labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy()
51
+ return labels
52
+
53
+
54
+ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()):
55
+ # Fixed parameters
56
+ view = "sax"
57
+ split = "train" if image_id <= 100 else "test"
58
+ trained_dataset = {
59
+ "ACDC": "acdc",
60
+ "M&MS": "mnms",
61
+ "M&MS2": "mnms2",
62
+ }[str(trained_dataset)]
63
+
64
+ # Download and load model
65
+ progress(0, desc="Downloading model and data...")
66
+ image_path = hf_hub_download(
67
+ repo_id="mathpluscode/ACDC",
68
+ repo_type="dataset",
69
+ filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz",
70
+ cache_dir=cache_dir,
71
+ )
72
+
73
+ model = ConvUNetR.from_finetuned(
74
+ repo_id="mathpluscode/CineMA",
75
+ model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
76
+ config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
77
+ cache_dir=cache_dir,
78
+ )
79
+
80
+ # Load and process data
81
+ transform = Compose(
82
+ [
83
+ ScaleIntensityd(keys=view),
84
+ SpatialPadd(
85
+ keys=view,
86
+ spatial_size=(192, 192, 16),
87
+ method="end",
88
+ lazy=True,
89
+ allow_missing_keys=True,
90
+ ),
91
+ ]
92
+ )
93
+
94
+ images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
95
+ images = images[..., ::t_step]
96
+ labels = inferece(images, view, transform, model, progress)
97
+
98
+ progress(1, desc="Plotting results...")
99
+ # Create segmentation visualization
100
+ n_slices, n_frames = labels.shape[-2:]
101
+ fig1, axs = plt.subplots(n_frames, n_slices, figsize=(n_slices, n_frames), dpi=300)
102
+ for t in range(n_frames):
103
+ for z in range(n_slices):
104
+ axs[t, z].imshow(images[..., z, t], cmap="gray")
105
+ axs[t, z].imshow(
106
+ (labels[..., z, t, None] == 1)
107
+ * np.array([108 / 255, 142 / 255, 191 / 255, 0.6])
108
+ )
109
+ axs[t, z].imshow(
110
+ (labels[..., z, t, None] == 2)
111
+ * np.array([214 / 255, 182 / 255, 86 / 255, 0.6])
112
+ )
113
+ axs[t, z].imshow(
114
+ (labels[..., z, t, None] == 3)
115
+ * np.array([130 / 255, 179 / 255, 102 / 255, 0.6])
116
+ )
117
+ axs[t, z].set_xticks([])
118
+ axs[t, z].set_yticks([])
119
+ if z == 0:
120
+ axs[t, z].set_ylabel(f"t = {t * t_step}")
121
+ fig1.suptitle(f"Subject {image_id} in {split} split")
122
+ axs[0, n_slices // 2].set_title("SAX Slices")
123
+ fig1.tight_layout()
124
+ plt.subplots_adjust(wspace=0, hspace=0)
125
+
126
+ # Create volume plot
127
+ xs = np.arange(n_frames) * t_step
128
+ rv_volumes = np.sum(labels == 1, axis=(0, 1, 2)) * 10 / 1000
129
+ myo_volumes = np.sum(labels == 2, axis=(0, 1, 2)) * 10 / 1000
130
+ lv_volumes = np.sum(labels == 3, axis=(0, 1, 2)) * 10 / 1000
131
+ lvef = (max(lv_volumes) - min(lv_volumes)) / max(lv_volumes) * 100
132
+ rvef = (max(rv_volumes) - min(rv_volumes)) / max(rv_volumes) * 100
133
+
134
+ fig2, ax = plt.subplots(figsize=(4, 4), dpi=120)
135
+ ax.plot(xs, rv_volumes, color="#6C8EBF", label="RV")
136
+ ax.plot(xs, myo_volumes, color="#D6B656", label="MYO")
137
+ ax.plot(xs, lv_volumes, color="#82B366", label="LV")
138
+ ax.set_xlabel("Frame")
139
+ ax.set_ylabel("Volume (ml)")
140
+ ax.set_title(f"LVEF = {lvef:.2f}%, RVEF = {rvef:.2f}%")
141
+ ax.legend(loc="lower right")
142
+ fig2.tight_layout()
143
+
144
+ return fig1, fig2
145
+
146
+
147
+ # Create the Gradio interface
148
+ theme = gr.themes.Ocean(
149
+ primary_hue="red",
150
+ secondary_hue="purple",
151
+ )
152
+ with gr.Blocks(
153
+ theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
154
+ ) as demo:
155
+ gr.Markdown(
156
+ """
157
+ # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
158
+
159
+ Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
160
+ """
161
+ )
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=0.4):
165
+ gr.Markdown("## Description")
166
+ gr.Markdown("""
167
+ Please adjust the settings on the right panels and click the button to run the inference.
168
+
169
+ ### Data
170
+
171
+ The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
172
+ Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
173
+
174
+ ### Model
175
+
176
+ The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0.
177
+
178
+ ### Visualization
179
+
180
+ The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
181
+ The right panel plots the ventricle and mycoardium volumes across all inference time frames.
182
+ """)
183
+ with gr.Column(scale=0.3):
184
+ gr.Markdown("## Data Settings")
185
+ image_id = gr.Slider(
186
+ minimum=1,
187
+ maximum=150,
188
+ step=1,
189
+ label="Choose an ACDC image, ID is between 1 and 150",
190
+ value=1,
191
+ )
192
+ t_step = gr.Slider(
193
+ minimum=1,
194
+ maximum=10,
195
+ step=1,
196
+ label="Choose the gap between time frames",
197
+ value=2,
198
+ )
199
+ with gr.Column(scale=0.3):
200
+ gr.Markdown("## Model Setting")
201
+ trained_dataset = gr.Dropdown(
202
+ choices=["ACDC", "M&MS", "M&MS2"],
203
+ label="Choose which dataset the segmentation model was finetuned on",
204
+ value="ACDC",
205
+ )
206
+ seed = gr.Slider(
207
+ minimum=0,
208
+ maximum=2,
209
+ step=1,
210
+ label="Choose which seed the finetuning used",
211
+ value=0,
212
+ )
213
+ run_button = gr.Button("Run segmentation inference", variant="primary")
214
+
215
+ with gr.Row():
216
+ segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation")
217
+ volume_plot = gr.Plot(label="Ejection Fraction Prediction")
218
+
219
+ run_button.click(
220
+ fn=run_inference,
221
+ inputs=[trained_dataset, seed, image_id, t_step],
222
+ outputs=[segmentation_plot, volume_plot],
223
+ )
224
+
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ streamlit==1.45.1
3
+ SimpleITK==2.5.0
4
+ einops==0.8.1
5
+ hydra-core==1.3.2
6
+ matplotlib==3.10.1
7
+ monai==1.4.0
8
+ nbmake==1.5.5
9
+ nibabel==5.3.2
10
+ numpy==1.26.4
11
+ pandas==2.2.3
12
+ plotly==6.0.1
13
+ pydicom==3.0.1
14
+ safetensors==0.5.3
15
+ scikit-image==0.25.2
16
+ scikit-learn==1.6.1
17
+ scipy==1.15.2
18
+ spaces==0.36.0
19
+ timm==1.0.15
20
+ wandb==0.19.11
21
+ git+https://github.com/mathpluscode/CineMA#egg=cinema
22
+ --extra-index-url https://download.pytorch.org/whl/cu113
23
+ torch==2.5.1