samaonline commited on
Commit
1b34a12
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
12
+ *.mds filter=lfs diff=lfs merge=lfs -text
13
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
+ *.model filter=lfs diff=lfs merge=lfs -text
15
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
16
+ *.npy filter=lfs diff=lfs merge=lfs -text
17
+ *.npz filter=lfs diff=lfs merge=lfs -text
18
+ *.onnx filter=lfs diff=lfs merge=lfs -text
19
+ *.ot filter=lfs diff=lfs merge=lfs -text
20
+ *.parquet filter=lfs diff=lfs merge=lfs -text
21
+ *.pb filter=lfs diff=lfs merge=lfs -text
22
+ *.pickle filter=lfs diff=lfs merge=lfs -text
23
+ *.pkl filter=lfs diff=lfs merge=lfs -text
24
+ *.pt filter=lfs diff=lfs merge=lfs -text
25
+ *.pth filter=lfs diff=lfs merge=lfs -text
26
+ *.rar filter=lfs diff=lfs merge=lfs -text
27
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
38
+ # Audio files - uncompressed
39
+ *.pcm filter=lfs diff=lfs merge=lfs -text
40
+ *.sam filter=lfs diff=lfs merge=lfs -text
41
+ *.raw filter=lfs diff=lfs merge=lfs -text
42
+ # Audio files - compressed
43
+ *.aac filter=lfs diff=lfs merge=lfs -text
44
+ *.flac filter=lfs diff=lfs merge=lfs -text
45
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
46
+ *.ogg filter=lfs diff=lfs merge=lfs -text
47
+ *.wav filter=lfs diff=lfs merge=lfs -text
48
+ # Image files - uncompressed
49
+ *.bmp filter=lfs diff=lfs merge=lfs -text
50
+ *.gif filter=lfs diff=lfs merge=lfs -text
51
+ *.png filter=lfs diff=lfs merge=lfs -text
52
+ *.tiff filter=lfs diff=lfs merge=lfs -text
53
+ # Image files - compressed
54
+ *.jpg filter=lfs diff=lfs merge=lfs -text
55
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
56
+ *.webp filter=lfs diff=lfs merge=lfs -text
57
+ # Video files - compressed
58
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ *.webm filter=lfs diff=lfs merge=lfs -text
60
+ dataset/kspace/data.mdb filter=lfs diff=lfs merge=lfs -text
61
+ dataset/rss/data.mdb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config files
2
+ fastmri.yaml
3
+
4
+ # Python specific
5
+ __pycache__/
6
+ .pytest_cache/
7
+ *.py[cod]
8
+ *.so
9
+ *.egg-info/
10
+ *.pyo
11
+ *.pyd
12
+
13
+ # Virtual environments
14
+ .env
15
+ .venv
16
+ env/
17
+ venv/
18
+ ENV/
19
+ env.bak/
20
+ venv.bak/
21
+
22
+ # Code in development
23
+ ignore_**.py
24
+
25
+ # Hidden / ignore folders
26
+ hidden/
27
+ ignore/
28
+ hidden_**
29
+
30
+ # Jupyter Notebook Checkpoints
31
+ .ipynb_checkpoints
32
+
33
+ # Data files
34
+ data/
35
+ datasets/
36
+ dataset
37
+ *.csv
38
+ *.tsv
39
+ *.h5
40
+ *.json
41
+ *.xml
42
+ *.parquet
43
+ *.pkl
44
+
45
+ # Model files
46
+ *.ckpt
47
+ *.h5
48
+ *.tflite
49
+ *.onnx
50
+ *.pb
51
+ *.pth
52
+ *.pt
53
+ *.joblib
54
+ *.pkl
55
+
56
+ # Logs and outputs
57
+ logs/
58
+ wandb/
59
+ *.log
60
+ *.out
61
+ *.txt
62
+ *.csv
63
+
64
+ # Test dir
65
+ !tests/**/*.txt
66
+ !tests/datasets
67
+
68
+ # Results
69
+ results/
70
+ output/
71
+ runs/
72
+ outfig/
73
+ figs/*.png
74
+
75
+ # SLURM
76
+ slurm/
77
+
78
+ # Ignore files related to experiments
79
+ experiments/
80
+
81
+ # Temporary files
82
+ *.tmp
83
+ *.temp
84
+ *.swp
85
+ *.swo
86
+
87
+ # VS Code specific
88
+ .vscode/
89
+ *.code-workspace
90
+
91
+ # System files
92
+ .DS_Store
93
+ Thumbs.db
94
+
95
+ # Environment files
96
+ *.env
97
+
98
+ # Ignore files from data processing tools
99
+ *.dvc
100
+ .dvc/
101
+
102
+ # PyTorch Lightning Logs
103
+ lightning_logs/
104
+
105
+ # Ignore files generated by package managers
106
+ Pipfile
107
+ Pipfile.lock
108
+ poetry.lock
109
+
110
+ # TensorBoard logs
111
+ logs/
112
+ events.out.tfevents.*
113
+
114
+ # Checkpoints and weights
115
+ checkpoints/
116
+ weights/
117
+
118
+ # Large file extensions
119
+ *.tar.gz
120
+ *.zip
121
+ *.tar
122
+ *.gz
123
+
124
+
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - medical
4
+ - mri
5
+ - neuraloperator
6
+ - fastmri
7
+ pretty_name: fastMRI Tiny
8
+ ---
9
+ # A Unified Model for Compressed Sensing MRI Across Undersampling Patterns
10
+
11
+ > [**A Unified Model for Compressed Sensing MRI Across Undersampling Patterns**](https://arxiv.org/abs/2410.16290)
12
+ > Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar
13
+ > *Paper at [CVPR 2025](https://cvpr.thecvf.com/Conferences/2025/AcceptedPapers)*
14
+
15
+ This is a tiny subset of 230 fastMRI samples, used in the demo for the above [paper](https://huggingface.co/armeet/nomri) at CVPR 2025!
16
+
17
+
18
+ ## Citation
19
+
20
+ If you found our work helpful or used any of our models (UDNO), please cite the following:
21
+ ```bibtex
22
+ @inproceedings{jatyani2025nomri,
23
+ author = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar},
24
+ title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns},
25
+ booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings},
26
+ abbr = {CVPR},
27
+ year = {2025}
28
+ }
29
+ ```
30
+
31
+ ![paper_preview](https://github.com/user-attachments/assets/7e6adaa5-a5fa-4b68-bd8c-5279f6f643d7)
32
+
33
+ https://arxiv.org/abs/2410.16290
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ # import spaces
9
+ # from huggingface_hub import hf_hub_download
10
+ from huggingface_hub import snapshot_download
11
+ from PIL import Image, ImageDraw, ImageFont
12
+
13
+ # Set the working directory to the root directory
14
+ # root_dir = os.path.abspath("..")
15
+ # os.chdir(root_dir)
16
+ # sys.path.insert(0, root_dir)
17
+
18
+ # download dataset & weights
19
+ snapshot_download(repo_id="armeet/fastmri-tiny", repo_type="dataset", local_dir=".")
20
+
21
+
22
+ device = "cuda"
23
+ # dataset_path = "/global/homes/p/peterwg/pscratch/datasets/mri_knee_dummy"
24
+ dataset_path = "dataset"
25
+
26
+ import matplotlib.pyplot as plt
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.nn import functional as F
31
+
32
+ import fastmri
33
+ from fastmri.datasets import SliceDatasetLMDB, SliceSample
34
+ from fastmri.subsample import create_mask_for_mask_type
35
+ from models.lightning.no_varnet_module import NOVarnetModule
36
+ from models.lightning.varnet_module import VarNetModule
37
+
38
+ acceleration_to_fractions = {
39
+ 1: 1,
40
+ 2: 0.16,
41
+ 4: 0.08,
42
+ 6: 0.06,
43
+ 8: 0.04,
44
+ 16: 0.02,
45
+ 32: 0.01,
46
+ }
47
+
48
+
49
+ def create_mask_fn(center_fraction, acceleration):
50
+ mask_fn = create_mask_for_mask_type(
51
+ "equispaced_fraction",
52
+ [center_fraction],
53
+ [acceleration],
54
+ )
55
+ return mask_fn
56
+
57
+
58
+ mask_4x = create_mask_fn(acceleration_to_fractions[4], 4)
59
+ mask_6x = create_mask_fn(acceleration_to_fractions[6], 6)
60
+ mask_8x = create_mask_fn(acceleration_to_fractions[8], 8)
61
+ mask_16x = create_mask_fn(acceleration_to_fractions[16], 16)
62
+
63
+ val_dataset_4x = SliceDatasetLMDB(
64
+ "knee",
65
+ partition="val",
66
+ mask_fns=[mask_4x],
67
+ complex=False,
68
+ root=dataset_path,
69
+ crop_shape=(320, 320),
70
+ coils=15,
71
+ )
72
+
73
+ val_dataset_6x = SliceDatasetLMDB(
74
+ "knee",
75
+ partition="val",
76
+ mask_fns=[mask_6x],
77
+ complex=False,
78
+ root=dataset_path,
79
+ crop_shape=(320, 320),
80
+ coils=15,
81
+ )
82
+
83
+ val_dataset_8x = SliceDatasetLMDB(
84
+ "knee",
85
+ partition="val",
86
+ mask_fns=[mask_8x],
87
+ complex=False,
88
+ root=dataset_path,
89
+ crop_shape=(320, 320),
90
+ coils=15,
91
+ )
92
+ val_dataset_16x = SliceDatasetLMDB(
93
+ "knee",
94
+ partition="val",
95
+ mask_fns=[mask_16x],
96
+ complex=False,
97
+ root=dataset_path,
98
+ crop_shape=(320, 320),
99
+ coils=15,
100
+ )
101
+
102
+ vn = VarNetModule.load_from_checkpoint(
103
+ "vn.ckpt",
104
+ )
105
+ no = NOVarnetModule.load_from_checkpoint(
106
+ "no.ckpt",
107
+ )
108
+ no.eval()
109
+ vn.eval()
110
+
111
+ bright_samples = [42, 69, 80, 137, 139, 226, 229]
112
+
113
+
114
+ def v(x):
115
+ return x.detach().cpu().numpy().squeeze()
116
+
117
+
118
+ def viz(x, cmap="gray", vmin=0, vmax=1):
119
+ processed_data = v(x)
120
+ fig, ax = plt.subplots()
121
+ ax.imshow(processed_data, cmap=cmap, vmin=vmin, vmax=vmax)
122
+ ax.axis("off") # Turn off axes
123
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # Adjust margins
124
+ buf = io.BytesIO()
125
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
126
+ buf.seek(0) # Rewind the buffer to the beginning
127
+ plt.show()
128
+ try:
129
+ img = Image.open(buf)
130
+ img_array = np.array(img)
131
+ except Exception as e:
132
+ print(f"Error converting image buffer to NumPy array: {e}")
133
+ img_array = None
134
+ finally:
135
+ plt.close(fig)
136
+ buf.close()
137
+ return img_array
138
+
139
+
140
+ def forward(model, idx, rate):
141
+ if rate == 4:
142
+ dataset = val_dataset_4x
143
+ elif rate == 6:
144
+ dataset = val_dataset_6x
145
+ elif rate == 8:
146
+ dataset = val_dataset_8x
147
+ elif rate == 16:
148
+ dataset = val_dataset_16x
149
+ else:
150
+ raise ValueError("Invalid rate")
151
+
152
+ sample = dataset[idx]
153
+ mask, k, target = (
154
+ sample.mask.to(device),
155
+ sample.masked_kspace.to(device),
156
+ sample.target.to(device),
157
+ )
158
+ pred = model(k.unsqueeze(0), mask.unsqueeze(0), None)
159
+
160
+ return mask, k, target, pred[0]
161
+
162
+
163
+ def update_interface(sample_id, sample_rate):
164
+ n = [None] * 6
165
+ if sample_id is None or sample_rate is None or sample_id not in bright_samples:
166
+ return n
167
+
168
+ mask, k, target, pred_vn = forward(vn, sample_id, sample_rate)
169
+ _, _, _, pred_no = forward(no, sample_id, sample_rate)
170
+
171
+ k = viz(mask[0, :, :, 0], cmap="gray", vmin=0, vmax=1)
172
+ target_res = viz(target, cmap="gray", vmin=None, vmax=None)
173
+
174
+ pred_no_res = viz(pred_no, cmap="gray", vmin=None, vmax=None)
175
+ pred_vn_res = viz(pred_vn, cmap="gray", vmin=None, vmax=None)
176
+
177
+ diff_no_res = viz(torch.abs(pred_no - target), cmap=None, vmin=None, vmax=None)
178
+ diff_vn_res = viz(torch.abs(pred_vn - target), cmap=None, vmin=None, vmax=None)
179
+
180
+ return k, target_res, pred_no_res, pred_vn_res, diff_no_res, diff_vn_res
181
+
182
+
183
+ with gr.Blocks(theme=gr.themes.Monochrome(), fill_width=True) as demo:
184
+ gr.Markdown(
185
+ "# A Unified Model for Compressed Sensing MRI Across Undersampling Patterns [CPVR 2025]"
186
+ )
187
+ gr.Markdown("""
188
+ > Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar
189
+ """)
190
+ gr.Markdown(
191
+ "[![arXiv](https://img.shields.io/badge/arXiv-2410.16290-b31b1b.svg?style=flat-square&logo=arxiv)](https://arxiv.org/abs/2410.16290)"
192
+ )
193
+ gr.Markdown(
194
+ "[![](https://img.shields.io/badge/Blog-armeet.ca%2Fnomri-yellow?style=flat-square)](https://armeet.ca/nomri)"
195
+ )
196
+
197
+ gr.Markdown(
198
+ "This demo showcases the performance of our unified model for compressed sensing MRI across different acceleration rates."
199
+ )
200
+
201
+ with gr.Row():
202
+ dropdown_sample = gr.Dropdown(
203
+ choices=bright_samples,
204
+ label="Select a Sample",
205
+ info="Choose one of the available samples.",
206
+ filterable=False,
207
+ value=229,
208
+ )
209
+ with gr.Row():
210
+ dropdown_rate = gr.Radio(
211
+ choices=[16, 8, 6, 4],
212
+ value=16,
213
+ label="Select an Acceleration Rate",
214
+ info="Ex: 4x means the model is trained to reconstruct from 4x undersampled k-space data",
215
+ # filterable=False,
216
+ )
217
+
218
+ with gr.Row():
219
+ with gr.Column():
220
+ gr.Label("Undersampling Mask")
221
+ k = gr.Image(label=None, interactive=False)
222
+ with gr.Column():
223
+ gr.Label("Ground Truth")
224
+ target = gr.Image(label=None, interactive=False)
225
+ with gr.Column():
226
+ gr.Label("NO (ours)")
227
+ pred_no = gr.Image(label="Reconstruction", interactive=False)
228
+ with gr.Column():
229
+ gr.Label("VN (existing)")
230
+ pred_vn = gr.Image(label="Reconstruction", interactive=False)
231
+ with gr.Row():
232
+ with gr.Column():
233
+ pass
234
+ with gr.Column():
235
+ pass
236
+ with gr.Column():
237
+ diff_no = gr.Image(label="| Recon - GT |", interactive=False)
238
+ with gr.Column():
239
+ diff_vn = gr.Image(label="| Recon - GT |", interactive=False)
240
+
241
+ gr.Markdown("""
242
+ ```
243
+ @inproceedings{jatyani2025nomri,
244
+ author = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar},
245
+ title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns},
246
+ booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings},
247
+ abbr = {CVPR},
248
+ year = {2025}
249
+ }
250
+ ```
251
+ """)
252
+
253
+ update_inputs = [dropdown_sample, dropdown_rate]
254
+ update_outputs = [k, target, pred_no, pred_vn, diff_no, diff_vn]
255
+
256
+ dropdown_sample.change(
257
+ fn=update_interface, inputs=update_inputs, outputs=update_outputs
258
+ )
259
+ dropdown_rate.change(
260
+ fn=update_interface, inputs=update_inputs, outputs=update_outputs
261
+ )
262
+
263
+ if __name__ == "__main__":
264
+ demo.launch(share=True)
environment.yml ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: no-med
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - asttokens=2.4.1=pyhd8ed1ab_0
9
+ - bzip2=1.0.8=h5eee18b_6
10
+ - ca-certificates=2024.8.30=hbcca054_0
11
+ - comm=0.2.2=pyhd8ed1ab_0
12
+ - debugpy=1.6.7=py312h6a678d5_0
13
+ - decorator=5.1.1=pyhd8ed1ab_0
14
+ - exceptiongroup=1.2.2=pyhd8ed1ab_0
15
+ - executing=2.1.0=pyhd8ed1ab_0
16
+ - expat=2.6.3=h6a678d5_0
17
+ - importlib-metadata=8.5.0=pyha770c72_0
18
+ - ipykernel=6.29.5=pyh3099207_0
19
+ - ipython=8.27.0=pyh707e725_0
20
+ - jedi=0.19.1=pyhd8ed1ab_0
21
+ - jupyter_client=8.6.3=pyhd8ed1ab_0
22
+ - jupyter_core=5.7.2=py312h06a4308_0
23
+ - krb5=1.21.3=h143b758_0
24
+ - ld_impl_linux-64=2.38=h1181459_1
25
+ - libedit=3.1.20230828=h5eee18b_0
26
+ - libffi=3.4.4=h6a678d5_1
27
+ - libgcc=14.1.0=h77fa898_1
28
+ - libgcc-ng=14.1.0=h69a702a_1
29
+ - libgomp=14.1.0=h77fa898_1
30
+ - libsodium=1.0.20=h4ab18f5_0
31
+ - libstdcxx=14.1.0=hc0a3c3a_1
32
+ - libstdcxx-ng=11.2.0=h1234567_1
33
+ - libuuid=1.41.5=h5eee18b_0
34
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_0
35
+ - ncurses=6.4=h6a678d5_0
36
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
37
+ - openssl=3.3.2=hb9d3cd8_0
38
+ - packaging=24.1=pyhd8ed1ab_0
39
+ - parso=0.8.4=pyhd8ed1ab_0
40
+ - pexpect=4.9.0=pyhd8ed1ab_0
41
+ - pickleshare=0.7.5=py_1003
42
+ - pip=24.2=py312h06a4308_0
43
+ - prompt-toolkit=3.0.47=pyha770c72_0
44
+ - ptyprocess=0.7.0=pyhd3deb0d_0
45
+ - pure_eval=0.2.3=pyhd8ed1ab_0
46
+ - pygments=2.18.0=pyhd8ed1ab_0
47
+ - python=3.12.4=h5148396_1
48
+ - pyzmq=25.1.2=py312h6a678d5_0
49
+ - readline=8.2=h5eee18b_0
50
+ - setuptools=72.1.0=py312h06a4308_0
51
+ - six=1.16.0=pyh6c4a22f_0
52
+ - sqlite=3.45.3=h5eee18b_0
53
+ - stack_data=0.6.2=pyhd8ed1ab_0
54
+ - tk=8.6.14=h39e8969_0
55
+ - tornado=6.4.1=py312h5eee18b_0
56
+ - traitlets=5.14.3=pyhd8ed1ab_0
57
+ - typing_extensions=4.12.2=pyha770c72_0
58
+ - wcwidth=0.2.13=pyhd8ed1ab_0
59
+ - wheel=0.43.0=py312h06a4308_0
60
+ - xz=5.4.6=h5eee18b_1
61
+ - zeromq=4.3.5=ha4adb4c_5
62
+ - zipp=3.20.2=pyhd8ed1ab_0
63
+ - zlib=1.2.13=h5eee18b_1
64
+ - pip:
65
+ - aiohappyeyeballs==2.4.0
66
+ - aiohttp==3.10.5
67
+ - aiosignal==1.3.1
68
+ - antlr4-python3-runtime==4.9.3
69
+ - attrs==24.2.0
70
+ - black==24.10.0
71
+ - certifi==2024.8.30
72
+ - charset-normalizer==3.3.2
73
+ - click==8.1.7
74
+ - cloudpickle==3.0.0
75
+ - contourpy==1.3.0
76
+ - cycler==0.12.1
77
+ - docker-pycreds==0.4.0
78
+ - filelock==3.16.0
79
+ - fonttools==4.53.1
80
+ - frozenlist==1.4.1
81
+ - fsspec==2024.9.0
82
+ - gitdb==4.0.11
83
+ - gitpython==3.1.43
84
+ - h5py==3.11.0
85
+ - hydra-core==1.3.2
86
+ - hydra-submitit-launcher==1.2.0
87
+ - idna==3.8
88
+ - imageio==2.35.1
89
+ - iniconfig==2.0.0
90
+ - isort==5.13.2
91
+ - jinja2==3.1.4
92
+ - joblib==1.4.2
93
+ - kiwisolver==1.4.7
94
+ - lazy-loader==0.4
95
+ - lightning==2.4.0
96
+ - lightning-utilities==0.11.7
97
+ - llvmlite==0.43.0
98
+ - lmdb==1.5.1
99
+ - markupsafe==2.1.5
100
+ - matplotlib==3.9.2
101
+ - mpmath==1.3.0
102
+ - multidict==6.0.5
103
+ - mypy-extensions==1.0.0
104
+ - networkx==3.3
105
+ - no-med==0.0.0
106
+ - numba==0.60.0
107
+ - numpy==2.0.2
108
+ - nvidia-cublas-cu12==12.1.3.1
109
+ - nvidia-cuda-cupti-cu12==12.1.105
110
+ - nvidia-cuda-nvrtc-cu12==12.1.105
111
+ - nvidia-cuda-runtime-cu12==12.1.105
112
+ - nvidia-cudnn-cu12==9.1.0.70
113
+ - nvidia-cufft-cu12==11.0.2.54
114
+ - nvidia-curand-cu12==10.3.2.106
115
+ - nvidia-cusolver-cu12==11.4.5.107
116
+ - nvidia-cusparse-cu12==12.1.0.106
117
+ - nvidia-nccl-cu12==2.20.5
118
+ - nvidia-nvjitlink-cu12==12.6.68
119
+ - nvidia-nvtx-cu12==12.1.105
120
+ - omegaconf==2.3.0
121
+ - opencv-python==4.10.0.84
122
+ - pandas==2.2.2
123
+ - pathspec==0.12.1
124
+ - pillow==10.4.0
125
+ - platformdirs==4.3.2
126
+ - pluggy==1.5.0
127
+ - protobuf==5.28.0
128
+ - psutil==6.0.0
129
+ - pyparsing==3.1.4
130
+ - pytest==8.3.3
131
+ - python-dateutil==2.9.0.post0
132
+ - pytorch-lightning==2.4.0
133
+ - pytz==2024.1
134
+ - pywavelets==1.7.0
135
+ - pyyaml==6.0.2
136
+ - requests==2.32.3
137
+ - runstats==2.0.0
138
+ - scikit-image==0.24.0
139
+ - scipy==1.14.1
140
+ - sentry-sdk==2.13.0
141
+ - setproctitle==1.3.3
142
+ - sigpy==0.1.26
143
+ - smmap==5.0.1
144
+ - submitit==1.5.1
145
+ - sympy==1.13.2
146
+ - tabulate==0.9.0
147
+ - tifffile==2024.8.30
148
+ - toolz==1.0.0
149
+ - torch==2.4.1
150
+ - torchmetrics==1.4.1
151
+ - torchvision==0.19.1
152
+ - tqdm==4.66.5
153
+ - triton==3.0.0
154
+ - tzdata==2024.1
155
+ - urllib3==2.2.2
156
+ - wandb==0.17.9
157
+ - yarl==1.11.0
158
+ prefix: /global/homes/p/peterwg/local/miniconda3/envs/no-med
fastmri/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from .coil_combine import rss, rss_complex, mvue
9
+ from .fftc import fft2c_new as fft2c
10
+ from .fftc import fftshift
11
+ from .fftc import ifft2c_new as ifft2c
12
+ from .fftc import ifftshift, roll
13
+ from .losses import SSIMLoss
14
+ from .math_utils import (
15
+ complex_abs,
16
+ complex_abs_sq,
17
+ complex_conj,
18
+ complex_mul,
19
+ tensor_to_complex_np,
20
+ )
fastmri/coil_combine.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import torch
9
+
10
+ import fastmri
11
+ import sigpy as sp
12
+ import numpy as np
13
+
14
+
15
+ def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
16
+ """
17
+ Compute the Root Sum of Squares (RSS).
18
+
19
+ The RSS is computed assuming that `dim` is the coil dimension.
20
+
21
+ Parameters
22
+ ----------
23
+ data : torch.Tensor
24
+ The input tensor.
25
+ dim : int, optional
26
+ The dimension along which to apply the RSS transform (default is 0).
27
+
28
+ Returns
29
+ -------
30
+ torch.Tensor
31
+ The computed RSS value.
32
+ """
33
+ return torch.sqrt((data**2).sum(dim))
34
+
35
+
36
+ def mvue(spatial_pred, sens_maps, dim: int = 0) -> torch.Tensor:
37
+ spatial_pred = torch.view_as_complex(spatial_pred)
38
+ sens_maps = torch.view_as_complex(sens_maps)
39
+
40
+ numerator = torch.sum(spatial_pred * torch.conj(sens_maps), dim=dim)
41
+ denominator = torch.sqrt(
42
+ torch.sum(torch.square(torch.abs(sens_maps)), dim=dim)
43
+ )
44
+ res = numerator / denominator
45
+ res = torch.abs(res)
46
+ return res
47
+
48
+
49
+ def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
50
+ """
51
+ Compute the Root Sum of Squares (RSS) for complex inputs.
52
+
53
+ The RSS is computed assuming that `dim` is the coil dimension.
54
+
55
+ Parameters
56
+ ----------
57
+ data : torch.Tensor
58
+ The input tensor containing complex values.
59
+ dim : int, optional
60
+ The dimension along which to apply the RSS transform (default is 0).
61
+
62
+ Returns
63
+ -------
64
+ torch.Tensor
65
+ The computed RSS value for complex inputs.
66
+ """
67
+ return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim))
fastmri/datasets.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import xml.etree.ElementTree as etree
3
+ from pathlib import Path
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ List,
9
+ Literal,
10
+ NamedTuple,
11
+ Optional,
12
+ Sequence,
13
+ Tuple,
14
+ Union,
15
+ )
16
+
17
+ import h5py
18
+ import lmdb
19
+ import numpy as np
20
+ import torch
21
+ import yaml
22
+ import sigpy as sp
23
+ import pandas as pd
24
+
25
+ import fastmri
26
+ import fastmri.transforms as T
27
+
28
+
29
+ class RawSample(NamedTuple):
30
+ fname: Path
31
+ slice_num: int
32
+ metadata: Dict[str, Any]
33
+
34
+
35
+ class SliceSample(NamedTuple):
36
+ masked_kspace: torch.Tensor
37
+ mask: torch.Tensor
38
+ num_low_frequencies: int
39
+ target: torch.Tensor
40
+ max_value: float
41
+ # attrs: Dict[str, Any]
42
+ fname: str
43
+ slice_num: int
44
+
45
+ class SliceSampleMVUE(NamedTuple):
46
+ masked_kspace: torch.Tensor
47
+ mask: torch.Tensor
48
+ num_low_frequencies: int
49
+ target: torch.Tensor
50
+ rss: torch.Tensor
51
+ max_value: float
52
+ # attrs: Dict[str, Any]
53
+ fname: str
54
+ slice_num: int
55
+
56
+ def et_query(
57
+ root: etree.Element,
58
+ qlist: Sequence[str],
59
+ namespace: str = "http://www.ismrm.org/ISMRMRD",
60
+ ) -> str:
61
+ """
62
+ Query an XML document using ElementTree.
63
+
64
+ This function allows querying an XML document by specifying a root and a list of nested queries.
65
+ It supports optional XML namespaces.
66
+
67
+ Parameters
68
+ ----------
69
+ root : ElementTree.Element
70
+ The root element of the XML to search through.
71
+ qlist : list of str
72
+ A list of strings for nested searches, e.g., ["Encoding", "matrixSize"].
73
+ namespace : str, optional
74
+ An optional XML namespace to prepend to the query (default is None).
75
+
76
+ Returns
77
+ -------
78
+ str
79
+ The retrieved data as a string.
80
+ """
81
+
82
+ s = "."
83
+ prefix = "ismrmrd_namespace"
84
+
85
+ ns = {prefix: namespace}
86
+
87
+ for el in qlist:
88
+ s = s + f"//{prefix}:{el}"
89
+
90
+ value = root.find(s, ns)
91
+ if value is None:
92
+ raise RuntimeError("Element not found")
93
+
94
+ return str(value.text)
95
+
96
+
97
+ class SliceDataset(torch.utils.data.Dataset):
98
+ """
99
+ A simplified PyTorch Dataset that provides access to multicoil MR image
100
+ slices from the fastMRI dataset.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ # root: Optional[Path | str],
106
+ body_part: Literal["knee", "brain"],
107
+ partition: Literal["train", "val", "test"],
108
+ mask_fns: Optional[List[Callable]] = None,
109
+ sample_rate: float = 1.0,
110
+ complex: bool = False,
111
+ crop_shape: Tuple[int, int] = (320, 320),
112
+ slug: str = "",
113
+ contrast: Optional[Literal["T1", "T2"]] = None,
114
+ coils: Optional[int] = None,
115
+ ):
116
+ """
117
+ Initializes the fastMRI multi-coil challenge dataset.
118
+
119
+ Samples are individual 2D slices taken from k-space volume data.
120
+
121
+ Parameters
122
+ ----------
123
+ body_part : {'knee', 'brain'}
124
+ The body part to analyze.
125
+ partition : {'train', 'val', 'test'}
126
+ The data partition type.
127
+ mask_fns : list of callable, optional
128
+ A list of masking functions to apply to samples.
129
+ If multiple are given, a mask is randomly chosen for each sample.
130
+ sample_rate : float, optional
131
+ Fraction of data to sample, by default 1.0.
132
+ complex : bool, optional
133
+ Whether the $k$-space data should return complex-valued, by default False.
134
+ If True, kspace values will be complex.
135
+ If False, kspace values will be real (shape, 2).
136
+ crop_shape : tuple of two ints, optional
137
+ The shape to center crop the k-space data, by default (320, 320).
138
+ slug : string
139
+ dataset slug name
140
+ contrast : {'T1', 'T2'}
141
+ If partition is brain, the contrast of images to use.
142
+ """
143
+
144
+ with open("fastmri.yaml", "r") as file:
145
+ config = yaml.safe_load(file)
146
+ self.contrast = contrast
147
+ self.slug = slug
148
+ self.partition = partition
149
+ self.body_part = body_part
150
+ self.root = (
151
+ Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}"
152
+ )
153
+ self.mask_fns = mask_fns
154
+ self.sample_rate = sample_rate
155
+ self.raw_samples: List[RawSample] = self._load_samples()
156
+ self.complex = complex
157
+ self.crop_shape = crop_shape
158
+ self.coils = coils
159
+
160
+ def _load_samples(self):
161
+ # Gather all files in the root directory
162
+ if self.body_part == "brain" and self.contrast:
163
+ files = list(self.root.glob(f"*{self.contrast}*.h5"))
164
+ else:
165
+ files = list(self.root.glob("*.h5"))
166
+ raw_samples = []
167
+
168
+ # Load and process metadata from each file
169
+ for fname in sorted(files):
170
+ with h5py.File(fname, "r") as hf:
171
+ metadata, num_slices = self._retrieve_metadata(fname)
172
+
173
+ # Collect samples for each slice, discard first c slices, and last c slices
174
+ c = 6
175
+ for slice_num in range(num_slices):
176
+ if c <= slice_num <= num_slices - c - 1:
177
+ raw_samples.append(
178
+ RawSample(fname, slice_num, metadata)
179
+ )
180
+
181
+ # Subsample if desired
182
+ if self.sample_rate < 1.0:
183
+ raw_samples = random.sample(
184
+ raw_samples, int(len(raw_samples) * self.sample_rate)
185
+ )
186
+
187
+ return raw_samples
188
+
189
+ def _retrieve_metadata(self, fname):
190
+ with h5py.File(fname, "r") as hf:
191
+ et_root = etree.fromstring(hf["ismrmrd_header"][()])
192
+
193
+ enc = ["encoding", "encodedSpace", "matrixSize"]
194
+ enc_size = (
195
+ int(et_query(et_root, enc + ["x"])),
196
+ int(et_query(et_root, enc + ["y"])),
197
+ int(et_query(et_root, enc + ["z"])),
198
+ )
199
+ rec = ["encoding", "reconSpace", "matrixSize"]
200
+ recon_size = (
201
+ int(et_query(et_root, rec + ["x"])),
202
+ int(et_query(et_root, rec + ["y"])),
203
+ int(et_query(et_root, rec + ["z"])),
204
+ )
205
+
206
+ lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
207
+ enc_limits_center = int(et_query(et_root, lims + ["center"]))
208
+ enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1
209
+
210
+ padding_left = enc_size[1] // 2 - enc_limits_center
211
+ padding_right = padding_left + enc_limits_max
212
+
213
+ num_slices = hf["kspace"].shape[0]
214
+
215
+ metadata = {
216
+ "padding_left": padding_left,
217
+ "padding_right": padding_right,
218
+ "encoding_size": enc_size,
219
+ "recon_size": recon_size,
220
+ **hf.attrs,
221
+ }
222
+
223
+ return metadata, num_slices
224
+
225
+ def __len__(self):
226
+ return len(self.raw_samples)
227
+
228
+ def __getitem__(self, idx) -> SliceSample:
229
+ try:
230
+ raw_sample: RawSample = self.raw_samples[idx]
231
+ fname, slice_num, metadata = raw_sample
232
+
233
+ # load kspace and target
234
+ with h5py.File(fname, "r") as hf:
235
+ kspace = torch.tensor(hf["kspace"][()][slice_num])
236
+ if not self.complex:
237
+ kspace = torch.view_as_real(kspace)
238
+ if self.coils:
239
+ if kspace.shape[0] < self.coils:
240
+ return None
241
+ kspace = kspace[: self.coils, :, :, :]
242
+ target_key = (
243
+ "reconstruction_rss"
244
+ if self.partition in ["train", "val"]
245
+ else "reconstruction_esc"
246
+ )
247
+ target = hf.get(target_key, None)
248
+ if target is not None:
249
+ target = torch.tensor(target[()][slice_num])
250
+ if self.body_part == "brain":
251
+ target = T.center_crop(target, self.crop_shape)
252
+
253
+ # center crop to enable collating for batching
254
+ if self.complex:
255
+ # if complex, crop across dims: -2 and -1 (last 2)
256
+ raise NotImplementedError("Not implemented for complex native")
257
+ else:
258
+ # crop in image space, to not lose high-frequency information
259
+ image = fastmri.ifft2c(kspace)
260
+ image_cropped = T.complex_center_crop(image, self.crop_shape)
261
+ kspace = fastmri.fft2c(image_cropped)
262
+
263
+ # apply transform mask if there is one
264
+ if self.mask_fns:
265
+ # choose a random mask
266
+ mask_fn = random.choice(self.mask_fns)
267
+ kspace, mask, num_low_frequencies = T.apply_mask(
268
+ kspace,
269
+ mask_fn,
270
+ # seed=seed,
271
+ )
272
+ mask = mask.bool()
273
+ else:
274
+ mask = torch.ones_like(kspace, dtype=torch.bool)
275
+ num_low_frequencies = 0
276
+ sample = SliceSample(
277
+ kspace,
278
+ mask,
279
+ num_low_frequencies,
280
+ target,
281
+ metadata["max"],
282
+ fname.name,
283
+ slice_num,
284
+ )
285
+ return sample
286
+ except:
287
+ return None
288
+
289
+
290
+ class SliceDatasetLMDB(torch.utils.data.Dataset):
291
+ """
292
+ A simplified PyTorch Dataset that provides access to multicoil MR image
293
+ slices from the fastMRI dataset. Loads from LMDB saved samples.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ body_part: Literal["knee", "brain"],
299
+ partition: Literal["train", "val", "test"],
300
+ root: Optional[Path | str] = None,
301
+ mask_fns: Optional[List[Callable]] = None,
302
+ sample_rate: float = 1.0,
303
+ complex: bool = False,
304
+ crop_shape: Tuple[int, int] = (320, 320),
305
+ slug: str = "",
306
+ coils: int = 15,
307
+ ):
308
+ """
309
+ Initializes the fastMRI multi-coil challenge dataset.
310
+
311
+ Samples are individual 2D slices taken from k-space volume data.
312
+
313
+ Parameters
314
+ ----------
315
+ body_part : {'knee', 'brain'}
316
+ The body part to analyze.
317
+ root : Path or str, optional
318
+ Root to lmdb dataset. If not provided, the root is automatically
319
+ loaded directly from fastmri.yaml config
320
+ partition : {'train', 'val', 'test'}
321
+ The data partition type.
322
+ mask_fns : list of callable, optional
323
+ A list of masking functions to apply to samples.
324
+ If multiple are given, a mask is randomly chosen for each sample.
325
+ sample_rate : float, optional
326
+ Fraction of data to sample, by default 1.0.
327
+ complex : bool, optional
328
+ Whether the $k$-space data should return complex-valued, by default False.
329
+ If True, kspace values will be complex.
330
+ If False, kspace values will be real (shape, 2).
331
+ crop_shape : tuple of two ints, optional
332
+ The shape to center crop the k-space data, by default (320, 320).
333
+ slug : string
334
+ dataset slug name
335
+ """
336
+
337
+ # set attrs
338
+ self.coils = coils
339
+ self.slug = slug
340
+ self.partition = partition
341
+ self.mask_fns = mask_fns
342
+ self.sample_rate = sample_rate
343
+ self.complex = complex
344
+ self.crop_shape = crop_shape
345
+
346
+ # load lmdb info
347
+ if root:
348
+ if isinstance(root, str):
349
+ root = Path(root)
350
+ assert root.exists(), "Provided root doesn't exist."
351
+ self.root = root
352
+ else:
353
+ with open("fastmri.yaml", "r") as file:
354
+ config = yaml.safe_load(file)
355
+ self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"])
356
+ self.meta = np.load(self.root / "meta.npy")
357
+ self.kspace_env = lmdb.open(
358
+ str(self.root / "kspace"),
359
+ readonly=True,
360
+ lock=False,
361
+ create=False,
362
+ )
363
+ self.kspace_txn = self.kspace_env.begin(write=False)
364
+ self.rss_env = lmdb.open(
365
+ str(self.root / "rss"),
366
+ readonly=True,
367
+ lock=False,
368
+ create=False,
369
+ )
370
+ self.rss_txn = self.rss_env.begin(write=False)
371
+ self.length = self.kspace_txn.stat()["entries"]
372
+
373
+ def __len__(self):
374
+ return int(self.sample_rate * self.length)
375
+
376
+ def __getitem__(self, idx) -> SliceSample:
377
+ idx_key = str(idx).encode("utf-8")
378
+
379
+ # load sample data
380
+ kspace = torch.from_numpy(
381
+ np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
382
+ .reshape(self.coils, 320, 320, 2)
383
+ .copy()
384
+ )
385
+ rss = torch.from_numpy(
386
+ np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
387
+ .reshape(320, 320)
388
+ .copy()
389
+ )
390
+
391
+ # crop in image space, to not lose high-frequency information
392
+ if self.crop_shape and self.crop_shape != (320, 320):
393
+ image = fastmri.ifft2c(kspace)
394
+ image_cropped = T.complex_center_crop(image, self.crop_shape)
395
+ kspace = fastmri.fft2c(image_cropped)
396
+ rss = T.center_crop(rss, self.crop_shape)
397
+
398
+ # load and apply mask
399
+ if self.mask_fns:
400
+ # choose a random mask
401
+ mask_fn = random.choice(self.mask_fns)
402
+ kspace, mask, num_low_frequencies = T.apply_mask(
403
+ kspace,
404
+ mask_fn, # type: ignore
405
+ )
406
+ mask = mask.bool()
407
+ else:
408
+ mask = torch.ones_like(kspace, dtype=torch.bool)
409
+ num_low_frequencies = 0
410
+
411
+ # load metadata
412
+ fname, slice_num, max_value = self.meta[idx]
413
+ fname = str(fname)
414
+ slice_num = int(slice_num)
415
+ max_value = float(max_value)
416
+
417
+ return SliceSample(
418
+ kspace,
419
+ mask,
420
+ num_low_frequencies,
421
+ rss,
422
+ max_value,
423
+ fname,
424
+ slice_num,
425
+ )
426
+
427
+
428
+ class SliceDatasetLMDB_MVUE(torch.utils.data.Dataset):
429
+ """
430
+ Loads from LMDB brain saved samples.
431
+
432
+ Modified to have MVUE targets
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ root: Path | str,
438
+ mask_fns: Optional[List[Callable]] = None,
439
+ sample_rate: float = 1.0,
440
+ crop_shape: Tuple[int, int] = (320, 320),
441
+ slug: str = "",
442
+ coils: int = 15,
443
+ ):
444
+
445
+ # set attrs
446
+ self.coils = coils
447
+ self.slug = slug
448
+ self.mask_fns = mask_fns
449
+ self.sample_rate = sample_rate
450
+ self.complex = complex
451
+ self.crop_shape = crop_shape
452
+
453
+ # load lmdb info
454
+ if isinstance(root, str):
455
+ root = Path(root)
456
+ assert root.exists(), "Provided root doesn't exist."
457
+ self.root = root
458
+ self.meta = np.load(self.root / "meta.npy")
459
+ self.mapping = pd.read_csv("brain_mvue_map.csv")
460
+ self.kspace_env = lmdb.open(
461
+ str(self.root / "kspace"),
462
+ readonly=True,
463
+ lock=False,
464
+ create=False,
465
+ )
466
+ self.kspace_txn = self.kspace_env.begin(write=False)
467
+ self.rss_env = lmdb.open(
468
+ str(self.root / "rss"),
469
+ readonly=True,
470
+ lock=False,
471
+ create=False,
472
+ )
473
+ self.rss_txn = self.rss_env.begin(write=False)
474
+
475
+ # ray mvue dataset
476
+ self.mvue_env = lmdb.open(
477
+ str("/pscratch/sd/p/peterwg/datasets/raytemp"),
478
+ readonly=True,
479
+ lock=False,
480
+ create=False,
481
+ )
482
+ self.mvue_txn = self.mvue_env.begin(write=False)
483
+
484
+ self.length = len(self.mapping)
485
+ # self.length = self.kspace_txn.stat()["entries"]
486
+
487
+ def __len__(self):
488
+ return int(self.sample_rate * self.length)
489
+
490
+ def __getitem__(self, idx) -> SliceSampleMVUE:
491
+ # ray's index: 0-n
492
+ ray_idx = idx
493
+ # my index: lookup(ray index)
494
+ idx = int(self.mapping.iloc[ray_idx].my_index)
495
+ ray_idx_key = str(ray_idx).encode("utf-8")
496
+ idx_key = str(idx).encode("utf-8")
497
+
498
+ # load sample data
499
+ kspace = torch.from_numpy(
500
+ np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
501
+ .reshape(self.coils, 320, 320, 2)
502
+ .copy()
503
+ )
504
+
505
+ # mvue_target = np.sum(
506
+ # sp.ifft(kspace, axes=(-1, -2)) * np.conj(s_maps), axis=1
507
+ # ) / np.sqrt(np.sum(np.square(np.abs(s_maps)), axis=1))
508
+ rss = torch.from_numpy(
509
+ np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
510
+ .reshape(320, 320)
511
+ .copy()
512
+ )
513
+
514
+ # load mvue from ray dataset
515
+ mvue = torch.from_numpy(
516
+ np.frombuffer(self.mvue_txn.get(ray_idx_key), dtype=np.complex64)
517
+ .reshape(320, 320)
518
+ .copy()
519
+ )
520
+ mvue = torch.abs(mvue)
521
+
522
+ # crop in image space, to not lose high-frequency information
523
+ if self.crop_shape and self.crop_shape != (320, 320):
524
+ image = fastmri.ifft2c(kspace)
525
+ image_cropped = T.complex_center_crop(image, self.crop_shape)
526
+ kspace = fastmri.fft2c(image_cropped)
527
+ rss = T.center_crop(rss, self.crop_shape)
528
+
529
+ # load and apply mask
530
+ if self.mask_fns:
531
+ # choose a random mask
532
+ mask_fn = random.choice(self.mask_fns)
533
+ kspace, mask, num_low_frequencies = T.apply_mask(
534
+ kspace,
535
+ mask_fn, # type: ignore
536
+ )
537
+ mask = mask.bool()
538
+ else:
539
+ mask = torch.ones_like(kspace, dtype=torch.bool)
540
+ num_low_frequencies = 0
541
+
542
+ # load metadata
543
+ fname, slice_num, max_value = self.meta[idx]
544
+ fname = str(fname)
545
+ slice_num = int(slice_num)
546
+ max_value = float(max_value)
547
+
548
+ return SliceSampleMVUE(
549
+ kspace,
550
+ mask,
551
+ num_low_frequencies,
552
+ mvue,
553
+ rss,
554
+ max_value,
555
+ fname,
556
+ slice_num,
557
+ )
558
+
559
+
560
+ # d = SliceDatasetLMDB("knee", "val", None, 1, True, (320, 320), "testdataset")
561
+ # print(len(d))
562
+ # breakpoint()
563
+
564
+ # ds = SuperSliceDatasetLMDB(
565
+ # "brain", # body_part
566
+ # "val", # partition
567
+ # None, # root
568
+ # None, # mask_fns
569
+ # 1.0, # sample_rate
570
+ # True, # complex
571
+ # (320, 320), # crop_shape
572
+ # "test-superres", # slug
573
+ # coils=16, # coils
574
+ # )
575
+ # breakpoint()
576
+
577
+ # d = SliceDataset("brain", "train", None, contrast="T2")
578
+
579
+
580
+ # # TESTING MVUE
581
+ # d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/mri_brain_train_lmdb", coils=16)
582
+ # x = d[0]
583
+ # d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/raytemp/", coils=16)
fastmri/evaluate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import argparse
9
+ import pathlib
10
+ from argparse import ArgumentParser
11
+ from typing import Optional
12
+
13
+ import h5py
14
+ import numpy as np
15
+ from runstats import Statistics
16
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
17
+
18
+ from fastmri import transforms
19
+
20
+
21
+ def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray:
22
+ """Compute Mean Squared Error (MSE)"""
23
+ return np.mean((gt - pred) ** 2)
24
+
25
+
26
+ def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray:
27
+ """Compute Normalized Mean Squared Error (NMSE)"""
28
+ return np.array(np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2)
29
+
30
+
31
+ def psnr(
32
+ gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None
33
+ ) -> np.ndarray:
34
+ """Compute Peak Signal to Noise Ratio metric (PSNR)"""
35
+ if maxval is None:
36
+ maxval = gt.max()
37
+ return peak_signal_noise_ratio(gt, pred, data_range=maxval)
38
+
39
+
40
+ def ssim(
41
+ gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None
42
+ ) -> np.ndarray:
43
+ """Compute Structural Similarity Index Metric (SSIM)"""
44
+ if not gt.ndim == 3:
45
+ raise ValueError("Unexpected number of dimensions in ground truth.")
46
+ if not gt.ndim == pred.ndim:
47
+ raise ValueError("Ground truth dimensions does not match pred.")
48
+
49
+ maxval = gt.max() if maxval is None else maxval
50
+
51
+ ssim = np.array([0])
52
+ for slice_num in range(gt.shape[0]):
53
+ ssim = ssim + structural_similarity(
54
+ gt[slice_num], pred[slice_num], data_range=maxval
55
+ )
56
+
57
+ return ssim / gt.shape[0]
58
+
59
+
60
+ METRIC_FUNCS = dict(
61
+ MSE=mse,
62
+ NMSE=nmse,
63
+ PSNR=psnr,
64
+ SSIM=ssim,
65
+ )
66
+
67
+
68
+ class Metrics:
69
+ """
70
+ Maintains running statistics for a given collection of metrics.
71
+ """
72
+
73
+ def __init__(self, metric_funcs):
74
+ """
75
+ Parameters
76
+ ----------
77
+ metric_funcs : dict
78
+ A dictionary where the keys are metric names (as strings) and the values
79
+ are Python functions for evaluating the corresponding metrics.
80
+ """
81
+
82
+ self.metrics = {metric: Statistics() for metric in metric_funcs}
83
+
84
+ def push(self, target, recons):
85
+ for metric, func in METRIC_FUNCS.items():
86
+ self.metrics[metric].push(func(target, recons))
87
+
88
+ def means(self):
89
+ return {metric: stat.mean() for metric, stat in self.metrics.items()}
90
+
91
+ def stddevs(self):
92
+ return {metric: stat.stddev() for metric, stat in self.metrics.items()}
93
+
94
+ def __repr__(self):
95
+ means = self.means()
96
+ stddevs = self.stddevs()
97
+ metric_names = sorted(list(means))
98
+ return " ".join(
99
+ f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}"
100
+ for name in metric_names
101
+ )
102
+
103
+
104
+ def evaluate(args, recons_key):
105
+ metrics = Metrics(METRIC_FUNCS)
106
+
107
+ for tgt_file in args.target_path.iterdir():
108
+ with h5py.File(tgt_file, "r") as target, h5py.File(
109
+ args.predictions_path / tgt_file.name, "r"
110
+ ) as recons:
111
+ if args.acquisition and args.acquisition != target.attrs["acquisition"]:
112
+ continue
113
+
114
+ if args.acceleration and target.attrs["acceleration"] != args.acceleration:
115
+ continue
116
+
117
+ target = target[recons_key][()]
118
+ recons = recons["reconstruction"][()]
119
+ target = transforms.center_crop(
120
+ target, (target.shape[-1], target.shape[-1])
121
+ )
122
+ recons = transforms.center_crop(
123
+ recons, (target.shape[-1], target.shape[-1])
124
+ )
125
+ metrics.push(target, recons)
126
+
127
+ return metrics
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
132
+ parser.add_argument(
133
+ "--target-path",
134
+ type=pathlib.Path,
135
+ required=True,
136
+ help="Path to the ground truth data",
137
+ )
138
+ parser.add_argument(
139
+ "--predictions-path",
140
+ type=pathlib.Path,
141
+ required=True,
142
+ help="Path to reconstructions",
143
+ )
144
+ parser.add_argument(
145
+ "--challenge",
146
+ choices=["singlecoil", "multicoil"],
147
+ required=True,
148
+ help="Which challenge",
149
+ )
150
+ parser.add_argument("--acceleration", type=int, default=None)
151
+ parser.add_argument(
152
+ "--acquisition",
153
+ choices=[
154
+ "CORPD_FBK",
155
+ "CORPDFS_FBK",
156
+ "AXT1",
157
+ "AXT1PRE",
158
+ "AXT1POST",
159
+ "AXT2",
160
+ "AXFLAIR",
161
+ ],
162
+ default=None,
163
+ help=(
164
+ "If set, only volumes of the specified acquisition type are used "
165
+ "for evaluation. By default, all volumes are included."
166
+ ),
167
+ )
168
+ args = parser.parse_args()
169
+
170
+ recons_key = (
171
+ "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
172
+ )
173
+ metrics = evaluate(args, recons_key)
174
+ print(metrics)
fastmri/fftc.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from typing import List, Optional
9
+
10
+ import torch
11
+ import torch.fft
12
+
13
+
14
+ def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
15
+ """
16
+ Apply a centered 2-dimensional Fast Fourier Transform (FFT).
17
+
18
+ Parameters
19
+ ----------
20
+ data : torch.Tensor
21
+ Complex-valued input data containing at least 3 dimensions.
22
+ Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2.
23
+ All other dimensions are assumed to be batch dimensions.
24
+ norm : str
25
+ Normalization mode. Refer to `torch.fft.fft` for details on normalization options.
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ The FFT of the input data.
31
+ """
32
+
33
+ if not data.shape[-1] == 2:
34
+ raise ValueError("Tensor does not have separate complex dim.")
35
+
36
+ data = ifftshift(data, dim=[-3, -2])
37
+ data = torch.view_as_real(
38
+ torch.fft.fftn( # type: ignore
39
+ torch.view_as_complex(data), dim=(-2, -1), norm=norm
40
+ )
41
+ )
42
+ data = fftshift(data, dim=[-3, -2])
43
+
44
+ return data
45
+
46
+
47
+ def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
48
+ """
49
+ Apply a centered 2-dimensional Inverse Fast Fourier Transform (IFFT).
50
+
51
+ Parameters
52
+ ----------
53
+ data : torch.Tensor
54
+ Complex-valued input data containing at least 3 dimensions.
55
+ Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2.
56
+ All other dimensions are assumed to be batch dimensions.
57
+ norm : str
58
+ Normalization mode. Refer to `torch.fft.ifft` for details on normalization options.
59
+
60
+ Returns
61
+ -------
62
+ torch.Tensor
63
+ The IFFT of the input data.
64
+ """
65
+
66
+ if not data.shape[-1] == 2:
67
+ raise ValueError("Tensor does not have separate complex dim.")
68
+
69
+ data = ifftshift(data, dim=[-3, -2])
70
+ data = torch.view_as_real(
71
+ torch.fft.ifftn( # type: ignore
72
+ torch.view_as_complex(data), dim=(-2, -1), norm=norm
73
+ )
74
+ )
75
+ data = fftshift(data, dim=[-3, -2])
76
+
77
+ return data
78
+
79
+
80
+ # Helper functions
81
+
82
+
83
+ def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
84
+ """
85
+ Roll a PyTorch tensor along a specified dimension.
86
+
87
+ This function is similar to `torch.roll` but operates on a single dimension.
88
+
89
+ Parameters
90
+ ----------
91
+ x : torch.Tensor
92
+ The input tensor to be rolled.
93
+ shift : int
94
+ Amount to roll.
95
+ dim : int
96
+ The dimension along which to roll the tensor.
97
+
98
+ Returns
99
+ -------
100
+ torch.Tensor
101
+ A tensor with the same shape as `x`, but rolled along the specified dimension.
102
+ """
103
+
104
+ shift = shift % x.size(dim)
105
+ if shift == 0:
106
+ return x
107
+
108
+ left = x.narrow(dim, 0, x.size(dim) - shift)
109
+ right = x.narrow(dim, x.size(dim) - shift, shift)
110
+
111
+ return torch.cat((right, left), dim=dim)
112
+
113
+
114
+ def roll(
115
+ x: torch.Tensor,
116
+ shift: List[int],
117
+ dim: List[int],
118
+ ) -> torch.Tensor:
119
+ """
120
+ Similar to np.roll but applies to PyTorch Tensors.
121
+
122
+ Parameters
123
+ ----------
124
+ x : torch.Tensor
125
+ A PyTorch tensor.
126
+ shift : int
127
+ Amount to roll.
128
+ dim : int
129
+ Which dimension to roll.
130
+
131
+ Returns
132
+ -------
133
+ torch.Tensor
134
+ Rolled version of x.
135
+ """
136
+
137
+ if len(shift) != len(dim):
138
+ raise ValueError("len(shift) must match len(dim)")
139
+
140
+ for s, d in zip(shift, dim):
141
+ x = roll_one_dim(x, s, d)
142
+
143
+ return x
144
+
145
+
146
+ def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
147
+ """
148
+ Similar to np.fft.fftshift but applies to PyTorch Tensors.
149
+
150
+ Parameters
151
+ ----------
152
+ x : torch.Tensor
153
+ A PyTorch tensor.
154
+ dim : list of int, optional
155
+ Which dimension to apply fftshift. If None, the shift is applied to all dimensions (default is None).
156
+
157
+ Returns
158
+ -------
159
+ torch.Tensor
160
+ fftshifted version of x.
161
+ """
162
+ if dim is None:
163
+ # this weird code is necessary for torch.jit.script typing
164
+ dim = [0] * (x.dim())
165
+ for i in range(1, x.dim()):
166
+ dim[i] = i
167
+
168
+ # also necessary for torch.jit.script
169
+ shift = [0] * len(dim)
170
+ for i, dim_num in enumerate(dim):
171
+ shift[i] = x.shape[dim_num] // 2
172
+
173
+ return roll(x, shift, dim)
174
+
175
+
176
+ def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
177
+ """
178
+ Similar to np.fft.ifftshift but applies to PyTorch Tensors.
179
+
180
+ Parameters
181
+ ----------
182
+ x : torch.Tensor
183
+ A PyTorch tensor.
184
+ dim : list of int, optional
185
+ Which dimension to apply ifftshift. If None, the shift is applied to all dimensions (default is None).
186
+
187
+ Returns
188
+ -------
189
+ torch.Tensor
190
+ ifftshifted version of x.
191
+ """
192
+ if dim is None:
193
+ # this weird code is necessary for torch.jit.script typing
194
+ dim = [0] * (x.dim())
195
+ for i in range(1, x.dim()):
196
+ dim[i] = i
197
+
198
+ # also necessary for torch.jit.script
199
+ shift = [0] * len(dim)
200
+ for i, dim_num in enumerate(dim):
201
+ shift[i] = (x.shape[dim_num] + 1) // 2
202
+
203
+ return roll(x, shift, dim)
fastmri/losses.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SSIMLoss(nn.Module):
14
+ """
15
+ SSIM loss module.
16
+ """
17
+
18
+ def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03):
19
+ """
20
+ Initialize the Losses class.
21
+
22
+ Parameters
23
+ ----------
24
+ win_size : int, optional
25
+ Window size for SSIM calculation.
26
+ k1 : float, optional
27
+ k1 parameter for SSIM calculation.
28
+ k2 : float, optional
29
+ k2 parameter for SSIM calculation.
30
+ """
31
+ super().__init__()
32
+ self.win_size = win_size
33
+ self.k1, self.k2 = k1, k2
34
+ self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2)
35
+ NP = win_size**2
36
+ self.cov_norm = NP / (NP - 1)
37
+
38
+ def forward(
39
+ self,
40
+ X: torch.Tensor,
41
+ Y: torch.Tensor,
42
+ data_range: torch.Tensor,
43
+ reduced: bool = True,
44
+ ):
45
+ assert isinstance(self.w, torch.Tensor)
46
+
47
+ data_range = data_range[:, None, None, None].to(X.device)
48
+ C1 = (self.k1 * data_range) ** 2
49
+ C2 = (self.k2 * data_range) ** 2
50
+
51
+ # Compute means
52
+ ux = F.conv2d(X, self.w)
53
+ uy = F.conv2d(Y, self.w)
54
+
55
+ # Compute variances
56
+ uxx = F.conv2d(X * X, self.w)
57
+ uyy = F.conv2d(Y * Y, self.w)
58
+ uxy = F.conv2d(X * Y, self.w)
59
+
60
+ # Compute covariances
61
+ vx = self.cov_norm * (uxx - ux * ux)
62
+ vy = self.cov_norm * (uyy - uy * uy)
63
+ vxy = self.cov_norm * (uxy - ux * uy)
64
+
65
+ # Compute SSIM components
66
+ A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2
67
+ B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2
68
+ D = B1 * B2
69
+ S = (A1 * A2) / D
70
+
71
+ if reduced:
72
+ return 1 - S.mean()
73
+ else:
74
+ return 1 - S
75
+
76
+
77
+ if __name__ == "__main__":
78
+ # Example usage
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+
81
+ # Create the SSIMLoss module and move it to the GPU
82
+ ssim_loss = SSIMLoss().to(device)
83
+
84
+ # Create example tensors and move them to the GPU
85
+ X = torch.randn(4, 1, 256, 256).to(device)
86
+ Y = torch.randn(4, 1, 256, 256).to(device)
87
+ data_range = torch.rand(4).to(device)
88
+
89
+ # Compute the loss
90
+ loss = ssim_loss(X, Y, data_range)
91
+ print(loss)
fastmri/math_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Complex multiplication.
15
+
16
+ Multiplies two complex tensors assuming that they are both stored as
17
+ real arrays with the last dimension being the complex dimension.
18
+
19
+ Parameters
20
+ ----------
21
+ x : torch.Tensor
22
+ A PyTorch tensor with the last dimension of size 2.
23
+ y : torch.Tensor
24
+ A PyTorch tensor with the last dimension of size 2.
25
+
26
+ Returns
27
+ -------
28
+ torch.Tensor
29
+ A PyTorch tensor with the last dimension of size 2, representing
30
+ the result of the complex multiplication.
31
+ """
32
+ if not x.shape[-1] == y.shape[-1] == 2:
33
+ raise ValueError("Tensors do not have separate complex dim.")
34
+
35
+ re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
36
+ im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
37
+
38
+ return torch.stack((re, im), dim=-1)
39
+
40
+
41
+ def complex_conj(x: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Complex conjugate.
44
+
45
+ Applies the complex conjugate assuming that the input array has the
46
+ last dimension as the complex dimension.
47
+
48
+ Parameters
49
+ ----------
50
+ x : torch.Tensor
51
+ A PyTorch tensor with the last dimension of size 2.
52
+
53
+ Returns
54
+ -------
55
+ torch.Tensor
56
+ A PyTorch tensor with the last dimension of size 2, representing
57
+ the complex conjugate of the input tensor.
58
+ """
59
+ if not x.shape[-1] == 2:
60
+ raise ValueError("Tensor does not have separate complex dim.")
61
+
62
+ return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
63
+
64
+
65
+ def complex_abs(data: torch.Tensor) -> torch.Tensor:
66
+ """
67
+ Compute the absolute value of a complex-valued input tensor.
68
+
69
+ Parameters
70
+ ----------
71
+ data : torch.Tensor
72
+ A complex-valued tensor, where the size of the final dimension
73
+ should be 2.
74
+
75
+ Returns
76
+ -------
77
+ torch.Tensor
78
+ Absolute value of the input tensor.
79
+ """
80
+ if not data.shape[-1] == 2:
81
+ raise ValueError("Tensor does not have separate complex dim.")
82
+
83
+ return (data**2).sum(dim=-1).sqrt()
84
+
85
+
86
+ def complex_abs_sq(data: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Compute the squared absolute value of a complex tensor.
89
+
90
+ Parameters
91
+ ----------
92
+ data : torch.Tensor
93
+ A complex-valued tensor, where the size of the final dimension
94
+ should be 2.
95
+
96
+ Returns
97
+ -------
98
+ torch.Tensor
99
+ Squared absolute value of the input tensor.
100
+ """
101
+ if not data.shape[-1] == 2:
102
+ raise ValueError("Tensor does not have separate complex dim.")
103
+
104
+ return (data**2).sum(dim=-1)
105
+
106
+
107
+ def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
108
+ """
109
+ Convert a complex PyTorch tensor to a NumPy array.
110
+
111
+ Parameters
112
+ ----------
113
+ data : torch.Tensor
114
+ Input data to be converted to a NumPy array.
115
+
116
+ Returns
117
+ -------
118
+ np.ndarray
119
+ A complex NumPy array version of the input tensor.
120
+ """
121
+ return torch.view_as_complex(data).numpy()
fastmri/poisson_cache/poisson_16x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22199ef1c9b045b6f747e57c4effa9a6667cfce864773ee035df8c3d2a28138f
3
+ size 819328
fastmri/poisson_cache/poisson_2x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d6908dbd90fda83085cc9e7d3ab35f7ed215ab3f735fd39023a091a9f1632df
3
+ size 819328
fastmri/poisson_cache/poisson_32x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e316852431f100b1c5c6749b672dfbf2dffc4f23ba4d118eddc64956b8c22f4
3
+ size 819328
fastmri/poisson_cache/poisson_4x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c9f0c9b2c3be534b7c94b8398c2aa66c8e634d343e9b45b842450137266cbc8
3
+ size 819328
fastmri/poisson_cache/poisson_6x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7eeed6d470af6ef2b7594da388ea952dad44e74a33080a1bf59faf9e8973ca8
3
+ size 819328
fastmri/poisson_cache/poisson_8x.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ce774dc182798a3fd4358c70cb86c81f936f12336fee104c93e89765727462
3
+ size 819328
fastmri/subsample.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import os
9
+ from typing import Dict, Optional, Sequence, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributions as D
14
+ from networkx import center
15
+ from sigpy.mri import poisson, radial, spiral
16
+
17
+
18
+ class MaskFunc:
19
+ """
20
+ An object for GRAPPA-style sampling masks.
21
+
22
+ This crates a sampling mask that densely samples the center while
23
+ subsampling outer k-space regions based on the undersampling factor.
24
+
25
+ When called, ``MaskFunc`` uses internal functions create mask by 1)
26
+ creating a mask for the k-space center, 2) create a mask outside of the
27
+ k-space center, and 3) combining them into a total mask. The internals are
28
+ handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1)
29
+ and ``calculate_acceleration_mask`` for (2). The combination is executed
30
+ in the ``MaskFunc`` ``__call__`` function.
31
+
32
+ If you would like to implement a new mask, simply subclass ``MaskFunc``
33
+ and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc``
34
+ and ``EquispacedMaskFunc``.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ center_fractions: Sequence[float],
40
+ accelerations: Sequence[int],
41
+ allow_any_combination: bool = False,
42
+ seed: Optional[int] = None,
43
+ ):
44
+ """
45
+ Args:
46
+ center_fractions: Fraction of low-frequency columns to be retained.
47
+ If multiple values are provided, then one of these numbers is
48
+ chosen uniformly each time.
49
+ accelerations: Amount of under-sampling. This should have the same
50
+ length as center_fractions. If multiple values are provided,
51
+ then one of these is chosen uniformly each time.
52
+ allow_any_combination: Whether to allow cross combinations of
53
+ elements from ``center_fractions`` and ``accelerations``.
54
+ seed: Seed for starting the internal random number generator of the
55
+ ``MaskFunc``.
56
+ """
57
+ if (
58
+ len(center_fractions) != len(accelerations)
59
+ and not allow_any_combination
60
+ ):
61
+ raise ValueError(
62
+ "Number of center fractions should match number of"
63
+ " accelerations if allow_any_combination is False."
64
+ )
65
+
66
+ self.center_fractions = center_fractions
67
+ self.accelerations = accelerations
68
+ self.allow_any_combination = allow_any_combination
69
+ self.rng = np.random.RandomState(seed)
70
+
71
+ def __call__(
72
+ self,
73
+ shape: Sequence[int],
74
+ offset: Optional[int] = None,
75
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
76
+ ) -> Tuple[torch.Tensor, int]:
77
+ """
78
+ Sample and return a k-space mask.
79
+
80
+ Args:
81
+ shape: Shape of k-space.
82
+ offset: Offset from 0 to begin mask (for equispaced masks). If no
83
+ offset is given, then one is selected randomly.
84
+ seed: Seed for random number generator for reproducibility.
85
+
86
+ Returns:
87
+ A 2-tuple containing 1) the k-space mask and 2) the number of
88
+ center frequency lines.
89
+ """
90
+ if len(shape) < 3:
91
+ raise ValueError("Shape should have 3 or more dimensions")
92
+
93
+ center_mask, accel_mask, num_low_frequencies = self.sample_mask(
94
+ shape, offset
95
+ )
96
+ # combine masks together
97
+ return torch.max(center_mask, accel_mask), num_low_frequencies
98
+
99
+ def sample_mask(
100
+ self,
101
+ shape: Sequence[int],
102
+ offset: Optional[int],
103
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
104
+ """
105
+ Sample a new k-space mask.
106
+
107
+ This function samples and returns two components of a k-space mask: 1)
108
+ the center mask (e.g., for sensitivity map calculation) and 2) the
109
+ acceleration mask (for the edge of k-space). Both of these masks, as
110
+ well as the integer of low frequency samples, are returned.
111
+
112
+ Args:
113
+ shape: Shape of the k-space to subsample.
114
+ offset: Offset from 0 to begin mask (for equispaced masks).
115
+
116
+ Returns:
117
+ A 3-tuple contaiing 1) the mask for the center of k-space, 2) the
118
+ mask for the high frequencies of k-space, and 3) the integer count
119
+ of low frequency samples.
120
+ """
121
+ num_cols = shape[-2]
122
+ center_fraction, acceleration = self.choose_acceleration()
123
+ num_low_frequencies = round(num_cols * center_fraction)
124
+ center_mask = self.reshape_mask(
125
+ self.calculate_center_mask(shape, num_low_frequencies), shape
126
+ )
127
+ acceleration_mask = self.reshape_mask(
128
+ self.calculate_acceleration_mask(
129
+ num_cols, acceleration, offset, num_low_frequencies
130
+ ),
131
+ shape,
132
+ )
133
+ return center_mask, acceleration_mask, num_low_frequencies
134
+
135
+ def reshape_mask(
136
+ self, mask: torch.Tensor, shape: Sequence[int]
137
+ ) -> torch.Tensor:
138
+ """Reshape mask to desired output shape."""
139
+ if len(mask.shape) == 1:
140
+ mask = torch.tensor(mask)
141
+ mask_num_freqs = len(mask)
142
+ mask = mask.reshape(1, 1, mask_num_freqs, 1)
143
+ mask = mask.expand(shape)
144
+ return mask.expand(shape)
145
+
146
+ def reshape_mask_old(
147
+ self, mask: np.ndarray, shape: Sequence[int]
148
+ ) -> torch.Tensor:
149
+ """Reshape mask to desired output shape."""
150
+ num_cols = shape[-2]
151
+ mask_shape = [1 for s in shape]
152
+ mask_shape[-2] = num_cols
153
+
154
+ return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
155
+
156
+ def calculate_acceleration_mask(
157
+ self,
158
+ num_cols: int,
159
+ acceleration: int,
160
+ offset: Optional[int],
161
+ num_low_frequencies: int,
162
+ ) -> np.ndarray:
163
+ """
164
+ Produce mask for non-central acceleration lines.
165
+
166
+ Args:
167
+ num_cols: Number of columns of k-space (2D subsampling).
168
+ acceleration: Desired acceleration rate.
169
+ offset: Offset from 0 to begin masking (for equispaced masks).
170
+ num_low_frequencies: Integer count of low-frequency lines sampled.
171
+
172
+ Returns:
173
+ A mask for the high spatial frequencies of k-space.
174
+ """
175
+ raise NotImplementedError
176
+
177
+ def calculate_center_mask(
178
+ self, shape: Sequence[int], num_low_freqs: int
179
+ ) -> np.ndarray:
180
+ """
181
+ Build center mask based on number of low frequencies.
182
+
183
+ Args:
184
+ shape: Shape of k-space to mask.
185
+ num_low_freqs: Number of low-frequency lines to sample.
186
+
187
+ Returns:
188
+ A mask for hte low spatial frequencies of k-space.
189
+ """
190
+ num_cols = shape[-2]
191
+ mask = np.zeros(num_cols, dtype=np.float32)
192
+ pad = (num_cols - num_low_freqs + 1) // 2
193
+ mask[pad : pad + num_low_freqs] = 1
194
+ assert mask.sum() == num_low_freqs
195
+
196
+ return mask
197
+
198
+ def choose_acceleration(self):
199
+ """Choose acceleration based on class parameters."""
200
+ if self.allow_any_combination:
201
+ return self.rng.choice(self.center_fractions), self.rng.choice(
202
+ self.accelerations
203
+ )
204
+ else:
205
+ choice = self.rng.randint(len(self.center_fractions))
206
+ return self.center_fractions[choice], self.accelerations[choice]
207
+
208
+
209
+ class RandomMaskFunc(MaskFunc):
210
+ """
211
+ Creates a random sub-sampling mask of a given shape.
212
+
213
+ The mask selects a subset of columns from the input k-space data. If the
214
+ k-space data has N columns, the mask picks out:
215
+ 1. N_low_freqs = (N * center_fraction) columns in the center
216
+ corresponding to low-frequencies.
217
+ 2. The other columns are selected uniformly at random with a
218
+ probability equal to: prob = (N / acceleration - N_low_freqs) /
219
+ (N - N_low_freqs). This ensures that the expected number of columns
220
+ selected is equal to (N / acceleration).
221
+
222
+ It is possible to use multiple center_fractions and accelerations, in which
223
+ case one possible (center_fraction, acceleration) is chosen uniformly at
224
+ random each time the ``RandomMaskFunc`` object is called.
225
+
226
+ For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04],
227
+ then there is a 50% probability that 4-fold acceleration with 8% center
228
+ fraction is selected and a 50% probability that 8-fold acceleration with 4%
229
+ center fraction is selected.
230
+ """
231
+
232
+ def calculate_acceleration_mask(
233
+ self,
234
+ num_cols: int,
235
+ acceleration: int,
236
+ offset: Optional[int],
237
+ num_low_frequencies: int,
238
+ ) -> np.ndarray:
239
+ prob = (num_cols / acceleration - num_low_frequencies) / (
240
+ num_cols - num_low_frequencies
241
+ )
242
+
243
+ return self.rng.uniform(size=num_cols) < prob
244
+
245
+
246
+ class EquiSpacedMaskFunc(MaskFunc):
247
+ """
248
+ Sample data with equally-spaced k-space lines.
249
+
250
+ The lines are spaced exactly evenly, as is done in standard GRAPPA-style
251
+ acquisitions. This means that with a densely-sampled center,
252
+ ``acceleration`` will be greater than the true acceleration rate.
253
+ """
254
+
255
+ def calculate_acceleration_mask(
256
+ self,
257
+ num_cols: int,
258
+ acceleration: int,
259
+ offset: Optional[int],
260
+ num_low_frequencies: int,
261
+ ) -> np.ndarray:
262
+ """
263
+ Produce mask for non-central acceleration lines.
264
+
265
+ Args:
266
+ num_cols: Number of columns of k-space (2D subsampling).
267
+ acceleration: Desired acceleration rate.
268
+ offset: Offset from 0 to begin masking. If no offset is specified,
269
+ then one is selected randomly.
270
+ num_low_frequencies: Not used.
271
+
272
+ Returns:
273
+ A mask for the high spatial frequencies of k-space.
274
+ """
275
+ if offset is None:
276
+ offset = self.rng.randint(0, high=round(acceleration))
277
+
278
+ mask = np.zeros(num_cols, dtype=np.float32)
279
+ mask[offset::acceleration] = 1
280
+
281
+ return mask
282
+
283
+
284
+ class EquispacedMaskFractionFunc(MaskFunc):
285
+ """
286
+ Equispaced mask with approximate acceleration matching.
287
+
288
+ The mask selects a subset of columns from the input k-space data. If the
289
+ k-space data has N columns, the mask picks out:
290
+ 1. N_low_freqs = (N * center_fraction) columns in the center
291
+ corresponding to low-frequencies.
292
+ 2. The other columns are selected with equal spacing at a proportion
293
+ that reaches the desired acceleration rate taking into consideration
294
+ the number of low frequencies. This ensures that the expected number
295
+ of columns selected is equal to (N / acceleration)
296
+
297
+ It is possible to use multiple center_fractions and accelerations, in which
298
+ case one possible (center_fraction, acceleration) is chosen uniformly at
299
+ random each time the EquispacedMaskFunc object is called.
300
+
301
+ Note that this function may not give equispaced samples (documented in
302
+ https://github.com/facebookresearch/fastMRI/issues/54), which will require
303
+ modifications to standard GRAPPA approaches. Nonetheless, this aspect of
304
+ the function has been preserved to match the public multicoil data.
305
+ """
306
+
307
+ def calculate_acceleration_mask(
308
+ self,
309
+ num_cols: int,
310
+ acceleration: int,
311
+ offset: Optional[int],
312
+ num_low_frequencies: int,
313
+ ) -> np.ndarray:
314
+ """
315
+ Produce mask for non-central acceleration lines.
316
+
317
+ Args:
318
+ num_cols: Number of columns of k-space (2D subsampling).
319
+ acceleration: Desired acceleration rate.
320
+ offset: Offset from 0 to begin masking. If no offset is specified,
321
+ then one is selected randomly.
322
+ num_low_frequencies: Number of low frequencies. Used to adjust mask
323
+ to exactly match the target acceleration.
324
+
325
+ Returns:
326
+ A mask for the high spatial frequencies of k-space.
327
+ """
328
+ # determine acceleration rate by adjusting for the number of low frequencies
329
+ adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / (
330
+ num_low_frequencies * acceleration - num_cols
331
+ )
332
+ if offset is None:
333
+ offset = self.rng.randint(0, high=round(adjusted_accel))
334
+
335
+ mask = np.zeros(num_cols, dtype=np.float32)
336
+ accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)
337
+ accel_samples = np.around(accel_samples).astype(np.uint)
338
+ mask[accel_samples] = 1.0
339
+
340
+ return mask
341
+
342
+
343
+ class MagicMaskFunc(MaskFunc):
344
+ """
345
+ Masking function for exploiting conjugate symmetry via offset-sampling.
346
+
347
+ This function applies the mask described in the following paper:
348
+
349
+ Defazio, A. (2019). Offset Sampling Improves Deep Learning based
350
+ Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint,
351
+ arXiv:1912.01101.
352
+
353
+ It is essentially an equispaced mask with an offset for the opposite site
354
+ of k-space. Since MRI images often exhibit approximate conjugate k-space
355
+ symmetry, this mask is generally more efficient than a standard equispaced
356
+ mask.
357
+
358
+ Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the
359
+ target acceleration rate.
360
+ """
361
+
362
+ def calculate_acceleration_mask(
363
+ self,
364
+ num_cols: int,
365
+ acceleration: int,
366
+ offset: Optional[int],
367
+ num_low_frequencies: int,
368
+ ) -> np.ndarray:
369
+ """
370
+ Produce mask for non-central acceleration lines.
371
+
372
+ Args:
373
+ num_cols: Number of columns of k-space (2D subsampling).
374
+ acceleration: Desired acceleration rate.
375
+ offset: Offset from 0 to begin masking. If no offset is specified,
376
+ then one is selected randomly.
377
+ num_low_frequencies: Not used.
378
+
379
+ Returns:
380
+ A mask for the high spatial frequencies of k-space.
381
+ """
382
+ if offset is None:
383
+ offset = self.rng.randint(0, high=acceleration)
384
+
385
+ if offset % 2 == 0:
386
+ offset_pos = offset + 1
387
+ offset_neg = offset + 2
388
+ else:
389
+ offset_pos = offset - 1 + 3
390
+ offset_neg = offset - 1 + 0
391
+
392
+ poslen = (num_cols + 1) // 2
393
+ neglen = num_cols - (num_cols + 1) // 2
394
+ mask_positive = np.zeros(poslen, dtype=np.float32)
395
+ mask_negative = np.zeros(neglen, dtype=np.float32)
396
+
397
+ mask_positive[offset_pos::acceleration] = 1
398
+ mask_negative[offset_neg::acceleration] = 1
399
+ mask_negative = np.flip(mask_negative)
400
+
401
+ mask = np.concatenate((mask_positive, mask_negative))
402
+
403
+ return np.fft.fftshift(mask) # shift mask and return
404
+
405
+
406
+ class MagicMaskFractionFunc(MagicMaskFunc):
407
+ """
408
+ Masking function for exploiting conjugate symmetry via offset-sampling.
409
+
410
+ This function applies the mask described in the following paper:
411
+
412
+ Defazio, A. (2019). Offset Sampling Improves Deep Learning based
413
+ Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint,
414
+ arXiv:1912.01101.
415
+
416
+ It is essentially an equispaced mask with an offset for the opposite site
417
+ of k-space. Since MRI images often exhibit approximate conjugate k-space
418
+ symmetry, this mask is generally more efficient than a standard equispaced
419
+ mask.
420
+
421
+ Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches
422
+ the target acceleration by adjusting the offsets.
423
+ """
424
+
425
+ def sample_mask(
426
+ self,
427
+ shape: Sequence[int],
428
+ offset: Optional[int],
429
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
430
+ """
431
+ Sample a new k-space mask.
432
+
433
+ This function samples and returns two components of a k-space mask: 1)
434
+ the center mask (e.g., for sensitivity map calculation) and 2) the
435
+ acceleration mask (for the edge of k-space). Both of these masks, as
436
+ well as the integer of low frequency samples, are returned.
437
+
438
+ Args:
439
+ shape: Shape of the k-space to subsample.
440
+ offset: Offset from 0 to begin mask (for equispaced masks).
441
+
442
+ Returns:
443
+ A 3-tuple contaiing 1) the mask for the center of k-space, 2) the
444
+ mask for the high frequencies of k-space, and 3) the integer count
445
+ of low frequency samples.
446
+ """
447
+ num_cols = shape[-2]
448
+ fraction_low_freqs, acceleration = self.choose_acceleration()
449
+ num_cols = shape[-2]
450
+ num_low_frequencies = round(num_cols * fraction_low_freqs)
451
+
452
+ # bound the number of low frequencies between 1 and target columns
453
+ target_columns_to_sample = round(num_cols / acceleration)
454
+ num_low_frequencies = max(
455
+ min(num_low_frequencies, target_columns_to_sample), 1
456
+ )
457
+
458
+ # adjust acceleration rate based on target acceleration.
459
+ adjusted_target_columns_to_sample = (
460
+ target_columns_to_sample - num_low_frequencies
461
+ )
462
+ adjusted_acceleration = 0
463
+ if adjusted_target_columns_to_sample > 0:
464
+ adjusted_acceleration = round(
465
+ num_cols / adjusted_target_columns_to_sample
466
+ )
467
+
468
+ center_mask = self.reshape_mask(
469
+ self.calculate_center_mask(shape, num_low_frequencies), shape
470
+ )
471
+ accel_mask = self.reshape_mask(
472
+ self.calculate_acceleration_mask(
473
+ num_cols, adjusted_acceleration, offset, num_low_frequencies
474
+ ),
475
+ shape,
476
+ )
477
+
478
+ return center_mask, accel_mask, num_low_frequencies
479
+
480
+
481
+ class Gaussian2DMaskFunc(MaskFunc):
482
+ """Gaussian 2D Masking
483
+
484
+ Args:
485
+ MaskFunc (_type_): _description_
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ accelerations: Sequence[int],
491
+ stds: Sequence[float],
492
+ seed: Optional[int] = None,
493
+ ):
494
+ """initialize Gaussian 2D Mask
495
+
496
+ Args:
497
+ accelerations (Sequence[int]): list of acceleration factors, when
498
+ generating a mask, an acceleration factor from this list will be chosen
499
+ stds (Sequence[float]): list of torch.Normal scale (~std) to choose from
500
+ seed (Optional[int], optional): Seed for selecting mask parameters. Defaults to None.
501
+ """
502
+ self.rng = np.random.RandomState(seed)
503
+ self.accelerations = accelerations
504
+ self.stds = stds
505
+
506
+ def __call__(
507
+ self,
508
+ shape: Sequence[int],
509
+ offset: Optional[int] = None,
510
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
511
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
512
+ if len(shape) < 3:
513
+ raise ValueError("Shape should have 3 or more dimensions")
514
+
515
+ acceleration = self.rng.choice(self.accelerations)
516
+ std = self.rng.choice(self.stds)
517
+
518
+ x, y = shape[-3], shape[-2]
519
+ mean_x = x // 2
520
+ mean_y = y // 2
521
+ num_samples_collected = 0
522
+
523
+ dist = D.Normal(
524
+ loc=torch.tensor([mean_x, mean_y], dtype=torch.float32),
525
+ scale=std,
526
+ )
527
+
528
+ N = (
529
+ int(1 / acceleration * x * y) + 10000
530
+ ) # add constant or won't reach desired subsampling rate
531
+ sample_x, sample_y = (
532
+ torch.zeros(N, dtype=torch.int),
533
+ torch.zeros(N, dtype=torch.int),
534
+ )
535
+
536
+ while num_samples_collected < N:
537
+ samples = dist.sample((N,)) # type: ignore
538
+ valid_samples = (
539
+ (samples[:, 0] >= 0)
540
+ & (samples[:, 0] < x)
541
+ & (samples[:, 1] >= 0)
542
+ & (samples[:, 1] < y)
543
+ )
544
+
545
+ valid_x = samples[valid_samples, 0].int()
546
+ valid_y = samples[valid_samples, 1].int()
547
+
548
+ num_to_take = min(N - num_samples_collected, valid_x.size(0))
549
+ sample_x[
550
+ num_samples_collected : num_samples_collected + num_to_take
551
+ ] = valid_x[:num_to_take]
552
+ sample_y[
553
+ num_samples_collected : num_samples_collected + num_to_take
554
+ ] = valid_y[:num_to_take]
555
+ num_samples_collected += num_to_take
556
+
557
+ mask = torch.zeros((x, y))
558
+ mask[sample_x, sample_y] = 1.0
559
+
560
+ # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
561
+ mask = mask.unsqueeze(-1) # (x, y, 1)
562
+ mask = mask.unsqueeze(0) # (1, x, y, 1)
563
+ mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
564
+
565
+ # num_low_freqs doesn't make sense so just return std (a number)
566
+ # returning None doesn't work since we can't stack for multiple batches
567
+ return mask, std
568
+
569
+
570
+ class Poisson2DMaskFunc(MaskFunc):
571
+ """
572
+ Variable Density Poisson Disk Sampling
573
+ https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.poisson.html#sigpy.mri.poisson
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ accelerations: Sequence[int],
579
+ stds: None,
580
+ seed: Optional[int] = None,
581
+ use_cache: bool = True,
582
+ ):
583
+ """initialize VDPD (Poisson) mask
584
+
585
+ Args:
586
+ accelerations (Sequence[int]): list of acceleration factors to
587
+ choose from
588
+ stds: Dummy param. Do not pass value. Defaults to None.
589
+ seed (Optional[int], optional): Seed for selecting mask params.
590
+ Defaults to None.
591
+ """
592
+ self.rng = np.random.RandomState(seed)
593
+ self.accelerations = accelerations
594
+ self.use_cache = use_cache
595
+ if use_cache:
596
+ self.cache: Dict[int, np.ndarray] = dict()
597
+ for acc in accelerations:
598
+ # assert os.path.exists(
599
+ # f"fastmri/poisson_cache/poisson_{acc}x.npy"
600
+ # )
601
+ # self.cache[acc] = np.load(
602
+ # f"fastmri/poisson_cache/poisson_{acc}x.npy"
603
+ # )
604
+ self.cache[acc] = np.load(
605
+ f"/global/homes/p/peterwg/more/medical-imaging/fastmri/poisson_cache/poisson_{acc}x.npy"
606
+ )
607
+
608
+ def __call__(
609
+ self,
610
+ shape: Sequence[int],
611
+ offset: Optional[int] = None,
612
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
613
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
614
+ if self.use_cache:
615
+ acceleration = self.rng.choice(self.accelerations)
616
+ return torch.from_numpy(self.cache[acceleration]), 1.0 # type: ignore
617
+ if len(shape) < 3:
618
+ raise ValueError("Shape should have 3 or more dimensions")
619
+
620
+ acceleration = self.rng.choice(self.accelerations)
621
+ x, y = shape[-3], shape[-2]
622
+
623
+ mask = poisson(img_shape=(x, y), accel=acceleration, dtype=np.float32)
624
+ mask = torch.from_numpy(mask)
625
+
626
+ # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
627
+ mask = mask.unsqueeze(-1) # (x, y, 1e
628
+ mask = mask.unsqueeze(0) # (1, x, y, 1)
629
+ mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
630
+
631
+ # num low freqs doesn't make sense here, so we return arbitrary value 1.0
632
+ return mask, 100.0
633
+
634
+
635
+ class Radial2DMaskFunc(MaskFunc):
636
+ """
637
+ Radial trajectory MRI masking method.
638
+ https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.radial.html#sigpy.mri.radial
639
+ """
640
+
641
+ def __init__(
642
+ self,
643
+ accelerations: Sequence[int],
644
+ arms: Optional[Sequence[int]],
645
+ seed: Optional[int] = None,
646
+ ):
647
+ """
648
+ initialize Radial mask
649
+
650
+ Args:
651
+ accelerations (Sequence[int]): list of acceleration factors to
652
+ choose from
653
+ arms: Number of radial arms.
654
+ seed (Optional[int], optional): Seed for selecting mask params.
655
+ Defaults to None.
656
+ """
657
+ self.rng = np.random.RandomState(seed)
658
+ self.accelerations = accelerations
659
+ self.arms = arms
660
+
661
+ def __call__(
662
+ self,
663
+ shape: Sequence[int],
664
+ offset: Optional[int] = None,
665
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
666
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
667
+ if len(shape) < 3:
668
+ raise ValueError("Shape should have 3 or more dimensions")
669
+
670
+ acceleration = self.rng.choice(self.accelerations)
671
+ x, y = shape[-3], shape[-2]
672
+ npoints = int(x * y * (1 / acceleration))
673
+ if self.arms:
674
+ arms = self.rng.choice(self.arms)
675
+ else:
676
+ points_per_arm = x // 3
677
+ arms = npoints // points_per_arm
678
+
679
+ # calculate radial parameters to satisfy acceleration factor
680
+ ntr = arms # num radial lines
681
+ nro = npoints // arms # num points on each radial line
682
+ ndim = 2 # 2D
683
+
684
+ # gen trajectory w/ shape (ntr, nro, ndim)
685
+ traj = radial(
686
+ coord_shape=[ntr, nro, ndim],
687
+ img_shape=(x, y),
688
+ golden=True,
689
+ dtype=int,
690
+ )
691
+
692
+ mask = torch.zeros(x, y, dtype=torch.float32)
693
+ x_coords = traj[..., 0].flatten() + (x // 2)
694
+ y_coords = traj[..., 1].flatten() + (y // 2)
695
+ mask[x_coords, y_coords] = 1.0
696
+
697
+ # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
698
+ mask = mask.unsqueeze(-1) # (x, y, 1)
699
+ mask = mask.unsqueeze(0) # (1, x, y, 1)
700
+ mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
701
+
702
+ # num low freqs doesn't make sense here, so we return arbitrary value 1.0
703
+ return mask, 100.0
704
+
705
+
706
+ class Spiral2DMaskFunc(MaskFunc):
707
+ """
708
+ Spiral trajectory MRI masking method.
709
+ https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.spiral.html#sigpy.mri.spiral
710
+ """
711
+
712
+ def __init__(
713
+ self,
714
+ accelerations: Sequence[int],
715
+ arms: Sequence[int],
716
+ seed: Optional[int] = None,
717
+ ):
718
+ """
719
+ initialize Radial mask
720
+
721
+ Args:
722
+ accelerations (Sequence[int]): list of acceleration factors to
723
+ choose from
724
+ arms: Number of radial arms.
725
+ seed (Optional[int], optional): Seed for selecting mask params.
726
+ Defaults to None.
727
+ """
728
+ self.rng = np.random.RandomState(seed)
729
+ self.accelerations = accelerations
730
+ self.arms = arms
731
+
732
+ def __call__(
733
+ self,
734
+ shape: Sequence[int],
735
+ offset: Optional[int] = None,
736
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
737
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
738
+ # TODO: implement
739
+ raise (NotImplementedError("Spiral2D not implemented"))
740
+ if len(shape) < 3:
741
+ raise ValueError("Shape should have 3 or more dimensions")
742
+ acceleration = self.rng.choice(self.accelerations)
743
+ arms = self.rng.choice(self.arms)
744
+ x, y = shape[-3], shape[-2]
745
+
746
+ # calculate radial parameters to satisfy acceleration factor
747
+ npoints = int(x * y * (1 / acceleration))
748
+
749
+ # gen trajectory w/ shape (ntr, nro, ndim)
750
+ traj = spiral(
751
+ N=npoints,
752
+ img_shape=(x, y),
753
+ golden=True,
754
+ dtype=int,
755
+ )
756
+
757
+ mask = torch.zeros(x, y, dtype=float)
758
+ x_coords = traj[..., 0].flatten() + (x // 2)
759
+ y_coords = traj[..., 1].flatten() + (y // 2)
760
+ mask[x_coords, y_coords] = 1.0
761
+
762
+ # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
763
+ mask = mask.unsqueeze(-1) # (x, y, 1)
764
+ mask = mask.unsqueeze(0) # (1, x, y, 1)
765
+ mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
766
+
767
+ # num low freqs doesn't make sense here, so we return arbitrary value 1.0
768
+ return mask, 100.0
769
+
770
+
771
+ def create_mask_for_mask_type(
772
+ mask_type_str: str,
773
+ center_fractions: Optional[Sequence],
774
+ accelerations: Sequence[int],
775
+ ) -> MaskFunc:
776
+ """
777
+ Creates a mask of the specified type.
778
+
779
+ Args:
780
+ center_fractions: What fraction of the center of k-space to include.
781
+ accelerations: What accelerations to apply.
782
+
783
+ Returns:
784
+ A mask func for the target mask type.
785
+ """
786
+ if mask_type_str == "random":
787
+ return RandomMaskFunc(center_fractions, accelerations)
788
+ elif mask_type_str == "equispaced":
789
+ return EquiSpacedMaskFunc(center_fractions, accelerations)
790
+ elif mask_type_str == "equispaced_fraction":
791
+ return EquispacedMaskFractionFunc(center_fractions, accelerations)
792
+ elif mask_type_str == "magic":
793
+ return MagicMaskFunc(center_fractions, accelerations)
794
+ elif mask_type_str == "magic_fraction":
795
+ return MagicMaskFractionFunc(center_fractions, accelerations)
796
+ elif mask_type_str == "gaussian_2d":
797
+ return Gaussian2DMaskFunc(
798
+ stds=center_fractions,
799
+ accelerations=accelerations,
800
+ )
801
+ elif mask_type_str == "poisson_2d":
802
+ return Poisson2DMaskFunc(
803
+ accelerations=accelerations,
804
+ stds=None,
805
+ )
806
+ elif mask_type_str == "radial_2d":
807
+ return Radial2DMaskFunc(
808
+ accelerations=accelerations,
809
+ arms=(
810
+ [int(arm) for arm in center_fractions]
811
+ if center_fractions
812
+ else None
813
+ ),
814
+ )
815
+ elif mask_type_str == "spiral_2d":
816
+ raise NotImplementedError("spiral_2d not implemented")
817
+ else:
818
+ raise ValueError(f"{mask_type_str} not supported")
fastmri/transforms.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import random
9
+ from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import torch
14
+
15
+ import fastmri
16
+
17
+ from .subsample import MaskFunc
18
+
19
+
20
+ def to_tensor(data: np.ndarray) -> torch.Tensor:
21
+ """
22
+ Convert numpy array to PyTorch tensor.
23
+
24
+ For complex arrays, the real and imaginary parts are stacked along the last
25
+ dimension.
26
+
27
+ Args:
28
+ data: Input numpy array.
29
+
30
+ Returns:
31
+ PyTorch version of data.
32
+ """
33
+ if np.iscomplexobj(data):
34
+ data = np.stack((data.real, data.imag), axis=-1)
35
+
36
+ return torch.from_numpy(data)
37
+
38
+
39
+ def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
40
+ """
41
+ Converts a complex torch tensor to numpy array.
42
+
43
+ Args:
44
+ data: Input data to be converted to numpy.
45
+
46
+ Returns:
47
+ Complex numpy version of data.
48
+ """
49
+ return torch.view_as_complex(data).numpy()
50
+
51
+
52
+ def apply_mask(
53
+ data: torch.Tensor,
54
+ mask_func: MaskFunc,
55
+ offset: Optional[int] = None,
56
+ seed: Optional[Union[int, Tuple[int, ...]]] = None,
57
+ padding: Optional[Sequence[int]] = None,
58
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
59
+ """
60
+ Subsample given k-space by multiplying with a mask.
61
+
62
+ Args:
63
+ data: The input k-space data. This should have at least 3 dimensions,
64
+ where dimensions -3 and -2 are the spatial dimensions, and the
65
+ final dimension has size 2 (for complex values).
66
+ mask_func: A function that takes a shape (tuple of ints) and a random
67
+ number seed and returns a mask.
68
+ seed: Seed for the random number generator.
69
+ padding: Padding value to apply for mask.
70
+
71
+ Returns:
72
+ tuple containing:
73
+ masked data: Subsampled k-space data.
74
+ mask: The generated mask.
75
+ num_low_frequencies: The number of low-resolution frequency samples
76
+ in the mask.
77
+ """
78
+ shape = (1,) * len(data.shape[:-3]) + tuple(data.shape[-3:])
79
+ mask, num_low_frequencies = mask_func(shape, offset, seed)
80
+ if padding is not None:
81
+ mask[..., : padding[0], :] = 0
82
+ mask[..., padding[1] :, :] = (
83
+ 0 # padding value inclusive on right of zeros
84
+ )
85
+
86
+ masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros
87
+
88
+ return masked_data, mask, num_low_frequencies
89
+
90
+
91
+ def mask_center(x: torch.Tensor, mask_from: int, mask_to: int) -> torch.Tensor:
92
+ """
93
+ Initializes a mask with the center filled in.
94
+
95
+ Args:
96
+ mask_from: Part of center to start filling.
97
+ mask_to: Part of center to end filling.
98
+
99
+ Returns:
100
+ A mask with the center filled.
101
+ """
102
+ mask = torch.zeros_like(x)
103
+ mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to]
104
+
105
+ return mask
106
+
107
+
108
+ def batched_mask_center(
109
+ x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor
110
+ ) -> torch.Tensor:
111
+ """
112
+ Initializes a mask with the center filled in.
113
+
114
+ Can operate with different masks for each batch element.
115
+
116
+ Args:
117
+ mask_from: Part of center to start filling.
118
+ mask_to: Part of center to end filling.
119
+
120
+ Returns:
121
+ A mask with the center filled.
122
+ """
123
+ if not mask_from.shape == mask_to.shape:
124
+ raise ValueError("mask_from and mask_to must match shapes.")
125
+ if not mask_from.ndim == 1:
126
+ raise ValueError("mask_from and mask_to must have 1 dimension.")
127
+ if not mask_from.shape[0] == 1:
128
+ if (not x.shape[0] == mask_from.shape[0]) or (
129
+ not x.shape[0] == mask_to.shape[0]
130
+ ):
131
+ raise ValueError(
132
+ "mask_from and mask_to must have batch_size length."
133
+ )
134
+
135
+ if mask_from.shape[0] == 1:
136
+ mask = mask_center(x, int(mask_from), int(mask_to))
137
+ else:
138
+ mask = torch.zeros_like(x)
139
+ for i, (start, end) in enumerate(zip(mask_from, mask_to)):
140
+ mask[i, :, :, start:end] = x[i, :, :, start:end]
141
+
142
+ return mask
143
+
144
+
145
+ def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
146
+ """
147
+ Apply a center crop to the input real image or batch of real images.
148
+
149
+ Args:
150
+ data: The input tensor to be center cropped. It should
151
+ have at least 2 dimensions and the cropping is applied along the
152
+ last two dimensions.
153
+ shape: The output shape. The shape should be smaller
154
+ than the corresponding dimensions of data.
155
+
156
+ Returns:
157
+ The center cropped image.
158
+ """
159
+ if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]):
160
+ raise ValueError("Invalid shapes.")
161
+
162
+ w_from = (data.shape[-2] - shape[0]) // 2
163
+ h_from = (data.shape[-1] - shape[1]) // 2
164
+ w_to = w_from + shape[0]
165
+ h_to = h_from + shape[1]
166
+
167
+ return data[..., w_from:w_to, h_from:h_to]
168
+
169
+
170
+ def complex_center_crop(
171
+ data: torch.Tensor, shape: Tuple[int, int]
172
+ ) -> torch.Tensor:
173
+ """
174
+ Apply a center crop to the input image or batch of complex images.
175
+
176
+ Args:
177
+ data: The complex input tensor to be center cropped. It should have at
178
+ least 3 dimensions and the cropping is applied along dimensions -3
179
+ and -2 and the last dimensions should have a size of 2.
180
+ shape: The output shape. The shape should be smaller than the
181
+ corresponding dimensions of data.
182
+ Returns:
183
+ The center cropped image
184
+ """
185
+ if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]):
186
+ raise ValueError("Invalid shapes.")
187
+
188
+ w_from = (data.shape[-3] - shape[0]) // 2
189
+ h_from = (data.shape[-2] - shape[1]) // 2
190
+ w_to = w_from + shape[0]
191
+ h_to = h_from + shape[1]
192
+
193
+ return data[..., w_from:w_to, h_from:h_to, :]
194
+
195
+
196
+ def center_crop_to_smallest(
197
+ x: torch.Tensor, y: torch.Tensor
198
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
199
+ """
200
+ Apply a center crop on the larger image to the size of the smaller.
201
+
202
+ The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at
203
+ dim=-1 and y is smaller than x at dim=-2, then the returned dimension will
204
+ be a mixture of the two.
205
+
206
+ Args:
207
+ x: The first image.
208
+ y: The second image.
209
+
210
+ Returns:
211
+ tuple of tensors x and y, each cropped to the minimim size.
212
+ """
213
+ smallest_width = min(x.shape[-1], y.shape[-1])
214
+ smallest_height = min(x.shape[-2], y.shape[-2])
215
+ x = center_crop(x, (smallest_height, smallest_width))
216
+ y = center_crop(y, (smallest_height, smallest_width))
217
+
218
+ return x, y
219
+
220
+
221
+ def normalize(
222
+ data: torch.Tensor,
223
+ mean: Union[float, torch.Tensor],
224
+ stddev: Union[float, torch.Tensor],
225
+ eps: Union[float, torch.Tensor] = 0.0,
226
+ ) -> torch.Tensor:
227
+ """
228
+ Normalize the given tensor.
229
+
230
+ Applies the formula (data - mean) / (stddev + eps).
231
+
232
+ Args:
233
+ data: Input data to be normalized.
234
+ mean: Mean value.
235
+ stddev: Standard deviation.
236
+ eps: Added to stddev to prevent dividing by zero.
237
+
238
+ Returns:
239
+ Normalized tensor.
240
+ """
241
+ return (data - mean) / (stddev + eps)
242
+
243
+
244
+ def normalize_instance(
245
+ data: torch.Tensor, eps: Union[float, torch.Tensor] = 0.0
246
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
247
+ """
248
+ Normalize the given tensor with instance norm/
249
+
250
+ Applies the formula (data - mean) / (stddev + eps), where mean and stddev
251
+ are computed from the data itself.
252
+
253
+ Args:
254
+ data: Input data to be normalized
255
+ eps: Added to stddev to prevent dividing by zero.
256
+
257
+ Returns:
258
+ torch.Tensor: Normalized tensor
259
+ """
260
+ mean = data.mean()
261
+ std = data.std()
262
+
263
+ return normalize(data, mean, std, eps), mean, std
264
+
265
+
266
+ class UnetSample(NamedTuple):
267
+ """
268
+ A subsampled image for U-Net reconstruction.
269
+
270
+ Args:
271
+ image: Subsampled image after inverse FFT.
272
+ target: The target image (if applicable).
273
+ mean: Per-channel mean values used for normalization.
274
+ std: Per-channel standard deviations used for normalization.
275
+ fname: File name.
276
+ slice_num: The slice index.
277
+ max_value: Maximum image value.
278
+ """
279
+
280
+ image: torch.Tensor
281
+ target: torch.Tensor
282
+ mean: torch.Tensor
283
+ std: torch.Tensor
284
+ fname: str
285
+ slice_num: int
286
+ max_value: float
287
+
288
+
289
+ class UnetDataTransform:
290
+ """
291
+ Data Transformer for training U-Net models.
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ which_challenge: str,
297
+ mask_func: Optional[MaskFunc] = None,
298
+ use_seed: bool = True,
299
+ ):
300
+ """
301
+ Args:
302
+ which_challenge: Challenge from ("singlecoil", "multicoil").
303
+ mask_func: Optional; A function that can create a mask of
304
+ appropriate shape.
305
+ use_seed: If true, this class computes a pseudo random number
306
+ generator seed from the filename. This ensures that the same
307
+ mask is used for all the slices of a given volume every time.
308
+ """
309
+ if which_challenge not in ("singlecoil", "multicoil"):
310
+ raise ValueError(
311
+ "Challenge should either be 'singlecoil' or 'multicoil'"
312
+ )
313
+
314
+ self.mask_func = mask_func
315
+ self.which_challenge = which_challenge
316
+ self.use_seed = use_seed
317
+
318
+ def __call__(
319
+ self,
320
+ kspace: np.ndarray,
321
+ mask: np.ndarray,
322
+ target: np.ndarray,
323
+ attrs: Dict,
324
+ fname: str,
325
+ slice_num: int,
326
+ ) -> Tuple[
327
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float
328
+ ]:
329
+ """
330
+ Args:
331
+ kspace: Input k-space of shape (num_coils, rows, cols) for
332
+ multi-coil data or (rows, cols) for single coil data.
333
+ mask: Mask from the test dataset.
334
+ target: Target image.
335
+ attrs: Acquisition related information stored in the HDF5 object.
336
+ fname: File name.
337
+ slice_num: Serial number of the slice.
338
+
339
+ Returns:
340
+ A tuple containing, zero-filled input image, the reconstruction
341
+ target, the mean used for normalization, the standard deviations
342
+ used for normalization, the filename, and the slice number.
343
+ """
344
+ kspace_torch = to_tensor(kspace)
345
+
346
+ # check for max value
347
+ max_value = attrs["max"] if "max" in attrs.keys() else 0.0
348
+
349
+ # apply mask
350
+ if self.mask_func:
351
+ seed = None if not self.use_seed else tuple(map(ord, fname))
352
+ # we only need first element, which is k-space after masking
353
+ masked_kspace = apply_mask(kspace_torch, self.mask_func, seed=seed)[
354
+ 0
355
+ ]
356
+ else:
357
+ masked_kspace = kspace_torch
358
+
359
+ # inverse Fourier transform to get zero filled solution
360
+ image = fastmri.ifft2c(masked_kspace)
361
+
362
+ # crop input to correct size
363
+ if target is not None:
364
+ crop_size = (target.shape[-2], target.shape[-1])
365
+ else:
366
+ crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
367
+
368
+ # check for FLAIR 203
369
+ if image.shape[-2] < crop_size[1]:
370
+ crop_size = (image.shape[-2], image.shape[-2])
371
+
372
+ image = complex_center_crop(image, crop_size)
373
+
374
+ # absolute value
375
+ image = fastmri.complex_abs(image)
376
+
377
+ # apply Root-Sum-of-Squares if multicoil data
378
+ if self.which_challenge == "multicoil":
379
+ image = fastmri.rss(image)
380
+
381
+ # normalize input
382
+ image, mean, std = normalize_instance(image, eps=1e-11)
383
+ image = image.clamp(-6, 6)
384
+
385
+ # normalize target
386
+ if target is not None:
387
+ target_torch = to_tensor(target)
388
+ target_torch = center_crop(target_torch, crop_size)
389
+ target_torch = normalize(target_torch, mean, std, eps=1e-11)
390
+ target_torch = target_torch.clamp(-6, 6)
391
+ else:
392
+ target_torch = torch.Tensor([0])
393
+
394
+ return UnetSample(
395
+ image=image,
396
+ target=target_torch,
397
+ mean=mean,
398
+ std=std,
399
+ fname=fname,
400
+ slice_num=slice_num,
401
+ max_value=max_value,
402
+ )
403
+
404
+
405
+ class VarNetSample(NamedTuple):
406
+ """
407
+ A sample of masked k-space for variational network reconstruction.
408
+
409
+ Args:
410
+ masked_kspace: k-space after applying sampling mask.
411
+ mask: The applied sampling mask.
412
+ num_low_frequencies: The number of samples for the densely-sampled
413
+ center.
414
+ target: The target image (if applicable).
415
+ fname: File name.
416
+ slice_num: The slice index.
417
+ max_value: Maximum image value.
418
+ crop_size: The size to crop the final image.
419
+ """
420
+
421
+ masked_kspace: torch.Tensor
422
+ mask: torch.Tensor
423
+ num_low_frequencies: Optional[int]
424
+ target: torch.Tensor
425
+ fname: str
426
+ slice_num: int
427
+ max_value: float
428
+ crop_size: Tuple[int, int]
429
+
430
+
431
+ class VarNetDataTransform:
432
+ """
433
+ Data Transformer for training VarNet models.
434
+ """
435
+
436
+ def __init__(
437
+ self, mask_func: Optional[MaskFunc] = None, use_seed: bool = True
438
+ ):
439
+ """
440
+ Args:
441
+ mask_func: Optional; A function that can create a mask of
442
+ appropriate shape. Defaults to None.
443
+ use_seed: If True, this class computes a pseudo random number
444
+ generator seed from the filename. This ensures that the same
445
+ mask is used for all the slices of a given volume every time.
446
+ """
447
+ self.mask_func = mask_func
448
+ self.use_seed = use_seed
449
+
450
+ def __call__(
451
+ self,
452
+ kspace: np.ndarray,
453
+ mask: np.ndarray,
454
+ target: Optional[np.ndarray],
455
+ attrs: Dict,
456
+ fname: str,
457
+ slice_num: int,
458
+ ) -> VarNetSample:
459
+ """
460
+ Args:
461
+ kspace: Input k-space of shape (num_coils, rows, cols) for
462
+ multi-coil data.
463
+ mask: Mask from the test dataset.
464
+ target: Target image.
465
+ attrs: Acquisition related information stored in the HDF5 object.
466
+ fname: File name.
467
+ slice_num: Serial number of the slice.
468
+
469
+ Returns:
470
+ A VarNetSample with the masked k-space, sampling mask, target
471
+ image, the filename, the slice number, the maximum image value
472
+ (from target), the target crop size, and the number of low
473
+ frequency lines sampled.
474
+ """
475
+ if target is not None:
476
+ target_torch = to_tensor(target)
477
+ max_value = attrs["max"]
478
+ else:
479
+ target_torch = torch.tensor(0)
480
+ max_value = 0.0
481
+
482
+ kspace_torch = to_tensor(kspace)
483
+ seed = None if not self.use_seed else tuple(map(ord, fname))
484
+ acq_start = attrs["padding_left"]
485
+ acq_end = attrs["padding_right"]
486
+
487
+ crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
488
+
489
+ if self.mask_func is not None:
490
+ masked_kspace, mask_torch, num_low_frequencies = apply_mask(
491
+ kspace_torch,
492
+ self.mask_func,
493
+ seed=seed,
494
+ padding=(acq_start, acq_end),
495
+ )
496
+
497
+ sample = VarNetSample(
498
+ masked_kspace=masked_kspace,
499
+ mask=mask_torch.to(torch.bool),
500
+ num_low_frequencies=num_low_frequencies,
501
+ target=target_torch,
502
+ fname=fname,
503
+ slice_num=slice_num,
504
+ max_value=max_value,
505
+ crop_size=crop_size,
506
+ )
507
+ else:
508
+ masked_kspace = kspace_torch
509
+ shape = np.array(kspace_torch.shape)
510
+ num_cols = shape[-2]
511
+ shape[:-3] = 1
512
+ mask_shape = [1] * len(shape)
513
+ mask_shape[-2] = num_cols
514
+ mask_torch = torch.from_numpy(
515
+ mask.reshape(*mask_shape).astype(np.float32)
516
+ )
517
+ mask_torch = mask_torch.reshape(*mask_shape)
518
+ mask_torch[:, :, :acq_start] = 0
519
+ mask_torch[:, :, acq_end:] = 0
520
+
521
+ sample = VarNetSample(
522
+ masked_kspace=masked_kspace,
523
+ mask=mask_torch.to(torch.bool),
524
+ num_low_frequencies=0,
525
+ target=target_torch,
526
+ fname=fname,
527
+ slice_num=slice_num,
528
+ max_value=max_value,
529
+ crop_size=crop_size,
530
+ )
531
+
532
+ # whether to crop samples for batch processing
533
+ batch_crop = False
534
+
535
+ def save_img(x, fname):
536
+ slice_kspace2 = x
537
+ slice_image = fastmri.ifft2c(
538
+ slice_kspace2
539
+ ) # Apply Inverse Fourier Transform to get the complex image
540
+ slice_image_abs = fastmri.complex_abs(
541
+ slice_image
542
+ ) # Compute absolute value to get a real image
543
+ slice_image_rss = fastmri.rss(slice_image_abs, dim=0)
544
+ plt.imsave(f"{fname}.png", torch.abs(slice_image_rss), cmap="gray")
545
+
546
+ def save_raw_img(x, fname):
547
+ # slice_kspace2 = x
548
+ # slice_image = fastmri.ifft2c(
549
+ # slice_kspace2
550
+ # ) # Apply Inverse Fourier Transform to get the complex image
551
+ # slice_image_abs = fastmri.complex_abs(
552
+ # slice_image
553
+ # ) # Compute absolute value to get a real image
554
+ x = fastmri.rss(x, dim=0)[:, :, 0]
555
+
556
+ plt.imsave(f"{fname}.png", torch.abs(x))
557
+
558
+ if batch_crop:
559
+ # crop kspace data to minx, miny size (640, 320 cols)
560
+ square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
561
+ # print(square_crop)
562
+ cropped_kspace = fastmri.fft2c(
563
+ complex_center_crop(
564
+ fastmri.ifft2c(sample.masked_kspace), square_crop
565
+ )
566
+ )
567
+ cropped_kspace = complex_center_crop(cropped_kspace, (320, 320))
568
+ # print(cropped_kspace.shape)
569
+ # exit(0)
570
+
571
+ # CHECK: debugging purposes
572
+ # save_img(sample.masked_kspace, "og")
573
+ # save_img(cropped_kspace, "cropped")
574
+ # save_raw_img(sample.masked_kspace, "og_kspace")
575
+ # save_raw_img(cropped_kspace, "cropped_kspace")
576
+
577
+ # exit(0)
578
+
579
+ # crop mask shape
580
+ h_from = (mask_torch.shape[-2] - 320) // 2
581
+ h_to = h_from + 320
582
+ cropped_mask = mask_torch[..., :, h_from:h_to, :]
583
+
584
+ sample = VarNetSample(
585
+ masked_kspace=cropped_kspace,
586
+ mask=cropped_mask.to(torch.bool),
587
+ num_low_frequencies=0,
588
+ target=target_torch,
589
+ fname=fname,
590
+ slice_num=slice_num,
591
+ max_value=max_value,
592
+ crop_size=crop_size,
593
+ )
594
+ return sample
595
+
596
+
597
+ class EnhancedVarNetDataTransform(VarNetDataTransform):
598
+ """
599
+ Enhanced Data Transformer for training VarNet models with additional functionality.
600
+ - allows for training on multiple patterns
601
+ """
602
+
603
+ def __init__(
604
+ self, mask_funcs: List[MaskFunc] = None, use_seed: bool = True
605
+ ):
606
+ self.mask_funcs = mask_funcs
607
+ self.use_seed = use_seed
608
+
609
+ def __call__(
610
+ self,
611
+ kspace: np.ndarray,
612
+ mask: np.ndarray,
613
+ target: Optional[np.ndarray],
614
+ attrs: Dict,
615
+ fname: str,
616
+ slice_num: int,
617
+ ) -> VarNetSample:
618
+ """
619
+ Args:
620
+ kspace: Input k-space of shape (num_coils, rows, cols) for
621
+ multi-coil data.
622
+ mask: Mask from the test dataset.
623
+ use mask for test data see og VarNetDataTransform __call__
624
+ target: Target image.
625
+ attrs: Acquisition related information stored in the HDF5 object.
626
+ fname: File name.
627
+ slice_num: Serial number of the slice.
628
+
629
+ Returns:
630
+ A VarNetSample with the masked k-space, sampling mask, target
631
+ image, the filename, the slice number, the maximum image value
632
+ (from target), the target crop size, and the number of low
633
+ frequency lines sampled.
634
+ """
635
+ if target is not None:
636
+ target_torch = to_tensor(target)
637
+ max_value = attrs["max"]
638
+ else:
639
+ target_torch = torch.tensor(0)
640
+ max_value = 0.0
641
+
642
+ kspace_torch = to_tensor(kspace)
643
+ seed = None if not self.use_seed else tuple(map(ord, fname))
644
+ acq_start = attrs["padding_left"]
645
+ acq_end = attrs["padding_right"]
646
+
647
+ crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
648
+
649
+ # choose one of the masking functions provided randomly
650
+ mask_func = random.choice(self.mask_funcs)
651
+
652
+ masked_kspace, mask_torch, num_low_frequencies = apply_mask(
653
+ kspace_torch,
654
+ mask_func,
655
+ seed=seed,
656
+ padding=(acq_start, acq_end),
657
+ )
658
+
659
+ # print(masked_kspace.shape)
660
+ # print(mask_torch.shape)
661
+
662
+ # torch.save(masked_kspace, f"masked_kspace_{slice_num}.pkl")
663
+ # torch.save(mask_torch, f"mask_torch_{slice_num}.pkl")
664
+
665
+ sample = VarNetSample(
666
+ masked_kspace=masked_kspace,
667
+ mask=mask_torch.to(torch.bool),
668
+ num_low_frequencies=num_low_frequencies,
669
+ target=target_torch,
670
+ fname=fname,
671
+ slice_num=slice_num,
672
+ max_value=max_value,
673
+ crop_size=crop_size,
674
+ )
675
+
676
+ # whether to crop samples for batch processing
677
+ batch_crop = False
678
+
679
+ if batch_crop:
680
+ # crop kspace data to minx, miny size (640, 320 cols)
681
+ square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
682
+ # print(square_crop)
683
+ cropped_kspace = fastmri.fft2c(
684
+ complex_center_crop(
685
+ fastmri.ifft2c(sample.masked_kspace), square_crop
686
+ )
687
+ )
688
+ # cropped_kspace = complex_center_crop(cropped_kspace, (640, 320))
689
+
690
+ # exit(0)
691
+
692
+ # crop mask shape
693
+ h_from = (mask_torch.shape[-2] - 320) // 2
694
+ h_to = h_from + 320
695
+ cropped_mask = mask_torch[..., :, h_from:h_to, :]
696
+
697
+ sample = VarNetSample(
698
+ masked_kspace=cropped_kspace,
699
+ mask=cropped_mask.to(torch.bool),
700
+ num_low_frequencies=0,
701
+ target=target_torch,
702
+ fname=fname,
703
+ slice_num=slice_num,
704
+ max_value=max_value,
705
+ crop_size=crop_size,
706
+ )
707
+
708
+ return sample
709
+
710
+
711
+ class MiniCoilSample(NamedTuple):
712
+ """
713
+ A sample of masked coil-compressed k-space for reconstruction.
714
+
715
+ Args:
716
+ kspace: the original k-space before masking.
717
+ masked_kspace: k-space after applying sampling mask.
718
+ mask: The applied sampling mask.
719
+ num_low_frequencies: The number of samples for the densely-sampled
720
+ center.
721
+ target: The target image (if applicable).
722
+ fname: File name.
723
+ slice_num: The slice index.
724
+ max_value: Maximum image value.
725
+ crop_size: The size to crop the final image.
726
+ """
727
+
728
+ kspace: torch.Tensor
729
+ masked_kspace: torch.Tensor
730
+ mask: torch.Tensor
731
+ target: torch.Tensor
732
+ fname: str
733
+ slice_num: int
734
+ max_value: float
735
+ crop_size: Tuple[int, int]
736
+
737
+
738
+ class MiniCoilTransform:
739
+ """
740
+ Multi-coil compressed transform, for faster prototyping.
741
+ """
742
+
743
+ def __init__(
744
+ self,
745
+ mask_func: Optional[MaskFunc] = None,
746
+ use_seed: Optional[bool] = True,
747
+ crop_size: Optional[tuple] = None,
748
+ num_compressed_coils: Optional[int] = None,
749
+ ):
750
+ """
751
+ Args:
752
+ mask_func: Optional; A function that can create a mask of
753
+ appropriate shape. Defaults to None.
754
+ use_seed: If True, this class computes a pseudo random number
755
+ generator seed from the filename. This ensures that the same
756
+ mask is used for all the slices of a given volume every time.
757
+ crop_size: Image dimensions for mini MR images.
758
+ num_compressed_coils: Number of coils to output from coil
759
+ compression.
760
+ """
761
+ self.mask_func = mask_func
762
+ self.use_seed = use_seed
763
+ self.crop_size = crop_size
764
+ self.num_compressed_coils = num_compressed_coils
765
+
766
+ def __call__(self, kspace, mask, target, attrs, fname, slice_num):
767
+ """
768
+ Args:
769
+ kspace: Input k-space of shape (num_coils, rows, cols) for
770
+ multi-coil data.
771
+ mask: Mask from the test dataset. Not used if mask_func is defined.
772
+ target: Target image.
773
+ attrs: Acquisition related information stored in the HDF5 object.
774
+ fname: File name.
775
+ slice_num: Serial number of the slice.
776
+
777
+ Returns:
778
+ tuple containing:
779
+ kspace: original kspace (used for active acquisition only).
780
+ masked_kspace: k-space after applying sampling mask. If there
781
+ is no mask or mask_func, returns same as kspace.
782
+ mask: The applied sampling mask
783
+ target: The target image (if applicable). The target is built
784
+ from the RSS opp of all coils pre-compression.
785
+ fname: File name.
786
+ slice_num: The slice index.
787
+ max_value: Maximum image value.
788
+ crop_size: The size to crop the final image.
789
+ """
790
+ if target is not None:
791
+ target = to_tensor(target)
792
+ max_value = attrs["max"]
793
+ else:
794
+ target = torch.tensor(0)
795
+ max_value = 0.0
796
+
797
+ if self.crop_size is None:
798
+ crop_size = torch.tensor(
799
+ [attrs["recon_size"][0], attrs["recon_size"][1]]
800
+ )
801
+ else:
802
+ if isinstance(self.crop_size, tuple) or isinstance(
803
+ self.crop_size, list
804
+ ):
805
+ assert len(self.crop_size) == 2
806
+ if self.crop_size[0] is None or self.crop_size[1] is None:
807
+ crop_size = torch.tensor(
808
+ [attrs["recon_size"][0], attrs["recon_size"][1]]
809
+ )
810
+ else:
811
+ crop_size = torch.tensor(self.crop_size)
812
+ elif isinstance(self.crop_size, int):
813
+ crop_size = torch.tensor((self.crop_size, self.crop_size))
814
+ else:
815
+ raise ValueError(
816
+ "`crop_size` should be None, tuple, list, or int, not:"
817
+ f" {type(self.crop_size)}"
818
+ )
819
+
820
+ if self.num_compressed_coils is None:
821
+ num_compressed_coils = kspace.shape[0]
822
+ else:
823
+ num_compressed_coils = self.num_compressed_coils
824
+
825
+ seed = None if not self.use_seed else tuple(map(ord, fname))
826
+ acq_start = 0
827
+ acq_end = crop_size[1]
828
+
829
+ # new cropping section
830
+ square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
831
+ kspace = fastmri.fft2c(
832
+ complex_center_crop(fastmri.ifft2c(to_tensor(kspace)), square_crop)
833
+ ).numpy()
834
+ kspace = complex_center_crop(kspace, crop_size)
835
+
836
+ # we calculate the target before coil compression. This causes the mini
837
+ # simulation to be one where we have a 15-coil, low-resolution image
838
+ # and our reconstructor has an SVD coil approximation. This is a little
839
+ # bit more realistic than doing the target after SVD compression
840
+ target = fastmri.rss_complex(fastmri.ifft2c(to_tensor(kspace)))
841
+ max_value = target.max()
842
+
843
+ # apply coil compression
844
+ new_shape = (num_compressed_coils,) + kspace.shape[1:]
845
+ kspace = np.reshape(kspace, (kspace.shape[0], -1))
846
+ left_vec, _, _ = np.linalg.svd(
847
+ kspace, compute_uv=True, full_matrices=False
848
+ )
849
+ kspace = np.reshape(
850
+ np.array(np.matrix(left_vec[:, :num_compressed_coils]).H @ kspace),
851
+ new_shape,
852
+ )
853
+ kspace = to_tensor(kspace)
854
+
855
+ # Mask kspace
856
+ if self.mask_func:
857
+ masked_kspace, mask, _ = apply_mask(
858
+ kspace, self.mask_func, seed, (acq_start, acq_end)
859
+ )
860
+ mask = mask.byte()
861
+ elif mask is not None:
862
+ masked_kspace = kspace
863
+ shape = np.array(kspace.shape)
864
+ num_cols = shape[-2]
865
+ shape[:-3] = 1
866
+ mask_shape = [1] * len(shape)
867
+ mask_shape[-2] = num_cols
868
+ mask = torch.from_numpy(
869
+ mask.reshape(*mask_shape).astype(np.float32)
870
+ )
871
+ mask = mask.reshape(*mask_shape)
872
+ mask = mask.byte()
873
+ else:
874
+ masked_kspace = kspace
875
+ shape = np.array(kspace.shape)
876
+ num_cols = shape[-2]
877
+
878
+ return MiniCoilSample(
879
+ kspace,
880
+ masked_kspace,
881
+ mask,
882
+ target,
883
+ fname,
884
+ slice_num,
885
+ max_value,
886
+ crop_size,
887
+ )
888
+
889
+
890
+ """
891
+ sens maps & feature transformations
892
+ - expand
893
+ - reduce
894
+ - batch -> chan
895
+ - chan -> batch
896
+ """
897
+
898
+
899
+ def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
900
+ """
901
+ Calculates F (x sens_maps)
902
+
903
+ Parameters
904
+ ----------
905
+ x : ndarray
906
+ Single-channel image of shape (..., H, W, 2)
907
+ sens_maps : ndarray
908
+ Sensitivity maps (image space)
909
+
910
+ Returns
911
+ -------
912
+ ndarray
913
+ Result of the operation F (x sens_maps)
914
+ """
915
+ return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
916
+
917
+
918
+ def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
919
+ """
920
+ Calculates F^{-1}(k) * conj(sens_maps)
921
+ where conj(sens_maps) is the element-wise applied complex conjugate
922
+
923
+ Parameters
924
+ ----------
925
+ k : ndarray
926
+ Multi-channel k-space of shape (B, C, H, W, 2)
927
+ sens_maps : ndarray
928
+ Sensitivity maps (image space)
929
+
930
+ Returns
931
+ -------
932
+ ndarray
933
+ Result of the operation F^{-1}(k) * conj(sens_maps)
934
+ """
935
+ return fastmri.complex_mul(
936
+ fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)
937
+ ).sum(dim=1, keepdim=True)
938
+
939
+
940
+ def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
941
+ """Reshapes batched multi-channel samples into multiple single channel samples.
942
+
943
+ Parameters
944
+ ----------
945
+ x : torch.Tensor
946
+ x has shape (b, c, h, w, 2)
947
+
948
+ Returns
949
+ -------
950
+ Tuple[torch.Tensor, int]
951
+ tensor of shape (b * c, 1, h, w, 2), b
952
+ """
953
+ b, c, h, w, comp = x.shape
954
+ return x.view(b * c, 1, h, w, comp), b
955
+
956
+
957
+ def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
958
+ """Reshapes batched independent samples into original multi-channel samples.
959
+
960
+ Parameters
961
+ ----------
962
+ x : torch.Tensor
963
+ tensor of shape (b * c, 1, h, w, 2)
964
+ batch_size : int
965
+ batch size
966
+
967
+ Returns
968
+ -------
969
+ torch.Tensor
970
+ original multi-channel tensor of shape (b, c, h, w, 2)
971
+ """
972
+ bc, _, h, w, comp = x.shape
973
+ c = bc // batch_size
974
+ return x.view(batch_size, c, h, w, comp)
models/lightning/mri_module.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified for use in <TODO: paper name>
3
+ - minified and removed extraneous abstractions
4
+ - updated to latest version of lightning
5
+
6
+ Copyright (c) Facebook, Inc. and its affiliates.
7
+
8
+ This source code is licensed under the MIT license found in the
9
+ LICENSE file in the root directory of this source tree.
10
+ """
11
+
12
+ from collections import defaultdict
13
+ from io import BytesIO
14
+ import pathlib
15
+ import os
16
+ from argparse import ArgumentParser
17
+ from collections import defaultdict
18
+
19
+ import numpy as np
20
+ import wandb
21
+ import lightning as L
22
+ import torch
23
+ from torchmetrics.metric import Metric
24
+ import matplotlib
25
+ import matplotlib.pyplot as plt
26
+ from PIL import Image
27
+
28
+ matplotlib.use("Agg")
29
+
30
+ from fastmri import evaluate
31
+
32
+
33
+ class DistributedMetricSum(Metric):
34
+ def __init__(self, dist_sync_on_step=True):
35
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
36
+
37
+ self.add_state(
38
+ "quantity", default=torch.tensor(0.0), dist_reduce_fx="sum"
39
+ )
40
+
41
+ def update(self, batch: torch.Tensor): # type: ignore
42
+ self.quantity += batch
43
+
44
+ def compute(self):
45
+ return self.quantity
46
+
47
+
48
+ class MriModule(L.LightningModule):
49
+ """
50
+ Abstract super class for deep learning reconstruction models.
51
+
52
+ This is a subclass of the LightningModule class from lightning,
53
+ with some additional functionality specific to fastMRI:
54
+ - Evaluating reconstructions
55
+ - Visualization
56
+
57
+ To implement a new reconstruction model, inherit from this class and
58
+ implement the following methods:
59
+ - training_step: Define what happens in one step of training
60
+ - validation_step: Define what happens in one step of validation
61
+ - test_step: Define what happens in one step of testing
62
+ - configure_optimizers: Create and return the optimizers
63
+
64
+ Other methods from LightningModule can be overridden as needed.
65
+ """
66
+
67
+ def __init__(self, num_log_images: int = 16):
68
+ """
69
+ Initialize the MRI module.
70
+
71
+ Parameters
72
+ ----------
73
+ num_log_images : int, optional
74
+ Number of images to log. Defaults to 16.
75
+ """
76
+ super().__init__()
77
+
78
+ self.num_log_images = num_log_images
79
+ self.val_log_indices = [1, 2, 3, 4, 5]
80
+ self.val_batch_results = []
81
+
82
+ self.NMSE = DistributedMetricSum()
83
+ self.SSIM = DistributedMetricSum()
84
+ self.PSNR = DistributedMetricSum()
85
+ self.ValLoss = DistributedMetricSum()
86
+ self.TotExamples = DistributedMetricSum()
87
+ self.TotSliceExamples = DistributedMetricSum()
88
+
89
+ def log_image(self, name, image):
90
+ if self.logger != None:
91
+ self.logger.log_image(
92
+ key=f"{name}", images=[image], caption=[{self.global_step}]
93
+ )
94
+
95
+ def on_validation_batch_end(
96
+ self, outputs, batch, batch_idx, dataloader_idx=0
97
+ ):
98
+ # breakpoint()
99
+ val_logs = outputs
100
+
101
+ mse_vals = defaultdict(dict)
102
+ target_norms = defaultdict(dict)
103
+ ssim_vals = defaultdict(dict)
104
+ max_vals = dict()
105
+
106
+ for i, fname in enumerate(val_logs["fname"]):
107
+ if i == 0 and batch_idx in self.val_log_indices:
108
+ key = f"val_images_idx_{batch_idx}"
109
+ target = val_logs["target"][i].unsqueeze(0)
110
+ output = val_logs["output"][i].unsqueeze(0)
111
+ error = torch.abs(target - output)
112
+ output = output / output.max()
113
+ target = target / target.max()
114
+ error = error / error.max()
115
+ self.log_image(f"{key}/target", target)
116
+ self.log_image(f"{key}/reconstruction", output)
117
+ self.log_image(f"{key}/error", error)
118
+ slice_num = int(val_logs["slice_num"][i].cpu())
119
+
120
+ maxval = val_logs["max_value"][i].cpu().numpy()
121
+ output = val_logs["output"][i].cpu().numpy()
122
+ target = val_logs["target"][i].cpu().numpy()
123
+ mse_vals[fname][slice_num] = torch.tensor(
124
+ evaluate.mse(target, output)
125
+ ).view(1)
126
+ target_norms[fname][slice_num] = torch.tensor(
127
+ evaluate.mse(target, np.zeros_like(target))
128
+ ).view(1)
129
+ ssim_vals[fname][slice_num] = torch.tensor(
130
+ evaluate.ssim(
131
+ target[None, ...], output[None, ...], maxval=maxval
132
+ )
133
+ ).view(1)
134
+ max_vals[fname] = maxval
135
+
136
+ self.val_batch_results.append(
137
+ {
138
+ "slug": val_logs["slug"],
139
+ "val_loss": val_logs["val_loss"],
140
+ "mse_vals": dict(mse_vals),
141
+ "target_norms": dict(target_norms),
142
+ "ssim_vals": dict(ssim_vals),
143
+ "max_vals": max_vals,
144
+ }
145
+ )
146
+
147
+ def on_validation_epoch_end(self):
148
+ val_logs = self.val_batch_results
149
+
150
+ dataset_metrics = defaultdict(
151
+ lambda: {
152
+ "losses": [],
153
+ "mse_vals": defaultdict(dict),
154
+ "target_norms": defaultdict(dict),
155
+ "ssim_vals": defaultdict(dict),
156
+ "max_vals": dict(),
157
+ }
158
+ )
159
+
160
+ # use dict updates to handle duplicate slices
161
+ for val_log in val_logs:
162
+ slug = val_log["slug"]
163
+ dataset_metrics[slug]["losses"].append(val_log["val_loss"].view(-1))
164
+
165
+ for k in val_log["mse_vals"].keys():
166
+ dataset_metrics[slug]["mse_vals"][k].update(
167
+ val_log["mse_vals"][k]
168
+ )
169
+ for k in val_log["target_norms"].keys():
170
+ dataset_metrics[slug]["target_norms"][k].update(
171
+ val_log["target_norms"][k]
172
+ )
173
+ for k in val_log["ssim_vals"].keys():
174
+ dataset_metrics[slug]["ssim_vals"][k].update(
175
+ val_log["ssim_vals"][k]
176
+ )
177
+ for k in val_log["max_vals"]:
178
+ dataset_metrics[slug]["max_vals"][k] = val_log["max_vals"][k]
179
+
180
+ metrics_to_plot = {"psnr": [], "ssim": [], "nmse": []}
181
+ slugs = []
182
+
183
+ for slug, metrics_data in dataset_metrics.items():
184
+ mse_vals, target_norms, ssim_vals, max_vals, losses = (
185
+ metrics_data["mse_vals"],
186
+ metrics_data["target_norms"],
187
+ metrics_data["ssim_vals"],
188
+ metrics_data["max_vals"],
189
+ metrics_data["losses"],
190
+ )
191
+ # check to make sure we have all files in all metrics
192
+ assert (
193
+ mse_vals.keys()
194
+ == target_norms.keys()
195
+ == ssim_vals.keys()
196
+ == max_vals.keys()
197
+ )
198
+
199
+ # apply means across image volumes
200
+ metrics = {"nmse": 0, "ssim": 0, "psnr": 0}
201
+ metric_values = {
202
+ "nmse": [],
203
+ "ssim": [],
204
+ "psnr": [],
205
+ } # to store individual values for std
206
+ local_examples = 0
207
+
208
+ for fname in mse_vals.keys():
209
+ local_examples = local_examples + 1
210
+ mse_val = torch.mean(
211
+ torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])
212
+ )
213
+ target_norm = torch.mean(
214
+ torch.cat(
215
+ [v.view(-1) for _, v in target_norms[fname].items()]
216
+ )
217
+ )
218
+ nmse = mse_val / target_norm
219
+ psnr = 20 * torch.log10(
220
+ torch.tensor(
221
+ max_vals[fname],
222
+ dtype=mse_val.dtype,
223
+ device=mse_val.device,
224
+ )
225
+ ) - 10 * torch.log10(mse_val)
226
+ ssim = torch.mean(
227
+ torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()])
228
+ )
229
+
230
+ # Accumulate metric values
231
+ metrics["nmse"] += nmse
232
+ metrics["psnr"] += psnr
233
+ metrics["ssim"] += ssim
234
+
235
+ # Store individual metric values for std calculation
236
+ metric_values["nmse"].append(nmse)
237
+ metric_values["psnr"].append(psnr)
238
+ metric_values["ssim"].append(ssim)
239
+
240
+ # reduce across ddp via sum
241
+ metrics["nmse"] = self.NMSE(metrics["nmse"])
242
+ metrics["ssim"] = self.SSIM(metrics["ssim"])
243
+ metrics["psnr"] = self.PSNR(metrics["psnr"])
244
+
245
+ tot_examples = self.TotExamples(torch.tensor(local_examples))
246
+ val_loss = self.ValLoss(torch.sum(torch.cat(losses))) # type: ignore
247
+ tot_slice_examples = self.TotSliceExamples(
248
+ torch.tensor(len(losses), dtype=torch.float)
249
+ )
250
+
251
+ metrics_to_plot["nmse"].append(
252
+ (
253
+ (metrics["nmse"] / tot_examples).item(),
254
+ torch.std(torch.stack(metric_values["nmse"])).item(),
255
+ )
256
+ )
257
+ metrics_to_plot["psnr"].append(
258
+ (
259
+ (metrics["psnr"] / tot_examples).item(),
260
+ torch.std(torch.stack(metric_values["psnr"])).item(),
261
+ )
262
+ )
263
+ metrics_to_plot["ssim"].append(
264
+ (
265
+ (metrics["ssim"] / tot_examples).item(),
266
+ torch.std(torch.stack(metric_values["ssim"])).item(),
267
+ )
268
+ )
269
+ slugs.append(slug)
270
+
271
+ # Log the mean values
272
+ self.log(
273
+ f"{slug}--validation_loss",
274
+ val_loss / tot_slice_examples,
275
+ prog_bar=True,
276
+ )
277
+ for metric, value in metrics.items():
278
+ self.log(f"{slug}--val_metrics_{metric}", value / tot_examples)
279
+
280
+ # Calculate and log the standard deviation for each metric
281
+ for metric, values in metric_values.items():
282
+ std_value = torch.std(torch.stack(values))
283
+ self.log(f"{slug}--val_metrics_{metric}_std", std_value)
284
+
285
+ # generate graph
286
+ # breakpoint()
287
+ for metric_name, values in metrics_to_plot.items():
288
+ scores = [val[0] for val in values]
289
+ std_devs = [val[1] for val in values]
290
+
291
+ plt.figure(figsize=(10, 6))
292
+ plt.bar(slugs, scores, yerr=std_devs, capsize=5)
293
+ plt.xlabel("Dataset Slug")
294
+ plt.ylabel(f"{metric_name.upper()} Score")
295
+ plt.title(
296
+ f"{metric_name.upper()} per Dataset with Standard Deviation"
297
+ )
298
+ plt.xticks(rotation=45)
299
+ plt.tight_layout()
300
+
301
+ # Save the plot
302
+ buf = BytesIO()
303
+ plt.savefig(buf, format="png")
304
+ buf.seek(0)
305
+ image = Image.open(buf)
306
+ image_array = np.array(image)
307
+ self.log_image(f"summary_plot_{metric_name}", image_array)
308
+ buf.close()
309
+ plt.close()
310
+
311
+ def OLD_on_validation_epoch_end(self):
312
+ val_logs = self.val_batch_results
313
+
314
+ # aggregate losses
315
+ losses = []
316
+ mse_vals = defaultdict(dict)
317
+ target_norms = defaultdict(dict)
318
+ ssim_vals = defaultdict(dict)
319
+ max_vals = dict()
320
+
321
+ # use dict updates to handle duplicate slices
322
+ for val_log in val_logs:
323
+ losses.append(val_log["val_loss"].view(-1))
324
+
325
+ for k in val_log["mse_vals"].keys():
326
+ mse_vals[k].update(val_log["mse_vals"][k])
327
+ for k in val_log["target_norms"].keys():
328
+ target_norms[k].update(val_log["target_norms"][k])
329
+ for k in val_log["ssim_vals"].keys():
330
+ ssim_vals[k].update(val_log["ssim_vals"][k])
331
+ for k in val_log["max_vals"]:
332
+ max_vals[k] = val_log["max_vals"][k]
333
+
334
+ # check to make sure we have all files in all metrics
335
+ assert (
336
+ mse_vals.keys()
337
+ == target_norms.keys()
338
+ == ssim_vals.keys()
339
+ == max_vals.keys()
340
+ )
341
+
342
+ # apply means across image volumes
343
+ metrics = {"nmse": 0, "ssim": 0, "psnr": 0}
344
+ local_examples = 0
345
+ for fname in mse_vals.keys():
346
+ local_examples = local_examples + 1
347
+ mse_val = torch.mean(
348
+ torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])
349
+ )
350
+ target_norm = torch.mean(
351
+ torch.cat([v.view(-1) for _, v in target_norms[fname].items()])
352
+ )
353
+ metrics["nmse"] = metrics["nmse"] + mse_val / target_norm
354
+ metrics["psnr"] = (
355
+ metrics["psnr"]
356
+ + 20
357
+ * torch.log10(
358
+ torch.tensor(
359
+ max_vals[fname],
360
+ dtype=mse_val.dtype,
361
+ device=mse_val.device,
362
+ )
363
+ )
364
+ - 10 * torch.log10(mse_val)
365
+ )
366
+ metrics["ssim"] = metrics["ssim"] + torch.mean(
367
+ torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()])
368
+ )
369
+
370
+ # reduce across ddp via sum
371
+ metrics["nmse"] = self.NMSE(metrics["nmse"])
372
+ metrics["ssim"] = self.SSIM(metrics["ssim"])
373
+ metrics["psnr"] = self.PSNR(metrics["psnr"])
374
+
375
+ tot_examples = self.TotExamples(torch.tensor(local_examples))
376
+ val_loss = self.ValLoss(torch.sum(torch.cat(losses)))
377
+ tot_slice_examples = self.TotSliceExamples(
378
+ torch.tensor(len(losses), dtype=torch.float)
379
+ )
380
+
381
+ self.log(
382
+ "validation_loss", val_loss / tot_slice_examples, prog_bar=True
383
+ )
384
+ for metric, value in metrics.items():
385
+ self.log(f"val_metrics_{metric}", value / tot_examples)
386
+
387
+ @staticmethod
388
+ def add_model_specific_args(parent_parser): # pragma: no-cover
389
+ """
390
+ Define parameters that only apply to this model
391
+ """
392
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
393
+
394
+ # logging params
395
+ parser.add_argument(
396
+ "--num_log_images",
397
+ default=16,
398
+ type=int,
399
+ help="Number of images to log to Tensorboard",
400
+ )
401
+
402
+ return parser
models/lightning/no_shared_module.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ import fastmri
7
+ from fastmri import transforms
8
+ from models.no_shared import NOShared
9
+
10
+ from models.lightning.mri_module import MriModule
11
+ from type_utils import tuple_type
12
+
13
+
14
+ class NOSharedModule(MriModule):
15
+ """
16
+ NO-Shared training module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_cascades: int = 12,
22
+ pools: int = 4,
23
+ chans: int = 18,
24
+ sens_pools: int = 4,
25
+ sens_chans: int = 8,
26
+ gno_pools: int = 4,
27
+ gno_chans: int = 16,
28
+ gno_radius_cutoff: float = 0.02,
29
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
30
+ radius_cutoff: float = 0.02,
31
+ kernel_shape: Tuple[int, int] = (6, 7),
32
+ in_shape: Tuple[int, int] = (320, 320),
33
+ use_dc_term: bool = True,
34
+ lr: float = 0.0003,
35
+ lr_step_size: int = 40,
36
+ lr_gamma: float = 0.1,
37
+ weight_decay: float = 0.0,
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Parameters
42
+ ----------
43
+ num_cascades : int
44
+ Number of cascades (i.e., layers) for the variational network.
45
+ pools : int
46
+ Number of downsampling and upsampling layers for the cascade U-Net.
47
+ chans : int
48
+ Number of channels for the cascade U-Net.
49
+ sens_pools : int
50
+ Number of downsampling and upsampling layers for the sensitivity map U-Net.
51
+ sens_chans : int
52
+ Number of channels for the sensitivity map U-Net.
53
+ lr : float
54
+ Learning rate.
55
+ lr_step_size : int
56
+ Learning rate step size.
57
+ lr_gamma : float
58
+ Learning rate gamma decay.
59
+ weight_decay : float
60
+ Parameter for penalizing weights norm.
61
+ """
62
+ super().__init__(**kwargs)
63
+ self.save_hyperparameters()
64
+
65
+ self.num_cascades = num_cascades
66
+ self.pools = pools
67
+ self.chans = chans
68
+ self.sens_pools = sens_pools
69
+ self.sens_chans = sens_chans
70
+ self.gno_pools = gno_pools
71
+ self.gno_chans = gno_chans
72
+ self.gno_radius_cutoff = gno_radius_cutoff
73
+ self.gno_kernel_shape = gno_kernel_shape
74
+ self.radius_cutoff = radius_cutoff
75
+ self.kernel_shape = kernel_shape
76
+ self.in_shape = in_shape
77
+ self.use_dc_term = use_dc_term
78
+ self.lr = lr
79
+ self.lr_step_size = lr_step_size
80
+ self.lr_gamma = lr_gamma
81
+ self.weight_decay = weight_decay
82
+
83
+ self.model = NOShared(
84
+ num_cascades=self.num_cascades,
85
+ sens_chans=self.sens_chans,
86
+ sens_pools=self.sens_pools,
87
+ chans=self.chans,
88
+ pools=self.pools,
89
+ gno_chans=self.gno_chans,
90
+ gno_pools=self.gno_pools,
91
+ gno_radius_cutoff=self.gno_radius_cutoff,
92
+ gno_kernel_shape=self.gno_kernel_shape,
93
+ radius_cutoff=radius_cutoff,
94
+ kernel_shape=kernel_shape,
95
+ in_shape=in_shape,
96
+ use_dc_term=use_dc_term,
97
+ )
98
+
99
+ self.criterion = fastmri.SSIMLoss()
100
+ self.num_params = sum(p.numel() for p in self.parameters())
101
+
102
+ def forward(self, masked_kspace, mask, num_low_frequencies):
103
+ return self.model(masked_kspace, mask, num_low_frequencies)
104
+
105
+ def training_step(self, batch, batch_idx):
106
+ output = self.forward(
107
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
108
+ )
109
+
110
+ target, output = transforms.center_crop_to_smallest(
111
+ batch.target, output
112
+ )
113
+ loss = self.criterion(
114
+ output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
115
+ )
116
+
117
+ self.log("train_loss", loss, on_step=True, on_epoch=True)
118
+ self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
119
+
120
+ return loss
121
+
122
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
123
+ dataloaders = self.trainer.val_dataloaders
124
+ slug = list(dataloaders.keys())[dataloader_idx]
125
+
126
+ output = self.forward(
127
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
128
+ )
129
+
130
+ target, output = transforms.center_crop_to_smallest(
131
+ batch.target, output
132
+ )
133
+
134
+ loss = self.criterion(
135
+ output.unsqueeze(1),
136
+ target.unsqueeze(1),
137
+ data_range=batch.max_value,
138
+ )
139
+
140
+ return {
141
+ "slug": slug,
142
+ "fname": batch.fname,
143
+ "slice_num": batch.slice_num,
144
+ "max_value": batch.max_value,
145
+ "output": output,
146
+ "target": target,
147
+ "val_loss": loss,
148
+ }
149
+
150
+ def configure_optimizers(self):
151
+ optim = torch.optim.Adam(
152
+ self.parameters(), lr=self.lr, weight_decay=self.weight_decay
153
+ )
154
+ scheduler = torch.optim.lr_scheduler.StepLR(
155
+ optim, self.lr_step_size, self.lr_gamma
156
+ )
157
+
158
+ return [optim], [scheduler]
159
+
160
+ @staticmethod
161
+ def add_model_specific_args(parent_parser):
162
+ """
163
+ Define parameters that only apply to this model
164
+ """
165
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
166
+ parser = MriModule.add_model_specific_args(parser)
167
+
168
+ # network params
169
+ parser.add_argument(
170
+ "--num_cascades",
171
+ default=12,
172
+ type=int,
173
+ help="Number of VarNet cascades",
174
+ )
175
+ parser.add_argument(
176
+ "--pools",
177
+ default=4,
178
+ type=int,
179
+ help="Number of U-Net pooling layers in VarNet blocks",
180
+ )
181
+ parser.add_argument(
182
+ "--chans",
183
+ default=18,
184
+ type=int,
185
+ help="Number of channels for U-Net in VarNet blocks",
186
+ )
187
+ parser.add_argument(
188
+ "--sens_pools",
189
+ default=4,
190
+ type=int,
191
+ help=(
192
+ "Number of pooling layers for sense map estimation U-Net in"
193
+ " VarNet"
194
+ ),
195
+ )
196
+ parser.add_argument(
197
+ "--sens_chans",
198
+ default=8,
199
+ type=float,
200
+ help="Number of channels for sense map estimation U-Net in VarNet",
201
+ )
202
+ parser.add_argument(
203
+ "--gno_pools",
204
+ default=4,
205
+ type=int,
206
+ help=("Number of pooling layers for GNO"),
207
+ )
208
+ parser.add_argument(
209
+ "--gno_chans",
210
+ default=16,
211
+ type=int,
212
+ help="Number of channels for GNO",
213
+ )
214
+ parser.add_argument(
215
+ "--gno_radius_cutoff",
216
+ default=0.02,
217
+ type=float,
218
+ help="GNO module radius_cutoff",
219
+ )
220
+ parser.add_argument(
221
+ "--gno_kernel_shape",
222
+ default=(6, 7),
223
+ type=tuple_type,
224
+ help="GNO module kernel_shape. Ex: (6, 7)",
225
+ )
226
+ parser.add_argument(
227
+ "--radius_cutoff",
228
+ default=0.02,
229
+ type=float,
230
+ help="DISCO module radius_cutoff",
231
+ )
232
+ parser.add_argument(
233
+ "--kernel_shape",
234
+ default=(6, 7),
235
+ type=tuple_type,
236
+ help="DISCO module kernel_shape. Ex: (6, 7)",
237
+ )
238
+ parser.add_argument(
239
+ "--in_shape",
240
+ default=(320, 320),
241
+ type=tuple_type,
242
+ help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
243
+ )
244
+ parser.add_argument(
245
+ "--use_dc_term",
246
+ default=True,
247
+ type=bool,
248
+ help="Whether to use the DC term in the unrolled iterative update step",
249
+ )
250
+
251
+ # training params (opt)
252
+ parser.add_argument(
253
+ "--lr", default=0.0003, type=float, help="Adam learning rate"
254
+ )
255
+ parser.add_argument(
256
+ "--lr_step_size",
257
+ default=40,
258
+ type=int,
259
+ help="Epoch at which to decrease step size",
260
+ )
261
+ parser.add_argument(
262
+ "--lr_gamma",
263
+ default=0.1,
264
+ type=float,
265
+ help="Extent to which step size should be decreased",
266
+ )
267
+ parser.add_argument(
268
+ "--weight_decay",
269
+ default=0.0,
270
+ type=float,
271
+ help="Strength of weight decay regularization",
272
+ )
273
+
274
+ return parser
models/lightning/no_varnet_module.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ import fastmri
7
+ from fastmri import transforms
8
+ from models.no_varnet import NOVarnet
9
+
10
+ from models.lightning.mri_module import MriModule
11
+ # from type_utils import tuple_type
12
+
13
+ def tuple_type(strings):
14
+ strings = strings.replace("(", "").replace(")", "").replace(" ", "")
15
+ mapped_int = map(int, strings.split(","))
16
+ return tuple(mapped_int)
17
+
18
+
19
+ class NOVarnetModule(MriModule):
20
+ """
21
+ NO-Varnet training module.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_cascades: int = 12,
27
+ pools: int = 4,
28
+ chans: int = 18,
29
+ sens_pools: int = 4,
30
+ sens_chans: int = 8,
31
+ gno_pools: int = 4,
32
+ gno_chans: int = 16,
33
+ gno_radius_cutoff: float = 0.02,
34
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
35
+ radius_cutoff: float = 0.02,
36
+ kernel_shape: Tuple[int, int] = (6, 7),
37
+ in_shape: Tuple[int, int] = (320, 320),
38
+ use_dc_term: bool = True,
39
+ lr: float = 0.0003,
40
+ lr_step_size: int = 40,
41
+ lr_gamma: float = 0.1,
42
+ weight_decay: float = 0.0,
43
+ reduction_method: str = "rss",
44
+ skip_method: str = "add",
45
+ **kwargs,
46
+ ):
47
+ """
48
+ Parameters
49
+ ----------
50
+ num_cascades : int
51
+ Number of cascades (i.e., layers) for the variational network.
52
+ pools : int
53
+ Number of downsampling and upsampling layers for the cascade U-Net.
54
+ chans : int
55
+ Number of channels for the cascade U-Net.
56
+ sens_pools : int
57
+ Number of downsampling and upsampling layers for the sensitivity map U-Net.
58
+ sens_chans : int
59
+ Number of channels for the sensitivity map U-Net.
60
+ lr : float
61
+ Learning rate.
62
+ lr_step_size : int
63
+ Learning rate step size.
64
+ lr_gamma : float
65
+ Learning rate gamma decay.
66
+ weight_decay : float
67
+ Parameter for penalizing weights norm.
68
+ """
69
+ super().__init__(**kwargs)
70
+ self.save_hyperparameters()
71
+
72
+ self.num_cascades = num_cascades
73
+ self.pools = pools
74
+ self.chans = chans
75
+ self.sens_pools = sens_pools
76
+ self.sens_chans = sens_chans
77
+ self.gno_pools = gno_pools
78
+ self.gno_chans = gno_chans
79
+ self.gno_radius_cutoff = gno_radius_cutoff
80
+ self.gno_kernel_shape = gno_kernel_shape
81
+ self.radius_cutoff = radius_cutoff
82
+ self.kernel_shape = kernel_shape
83
+ self.in_shape = in_shape
84
+ self.use_dc_term = use_dc_term
85
+ self.lr = lr
86
+ self.lr_step_size = lr_step_size
87
+ self.lr_gamma = lr_gamma
88
+ self.weight_decay = weight_decay
89
+ self.reduction_method = reduction_method
90
+ self.skip_method = skip_method
91
+
92
+ self.model = NOVarnet(
93
+ num_cascades=self.num_cascades,
94
+ sens_chans=self.sens_chans,
95
+ sens_pools=self.sens_pools,
96
+ chans=self.chans,
97
+ pools=self.pools,
98
+ gno_chans=self.gno_chans,
99
+ gno_pools=self.gno_pools,
100
+ gno_radius_cutoff=self.gno_radius_cutoff,
101
+ gno_kernel_shape=self.gno_kernel_shape,
102
+ radius_cutoff=radius_cutoff,
103
+ kernel_shape=kernel_shape,
104
+ in_shape=in_shape,
105
+ use_dc_term=use_dc_term,
106
+ reduction_method=reduction_method,
107
+ skip_method=skip_method,
108
+ )
109
+
110
+ self.criterion = fastmri.SSIMLoss()
111
+ self.num_params = sum(p.numel() for p in self.parameters())
112
+
113
+ def forward(self, masked_kspace, mask, num_low_frequencies):
114
+ return self.model(masked_kspace, mask, num_low_frequencies)
115
+
116
+ def training_step(self, batch, batch_idx):
117
+ output = self.forward(
118
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
119
+ )
120
+
121
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
122
+ loss = self.criterion(
123
+ output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
124
+ )
125
+
126
+ self.log("train_loss", loss, on_step=True, on_epoch=True)
127
+ self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
128
+
129
+ return loss
130
+
131
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
132
+ dataloaders = self.trainer.val_dataloaders
133
+ slug = list(dataloaders.keys())[dataloader_idx]
134
+
135
+ output = self.forward(
136
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
137
+ )
138
+
139
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
140
+
141
+ loss = self.criterion(
142
+ output.unsqueeze(1),
143
+ target.unsqueeze(1),
144
+ data_range=batch.max_value,
145
+ )
146
+
147
+ return {
148
+ "slug": slug,
149
+ "fname": batch.fname,
150
+ "slice_num": batch.slice_num,
151
+ "max_value": batch.max_value,
152
+ "output": output,
153
+ "target": target,
154
+ "val_loss": loss,
155
+ }
156
+
157
+ def configure_optimizers(self):
158
+ optim = torch.optim.Adam(
159
+ self.parameters(), lr=self.lr, weight_decay=self.weight_decay
160
+ )
161
+ scheduler = torch.optim.lr_scheduler.StepLR(
162
+ optim, self.lr_step_size, self.lr_gamma
163
+ )
164
+
165
+ return [optim], [scheduler]
166
+
167
+ @staticmethod
168
+ def add_model_specific_args(parent_parser):
169
+ """
170
+ Define parameters that only apply to this model
171
+ """
172
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
173
+ parser = MriModule.add_model_specific_args(parser)
174
+
175
+ # network params
176
+ parser.add_argument(
177
+ "--num_cascades",
178
+ default=12,
179
+ type=int,
180
+ help="Number of VarNet cascades",
181
+ )
182
+ parser.add_argument(
183
+ "--pools",
184
+ default=4,
185
+ type=int,
186
+ help="Number of U-Net pooling layers in VarNet blocks",
187
+ )
188
+ parser.add_argument(
189
+ "--chans",
190
+ default=18,
191
+ type=int,
192
+ help="Number of channels for U-Net in VarNet blocks",
193
+ )
194
+ parser.add_argument(
195
+ "--sens_pools",
196
+ default=4,
197
+ type=int,
198
+ help=(
199
+ "Number of pooling layers for sense map estimation U-Net in" " VarNet"
200
+ ),
201
+ )
202
+ parser.add_argument(
203
+ "--sens_chans",
204
+ default=8,
205
+ type=float,
206
+ help="Number of channels for sense map estimation U-Net in VarNet",
207
+ )
208
+ parser.add_argument(
209
+ "--gno_pools",
210
+ default=4,
211
+ type=int,
212
+ help=("Number of pooling layers for GNO"),
213
+ )
214
+ parser.add_argument(
215
+ "--gno_chans",
216
+ default=16,
217
+ type=int,
218
+ help="Number of channels for GNO",
219
+ )
220
+ parser.add_argument(
221
+ "--gno_radius_cutoff",
222
+ default=0.02,
223
+ type=float,
224
+ required=True,
225
+ help="GNO module radius_cutoff",
226
+ )
227
+ parser.add_argument(
228
+ "--gno_kernel_shape",
229
+ default=(6, 7),
230
+ type=tuple_type,
231
+ required=True,
232
+ help="GNO module kernel_shape. Ex: (6, 7)",
233
+ )
234
+ parser.add_argument(
235
+ "--radius_cutoff",
236
+ default=0.01,
237
+ type=float,
238
+ required=True,
239
+ help="DISCO module radius_cutoff",
240
+ )
241
+ parser.add_argument(
242
+ "--kernel_shape",
243
+ default=(6, 7),
244
+ type=tuple_type,
245
+ required=True,
246
+ help="DISCO module kernel_shape. Ex: (6, 7)",
247
+ )
248
+ parser.add_argument(
249
+ "--in_shape",
250
+ default=(640, 320),
251
+ type=tuple_type,
252
+ required=True,
253
+ help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
254
+ )
255
+ parser.add_argument(
256
+ "--use_dc_term",
257
+ default=True,
258
+ type=bool,
259
+ help="Whether to use the DC term in the unrolled iterative update step",
260
+ )
261
+
262
+ # training params (opt)
263
+ parser.add_argument(
264
+ "--lr", default=0.0003, type=float, help="Adam learning rate"
265
+ )
266
+ parser.add_argument(
267
+ "--lr_step_size",
268
+ default=40,
269
+ type=int,
270
+ help="Epoch at which to decrease step size",
271
+ )
272
+ parser.add_argument(
273
+ "--lr_gamma",
274
+ default=0.1,
275
+ type=float,
276
+ help="Extent to which step size should be decreased",
277
+ )
278
+ parser.add_argument(
279
+ "--weight_decay",
280
+ default=0.0,
281
+ type=float,
282
+ help="Strength of weight decay regularization",
283
+ )
284
+ parser.add_argument(
285
+ "--reduction_method",
286
+ default="rss",
287
+ type=str,
288
+ choices=["rss", "batch"],
289
+ help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
290
+ )
291
+ parser.add_argument(
292
+ "--skip_method",
293
+ default="add_inv",
294
+ type=str,
295
+ choices=["add_inv", "add", "concat", "replace"],
296
+ help="Method for skip connection around inpainting module.",
297
+ )
298
+
299
+ return parser
models/lightning/no_varnet_nokno_module.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ import fastmri
7
+ from fastmri import transforms
8
+
9
+ from models.lightning.mri_module import MriModule
10
+ from models.no_varnet_nokno import NOVarnet_no_KNO
11
+ from type_utils import tuple_type
12
+
13
+
14
+ class NOVarnet_no_KNOModule(MriModule):
15
+ """
16
+ NO-Varnet training module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_cascades: int = 12,
22
+ pools: int = 4,
23
+ chans: int = 18,
24
+ sens_pools: int = 4,
25
+ sens_chans: int = 8,
26
+ gno_pools: int = 4,
27
+ gno_chans: int = 16,
28
+ gno_radius_cutoff: float = 0.02,
29
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
30
+ radius_cutoff: float = 0.02,
31
+ kernel_shape: Tuple[int, int] = (6, 7),
32
+ in_shape: Tuple[int, int] = (320, 320),
33
+ use_dc_term: bool = True,
34
+ lr: float = 0.0003,
35
+ lr_step_size: int = 40,
36
+ lr_gamma: float = 0.1,
37
+ weight_decay: float = 0.0,
38
+ reduction_method: str = "rss",
39
+ skip_method: str = "add",
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Parameters
44
+ ----------
45
+ num_cascades : int
46
+ Number of cascades (i.e., layers) for the variational network.
47
+ pools : int
48
+ Number of downsampling and upsampling layers for the cascade U-Net.
49
+ chans : int
50
+ Number of channels for the cascade U-Net.
51
+ sens_pools : int
52
+ Number of downsampling and upsampling layers for the sensitivity map U-Net.
53
+ sens_chans : int
54
+ Number of channels for the sensitivity map U-Net.
55
+ lr : float
56
+ Learning rate.
57
+ lr_step_size : int
58
+ Learning rate step size.
59
+ lr_gamma : float
60
+ Learning rate gamma decay.
61
+ weight_decay : float
62
+ Parameter for penalizing weights norm.
63
+ """
64
+ super().__init__(**kwargs)
65
+ self.save_hyperparameters()
66
+
67
+ self.num_cascades = num_cascades
68
+ self.pools = pools
69
+ self.chans = chans
70
+ self.sens_pools = sens_pools
71
+ self.sens_chans = sens_chans
72
+ self.gno_pools = gno_pools
73
+ self.gno_chans = gno_chans
74
+ self.gno_radius_cutoff = gno_radius_cutoff
75
+ self.gno_kernel_shape = gno_kernel_shape
76
+ self.radius_cutoff = radius_cutoff
77
+ self.kernel_shape = kernel_shape
78
+ self.in_shape = in_shape
79
+ self.use_dc_term = use_dc_term
80
+ self.lr = lr
81
+ self.lr_step_size = lr_step_size
82
+ self.lr_gamma = lr_gamma
83
+ self.weight_decay = weight_decay
84
+ self.reduction_method = reduction_method
85
+ self.skip_method = skip_method
86
+
87
+ self.model = NOVarnet_no_KNO(
88
+ num_cascades=self.num_cascades,
89
+ sens_chans=self.sens_chans,
90
+ sens_pools=self.sens_pools,
91
+ chans=self.chans,
92
+ pools=self.pools,
93
+ gno_chans=self.gno_chans,
94
+ gno_pools=self.gno_pools,
95
+ gno_radius_cutoff=self.gno_radius_cutoff,
96
+ gno_kernel_shape=self.gno_kernel_shape,
97
+ radius_cutoff=radius_cutoff,
98
+ kernel_shape=kernel_shape,
99
+ in_shape=in_shape,
100
+ use_dc_term=use_dc_term,
101
+ reduction_method=reduction_method,
102
+ skip_method=skip_method,
103
+ )
104
+
105
+ self.criterion = fastmri.SSIMLoss()
106
+ self.num_params = sum(p.numel() for p in self.parameters())
107
+
108
+ def forward(self, masked_kspace, mask, num_low_frequencies):
109
+ return self.model(masked_kspace, mask, num_low_frequencies)
110
+
111
+ def training_step(self, batch, batch_idx):
112
+ output = self.forward(
113
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
114
+ )
115
+
116
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
117
+ loss = self.criterion(
118
+ output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
119
+ )
120
+
121
+ self.log("train_loss", loss, on_step=True, on_epoch=True)
122
+ self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
123
+
124
+ return loss
125
+
126
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
127
+ dataloaders = self.trainer.val_dataloaders
128
+ slug = list(dataloaders.keys())[dataloader_idx]
129
+
130
+ output = self.forward(
131
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
132
+ )
133
+
134
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
135
+
136
+ loss = self.criterion(
137
+ output.unsqueeze(1),
138
+ target.unsqueeze(1),
139
+ data_range=batch.max_value,
140
+ )
141
+
142
+ return {
143
+ "slug": slug,
144
+ "fname": batch.fname,
145
+ "slice_num": batch.slice_num,
146
+ "max_value": batch.max_value,
147
+ "output": output,
148
+ "target": target,
149
+ "val_loss": loss,
150
+ }
151
+
152
+ def configure_optimizers(self):
153
+ optim = torch.optim.Adam(
154
+ self.parameters(), lr=self.lr, weight_decay=self.weight_decay
155
+ )
156
+ scheduler = torch.optim.lr_scheduler.StepLR(
157
+ optim, self.lr_step_size, self.lr_gamma
158
+ )
159
+
160
+ return [optim], [scheduler]
161
+
162
+ @staticmethod
163
+ def add_model_specific_args(parent_parser):
164
+ """
165
+ Define parameters that only apply to this model
166
+ """
167
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
168
+ parser = MriModule.add_model_specific_args(parser)
169
+
170
+ # network params
171
+ parser.add_argument(
172
+ "--num_cascades",
173
+ default=12,
174
+ type=int,
175
+ help="Number of VarNet cascades",
176
+ )
177
+ parser.add_argument(
178
+ "--pools",
179
+ default=4,
180
+ type=int,
181
+ help="Number of U-Net pooling layers in VarNet blocks",
182
+ )
183
+ parser.add_argument(
184
+ "--chans",
185
+ default=18,
186
+ type=int,
187
+ help="Number of channels for U-Net in VarNet blocks",
188
+ )
189
+ parser.add_argument(
190
+ "--sens_pools",
191
+ default=4,
192
+ type=int,
193
+ help=(
194
+ "Number of pooling layers for sense map estimation U-Net in" " VarNet"
195
+ ),
196
+ )
197
+ parser.add_argument(
198
+ "--sens_chans",
199
+ default=8,
200
+ type=float,
201
+ help="Number of channels for sense map estimation U-Net in VarNet",
202
+ )
203
+ parser.add_argument(
204
+ "--gno_pools",
205
+ default=4,
206
+ type=int,
207
+ help=("Number of pooling layers for GNO"),
208
+ )
209
+ parser.add_argument(
210
+ "--gno_chans",
211
+ default=16,
212
+ type=int,
213
+ help="Number of channels for GNO",
214
+ )
215
+ parser.add_argument(
216
+ "--gno_radius_cutoff",
217
+ default=0.02,
218
+ type=float,
219
+ required=True,
220
+ help="GNO module radius_cutoff",
221
+ )
222
+ parser.add_argument(
223
+ "--gno_kernel_shape",
224
+ default=(6, 7),
225
+ type=tuple_type,
226
+ required=True,
227
+ help="GNO module kernel_shape. Ex: (6, 7)",
228
+ )
229
+ parser.add_argument(
230
+ "--radius_cutoff",
231
+ default=0.01,
232
+ type=float,
233
+ required=True,
234
+ help="DISCO module radius_cutoff",
235
+ )
236
+ parser.add_argument(
237
+ "--kernel_shape",
238
+ default=(6, 7),
239
+ type=tuple_type,
240
+ required=True,
241
+ help="DISCO module kernel_shape. Ex: (6, 7)",
242
+ )
243
+ parser.add_argument(
244
+ "--in_shape",
245
+ default=(640, 320),
246
+ type=tuple_type,
247
+ required=True,
248
+ help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
249
+ )
250
+ parser.add_argument(
251
+ "--use_dc_term",
252
+ default=True,
253
+ type=bool,
254
+ help="Whether to use the DC term in the unrolled iterative update step",
255
+ )
256
+
257
+ # training params (opt)
258
+ parser.add_argument(
259
+ "--lr", default=0.0003, type=float, help="Adam learning rate"
260
+ )
261
+ parser.add_argument(
262
+ "--lr_step_size",
263
+ default=40,
264
+ type=int,
265
+ help="Epoch at which to decrease step size",
266
+ )
267
+ parser.add_argument(
268
+ "--lr_gamma",
269
+ default=0.1,
270
+ type=float,
271
+ help="Extent to which step size should be decreased",
272
+ )
273
+ parser.add_argument(
274
+ "--weight_decay",
275
+ default=0.0,
276
+ type=float,
277
+ help="Strength of weight decay regularization",
278
+ )
279
+ parser.add_argument(
280
+ "--reduction_method",
281
+ default="rss",
282
+ type=str,
283
+ choices=["rss", "batch"],
284
+ help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
285
+ )
286
+ parser.add_argument(
287
+ "--skip_method",
288
+ default="add_inv",
289
+ type=str,
290
+ choices=["add_inv", "add", "concat", "replace"],
291
+ help="Method for skip connection around inpainting module.",
292
+ )
293
+
294
+ return parser
models/lightning/varnet_module.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from argparse import ArgumentParser
9
+
10
+ import torch
11
+
12
+ import fastmri
13
+ from fastmri import transforms
14
+ from ..varnet import VarNet
15
+ import wandb
16
+
17
+ from .mri_module import MriModule
18
+
19
+
20
+ class VarNetModule(MriModule):
21
+ """
22
+ VarNet training module.
23
+
24
+ This can be used to train variational networks from the paper:
25
+
26
+ A. Sriram et al. End-to-end variational networks for accelerated MRI
27
+ reconstruction. In International Conference on Medical Image Computing and
28
+ Computer-Assisted Intervention, 2020.
29
+
30
+ which was inspired by the earlier paper:
31
+
32
+ K. Hammernik et al. Learning a variational network for reconstruction of
33
+ accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ num_cascades: int = 12,
39
+ pools: int = 4,
40
+ chans: int = 18,
41
+ sens_pools: int = 4,
42
+ sens_chans: int = 8,
43
+ lr: float = 0.0003,
44
+ lr_step_size: int = 40,
45
+ lr_gamma: float = 0.1,
46
+ weight_decay: float = 0.0,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Parameters
51
+ ----------
52
+ num_cascades : int
53
+ Number of cascades (i.e., layers) for the variational network.
54
+ pools : int
55
+ Number of downsampling and upsampling layers for the cascade U-Net.
56
+ chans : int
57
+ Number of channels for the cascade U-Net.
58
+ sens_pools : int
59
+ Number of downsampling and upsampling layers for the sensitivity map U-Net.
60
+ sens_chans : int
61
+ Number of channels for the sensitivity map U-Net.
62
+ lr : float
63
+ Learning rate.
64
+ lr_step_size : int
65
+ Learning rate step size.
66
+ lr_gamma : float
67
+ Learning rate gamma decay.
68
+ weight_decay : float
69
+ Parameter for penalizing weights norm.
70
+ num_sense_lines : int, optional
71
+ Number of low-frequency lines to use for sensitivity map computation.
72
+ Must be even or `None`. Default `None` will automatically compute the number
73
+ from masks. Default behavior may cause some slices to use more low-frequency
74
+ lines than others, when used in conjunction with e.g. the EquispacedMaskFunc
75
+ defaults. To prevent this, either set `num_sense_lines`, or set
76
+ `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc.
77
+ Note that setting this value may lead to undesired behavior when training on
78
+ multiple accelerations simultaneously.
79
+ """
80
+ super().__init__(**kwargs)
81
+ self.save_hyperparameters()
82
+
83
+ self.num_cascades = num_cascades
84
+ self.pools = pools
85
+ self.chans = chans
86
+ self.sens_pools = sens_pools
87
+ self.sens_chans = sens_chans
88
+ self.lr = lr
89
+ self.lr_step_size = lr_step_size
90
+ self.lr_gamma = lr_gamma
91
+ self.weight_decay = weight_decay
92
+
93
+ self.varnet = VarNet(
94
+ num_cascades=self.num_cascades,
95
+ sens_chans=self.sens_chans,
96
+ sens_pools=self.sens_pools,
97
+ chans=self.chans,
98
+ pools=self.pools,
99
+ )
100
+
101
+ self.criterion = fastmri.SSIMLoss()
102
+ self.num_params = sum(p.numel() for p in self.parameters())
103
+
104
+ def forward(self, masked_kspace, mask, num_low_frequencies):
105
+ return self.varnet(masked_kspace, mask, num_low_frequencies)
106
+
107
+ def training_step(self, batch, batch_idx):
108
+ output = self.forward(
109
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
110
+ )
111
+
112
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
113
+ loss = self.criterion(
114
+ output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
115
+ )
116
+
117
+ self.log("train_loss", loss, on_step=True, on_epoch=True)
118
+ self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
119
+
120
+ return loss
121
+
122
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
123
+ dataloaders = self.trainer.val_dataloaders
124
+ slug = list(dataloaders.keys())[dataloader_idx]
125
+
126
+ # breakpoint()
127
+ output = self.forward(
128
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
129
+ )
130
+
131
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
132
+
133
+ loss = self.criterion(
134
+ output.unsqueeze(1),
135
+ target.unsqueeze(1),
136
+ data_range=batch.max_value,
137
+ )
138
+
139
+ return {
140
+ "slug": slug,
141
+ "fname": batch.fname,
142
+ "slice_num": batch.slice_num,
143
+ "max_value": batch.max_value,
144
+ "output": output,
145
+ "target": target,
146
+ "val_loss": loss,
147
+ }
148
+
149
+ def configure_optimizers(self):
150
+ optim = torch.optim.Adam(
151
+ self.parameters(), lr=self.lr, weight_decay=self.weight_decay
152
+ )
153
+ scheduler = torch.optim.lr_scheduler.StepLR(
154
+ optim, self.lr_step_size, self.lr_gamma
155
+ )
156
+
157
+ return [optim], [scheduler]
158
+
159
+ @staticmethod
160
+ def add_model_specific_args(parent_parser): # pragma: no-cover
161
+ """
162
+ Define parameters that only apply to this model
163
+ """
164
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
165
+ parser = MriModule.add_model_specific_args(parser)
166
+
167
+ # network params
168
+ parser.add_argument(
169
+ "--num_cascades",
170
+ default=12,
171
+ type=int,
172
+ help="Number of VarNet cascades",
173
+ )
174
+ parser.add_argument(
175
+ "--pools",
176
+ default=4,
177
+ type=int,
178
+ help="Number of U-Net pooling layers in VarNet blocks",
179
+ )
180
+ parser.add_argument(
181
+ "--chans",
182
+ default=18,
183
+ type=int,
184
+ help="Number of channels for U-Net in VarNet blocks",
185
+ )
186
+ parser.add_argument(
187
+ "--sens_pools",
188
+ default=4,
189
+ type=int,
190
+ help=(
191
+ "Number of pooling layers for sense map estimation U-Net in" " VarNet"
192
+ ),
193
+ )
194
+ parser.add_argument(
195
+ "--sens_chans",
196
+ default=8,
197
+ type=float,
198
+ help="Number of channels for sense map estimation U-Net in VarNet",
199
+ )
200
+
201
+ # training params (opt)
202
+ parser.add_argument(
203
+ "--lr", default=0.0003, type=float, help="Adam learning rate"
204
+ )
205
+ parser.add_argument(
206
+ "--lr_step_size",
207
+ default=40,
208
+ type=int,
209
+ help="Epoch at which to decrease step size",
210
+ )
211
+ parser.add_argument(
212
+ "--lr_gamma",
213
+ default=0.1,
214
+ type=float,
215
+ help="Extent to which step size should be decreased",
216
+ )
217
+ parser.add_argument(
218
+ "--weight_decay",
219
+ default=0.0,
220
+ type=float,
221
+ help="Strength of weight decay regularization",
222
+ )
223
+
224
+ return parser
models/no_shared.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Literal, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import fastmri
9
+ from fastmri.transforms import (
10
+ batched_mask_center,
11
+ batch_chans_to_chan_dim,
12
+ chans_to_batch_dim,
13
+ sens_reduce,
14
+ sens_expand,
15
+ )
16
+ from models.udno import UDNO
17
+
18
+
19
+ class NormUDNO(nn.Module):
20
+ """
21
+ Normalized UDNO model.
22
+
23
+ Inputs are normalized before the UDNO for numerically stable training.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ chans: int,
29
+ num_pool_layers: int,
30
+ radius_cutoff: float,
31
+ in_shape: Tuple[int, int],
32
+ kernel_shape: Tuple[int, int],
33
+ in_chans: int = 2,
34
+ out_chans: int = 2,
35
+ drop_prob: float = 0.0,
36
+ ):
37
+ """
38
+ Initialize the VarNet model.
39
+
40
+ Parameters
41
+ ----------
42
+ chans : int
43
+ Number of output channels of the first convolution layer.
44
+ num_pools : int
45
+ Number of down-sampling and up-sampling layers.
46
+ in_chans : int, optional
47
+ Number of channels in the input to the U-Net model. Default is 2.
48
+ out_chans : int, optional
49
+ Number of channels in the output to the U-Net model. Default is 2.
50
+ drop_prob : float, optional
51
+ Dropout probability. Default is 0.0.
52
+ """
53
+ super().__init__()
54
+
55
+ self.udno = UDNO(
56
+ in_chans=in_chans,
57
+ out_chans=out_chans,
58
+ radius_cutoff=radius_cutoff,
59
+ chans=chans,
60
+ num_pool_layers=num_pool_layers,
61
+ drop_prob=drop_prob,
62
+ in_shape=in_shape,
63
+ kernel_shape=kernel_shape,
64
+ )
65
+
66
+ def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
67
+ b, c, h, w, two = x.shape
68
+ assert two == 2
69
+ return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
70
+
71
+ def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
72
+ b, c2, h, w = x.shape
73
+ assert c2 % 2 == 0
74
+ c = c2 // 2
75
+ return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
76
+
77
+ def norm(
78
+ self, x: torch.Tensor
79
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
80
+ # group norm
81
+ b, c, h, w = x.shape
82
+ x = x.view(b, 2, c // 2 * h * w)
83
+
84
+ mean = x.mean(dim=2).view(b, 2, 1, 1)
85
+ std = x.std(dim=2).view(b, 2, 1, 1)
86
+
87
+ x = x.view(b, c, h, w)
88
+
89
+ return (x - mean) / std, mean, std
90
+
91
+ def norm_new(
92
+ self, x: torch.Tensor
93
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94
+ # group norm
95
+ b, c, h, w = x.shape
96
+ num_groups = 2
97
+ assert (
98
+ c % num_groups == 0
99
+ ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
100
+
101
+ x = x.view(b, num_groups, c // num_groups * h * w)
102
+
103
+ mean = x.mean(dim=2).view(b, num_groups, 1, 1)
104
+ std = x.std(dim=2).view(b, num_groups, 1, 1)
105
+ print(x.shape, mean.shape, std.shape)
106
+
107
+ x = x.view(b, c, h, w)
108
+ mean = (
109
+ mean.view(b, num_groups, 1, 1)
110
+ .repeat(1, c // num_groups, h, w)
111
+ .view(b, c, h, w)
112
+ )
113
+ std = (
114
+ std.view(b, num_groups, 1, 1)
115
+ .repeat(1, c // num_groups, h, w)
116
+ .view(b, c, h, w)
117
+ )
118
+
119
+ return (x - mean) / std, mean, std
120
+
121
+ def unnorm(
122
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
123
+ ) -> torch.Tensor:
124
+ return x * std + mean
125
+
126
+ def pad(
127
+ self, x: torch.Tensor
128
+ ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
129
+ _, _, h, w = x.shape
130
+ w_mult = ((w - 1) | 15) + 1
131
+ h_mult = ((h - 1) | 15) + 1
132
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
133
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
134
+ # TODO: fix this type when PyTorch fixes theirs
135
+ # the documentation lies - this actually takes a list
136
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
137
+ # https://github.com/pytorch/pytorch/pull/16949
138
+ x = F.pad(x, w_pad + h_pad)
139
+
140
+ return x, (h_pad, w_pad, h_mult, w_mult)
141
+
142
+ def unpad(
143
+ self,
144
+ x: torch.Tensor,
145
+ h_pad: List[int],
146
+ w_pad: List[int],
147
+ h_mult: int,
148
+ w_mult: int,
149
+ ) -> torch.Tensor:
150
+ return x[
151
+ ..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
152
+ ]
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ if not x.shape[-1] == 2:
156
+ raise ValueError("Last dimension must be 2 for complex.")
157
+
158
+ chans = x.shape[1]
159
+ if chans == 2:
160
+ # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
161
+ x = self.complex_to_chan_dim(x)
162
+ x = self.udno(x)
163
+ return self.chan_complex_to_last_dim(x)
164
+
165
+ # get shapes for unet and normalize
166
+ x = self.complex_to_chan_dim(x)
167
+ x, mean, std = self.norm(x)
168
+ x, pad_sizes = self.pad(x)
169
+
170
+ x = self.udno(x)
171
+
172
+ # get shapes back and unnormalize
173
+ x = self.unpad(x, *pad_sizes)
174
+ x = self.unnorm(x, mean, std)
175
+ x = self.chan_complex_to_last_dim(x)
176
+
177
+ return x
178
+
179
+
180
+ class SensitivityModel(nn.Module):
181
+ """
182
+ Learn sensitivity maps
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ chans: int,
188
+ num_pools: int,
189
+ radius_cutoff: float,
190
+ in_shape: Tuple[int, int],
191
+ kernel_shape: Tuple[int, int],
192
+ in_chans: int = 2,
193
+ out_chans: int = 2,
194
+ drop_prob: float = 0.0,
195
+ mask_center: bool = True,
196
+ ):
197
+ """
198
+ Parameters
199
+ ----------
200
+ chans : int
201
+ Number of output channels of the first convolution layer.
202
+ num_pools : int
203
+ Number of down-sampling and up-sampling layers.
204
+ in_chans : int, optional
205
+ Number of channels in the input to the U-Net model. Default is 2.
206
+ out_chans : int, optional
207
+ Number of channels in the output to the U-Net model. Default is 2.
208
+ drop_prob : float, optional
209
+ Dropout probability. Default is 0.0.
210
+ mask_center : bool, optional
211
+ Whether to mask center of k-space for sensitivity map calculation.
212
+ Default is True.
213
+ """
214
+ super().__init__()
215
+ self.mask_center = mask_center
216
+ self.norm_udno = NormUDNO(
217
+ chans,
218
+ num_pools,
219
+ radius_cutoff,
220
+ in_shape,
221
+ kernel_shape,
222
+ in_chans=in_chans,
223
+ out_chans=out_chans,
224
+ drop_prob=drop_prob,
225
+ )
226
+
227
+ def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
228
+ return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
229
+
230
+ def get_pad_and_num_low_freqs(
231
+ self, mask: torch.Tensor, num_low_frequencies=None
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ if num_low_frequencies is None or any(
234
+ torch.any(t == 0) for t in num_low_frequencies
235
+ ):
236
+ # get low frequency line locations and mask them out
237
+ squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
238
+ cent = squeezed_mask.shape[1] // 2
239
+ # running argmin returns the first non-zero
240
+ left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
241
+ right = torch.argmin(squeezed_mask[:, cent:], dim=1)
242
+ num_low_frequencies_tensor = torch.max(
243
+ 2 * torch.min(left, right), torch.ones_like(left)
244
+ ) # force a symmetric center unless 1
245
+ else:
246
+ num_low_frequencies_tensor = num_low_frequencies * torch.ones(
247
+ mask.shape[0], dtype=mask.dtype, device=mask.device
248
+ )
249
+
250
+ pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
251
+
252
+ return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
253
+
254
+ def forward(
255
+ self,
256
+ masked_kspace: torch.Tensor,
257
+ mask: torch.Tensor,
258
+ num_low_frequencies: Optional[int] = None,
259
+ ) -> torch.Tensor:
260
+ if self.mask_center:
261
+ pad, num_low_freqs = self.get_pad_and_num_low_freqs(
262
+ mask, num_low_frequencies
263
+ )
264
+ masked_kspace = batched_mask_center(
265
+ masked_kspace, pad, pad + num_low_freqs
266
+ )
267
+
268
+ # convert to image space
269
+ images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
270
+
271
+ # estimate sensitivities
272
+ return self.divide_root_sum_of_squares(
273
+ batch_chans_to_chan_dim(self.norm_udno(images), batches)
274
+ )
275
+
276
+
277
+ class VarNetBlock(nn.Module):
278
+ """
279
+ Model block for iterative refinement of k-space data.
280
+
281
+ This model applies a combination of soft data consistency with the input
282
+ model as a regularizer. A series of these blocks can be stacked to form
283
+ the full variational network.
284
+
285
+ aka Refinement Module in Fig 1
286
+ """
287
+
288
+ def __init__(self, kno: nn.Module, ino: nn.Module):
289
+ """
290
+ Args:
291
+ model: Module for "regularization" component of variational
292
+ network.
293
+ """
294
+ super().__init__()
295
+ self.kno = kno
296
+ self.ino = ino
297
+ self.dc_weight = nn.Parameter(torch.ones(1))
298
+
299
+ def forward(
300
+ self,
301
+ current_kspace: torch.Tensor,
302
+ ref_kspace: torch.Tensor,
303
+ mask: torch.Tensor,
304
+ sens_maps: torch.Tensor,
305
+ use_dc_term: bool = True,
306
+ ) -> torch.Tensor:
307
+ """
308
+ Args:
309
+ current_kspace: The current k-space data (frequency domain data)
310
+ being processed by the network. (torch.Tensor)
311
+ ref_kspace: Original subsampled k-space data (from which we are
312
+ reconstrucintg the image (reference k-space). (torch.Tensor)
313
+ mask: A binary mask indicating the locations in k-space where
314
+ data consistency should be enforced. (torch.Tensor)
315
+ sens_maps: Sensitivity maps for the different coils in parallel
316
+ imaging. (torch.Tensor)
317
+ """
318
+
319
+ # model-term see orange box of Fig 1 in E2E-VarNet paper!
320
+ # multi channel k-space -> single channel image-space
321
+ b, c, h, w, _ = current_kspace.shape
322
+
323
+ # ======= kNO in measurement (k) space ========
324
+ current_kspace, b = chans_to_batch_dim(current_kspace) # reduce
325
+ current_kspace = self.kno(current_kspace) # inpaint
326
+ current_kspace = batch_chans_to_chan_dim(current_kspace, b) # expand
327
+
328
+ # ======= iNO in image (i) space ========
329
+ reduced_image = sens_reduce(current_kspace, sens_maps)
330
+ # single channel image-space
331
+ refined_image = self.ino(reduced_image)
332
+ # single channel image-space -> multi channel k-space
333
+ model_term = sens_expand(refined_image, sens_maps)
334
+
335
+ # only use first 15 channels (masked_kspace) in the update
336
+ # current_kspace = current_kspace[:, :15, :, :, :]
337
+
338
+ if not use_dc_term:
339
+ return current_kspace - model_term
340
+
341
+ """
342
+ Soft data consistency term:
343
+ - Calculates the difference between current k-space and reference k-space where the mask is true.
344
+ - Multiplies this difference by the data consistency weight.
345
+ """
346
+ # dc_term: see green box of Fig 1 in E2E-VarNet paper!
347
+ zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
348
+ soft_dc = (
349
+ torch.where(mask, current_kspace - ref_kspace, zero)
350
+ * self.dc_weight
351
+ )
352
+ return current_kspace - soft_dc - model_term
353
+
354
+
355
+ class NOShared(nn.Module):
356
+ """
357
+ Neural Operator model with shared cascade parameters for MRI reconstruction.
358
+
359
+ Uses a variational architecture (iterative updates) with a learned sensitivity
360
+ model. All operations are resolution invariant employing neural operator
361
+ modules.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ num_cascades: int = 12,
367
+ sens_chans: int = 8,
368
+ sens_pools: int = 4,
369
+ chans: int = 18,
370
+ pools: int = 4,
371
+ gno_chans: int = 16,
372
+ gno_pools: int = 4,
373
+ gno_radius_cutoff: float = 0.02,
374
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
375
+ radius_cutoff: float = 0.01,
376
+ kernel_shape: Tuple[int, int] = (3, 4),
377
+ in_shape: Tuple[int, int] = (320, 320),
378
+ mask_center: bool = True,
379
+ use_dc_term: bool = True,
380
+ ):
381
+ """
382
+ Parameters
383
+ ----------
384
+ num_cascades : int
385
+ Number of cascades (i.e., layers) for variational network.
386
+ sens_chans : int
387
+ Number of channels for sensitivity map U-Net.
388
+ sens_pools : int
389
+ Number of downsampling and upsampling layers for sensitivity map U-Net.
390
+ chans : int
391
+ Number of channels for cascade U-Net.
392
+ pools : int
393
+ Number of downsampling and upsampling layers for cascade U-Net.
394
+ mask_center : bool
395
+ Whether to mask center of k-space for sensitivity map calculation.
396
+ use_dc_term : bool
397
+ Whether to use the data consistency term.
398
+ """
399
+
400
+ super().__init__()
401
+
402
+ self.num_cascades = num_cascades
403
+
404
+ self.sens_net = SensitivityModel(
405
+ sens_chans,
406
+ sens_pools,
407
+ radius_cutoff,
408
+ in_shape,
409
+ kernel_shape,
410
+ mask_center=False,
411
+ )
412
+ self.kno = NormUDNO(
413
+ gno_chans,
414
+ gno_pools,
415
+ in_shape=in_shape,
416
+ radius_cutoff=gno_radius_cutoff,
417
+ kernel_shape=gno_kernel_shape,
418
+ in_chans=2,
419
+ out_chans=2,
420
+ )
421
+ self.ino = NormUDNO(
422
+ chans,
423
+ pools,
424
+ radius_cutoff,
425
+ in_shape,
426
+ kernel_shape,
427
+ in_chans=2,
428
+ out_chans=2,
429
+ )
430
+ self.cascade = VarNetBlock(self.kno, self.ino)
431
+ self.use_dc_term = use_dc_term
432
+
433
+ def forward(
434
+ self,
435
+ masked_kspace: torch.Tensor,
436
+ mask: torch.Tensor,
437
+ num_low_frequencies: Optional[int] = None,
438
+ ) -> torch.Tensor:
439
+
440
+ # (B, C, X, Y, 2)
441
+ kspace_pred = masked_kspace
442
+ # iterative update
443
+ for _ in range(self.num_cascades):
444
+ # sens model
445
+ sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies)
446
+
447
+ # kno + ino (cascade)
448
+ kspace_pred = self.cascade(
449
+ kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
450
+ )
451
+
452
+ spatial_pred = fastmri.ifft2c(kspace_pred)
453
+ spatial_pred_abs = fastmri.complex_abs(spatial_pred)
454
+ combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
455
+
456
+ return combined_spatial
457
+
458
+
459
+ if __name__ == "__main__":
460
+ model = NOShared(
461
+ num_cascades=4,
462
+ radius_cutoff=0.02,
463
+ kernel_shape=(6, 7),
464
+ )
465
+
466
+ x = torch.rand((2, 15, 320, 320, 2))
467
+ o = model(x, x.bool(), None)
models/no_varnet.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Literal, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import fastmri
9
+ from fastmri import transforms
10
+ from models.udno import UDNO
11
+
12
+
13
+ def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ Calculates F (x sens_maps)
16
+
17
+ Parameters
18
+ ----------
19
+ x : ndarray
20
+ Single-channel image of shape (..., H, W, 2)
21
+ sens_maps : ndarray
22
+ Sensitivity maps (image space)
23
+
24
+ Returns
25
+ -------
26
+ ndarray
27
+ Result of the operation F (x sens_maps)
28
+ """
29
+ return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
30
+
31
+
32
+ def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Calculates F^{-1}(k) * conj(sens_maps)
35
+ where conj(sens_maps) is the element-wise applied complex conjugate
36
+
37
+ Parameters
38
+ ----------
39
+ k : ndarray
40
+ Multi-channel k-space of shape (B, C, H, W, 2)
41
+ sens_maps : ndarray
42
+ Sensitivity maps (image space)
43
+
44
+ Returns
45
+ -------
46
+ ndarray
47
+ Result of the operation F^{-1}(k) * conj(sens_maps)
48
+ """
49
+ return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum(
50
+ dim=1, keepdim=True
51
+ )
52
+
53
+
54
+ def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
55
+ """Reshapes batched multi-channel samples into multiple single channel samples.
56
+
57
+ Parameters
58
+ ----------
59
+ x : torch.Tensor
60
+ x has shape (b, c, h, w, 2)
61
+
62
+ Returns
63
+ -------
64
+ Tuple[torch.Tensor, int]
65
+ tensor of shape (b * c, 1, h, w, 2), b
66
+ """
67
+ b, c, h, w, comp = x.shape
68
+ return x.view(b * c, 1, h, w, comp), b
69
+
70
+
71
+ def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
72
+ """Reshapes batched independent samples into original multi-channel samples.
73
+
74
+ Parameters
75
+ ----------
76
+ x : torch.Tensor
77
+ tensor of shape (b * c, 1, h, w, 2)
78
+ batch_size : int
79
+ batch size
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ original multi-channel tensor of shape (b, c, h, w, 2)
85
+ """
86
+ bc, _, h, w, comp = x.shape
87
+ c = bc // batch_size
88
+ return x.view(batch_size, c, h, w, comp)
89
+
90
+
91
+ class NormUDNO(nn.Module):
92
+ """
93
+ Normalized UDNO model.
94
+
95
+ Inputs are normalized before the UDNO for numerically stable training.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ chans: int,
101
+ num_pool_layers: int,
102
+ radius_cutoff: float,
103
+ in_shape: Tuple[int, int],
104
+ kernel_shape: Tuple[int, int],
105
+ in_chans: int = 2,
106
+ out_chans: int = 2,
107
+ drop_prob: float = 0.0,
108
+ ):
109
+ """
110
+ Initialize the VarNet model.
111
+
112
+ Parameters
113
+ ----------
114
+ chans : int
115
+ Number of output channels of the first convolution layer.
116
+ num_pools : int
117
+ Number of down-sampling and up-sampling layers.
118
+ in_chans : int, optional
119
+ Number of channels in the input to the U-Net model. Default is 2.
120
+ out_chans : int, optional
121
+ Number of channels in the output to the U-Net model. Default is 2.
122
+ drop_prob : float, optional
123
+ Dropout probability. Default is 0.0.
124
+ """
125
+ super().__init__()
126
+
127
+ self.udno = UDNO(
128
+ in_chans=in_chans,
129
+ out_chans=out_chans,
130
+ radius_cutoff=radius_cutoff,
131
+ chans=chans,
132
+ num_pool_layers=num_pool_layers,
133
+ drop_prob=drop_prob,
134
+ in_shape=in_shape,
135
+ kernel_shape=kernel_shape,
136
+ )
137
+
138
+ def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
139
+ b, c, h, w, two = x.shape
140
+ assert two == 2
141
+ return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
142
+
143
+ def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
144
+ b, c2, h, w = x.shape
145
+ assert c2 % 2 == 0
146
+ c = c2 // 2
147
+ return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
148
+
149
+ def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
150
+ # group norm
151
+ b, c, h, w = x.shape
152
+ x = x.view(b, 2, c // 2 * h * w)
153
+
154
+ mean = x.mean(dim=2).view(b, 2, 1, 1)
155
+ std = x.std(dim=2).view(b, 2, 1, 1)
156
+
157
+ x = x.view(b, c, h, w)
158
+
159
+ return (x - mean) / std, mean, std
160
+
161
+ def norm_new(
162
+ self, x: torch.Tensor
163
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ # FIXME: not working, wip
165
+ # group norm
166
+ b, c, h, w = x.shape
167
+ num_groups = 2
168
+ assert (
169
+ c % num_groups == 0
170
+ ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
171
+
172
+ x = x.view(b, num_groups, c // num_groups * h * w)
173
+
174
+ mean = x.mean(dim=2).view(b, num_groups, 1, 1)
175
+ std = x.std(dim=2).view(b, num_groups, 1, 1)
176
+ print(x.shape, mean.shape, std.shape)
177
+
178
+ x = x.view(b, c, h, w)
179
+ mean = (
180
+ mean.view(b, num_groups, 1, 1)
181
+ .repeat(1, c // num_groups, h, w)
182
+ .view(b, c, h, w)
183
+ )
184
+ std = (
185
+ std.view(b, num_groups, 1, 1)
186
+ .repeat(1, c // num_groups, h, w)
187
+ .view(b, c, h, w)
188
+ )
189
+
190
+ return (x - mean) / std, mean, std
191
+
192
+ def unnorm(
193
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
194
+ ) -> torch.Tensor:
195
+ return x * std + mean
196
+
197
+ def pad(
198
+ self, x: torch.Tensor
199
+ ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
200
+ _, _, h, w = x.shape
201
+ w_mult = ((w - 1) | 15) + 1
202
+ h_mult = ((h - 1) | 15) + 1
203
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
204
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
205
+ # TODO: fix this type when PyTorch fixes theirs
206
+ # the documentation lies - this actually takes a list
207
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
208
+ # https://github.com/pytorch/pytorch/pull/16949
209
+ x = F.pad(x, w_pad + h_pad)
210
+
211
+ return x, (h_pad, w_pad, h_mult, w_mult)
212
+
213
+ def unpad(
214
+ self,
215
+ x: torch.Tensor,
216
+ h_pad: List[int],
217
+ w_pad: List[int],
218
+ h_mult: int,
219
+ w_mult: int,
220
+ ) -> torch.Tensor:
221
+ return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ if not x.shape[-1] == 2:
225
+ raise ValueError("Last dimension must be 2 for complex.")
226
+
227
+ chans = x.shape[1]
228
+ if chans == 2:
229
+ # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
230
+ x = self.complex_to_chan_dim(x)
231
+ x = self.udno(x)
232
+ return self.chan_complex_to_last_dim(x)
233
+
234
+ # get shapes for unet and normalize
235
+ x = self.complex_to_chan_dim(x)
236
+ x, mean, std = self.norm(x)
237
+ x, pad_sizes = self.pad(x)
238
+
239
+ x = self.udno(x)
240
+
241
+ # get shapes back and unnormalize
242
+ x = self.unpad(x, *pad_sizes)
243
+ x = self.unnorm(x, mean, std)
244
+ x = self.chan_complex_to_last_dim(x)
245
+
246
+ return x
247
+
248
+
249
+ class SensitivityModel(nn.Module):
250
+ """
251
+ Learn sensitivity maps
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ chans: int,
257
+ num_pools: int,
258
+ radius_cutoff: float,
259
+ in_shape: Tuple[int, int],
260
+ kernel_shape: Tuple[int, int],
261
+ in_chans: int = 2,
262
+ out_chans: int = 2,
263
+ drop_prob: float = 0.0,
264
+ mask_center: bool = True,
265
+ ):
266
+ """
267
+ Parameters
268
+ ----------
269
+ chans : int
270
+ Number of output channels of the first convolution layer.
271
+ num_pools : int
272
+ Number of down-sampling and up-sampling layers.
273
+ in_chans : int, optional
274
+ Number of channels in the input to the U-Net model. Default is 2.
275
+ out_chans : int, optional
276
+ Number of channels in the output to the U-Net model. Default is 2.
277
+ drop_prob : float, optional
278
+ Dropout probability. Default is 0.0.
279
+ mask_center : bool, optional
280
+ Whether to mask center of k-space for sensitivity map calculation.
281
+ Default is True.
282
+ """
283
+ super().__init__()
284
+ self.mask_center = mask_center
285
+ self.norm_udno = NormUDNO(
286
+ chans,
287
+ num_pools,
288
+ radius_cutoff,
289
+ in_shape,
290
+ kernel_shape,
291
+ in_chans=in_chans,
292
+ out_chans=out_chans,
293
+ drop_prob=drop_prob,
294
+ )
295
+
296
+ def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
297
+ return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
298
+
299
+ def get_pad_and_num_low_freqs(
300
+ self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
301
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
302
+ if num_low_frequencies is None or any(
303
+ torch.any(t == 0) for t in num_low_frequencies
304
+ ):
305
+ # get low frequency line locations and mask them out
306
+ squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
307
+ cent = squeezed_mask.shape[1] // 2
308
+ # running argmin returns the first non-zero
309
+ left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
310
+ right = torch.argmin(squeezed_mask[:, cent:], dim=1)
311
+ num_low_frequencies_tensor = torch.max(
312
+ 2 * torch.min(left, right), torch.ones_like(left)
313
+ ) # force a symmetric center unless 1
314
+ else:
315
+ num_low_frequencies_tensor = num_low_frequencies * torch.ones(
316
+ mask.shape[0], dtype=mask.dtype, device=mask.device
317
+ )
318
+
319
+ pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
320
+
321
+ return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
322
+
323
+ def forward(
324
+ self,
325
+ masked_kspace: torch.Tensor,
326
+ mask: torch.Tensor,
327
+ num_low_frequencies: Optional[int] = None,
328
+ ) -> torch.Tensor:
329
+ if self.mask_center:
330
+ pad, num_low_freqs = self.get_pad_and_num_low_freqs(
331
+ mask, num_low_frequencies
332
+ )
333
+ masked_kspace = transforms.batched_mask_center(
334
+ masked_kspace, pad, pad + num_low_freqs
335
+ )
336
+
337
+ # convert to image space
338
+ images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
339
+
340
+ # estimate sensitivities
341
+ return self.divide_root_sum_of_squares(
342
+ batch_chans_to_chan_dim(self.norm_udno(images), batches)
343
+ )
344
+
345
+
346
+ class VarNetBlock(nn.Module):
347
+ """
348
+ Model block for iterative refinement of k-space data.
349
+
350
+ This model applies a combination of soft data consistency with the input
351
+ model as a regularizer. A series of these blocks can be stacked to form
352
+ the full variational network.
353
+
354
+ aka Refinement Module in Fig 1
355
+ """
356
+
357
+ def __init__(self, model: nn.Module):
358
+ """
359
+ Args:
360
+ model: Module for "regularization" component of variational
361
+ network.
362
+ """
363
+ super().__init__()
364
+
365
+ self.model = model
366
+ self.dc_weight = nn.Parameter(torch.ones(1))
367
+
368
+ def forward(
369
+ self,
370
+ current_kspace: torch.Tensor,
371
+ ref_kspace: torch.Tensor,
372
+ mask: torch.Tensor,
373
+ sens_maps: torch.Tensor,
374
+ use_dc_term: bool = True,
375
+ ) -> torch.Tensor:
376
+ """
377
+ Args:
378
+ current_kspace: The current k-space data (frequency domain data)
379
+ being processed by the network. (torch.Tensor)
380
+ ref_kspace: Original subsampled k-space data (from which we are
381
+ reconstrucintg the image (reference k-space). (torch.Tensor)
382
+ mask: A binary mask indicating the locations in k-space where
383
+ data consistency should be enforced. (torch.Tensor)
384
+ sens_maps: Sensitivity maps for the different coils in parallel
385
+ imaging. (torch.Tensor)
386
+ """
387
+
388
+ # model-term see orange box of Fig 1 in E2E-VarNet paper!
389
+ # multi channel k-space -> single channel image-space
390
+ b, c, h, w, _ = current_kspace.shape
391
+
392
+ if c == 30:
393
+ # get kspace and inpainted kspace
394
+ kspace = current_kspace[:, :15, :, :, :]
395
+ in_kspace = current_kspace[:, 15:, :, :, :]
396
+ # convert to image space
397
+ image = sens_reduce(kspace, sens_maps)
398
+ in_image = sens_reduce(in_kspace, sens_maps)
399
+ # concatenate both onto each other
400
+ reduced_image = torch.cat([image, in_image], dim=1)
401
+ else:
402
+ reduced_image = sens_reduce(current_kspace, sens_maps)
403
+
404
+ # single channel image-space
405
+ refined_image = self.model(reduced_image)
406
+
407
+ # single channel image-space -> multi channel k-space
408
+ model_term = sens_expand(refined_image, sens_maps)
409
+
410
+ # only use first 15 channels (masked_kspace) in the update
411
+ # current_kspace = current_kspace[:, :15, :, :, :]
412
+
413
+ if not use_dc_term:
414
+ return current_kspace - model_term
415
+
416
+ """
417
+ Soft data consistency term:
418
+ - Calculates the difference between current k-space and reference k-space where the mask is true.
419
+ - Multiplies this difference by the data consistency weight.
420
+ """
421
+ # dc_term: see green box of Fig 1 in E2E-VarNet paper!
422
+ zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
423
+ soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
424
+ return current_kspace - soft_dc - model_term
425
+
426
+
427
+ class NOVarnet(nn.Module):
428
+ """
429
+ Neural Operator model for MRI reconstruction.
430
+
431
+ Uses a variational architecture (iterative updates) with a learned sensitivity
432
+ model. All operations are resolution invariant employing neural operator
433
+ modules (GNO, UDNO).
434
+ """
435
+
436
+ def __init__(
437
+ self,
438
+ num_cascades: int = 12,
439
+ sens_chans: int = 8,
440
+ sens_pools: int = 4,
441
+ chans: int = 18,
442
+ pools: int = 4,
443
+ gno_chans: int = 16,
444
+ gno_pools: int = 4,
445
+ gno_radius_cutoff: float = 0.02,
446
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
447
+ radius_cutoff: float = 0.01,
448
+ kernel_shape: Tuple[int, int] = (3, 4),
449
+ in_shape: Tuple[int, int] = (640, 320),
450
+ mask_center: bool = True,
451
+ use_dc_term: bool = True,
452
+ reduction_method: Literal["batch", "rss"] = "rss",
453
+ skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
454
+ ):
455
+ """
456
+ Parameters
457
+ ----------
458
+ num_cascades : int
459
+ Number of cascades (i.e., layers) for variational network.
460
+ sens_chans : int
461
+ Number of channels for sensitivity map U-Net.
462
+ sens_pools : int
463
+ Number of downsampling and upsampling layers for sensitivity map U-Net.
464
+ chans : int
465
+ Number of channels for cascade U-Net.
466
+ pools : int
467
+ Number of downsampling and upsampling layers for cascade U-Net.
468
+ mask_center : bool
469
+ Whether to mask center of k-space for sensitivity map calculation.
470
+ use_dc_term : bool
471
+ Whether to use the data consistency term.
472
+ reduction_method : "batch" or "rss"
473
+ Method for reducing sensitivity maps to single channel.
474
+ "batch" reduces to single channel by stacking channels.
475
+ "rss" reduces to single channel by root sum of squares.
476
+ skip_method : "replace" or "add" or "add_inv" or "concat"
477
+ "replace" replaces the input with the output of the GNO
478
+ "add" adds the output of the GNO to the input
479
+ "add_inv" adds the output of the GNO to the input (only where samples are missing)
480
+ "concat" concatenates the output of the GNO to the input
481
+ """
482
+
483
+ super().__init__()
484
+
485
+ self.sens_net = SensitivityModel(
486
+ sens_chans,
487
+ sens_pools,
488
+ radius_cutoff,
489
+ in_shape,
490
+ kernel_shape,
491
+ mask_center=mask_center,
492
+ )
493
+ self.gno = NormUDNO(
494
+ gno_chans,
495
+ gno_pools,
496
+ in_shape=in_shape,
497
+ radius_cutoff=radius_cutoff,
498
+ kernel_shape=kernel_shape,
499
+ # radius_cutoff=gno_radius_cutoff,
500
+ # kernel_shape=gno_kernel_shape,
501
+ in_chans=2,
502
+ out_chans=2,
503
+ )
504
+ self.cascades = nn.ModuleList(
505
+ [
506
+ VarNetBlock(
507
+ NormUDNO(
508
+ chans,
509
+ pools,
510
+ radius_cutoff,
511
+ in_shape,
512
+ kernel_shape,
513
+ in_chans=(
514
+ 4 if skip_method == "concat" and cascade_idx == 0 else 2
515
+ ),
516
+ out_chans=2,
517
+ )
518
+ )
519
+ for cascade_idx in range(num_cascades)
520
+ ]
521
+ )
522
+ self.use_dc_term = use_dc_term
523
+ self.reduction_method = reduction_method
524
+ self.skip_method = skip_method
525
+
526
+ def forward(
527
+ self,
528
+ masked_kspace: torch.Tensor,
529
+ mask: torch.Tensor,
530
+ num_low_frequencies: Optional[int] = None,
531
+ ) -> torch.Tensor:
532
+
533
+ # (B, C, X, Y, 2)
534
+ sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
535
+
536
+ # reduce before inpainting
537
+ if self.reduction_method == "rss":
538
+ # (B, 1, H, W, 2) single channel image space
539
+ x_reduced = sens_reduce(masked_kspace, sens_maps)
540
+ # (B, 1, H, W, 2)
541
+ k_reduced = fastmri.fft2c(x_reduced)
542
+ elif self.reduction_method == "batch":
543
+ k_reduced, b = chans_to_batch_dim(masked_kspace)
544
+
545
+ # inpainting
546
+ if self.skip_method == "replace":
547
+ kspace_pred = self.gno(k_reduced)
548
+ elif self.skip_method == "add_inv":
549
+ # FIXME: this is not correct (mask has shape B, 1, H, W, 2 and self.gno(k_reduced) has shape B*C, 1, H, W, 2)
550
+ kspace_pred = k_reduced.clone() + (~mask * self.gno(k_reduced))
551
+ elif self.skip_method == "add":
552
+ kspace_pred = k_reduced.clone() + self.gno(k_reduced)
553
+ elif self.skip_method == "concat":
554
+ kspace_pred = torch.cat([k_reduced.clone(), self.gno(k_reduced)], dim=1)
555
+ else:
556
+ raise NotImplementedError("skip_method not implemented")
557
+
558
+ # expand after inpainting
559
+ if self.reduction_method == "rss":
560
+ if self.skip_method == "concat":
561
+ # kspace_pred is (B, 2, H, W, 2)
562
+ kspace = kspace_pred[:, :1, :, :, :]
563
+ in_kspace = kspace_pred[:, 1:, :, :, :]
564
+ # B, 2C, H, W, 2
565
+ kspace_pred = torch.cat(
566
+ [sens_expand(kspace, sens_maps), sens_expand(in_kspace, sens_maps)],
567
+ dim=1,
568
+ )
569
+ else:
570
+ # (B, 1, H, W, 2) -> (B, C, H, W, 2) multi-channel k space
571
+ kspace_pred = sens_expand(kspace_pred, sens_maps)
572
+ elif self.reduction_method == "batch":
573
+ # (B, C, H, W, 2) multi-channel k space
574
+ if self.skip_method == "concat":
575
+ kspace = kspace_pred[:, :1, :, :, :]
576
+ in_kspace = kspace_pred[:, 1:, :, :, :]
577
+ # B, 2C, H, W, 2
578
+ kspace_pred = torch.cat(
579
+ [
580
+ batch_chans_to_chan_dim(kspace, b),
581
+ batch_chans_to_chan_dim(in_kspace, b),
582
+ ],
583
+ dim=1,
584
+ )
585
+ else:
586
+ kspace_pred = batch_chans_to_chan_dim(kspace_pred, b)
587
+
588
+ # iterative update
589
+ for cascade in self.cascades:
590
+ kspace_pred = cascade(
591
+ kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
592
+ )
593
+
594
+ spatial_pred = fastmri.ifft2c(kspace_pred)
595
+ spatial_pred_abs = fastmri.complex_abs(spatial_pred)
596
+ combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
597
+
598
+ return combined_spatial
models/no_varnet_nokno.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NO Varnet WITHOUT KNO for ablation
3
+ """
4
+
5
+ import math
6
+ from typing import Iterable, List, Literal, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ import fastmri
13
+ from fastmri import transforms
14
+ from fastmri.datasets import SliceDatasetLMDB, SliceSample
15
+ from models.udno import UDNO
16
+
17
+
18
+ def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ Calculates F (x sens_maps)
21
+
22
+ Parameters
23
+ ----------
24
+ x : ndarray
25
+ Single-channel image of shape (..., H, W, 2)
26
+ sens_maps : ndarray
27
+ Sensitivity maps (image space)
28
+
29
+ Returns
30
+ -------
31
+ ndarray
32
+ Result of the operation F (x sens_maps)
33
+ """
34
+ return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
35
+
36
+
37
+ def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ Calculates F^{-1}(k) * conj(sens_maps)
40
+ where conj(sens_maps) is the element-wise applied complex conjugate
41
+
42
+ Parameters
43
+ ----------
44
+ k : ndarray
45
+ Multi-channel k-space of shape (B, C, H, W, 2)
46
+ sens_maps : ndarray
47
+ Sensitivity maps (image space)
48
+
49
+ Returns
50
+ -------
51
+ ndarray
52
+ Result of the operation F^{-1}(k) * conj(sens_maps)
53
+ """
54
+ return fastmri.complex_mul(
55
+ fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)
56
+ ).sum(dim=1, keepdim=True)
57
+
58
+
59
+ def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
60
+ """Reshapes batched multi-channel samples into multiple single channel samples.
61
+
62
+ Parameters
63
+ ----------
64
+ x : torch.Tensor
65
+ x has shape (b, c, h, w, 2)
66
+
67
+ Returns
68
+ -------
69
+ Tuple[torch.Tensor, int]
70
+ tensor of shape (b * c, 1, h, w, 2), b
71
+ """
72
+ b, c, h, w, comp = x.shape
73
+ return x.view(b * c, 1, h, w, comp), b
74
+
75
+
76
+ def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
77
+ """Reshapes batched independent samples into original multi-channel samples.
78
+
79
+ Parameters
80
+ ----------
81
+ x : torch.Tensor
82
+ tensor of shape (b * c, 1, h, w, 2)
83
+ batch_size : int
84
+ batch size
85
+
86
+ Returns
87
+ -------
88
+ torch.Tensor
89
+ original multi-channel tensor of shape (b, c, h, w, 2)
90
+ """
91
+ bc, _, h, w, comp = x.shape
92
+ c = bc // batch_size
93
+ return x.view(batch_size, c, h, w, comp)
94
+
95
+
96
+ class NormUDNO(nn.Module):
97
+ """
98
+ Normalized UDNO model.
99
+
100
+ Inputs are normalized before the UDNO for numerically stable training.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ chans: int,
106
+ num_pool_layers: int,
107
+ radius_cutoff: float,
108
+ in_shape: Tuple[int, int],
109
+ kernel_shape: Tuple[int, int],
110
+ in_chans: int = 2,
111
+ out_chans: int = 2,
112
+ drop_prob: float = 0.0,
113
+ ):
114
+ """
115
+ Initialize the VarNet model.
116
+
117
+ Parameters
118
+ ----------
119
+ chans : int
120
+ Number of output channels of the first convolution layer.
121
+ num_pools : int
122
+ Number of down-sampling and up-sampling layers.
123
+ in_chans : int, optional
124
+ Number of channels in the input to the U-Net model. Default is 2.
125
+ out_chans : int, optional
126
+ Number of channels in the output to the U-Net model. Default is 2.
127
+ drop_prob : float, optional
128
+ Dropout probability. Default is 0.0.
129
+ """
130
+ super().__init__()
131
+
132
+ self.udno = UDNO(
133
+ in_chans=in_chans,
134
+ out_chans=out_chans,
135
+ radius_cutoff=radius_cutoff,
136
+ chans=chans,
137
+ num_pool_layers=num_pool_layers,
138
+ drop_prob=drop_prob,
139
+ in_shape=in_shape,
140
+ kernel_shape=kernel_shape,
141
+ )
142
+
143
+ def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
144
+ b, c, h, w, two = x.shape
145
+ assert two == 2
146
+ return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
147
+
148
+ def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
149
+ b, c2, h, w = x.shape
150
+ assert c2 % 2 == 0
151
+ c = c2 // 2
152
+ return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
153
+
154
+ def norm(
155
+ self, x: torch.Tensor
156
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
157
+ # group norm
158
+ b, c, h, w = x.shape
159
+ x = x.view(b, 2, c // 2 * h * w)
160
+
161
+ mean = x.mean(dim=2).view(b, 2, 1, 1)
162
+ std = x.std(dim=2).view(b, 2, 1, 1)
163
+
164
+ x = x.view(b, c, h, w)
165
+
166
+ return (x - mean) / std, mean, std
167
+
168
+ def norm_new(
169
+ self, x: torch.Tensor
170
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ # FIXME: not working, wip
172
+ # group norm
173
+ b, c, h, w = x.shape
174
+ num_groups = 2
175
+ assert (
176
+ c % num_groups == 0
177
+ ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
178
+
179
+ x = x.view(b, num_groups, c // num_groups * h * w)
180
+
181
+ mean = x.mean(dim=2).view(b, num_groups, 1, 1)
182
+ std = x.std(dim=2).view(b, num_groups, 1, 1)
183
+ print(x.shape, mean.shape, std.shape)
184
+
185
+ x = x.view(b, c, h, w)
186
+ mean = (
187
+ mean.view(b, num_groups, 1, 1)
188
+ .repeat(1, c // num_groups, h, w)
189
+ .view(b, c, h, w)
190
+ )
191
+ std = (
192
+ std.view(b, num_groups, 1, 1)
193
+ .repeat(1, c // num_groups, h, w)
194
+ .view(b, c, h, w)
195
+ )
196
+
197
+ return (x - mean) / std, mean, std
198
+
199
+ def unnorm(
200
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
201
+ ) -> torch.Tensor:
202
+ return x * std + mean
203
+
204
+ def pad(
205
+ self, x: torch.Tensor
206
+ ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
207
+ _, _, h, w = x.shape
208
+ w_mult = ((w - 1) | 15) + 1
209
+ h_mult = ((h - 1) | 15) + 1
210
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
211
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
212
+ # TODO: fix this type when PyTorch fixes theirs
213
+ # the documentation lies - this actually takes a list
214
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
215
+ # https://github.com/pytorch/pytorch/pull/16949
216
+ x = F.pad(x, w_pad + h_pad)
217
+
218
+ return x, (h_pad, w_pad, h_mult, w_mult)
219
+
220
+ def unpad(
221
+ self,
222
+ x: torch.Tensor,
223
+ h_pad: List[int],
224
+ w_pad: List[int],
225
+ h_mult: int,
226
+ w_mult: int,
227
+ ) -> torch.Tensor:
228
+ return x[
229
+ ..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
230
+ ]
231
+
232
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
233
+ if not x.shape[-1] == 2:
234
+ raise ValueError("Last dimension must be 2 for complex.")
235
+
236
+ chans = x.shape[1]
237
+ if chans == 2:
238
+ # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
239
+ x = self.complex_to_chan_dim(x)
240
+ x = self.udno(x)
241
+ return self.chan_complex_to_last_dim(x)
242
+
243
+ # get shapes for unet and normalize
244
+ x = self.complex_to_chan_dim(x)
245
+ x, mean, std = self.norm(x)
246
+ x, pad_sizes = self.pad(x)
247
+
248
+ x = self.udno(x)
249
+
250
+ # get shapes back and unnormalize
251
+ x = self.unpad(x, *pad_sizes)
252
+ x = self.unnorm(x, mean, std)
253
+ x = self.chan_complex_to_last_dim(x)
254
+
255
+ return x
256
+
257
+
258
+ class SensitivityModel(nn.Module):
259
+ """
260
+ Learn sensitivity maps
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ chans: int,
266
+ num_pools: int,
267
+ radius_cutoff: float,
268
+ in_shape: Tuple[int, int],
269
+ kernel_shape: Tuple[int, int],
270
+ in_chans: int = 2,
271
+ out_chans: int = 2,
272
+ drop_prob: float = 0.0,
273
+ mask_center: bool = True,
274
+ ):
275
+ """
276
+ Parameters
277
+ ----------
278
+ chans : int
279
+ Number of output channels of the first convolution layer.
280
+ num_pools : int
281
+ Number of down-sampling and up-sampling layers.
282
+ in_chans : int, optional
283
+ Number of channels in the input to the U-Net model. Default is 2.
284
+ out_chans : int, optional
285
+ Number of channels in the output to the U-Net model. Default is 2.
286
+ drop_prob : float, optional
287
+ Dropout probability. Default is 0.0.
288
+ mask_center : bool, optional
289
+ Whether to mask center of k-space for sensitivity map calculation.
290
+ Default is True.
291
+ """
292
+ super().__init__()
293
+ self.mask_center = mask_center
294
+ self.norm_udno = NormUDNO(
295
+ chans,
296
+ num_pools,
297
+ radius_cutoff,
298
+ in_shape,
299
+ kernel_shape,
300
+ in_chans=in_chans,
301
+ out_chans=out_chans,
302
+ drop_prob=drop_prob,
303
+ )
304
+
305
+ def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
306
+ return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
307
+
308
+ def get_pad_and_num_low_freqs(
309
+ self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
310
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
311
+ if num_low_frequencies is None or (isinstance(num_low_frequencies, Iterable) and any(
312
+ torch.any(t == 0) for t in num_low_frequencies
313
+ )):
314
+ # get low frequency line locations and mask them out
315
+ squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
316
+ cent = squeezed_mask.shape[1] // 2
317
+ # running argmin returns the first non-zero
318
+ left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
319
+ right = torch.argmin(squeezed_mask[:, cent:], dim=1)
320
+ num_low_frequencies_tensor = torch.max(
321
+ 2 * torch.min(left, right), torch.ones_like(left)
322
+ ) # force a symmetric center unless 1
323
+ else:
324
+ num_low_frequencies_tensor = num_low_frequencies * torch.ones(
325
+ mask.shape[0], dtype=mask.dtype, device=mask.device
326
+ )
327
+
328
+ pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
329
+
330
+ return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
331
+
332
+ def forward(
333
+ self,
334
+ masked_kspace: torch.Tensor,
335
+ mask: torch.Tensor,
336
+ num_low_frequencies: Optional[int] = None,
337
+ ) -> torch.Tensor:
338
+ if self.mask_center:
339
+ pad, num_low_freqs = self.get_pad_and_num_low_freqs(
340
+ mask, num_low_frequencies
341
+ )
342
+ masked_kspace = transforms.batched_mask_center(
343
+ masked_kspace, pad, pad + num_low_freqs
344
+ )
345
+
346
+ # convert to image space
347
+ images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
348
+
349
+ # estimate sensitivities
350
+ return self.divide_root_sum_of_squares(
351
+ batch_chans_to_chan_dim(self.norm_udno(images), batches)
352
+ )
353
+
354
+
355
+ class VarNetBlock(nn.Module):
356
+ """
357
+ Model block for iterative refinement of k-space data.
358
+
359
+ This model applies a combination of soft data consistency with the input
360
+ model as a regularizer. A series of these blocks can be stacked to form
361
+ the full variational network.
362
+
363
+ aka Refinement Module in Fig 1
364
+ """
365
+
366
+ def __init__(self, model: nn.Module):
367
+ """
368
+ Args:
369
+ model: Module for "regularization" component of variational
370
+ network.
371
+ """
372
+ super().__init__()
373
+
374
+ self.model = model
375
+ self.dc_weight = nn.Parameter(torch.ones(1))
376
+
377
+ def forward(
378
+ self,
379
+ current_kspace: torch.Tensor,
380
+ ref_kspace: torch.Tensor,
381
+ mask: torch.Tensor,
382
+ sens_maps: torch.Tensor,
383
+ use_dc_term: bool = True,
384
+ ) -> torch.Tensor:
385
+ """
386
+ Args:
387
+ current_kspace: The current k-space data (frequency domain data)
388
+ being processed by the network. (torch.Tensor)
389
+ ref_kspace: Original subsampled k-space data (from which we are
390
+ reconstrucintg the image (reference k-space). (torch.Tensor)
391
+ mask: A binary mask indicating the locations in k-space where
392
+ data consistency should be enforced. (torch.Tensor)
393
+ sens_maps: Sensitivity maps for the different coils in parallel
394
+ imaging. (torch.Tensor)
395
+ """
396
+
397
+ # model-term see orange box of Fig 1 in E2E-VarNet paper!
398
+ # multi channel k-space -> single channel image-space
399
+ b, c, h, w, _ = current_kspace.shape
400
+
401
+ if c == 30:
402
+ # get kspace and inpainted kspace
403
+ kspace = current_kspace[:, :15, :, :, :]
404
+ in_kspace = current_kspace[:, 15:, :, :, :]
405
+ # convert to image space
406
+ image = sens_reduce(kspace, sens_maps)
407
+ in_image = sens_reduce(in_kspace, sens_maps)
408
+ # concatenate both onto each other
409
+ reduced_image = torch.cat([image, in_image], dim=1)
410
+ else:
411
+ reduced_image = sens_reduce(current_kspace, sens_maps)
412
+
413
+ # single channel image-space
414
+ refined_image = self.model(reduced_image)
415
+
416
+ # single channel image-space -> multi channel k-space
417
+ model_term = sens_expand(refined_image, sens_maps)
418
+
419
+ # only use first 15 channels (masked_kspace) in the update
420
+ # current_kspace = current_kspace[:, :15, :, :, :]
421
+
422
+ if not use_dc_term:
423
+ return current_kspace - model_term
424
+
425
+ """
426
+ Soft data consistency term:
427
+ - Calculates the difference between current k-space and reference k-space where the mask is true.
428
+ - Multiplies this difference by the data consistency weight.
429
+ """
430
+ # dc_term: see green box of Fig 1 in E2E-VarNet paper!
431
+ zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
432
+ soft_dc = (
433
+ torch.where(mask, current_kspace - ref_kspace, zero)
434
+ * self.dc_weight
435
+ )
436
+ return current_kspace - soft_dc - model_term
437
+
438
+
439
+ class NOVarnet_no_KNO(nn.Module):
440
+ """
441
+ Neural Operator model for MRI reconstruction.
442
+
443
+ Uses a variational architecture (iterative updates) with a learned sensitivity
444
+ model. All operations are resolution invariant employing neural operator
445
+ modules (GNO, UDNO).
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ num_cascades: int = 12,
451
+ sens_chans: int = 8,
452
+ sens_pools: int = 4,
453
+ chans: int = 18,
454
+ pools: int = 4,
455
+ gno_chans: int = 16,
456
+ gno_pools: int = 4,
457
+ gno_radius_cutoff: float = 0.02,
458
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
459
+ radius_cutoff: float = 0.01,
460
+ kernel_shape: Tuple[int, int] = (3, 4),
461
+ in_shape: Tuple[int, int] = (640, 320),
462
+ mask_center: bool = True,
463
+ use_dc_term: bool = True,
464
+ reduction_method: Literal["batch", "rss"] = "rss",
465
+ skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
466
+ ):
467
+ """
468
+ Parameters
469
+ ----------
470
+ num_cascades : int
471
+ Number of cascades (i.e., layers) for variational network.
472
+ sens_chans : int
473
+ Number of channels for sensitivity map U-Net.
474
+ sens_pools : int
475
+ Number of downsampling and upsampling layers for sensitivity map U-Net.
476
+ chans : int
477
+ Number of channels for cascade U-Net.
478
+ pools : int
479
+ Number of downsampling and upsampling layers for cascade U-Net.
480
+ mask_center : bool
481
+ Whether to mask center of k-space for sensitivity map calculation.
482
+ use_dc_term : bool
483
+ Whether to use the data consistency term.
484
+ reduction_method : "batch" or "rss"
485
+ Method for reducing sensitivity maps to single channel.
486
+ "batch" reduces to single channel by stacking channels.
487
+ "rss" reduces to single channel by root sum of squares.
488
+ skip_method : "replace" or "add" or "add_inv" or "concat"
489
+ "replace" replaces the input with the output of the GNO
490
+ "add" adds the output of the GNO to the input
491
+ "add_inv" adds the output of the GNO to the input (only where samples are missing)
492
+ "concat" concatenates the output of the GNO to the input
493
+ """
494
+
495
+ super().__init__()
496
+
497
+ self.sens_net = SensitivityModel(
498
+ sens_chans,
499
+ sens_pools,
500
+ radius_cutoff,
501
+ in_shape,
502
+ kernel_shape,
503
+ mask_center=mask_center,
504
+ )
505
+ # self.gno = NormUDNO(
506
+ # gno_chans,
507
+ # gno_pools,
508
+ # in_shape=in_shape,
509
+ # radius_cutoff=radius_cutoff,
510
+ # kernel_shape=kernel_shape,
511
+ # # radius_cutoff=gno_radius_cutoff,
512
+ # # kernel_shape=gno_kernel_shape,
513
+ # in_chans=2,
514
+ # out_chans=2,
515
+ # )
516
+ self.cascades = nn.ModuleList(
517
+ [
518
+ VarNetBlock(
519
+ NormUDNO(
520
+ chans,
521
+ pools,
522
+ radius_cutoff,
523
+ in_shape,
524
+ kernel_shape,
525
+ in_chans=(
526
+ 4
527
+ if skip_method == "concat" and cascade_idx == 0
528
+ else 2
529
+ ),
530
+ out_chans=2,
531
+ )
532
+ )
533
+ for cascade_idx in range(num_cascades)
534
+ ]
535
+ )
536
+ self.use_dc_term = use_dc_term
537
+ self.reduction_method = reduction_method # not used anywhere anymore
538
+ self.skip_method = skip_method # not used anywhere anymore
539
+
540
+ def forward(
541
+ self,
542
+ masked_kspace: torch.Tensor,
543
+ mask: torch.Tensor,
544
+ num_low_frequencies: Optional[int] = None,
545
+ ) -> torch.Tensor:
546
+
547
+ # (B, C, X, Y, 2)
548
+ sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
549
+
550
+ kspace_pred = masked_kspace.clone()
551
+
552
+ # iterative update
553
+ for cascade in self.cascades:
554
+ kspace_pred = cascade(
555
+ kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
556
+ )
557
+
558
+ spatial_pred = fastmri.ifft2c(kspace_pred)
559
+ spatial_pred_abs = fastmri.complex_abs(spatial_pred)
560
+ combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
561
+
562
+ return combined_spatial
563
+
564
+
565
+ if __name__ == "__main__":
566
+ ds = SliceDatasetLMDB(
567
+ "knee",
568
+ partition="train",
569
+ mask_fns=None, # type: ignore
570
+ complex=False,
571
+ sample_rate=0.5,
572
+ crop_shape=(320, 320),
573
+ coils=15,
574
+ )
575
+
576
+ sample: SliceSample = ds[0]
577
+ kspace = sample.masked_kspace
578
+ target = sample.target
579
+
580
+ model = NOVarnet_no_KNO(1)
581
+ res = model.forward(sample.masked_kspace.unsqueeze(0), sample.mask.unsqueeze(0), torch.tensor(sample.num_low_frequencies).unsqueeze(0))
models/temp/no_repeatk.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pdb
3
+ from typing import List, Optional, Tuple, Literal
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import fastmri
10
+ from fastmri import transforms
11
+
12
+ from models.udno import UDNO
13
+
14
+
15
+ def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Calculates F (x sens_maps)
18
+
19
+ Parameters
20
+ ----------
21
+ x : ndarray
22
+ Single-channel image of shape (..., H, W, 2)
23
+ sens_maps : ndarray
24
+ Sensitivity maps (image space)
25
+
26
+ Returns
27
+ -------
28
+ ndarray
29
+ Result of the operation F (x sens_maps)
30
+ """
31
+ return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
32
+
33
+
34
+ def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Calculates F^{-1}(k) * conj(sens_maps)
37
+ where conj(sens_maps) is the element-wise applied complex conjugate
38
+
39
+ Parameters
40
+ ----------
41
+ k : ndarray
42
+ Multi-channel k-space of shape (B, C, H, W, 2)
43
+ sens_maps : ndarray
44
+ Sensitivity maps (image space)
45
+
46
+ Returns
47
+ -------
48
+ ndarray
49
+ Result of the operation F^{-1}(k) * conj(sens_maps)
50
+ """
51
+ return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum(
52
+ dim=1, keepdim=True
53
+ )
54
+
55
+
56
+ def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
57
+ """Reshapes batched multi-channel samples into multiple single channel samples.
58
+
59
+ Parameters
60
+ ----------
61
+ x : torch.Tensor
62
+ x has shape (b, c, h, w, 2)
63
+
64
+ Returns
65
+ -------
66
+ Tuple[torch.Tensor, int]
67
+ tensor of shape (b * c, 1, h, w, 2), b
68
+ """
69
+ b, c, h, w, comp = x.shape
70
+ return x.view(b * c, 1, h, w, comp), b
71
+
72
+
73
+ def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
74
+ """Reshapes batched independent samples into original multi-channel samples.
75
+
76
+ Parameters
77
+ ----------
78
+ x : torch.Tensor
79
+ tensor of shape (b * c, 1, h, w, 2)
80
+ batch_size : int
81
+ batch size
82
+
83
+ Returns
84
+ -------
85
+ torch.Tensor
86
+ original multi-channel tensor of shape (b, c, h, w, 2)
87
+ """
88
+ bc, _, h, w, comp = x.shape
89
+ c = bc // batch_size
90
+ return x.view(batch_size, c, h, w, comp)
91
+
92
+
93
+ class NormUDNO(nn.Module):
94
+ """
95
+ Normalized UDNO model.
96
+
97
+ Inputs are normalized before the UDNO for numerically stable training.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ chans: int,
103
+ num_pool_layers: int,
104
+ radius_cutoff: float,
105
+ in_shape: Tuple[int, int],
106
+ kernel_shape: Tuple[int, int],
107
+ in_chans: int = 2,
108
+ out_chans: int = 2,
109
+ drop_prob: float = 0.0,
110
+ ):
111
+ """
112
+ Initialize the VarNet model.
113
+
114
+ Parameters
115
+ ----------
116
+ chans : int
117
+ Number of output channels of the first convolution layer.
118
+ num_pools : int
119
+ Number of down-sampling and up-sampling layers.
120
+ in_chans : int, optional
121
+ Number of channels in the input to the U-Net model. Default is 2.
122
+ out_chans : int, optional
123
+ Number of channels in the output to the U-Net model. Default is 2.
124
+ drop_prob : float, optional
125
+ Dropout probability. Default is 0.0.
126
+ """
127
+ super().__init__()
128
+
129
+ self.udno = UDNO(
130
+ in_chans=in_chans,
131
+ out_chans=out_chans,
132
+ radius_cutoff=radius_cutoff,
133
+ chans=chans,
134
+ num_pool_layers=num_pool_layers,
135
+ drop_prob=drop_prob,
136
+ in_shape=in_shape,
137
+ kernel_shape=kernel_shape,
138
+ )
139
+
140
+ def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
141
+ b, c, h, w, two = x.shape
142
+ assert two == 2
143
+ return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
144
+
145
+ def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
146
+ b, c2, h, w = x.shape
147
+ assert c2 % 2 == 0
148
+ c = c2 // 2
149
+ return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
150
+
151
+ def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152
+ # group norm
153
+ b, c, h, w = x.shape
154
+ x = x.view(b, 2, c // 2 * h * w)
155
+
156
+ mean = x.mean(dim=2).view(b, 2, 1, 1)
157
+ std = x.std(dim=2).view(b, 2, 1, 1)
158
+
159
+ x = x.view(b, c, h, w)
160
+
161
+ return (x - mean) / std, mean, std
162
+
163
+ def norm_new(
164
+ self, x: torch.Tensor
165
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
166
+ # FIXME: not working, wip
167
+ # group norm
168
+ b, c, h, w = x.shape
169
+ num_groups = 2
170
+ assert (
171
+ c % num_groups == 0
172
+ ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
173
+
174
+ x = x.view(b, num_groups, c // num_groups * h * w)
175
+
176
+ mean = x.mean(dim=2).view(b, num_groups, 1, 1)
177
+ std = x.std(dim=2).view(b, num_groups, 1, 1)
178
+ print(x.shape, mean.shape, std.shape)
179
+
180
+ x = x.view(b, c, h, w)
181
+ mean = (
182
+ mean.view(b, num_groups, 1, 1)
183
+ .repeat(1, c // num_groups, h, w)
184
+ .view(b, c, h, w)
185
+ )
186
+ std = (
187
+ std.view(b, num_groups, 1, 1)
188
+ .repeat(1, c // num_groups, h, w)
189
+ .view(b, c, h, w)
190
+ )
191
+
192
+ return (x - mean) / std, mean, std
193
+
194
+ def unnorm(
195
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
196
+ ) -> torch.Tensor:
197
+ return x * std + mean
198
+
199
+ def pad(
200
+ self, x: torch.Tensor
201
+ ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
202
+ _, _, h, w = x.shape
203
+ w_mult = ((w - 1) | 15) + 1
204
+ h_mult = ((h - 1) | 15) + 1
205
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
206
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
207
+ # TODO: fix this type when PyTorch fixes theirs
208
+ # the documentation lies - this actually takes a list
209
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
210
+ # https://github.com/pytorch/pytorch/pull/16949
211
+ x = F.pad(x, w_pad + h_pad)
212
+
213
+ return x, (h_pad, w_pad, h_mult, w_mult)
214
+
215
+ def unpad(
216
+ self,
217
+ x: torch.Tensor,
218
+ h_pad: List[int],
219
+ w_pad: List[int],
220
+ h_mult: int,
221
+ w_mult: int,
222
+ ) -> torch.Tensor:
223
+ return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ if not x.shape[-1] == 2:
227
+ raise ValueError("Last dimension must be 2 for complex.")
228
+
229
+ chans = x.shape[1]
230
+ if chans == 2:
231
+ # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
232
+ x = self.complex_to_chan_dim(x)
233
+ x = self.udno(x)
234
+ return self.chan_complex_to_last_dim(x)
235
+
236
+ # get shapes for unet and normalize
237
+ x = self.complex_to_chan_dim(x)
238
+ x, mean, std = self.norm(x)
239
+ x, pad_sizes = self.pad(x)
240
+
241
+ x = self.udno(x)
242
+
243
+ # get shapes back and unnormalize
244
+ x = self.unpad(x, *pad_sizes)
245
+ x = self.unnorm(x, mean, std)
246
+ x = self.chan_complex_to_last_dim(x)
247
+
248
+ return x
249
+
250
+
251
+ class SensitivityModel(nn.Module):
252
+ """
253
+ Learn sensitivity maps
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ chans: int,
259
+ num_pools: int,
260
+ radius_cutoff: float,
261
+ in_shape: Tuple[int, int],
262
+ kernel_shape: Tuple[int, int],
263
+ in_chans: int = 2,
264
+ out_chans: int = 2,
265
+ drop_prob: float = 0.0,
266
+ mask_center: bool = True,
267
+ ):
268
+ """
269
+ Parameters
270
+ ----------
271
+ chans : int
272
+ Number of output channels of the first convolution layer.
273
+ num_pools : int
274
+ Number of down-sampling and up-sampling layers.
275
+ in_chans : int, optional
276
+ Number of channels in the input to the U-Net model. Default is 2.
277
+ out_chans : int, optional
278
+ Number of channels in the output to the U-Net model. Default is 2.
279
+ drop_prob : float, optional
280
+ Dropout probability. Default is 0.0.
281
+ mask_center : bool, optional
282
+ Whether to mask center of k-space for sensitivity map calculation.
283
+ Default is True.
284
+ """
285
+ super().__init__()
286
+ self.mask_center = mask_center
287
+ self.norm_udno = NormUDNO(
288
+ chans,
289
+ num_pools,
290
+ radius_cutoff,
291
+ in_shape,
292
+ kernel_shape,
293
+ in_chans=in_chans,
294
+ out_chans=out_chans,
295
+ drop_prob=drop_prob,
296
+ )
297
+
298
+ def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
299
+ return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
300
+
301
+ def get_pad_and_num_low_freqs(
302
+ self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
303
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
304
+ if num_low_frequencies is None or any(
305
+ torch.any(t == 0) for t in num_low_frequencies
306
+ ):
307
+ # get low frequency line locations and mask them out
308
+ squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
309
+ cent = squeezed_mask.shape[1] // 2
310
+ # running argmin returns the first non-zero
311
+ left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
312
+ right = torch.argmin(squeezed_mask[:, cent:], dim=1)
313
+ num_low_frequencies_tensor = torch.max(
314
+ 2 * torch.min(left, right), torch.ones_like(left)
315
+ ) # force a symmetric center unless 1
316
+ else:
317
+ num_low_frequencies_tensor = num_low_frequencies * torch.ones(
318
+ mask.shape[0], dtype=mask.dtype, device=mask.device
319
+ )
320
+
321
+ pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
322
+
323
+ return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
324
+
325
+ def forward(
326
+ self,
327
+ masked_kspace: torch.Tensor,
328
+ mask: torch.Tensor,
329
+ num_low_frequencies: Optional[int] = None,
330
+ ) -> torch.Tensor:
331
+ if self.mask_center:
332
+ pad, num_low_freqs = self.get_pad_and_num_low_freqs(
333
+ mask, num_low_frequencies
334
+ )
335
+ masked_kspace = transforms.batched_mask_center(
336
+ masked_kspace, pad, pad + num_low_freqs
337
+ )
338
+
339
+ # convert to image space
340
+ images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
341
+
342
+ # estimate sensitivities
343
+ return self.divide_root_sum_of_squares(
344
+ batch_chans_to_chan_dim(self.norm_udno(images), batches)
345
+ )
346
+
347
+
348
+ class VarNetBlock(nn.Module):
349
+ """
350
+ Model block for iterative refinement of k-space data.
351
+
352
+ This model applies a combination of soft data consistency with the input
353
+ model as a regularizer. A series of these blocks can be stacked to form
354
+ the full variational network.
355
+
356
+ aka Refinement Module in Fig 1
357
+ """
358
+
359
+ def __init__(self, model: nn.Module):
360
+ """
361
+ Args:
362
+ model: Module for "regularization" component of variational
363
+ network.
364
+ """
365
+ super().__init__()
366
+
367
+ self.model = model
368
+ self.dc_weight = nn.Parameter(torch.ones(1))
369
+
370
+ def forward(
371
+ self,
372
+ current_kspace: torch.Tensor,
373
+ ref_kspace: torch.Tensor,
374
+ mask: torch.Tensor,
375
+ sens_maps: torch.Tensor,
376
+ use_dc_term: bool = True,
377
+ ) -> torch.Tensor:
378
+ """
379
+ Args:
380
+ current_kspace: The current k-space data (frequency domain data)
381
+ being processed by the network. (torch.Tensor)
382
+ ref_kspace: Original subsampled k-space data (from which we are
383
+ reconstrucintg the image (reference k-space). (torch.Tensor)
384
+ mask: A binary mask indicating the locations in k-space where
385
+ data consistency should be enforced. (torch.Tensor)
386
+ sens_maps: Sensitivity maps for the different coils in parallel
387
+ imaging. (torch.Tensor)
388
+ """
389
+
390
+ # model-term see orange box of Fig 1 in E2E-VarNet paper!
391
+ # multi channel k-space -> single channel image-space
392
+ b, c, h, w, _ = current_kspace.shape
393
+
394
+ if c == 30:
395
+ # get kspace and inpainted kspace
396
+ kspace = current_kspace[:, :15, :, :, :]
397
+ in_kspace = current_kspace[:, 15:, :, :, :]
398
+ # convert to image space
399
+ image = sens_reduce(kspace, sens_maps)
400
+ in_image = sens_reduce(in_kspace, sens_maps)
401
+ # concatenate both onto each other
402
+ reduced_image = torch.cat([image, in_image], dim=1)
403
+ else:
404
+ reduced_image = sens_reduce(current_kspace, sens_maps)
405
+
406
+ # single channel image-space
407
+ refined_image = self.model(reduced_image)
408
+
409
+ # single channel image-space -> multi channel k-space
410
+ model_term = sens_expand(refined_image, sens_maps)
411
+
412
+ # only use first 15 channels (masked_kspace) in the update
413
+ current_kspace = current_kspace[:, :15, :, :, :]
414
+
415
+ if not use_dc_term:
416
+ return current_kspace - model_term
417
+
418
+ """
419
+ Soft data consistency term:
420
+ - Calculates the difference between current k-space and reference k-space where the mask is true.
421
+ - Multiplies this difference by the data consistency weight.
422
+ """
423
+ # dc_term: see green box of Fig 1 in E2E-VarNet paper!
424
+ zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
425
+ soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
426
+ return current_kspace - soft_dc - model_term
427
+
428
+
429
+ class NOVarnet(nn.Module):
430
+ """
431
+ Neural Operator model for MRI reconstruction.
432
+
433
+ Uses a variational architecture (iterative updates) with a learned sensitivity
434
+ model. All operations are resolution invariant employing neural operator
435
+ modules (GNO, UDNO).
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ num_cascades: int = 12,
441
+ sens_chans: int = 8,
442
+ sens_pools: int = 4,
443
+ chans: int = 18,
444
+ pools: int = 4,
445
+ gno_chans: int = 16,
446
+ gno_pools: int = 4,
447
+ gno_radius_cutoff: float = 0.02,
448
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
449
+ radius_cutoff: float = 0.01,
450
+ kernel_shape: Tuple[int, int] = (3, 4),
451
+ in_shape: Tuple[int, int] = (640, 320),
452
+ mask_center: bool = True,
453
+ use_dc_term: bool = True,
454
+ reduction_method: Literal["batch", "rss"] = "rss",
455
+ skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
456
+ ):
457
+ """
458
+ Parameters
459
+ ----------
460
+ num_cascades : int
461
+ Number of cascades (i.e., layers) for variational network.
462
+ sens_chans : int
463
+ Number of channels for sensitivity map U-Net.
464
+ sens_pools : int
465
+ Number of downsampling and upsampling layers for sensitivity map U-Net.
466
+ chans : int
467
+ Number of channels for cascade U-Net.
468
+ pools : int
469
+ Number of downsampling and upsampling layers for cascade U-Net.
470
+ mask_center : bool
471
+ Whether to mask center of k-space for sensitivity map calculation.
472
+ use_dc_term : bool
473
+ Whether to use the data consistency term.
474
+ reduction_method : "batch" or "rss"
475
+ Method for reducing sensitivity maps to single channel.
476
+ "batch" reduces to single channel by stacking channels.
477
+ "rss" reduces to single channel by root sum of squares.
478
+ skip_method : "replace" or "add" or "add_inv" or "concat"
479
+ "replace" replaces the input with the output of the GNO
480
+ "add" adds the output of the GNO to the input
481
+ "add_inv" adds the output of the GNO to the input (only where samples are missing)
482
+ "concat" concatenates the output of the GNO to the input
483
+ """
484
+
485
+ super().__init__()
486
+
487
+ self.sens_net = SensitivityModel(
488
+ sens_chans,
489
+ sens_pools,
490
+ radius_cutoff,
491
+ in_shape,
492
+ kernel_shape,
493
+ mask_center=mask_center,
494
+ )
495
+ self.gno = NormUDNO(
496
+ gno_chans,
497
+ gno_pools,
498
+ in_shape=in_shape,
499
+ radius_cutoff=radius_cutoff,
500
+ kernel_shape=kernel_shape,
501
+ in_chans=2,
502
+ out_chans=2,
503
+ )
504
+ self.cascades = nn.ModuleList(
505
+ [
506
+ VarNetBlock(
507
+ NormUDNO(
508
+ chans,
509
+ pools,
510
+ radius_cutoff,
511
+ in_shape,
512
+ kernel_shape,
513
+ in_chans=(
514
+ 4 if skip_method == "concat" and cascade_idx == 0 else 2
515
+ ),
516
+ out_chans=2,
517
+ )
518
+ )
519
+ for cascade_idx in range(num_cascades)
520
+ ]
521
+ )
522
+ self.use_dc_term = use_dc_term
523
+ self.reduction_method = reduction_method
524
+ self.skip_method = skip_method
525
+
526
+ print("===================================")
527
+ print("===================================")
528
+ print("initialized no repeat k ")
529
+ print("===================================")
530
+ print("===================================")
531
+
532
+ def forward(
533
+ self,
534
+ masked_kspace: torch.Tensor,
535
+ mask: torch.Tensor,
536
+ num_low_frequencies: Optional[int] = None,
537
+ ) -> torch.Tensor:
538
+
539
+ # (B, C, X, Y, 2)
540
+ sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
541
+
542
+ kspace_pred = masked_kspace
543
+ # iterative update
544
+ for cascade in self.cascades:
545
+ # breakpoint()
546
+ sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies)
547
+ # reduce before inpainting
548
+ kspace_pred, b = chans_to_batch_dim(kspace_pred)
549
+ # inpainting
550
+ kspace_pred = self.gno(kspace_pred)
551
+ kspace_pred = batch_chans_to_chan_dim(kspace_pred, b)
552
+
553
+ # image
554
+ kspace_pred = cascade(
555
+ kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
556
+ )
557
+
558
+ spatial_pred = fastmri.ifft2c(kspace_pred)
559
+ spatial_pred_abs = fastmri.complex_abs(spatial_pred)
560
+ combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
561
+
562
+ return combined_spatial
models/temp/no_repeatk_module.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ import fastmri
7
+ from fastmri import transforms
8
+ from models.temp.no_repeatk import NOVarnet
9
+
10
+ from models.lightning.mri_module import MriModule
11
+ from type_utils import tuple_type
12
+
13
+
14
+ class NORepeatKModule(MriModule):
15
+ """
16
+ NO-Varnet repeat-k (temp) training module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_cascades: int = 12,
22
+ pools: int = 4,
23
+ chans: int = 18,
24
+ sens_pools: int = 4,
25
+ sens_chans: int = 8,
26
+ gno_pools: int = 4,
27
+ gno_chans: int = 16,
28
+ gno_radius_cutoff: float = 0.02,
29
+ gno_kernel_shape: Tuple[int, int] = (6, 7),
30
+ radius_cutoff: float = 0.02,
31
+ kernel_shape: Tuple[int, int] = (6, 7),
32
+ in_shape: Tuple[int, int] = (320, 320),
33
+ use_dc_term: bool = True,
34
+ lr: float = 0.0003,
35
+ lr_step_size: int = 40,
36
+ lr_gamma: float = 0.1,
37
+ weight_decay: float = 0.0,
38
+ reduction_method: str = "rss",
39
+ skip_method: str = "add",
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Parameters
44
+ ----------
45
+ num_cascades : int
46
+ Number of cascades (i.e., layers) for the variational network.
47
+ pools : int
48
+ Number of downsampling and upsampling layers for the cascade U-Net.
49
+ chans : int
50
+ Number of channels for the cascade U-Net.
51
+ sens_pools : int
52
+ Number of downsampling and upsampling layers for the sensitivity map U-Net.
53
+ sens_chans : int
54
+ Number of channels for the sensitivity map U-Net.
55
+ lr : float
56
+ Learning rate.
57
+ lr_step_size : int
58
+ Learning rate step size.
59
+ lr_gamma : float
60
+ Learning rate gamma decay.
61
+ weight_decay : float
62
+ Parameter for penalizing weights norm.
63
+ num_sense_lines : int, optional
64
+ Number of low-frequency lines to use for sensitivity map computation.
65
+ Must be even or `None`. Default `None` will automatically compute the number
66
+ from masks. Default behavior may cause some slices to use more low-frequency
67
+ lines than others, when used in conjunction with e.g. the EquispacedMaskFunc
68
+ defaults. To prevent this, either set `num_sense_lines`, or set
69
+ `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc.
70
+ Note that setting this value may lead to undesired behavior when training on
71
+ multiple accelerations simultaneously.
72
+ """
73
+ super().__init__(**kwargs)
74
+ self.save_hyperparameters()
75
+
76
+ self.num_cascades = num_cascades
77
+ self.pools = pools
78
+ self.chans = chans
79
+ self.sens_pools = sens_pools
80
+ self.sens_chans = sens_chans
81
+ self.gno_pools = gno_pools
82
+ self.gno_chans = gno_chans
83
+ self.gno_radius_cutoff = gno_radius_cutoff
84
+ self.gno_kernel_shape = gno_kernel_shape
85
+ self.radius_cutoff = radius_cutoff
86
+ self.kernel_shape = kernel_shape
87
+ self.in_shape = in_shape
88
+ self.use_dc_term = use_dc_term
89
+ self.lr = lr
90
+ self.lr_step_size = lr_step_size
91
+ self.lr_gamma = lr_gamma
92
+ self.weight_decay = weight_decay
93
+ self.reduction_method = reduction_method
94
+ self.skip_method = skip_method
95
+
96
+ self.model = NOVarnet(
97
+ num_cascades=self.num_cascades,
98
+ sens_chans=self.sens_chans,
99
+ sens_pools=self.sens_pools,
100
+ chans=self.chans,
101
+ pools=self.pools,
102
+ gno_chans=self.gno_chans,
103
+ gno_pools=self.gno_pools,
104
+ gno_radius_cutoff=self.gno_radius_cutoff,
105
+ gno_kernel_shape=self.gno_kernel_shape,
106
+ radius_cutoff=radius_cutoff,
107
+ kernel_shape=kernel_shape,
108
+ in_shape=in_shape,
109
+ use_dc_term=use_dc_term,
110
+ reduction_method=reduction_method,
111
+ skip_method=skip_method,
112
+ )
113
+
114
+ self.criterion = fastmri.SSIMLoss()
115
+ self.num_params = sum(p.numel() for p in self.parameters())
116
+
117
+ def forward(self, masked_kspace, mask, num_low_frequencies):
118
+ return self.model(masked_kspace, mask, num_low_frequencies)
119
+
120
+ def training_step(self, batch, batch_idx):
121
+ output = self.forward(
122
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
123
+ )
124
+
125
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
126
+ loss = self.criterion(
127
+ output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
128
+ )
129
+
130
+ self.log("train_loss", loss, on_step=True, on_epoch=True)
131
+ self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
132
+
133
+ return loss
134
+
135
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
136
+ dataloaders = self.trainer.val_dataloaders
137
+ slug = list(dataloaders.keys())[dataloader_idx]
138
+
139
+ output = self.forward(
140
+ batch.masked_kspace, batch.mask, batch.num_low_frequencies
141
+ )
142
+
143
+ target, output = transforms.center_crop_to_smallest(batch.target, output)
144
+
145
+ loss = self.criterion(
146
+ output.unsqueeze(1),
147
+ target.unsqueeze(1),
148
+ data_range=batch.max_value,
149
+ )
150
+
151
+ return {
152
+ "slug": slug,
153
+ "fname": batch.fname,
154
+ "slice_num": batch.slice_num,
155
+ "max_value": batch.max_value,
156
+ "output": output,
157
+ "target": target,
158
+ "val_loss": loss,
159
+ }
160
+
161
+ def configure_optimizers(self):
162
+ optim = torch.optim.Adam(
163
+ self.parameters(), lr=self.lr, weight_decay=self.weight_decay
164
+ )
165
+ scheduler = torch.optim.lr_scheduler.StepLR(
166
+ optim, self.lr_step_size, self.lr_gamma
167
+ )
168
+
169
+ return [optim], [scheduler]
170
+
171
+ @staticmethod
172
+ def add_model_specific_args(parent_parser):
173
+ """
174
+ Define parameters that only apply to this model
175
+ """
176
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
177
+ parser = MriModule.add_model_specific_args(parser)
178
+
179
+ # network params
180
+ parser.add_argument(
181
+ "--num_cascades",
182
+ default=12,
183
+ type=int,
184
+ help="Number of VarNet cascades",
185
+ )
186
+ parser.add_argument(
187
+ "--pools",
188
+ default=4,
189
+ type=int,
190
+ help="Number of U-Net pooling layers in VarNet blocks",
191
+ )
192
+ parser.add_argument(
193
+ "--chans",
194
+ default=18,
195
+ type=int,
196
+ help="Number of channels for U-Net in VarNet blocks",
197
+ )
198
+ parser.add_argument(
199
+ "--sens_pools",
200
+ default=4,
201
+ type=int,
202
+ help=(
203
+ "Number of pooling layers for sense map estimation U-Net in" " VarNet"
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--sens_chans",
208
+ default=8,
209
+ type=float,
210
+ help="Number of channels for sense map estimation U-Net in VarNet",
211
+ )
212
+ parser.add_argument(
213
+ "--gno_pools",
214
+ default=4,
215
+ type=int,
216
+ help=("Number of pooling layers for GNO"),
217
+ )
218
+ parser.add_argument(
219
+ "--gno_chans",
220
+ default=16,
221
+ type=int,
222
+ help="Number of channels for GNO",
223
+ )
224
+ parser.add_argument(
225
+ "--gno_radius_cutoff",
226
+ default=0.02,
227
+ type=float,
228
+ required=True,
229
+ help="GNO module radius_cutoff",
230
+ )
231
+ parser.add_argument(
232
+ "--gno_kernel_shape",
233
+ default=(6, 7),
234
+ type=tuple_type,
235
+ required=True,
236
+ help="GNO module kernel_shape. Ex: (6, 7)",
237
+ )
238
+ parser.add_argument(
239
+ "--radius_cutoff",
240
+ default=0.01,
241
+ type=float,
242
+ required=True,
243
+ help="DISCO module radius_cutoff",
244
+ )
245
+ parser.add_argument(
246
+ "--kernel_shape",
247
+ default=(6, 7),
248
+ type=tuple_type,
249
+ required=True,
250
+ help="DISCO module kernel_shape. Ex: (6, 7)",
251
+ )
252
+ parser.add_argument(
253
+ "--in_shape",
254
+ default=(640, 320),
255
+ type=tuple_type,
256
+ required=True,
257
+ help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
258
+ )
259
+ parser.add_argument(
260
+ "--use_dc_term",
261
+ default=True,
262
+ type=bool,
263
+ help="Whether to use the DC term in the unrolled iterative update step",
264
+ )
265
+
266
+ # training params (opt)
267
+ parser.add_argument(
268
+ "--lr", default=0.0003, type=float, help="Adam learning rate"
269
+ )
270
+ parser.add_argument(
271
+ "--lr_step_size",
272
+ default=40,
273
+ type=int,
274
+ help="Epoch at which to decrease step size",
275
+ )
276
+ parser.add_argument(
277
+ "--lr_gamma",
278
+ default=0.1,
279
+ type=float,
280
+ help="Extent to which step size should be decreased",
281
+ )
282
+ parser.add_argument(
283
+ "--weight_decay",
284
+ default=0.0,
285
+ type=float,
286
+ help="Strength of weight decay regularization",
287
+ )
288
+ parser.add_argument(
289
+ "--reduction_method",
290
+ default="rss",
291
+ type=str,
292
+ choices=["rss", "batch"],
293
+ help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
294
+ )
295
+ parser.add_argument(
296
+ "--skip_method",
297
+ default="add_inv",
298
+ type=str,
299
+ choices=["add_inv", "add", "concat", "replace"],
300
+ help="Method for skip connection around inpainting module.",
301
+ )
302
+
303
+ return parser
models/udno.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ U-shaped DISCO Neural Operator
3
+ """
4
+
5
+ from typing import List, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from torch_harmonics_local.convolution import (
12
+ EquidistantDiscreteContinuousConv2d as DISCO2d,
13
+ )
14
+
15
+
16
+ class UDNO(nn.Module):
17
+ """
18
+ U-shaped DISCO Neural Operator in PyTorch
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ in_chans: int,
24
+ out_chans: int,
25
+ radius_cutoff: float,
26
+ chans: int = 32,
27
+ num_pool_layers: int = 4,
28
+ drop_prob: float = 0.0,
29
+ in_shape: Tuple[int, int] = (320, 320),
30
+ kernel_shape: Tuple[int, int] = (3, 4),
31
+ ):
32
+ """
33
+ Parameters
34
+ ----------
35
+ in_chans : int
36
+ Number of channels in the input to the U-Net model.
37
+ out_chans : int
38
+ Number of channels in the output to the U-Net model.
39
+ radius_cutoff : float
40
+ Control the effective radius of the DISCO kernel. Values are
41
+ between 0.0 and 1.0. The radius_cutoff is represented as a proportion
42
+ of the normalized input space, to ensure that kernels are resolution
43
+ invaraint.
44
+ chans : int, optional
45
+ Number of output channels of the first DISCO layer. Default is 32.
46
+ num_pool_layers : int, optional
47
+ Number of down-sampling and up-sampling layers. Default is 4.
48
+ drop_prob : float, optional
49
+ Dropout probability. Default is 0.0.
50
+ in_shape : Tuple[int, int]
51
+ Shape of the input to the UDNO. This is required to dynamically
52
+ compile DISCO kernels for resolution invariance.
53
+ kernel_shape : Tuple[int, int], optional
54
+ Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
55
+ rings and 4 anisotropic basis functions. Under the hood, each DISCO
56
+ kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
57
+ 3x3 convolution kernel.
58
+
59
+ Note: This is NOT kernel_size, as under the DISCO framework,
60
+ kernels are dynamically compiled to support resolution invariance.
61
+ """
62
+ super().__init__()
63
+ assert len(in_shape) == 2, "Input shape must be 2D"
64
+
65
+ self.in_chans = in_chans
66
+ self.out_chans = out_chans
67
+ self.chans = chans
68
+ self.num_pool_layers = num_pool_layers
69
+ self.drop_prob = drop_prob
70
+ self.in_shape = in_shape
71
+ self.kernel_shape = kernel_shape
72
+
73
+ self.down_sample_layers = nn.ModuleList(
74
+ [
75
+ DISCOBlock(
76
+ in_chans,
77
+ chans,
78
+ radius_cutoff,
79
+ drop_prob,
80
+ in_shape,
81
+ kernel_shape,
82
+ )
83
+ ]
84
+ )
85
+ ch = chans
86
+ shape = (in_shape[0] // 2, in_shape[1] // 2)
87
+ radius_cutoff = radius_cutoff * 2
88
+ for _ in range(num_pool_layers - 1):
89
+ self.down_sample_layers.append(
90
+ DISCOBlock(
91
+ ch,
92
+ ch * 2,
93
+ radius_cutoff,
94
+ drop_prob,
95
+ in_shape=shape,
96
+ kernel_shape=kernel_shape,
97
+ )
98
+ )
99
+ ch *= 2
100
+ shape = (shape[0] // 2, shape[1] // 2)
101
+ radius_cutoff *= 2
102
+
103
+ # test commit
104
+
105
+ self.bottleneck = DISCOBlock(
106
+ ch,
107
+ ch * 2,
108
+ radius_cutoff,
109
+ drop_prob,
110
+ in_shape=shape,
111
+ kernel_shape=kernel_shape,
112
+ )
113
+
114
+ self.up = nn.ModuleList()
115
+ self.up_transpose = nn.ModuleList()
116
+ for _ in range(num_pool_layers - 1):
117
+ self.up_transpose.append(
118
+ TransposeDISCOBlock(
119
+ ch * 2,
120
+ ch,
121
+ radius_cutoff,
122
+ in_shape=shape,
123
+ kernel_shape=kernel_shape,
124
+ )
125
+ )
126
+ shape = (shape[0] * 2, shape[1] * 2)
127
+ radius_cutoff /= 2
128
+ self.up.append(
129
+ DISCOBlock(
130
+ ch * 2,
131
+ ch,
132
+ radius_cutoff,
133
+ drop_prob,
134
+ in_shape=shape,
135
+ kernel_shape=kernel_shape,
136
+ )
137
+ )
138
+ ch //= 2
139
+
140
+ self.up_transpose.append(
141
+ TransposeDISCOBlock(
142
+ ch * 2,
143
+ ch,
144
+ radius_cutoff,
145
+ in_shape=shape,
146
+ kernel_shape=kernel_shape,
147
+ )
148
+ )
149
+ shape = (shape[0] * 2, shape[1] * 2)
150
+ radius_cutoff /= 2
151
+ self.up.append(
152
+ nn.Sequential(
153
+ DISCOBlock(
154
+ ch * 2,
155
+ ch,
156
+ radius_cutoff,
157
+ drop_prob,
158
+ in_shape=shape,
159
+ kernel_shape=kernel_shape,
160
+ ),
161
+ nn.Conv2d(
162
+ ch, self.out_chans, kernel_size=1, stride=1
163
+ ), # 1x1 conv is always res-invariant (pixel wise channel transformation)
164
+ )
165
+ )
166
+
167
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Parameters
170
+ ----------
171
+ image : torch.Tensor
172
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
173
+
174
+ Returns
175
+ -------
176
+ torch.Tensor
177
+ Output tensor of shape `(N, out_chans, H, W)`.
178
+ """
179
+ stack = []
180
+ output = image
181
+
182
+ # apply down-sampling layers
183
+ for layer in self.down_sample_layers:
184
+ output = layer(output)
185
+ stack.append(output)
186
+ output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
187
+
188
+ output = self.bottleneck(output)
189
+
190
+ # apply up-sampling layers
191
+ for transpose, disco in zip(self.up_transpose, self.up):
192
+ downsample_layer = stack.pop()
193
+ output = transpose(output)
194
+
195
+ # reflect pad on the right/botton if needed to handle odd input dimensions
196
+ padding = [0, 0, 0, 0]
197
+ if output.shape[-1] != downsample_layer.shape[-1]:
198
+ padding[1] = 1 # padding right
199
+ if output.shape[-2] != downsample_layer.shape[-2]:
200
+ padding[3] = 1 # padding bottom
201
+ if torch.sum(torch.tensor(padding)) != 0:
202
+ output = F.pad(output, padding, "reflect")
203
+
204
+ output = torch.cat([output, downsample_layer], dim=1)
205
+ output = disco(output)
206
+
207
+ return output
208
+
209
+
210
+ class DISCOBlock(nn.Module):
211
+ """
212
+ A DISCO Block that consists of two DISCO layers each followed by
213
+ instance normalization, LeakyReLU activation and dropout.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ in_chans: int,
219
+ out_chans: int,
220
+ radius_cutoff: float,
221
+ drop_prob: float,
222
+ in_shape: Tuple[int, int],
223
+ kernel_shape: Tuple[int, int] = (3, 4),
224
+ ):
225
+ """
226
+ Parameters
227
+ ----------
228
+ in_chans : int
229
+ Number of channels in the input.
230
+ out_chans : int
231
+ Number of channels in the output.
232
+ radius_cutoff : float
233
+ Control the effective radius of the DISCO kernel. Values are
234
+ between 0.0 and 1.0. The radius_cutoff is represented as a proportion
235
+ of the normalized input space, to ensure that kernels are resolution
236
+ invaraint.
237
+ in_shape : Tuple[int]
238
+ Unbatched spatial 2D shape of the input to this block.
239
+ Rrequired to dynamically compile DISCO kernels for resolution invariance.
240
+ kernel_shape : Tuple[int, int], optional
241
+ Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
242
+ rings and 4 anisotropic basis functions. Under the hood, each DISCO
243
+ kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
244
+ 3x3 convolution kernel.
245
+
246
+ Note: This is NOT kernel_size, as under the DISCO framework,
247
+ kernels are dynamically compiled to support resolution invariance.
248
+ drop_prob : float
249
+ Dropout probability.
250
+ """
251
+ super().__init__()
252
+
253
+ self.in_chans = in_chans
254
+ self.out_chans = out_chans
255
+ self.drop_prob = drop_prob
256
+
257
+ self.layers = nn.Sequential(
258
+ DISCO2d(
259
+ in_chans,
260
+ out_chans,
261
+ kernel_shape=kernel_shape,
262
+ in_shape=in_shape,
263
+ bias=False,
264
+ radius_cutoff=radius_cutoff,
265
+ padding_mode="constant",
266
+ ),
267
+ nn.InstanceNorm2d(out_chans),
268
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
269
+ nn.Dropout2d(drop_prob),
270
+ DISCO2d(
271
+ out_chans,
272
+ out_chans,
273
+ kernel_shape=kernel_shape,
274
+ in_shape=in_shape,
275
+ bias=False,
276
+ radius_cutoff=radius_cutoff,
277
+ padding_mode="constant",
278
+ ),
279
+ nn.InstanceNorm2d(out_chans),
280
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
281
+ nn.Dropout2d(drop_prob),
282
+ )
283
+
284
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
285
+ """
286
+ Parameters
287
+ ----------
288
+ image : ndarray
289
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
290
+
291
+ Returns
292
+ -------
293
+ ndarray
294
+ Output tensor of shape `(N, out_chans, H, W)`.
295
+ """
296
+ return self.layers(image)
297
+
298
+
299
+ class TransposeDISCOBlock(nn.Module):
300
+ """
301
+ A transpose DISCO Block that consists of an up-sampling layer followed by a
302
+ DISCO layer, instance normalization, and LeakyReLU activation.
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ in_chans: int,
308
+ out_chans: int,
309
+ radius_cutoff: float,
310
+ in_shape: Tuple[int, int],
311
+ kernel_shape: Tuple[int, int] = (3, 4),
312
+ ):
313
+ """
314
+ Parameters
315
+ ----------
316
+ in_chans : int
317
+ Number of channels in the input.
318
+ out_chans : int
319
+ Number of channels in the output.
320
+ radius_cutoff : float
321
+ Control the effective radius of the DISCO kernel. Values are
322
+ between 0.0 and 1.0. The radius_cutoff is represented as a proportion
323
+ of the normalized input space, to ensure that kernels are resolution
324
+ invaraint.
325
+ in_shape : Tuple[int]
326
+ Unbatched spatial 2D shape of the input to this block.
327
+ Rrequired to dynamically compile DISCO kernels for resolution invariance.
328
+ kernel_shape : Tuple[int, int], optional
329
+ Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
330
+ rings and 4 anisotropic basis functions. Under the hood, each DISCO
331
+ kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
332
+ 3x3 convolution kernel.
333
+
334
+ Note: This is NOT kernel_size, as under the DISCO framework,
335
+ kernels are dynamically compiled to support resolution invariance
336
+ """
337
+ super().__init__()
338
+
339
+ self.in_chans = in_chans
340
+ self.out_chans = out_chans
341
+
342
+ self.layers = nn.Sequential(
343
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
344
+ DISCO2d(
345
+ in_chans,
346
+ out_chans,
347
+ kernel_shape=kernel_shape,
348
+ in_shape=(2 * in_shape[0], 2 * in_shape[1]),
349
+ bias=False,
350
+ radius_cutoff=(radius_cutoff / 2),
351
+ padding_mode="constant",
352
+ ),
353
+ nn.InstanceNorm2d(out_chans),
354
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
355
+ )
356
+
357
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
358
+ """
359
+ Parameters
360
+ ----------
361
+ image : torch.Tensor
362
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
363
+
364
+ Returns
365
+ -------
366
+ torch.Tensor
367
+ Output tensor of shape `(N, out_chans, H*2, W*2)`.
368
+ """
369
+ return self.layers(image)
models/unet.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from typing import List, Tuple
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+
15
+ class Unet(nn.Module):
16
+ """
17
+ PyTorch implementation of a U-Net model.
18
+
19
+ O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
20
+ for biomedical image segmentation. In International Conference on Medical
21
+ image computing and computer-assisted intervention, pages 234–241.
22
+ Springer, 2015.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ in_chans: int,
28
+ out_chans: int,
29
+ chans: int = 32,
30
+ num_pool_layers: int = 4,
31
+ drop_prob: float = 0.0,
32
+ ):
33
+ """
34
+ Parameters
35
+ ----------
36
+ in_chans : int
37
+ Number of channels in the input to the U-Net model.
38
+ out_chans : int
39
+ Number of channels in the output to the U-Net model.
40
+ chans : int, optional
41
+ Number of output channels of the first convolution layer. Default is 32.
42
+ num_pool_layers : int, optional
43
+ Number of down-sampling and up-sampling layers. Default is 4.
44
+ drop_prob : float, optional
45
+ Dropout probability. Default is 0.0.
46
+ """
47
+ super().__init__()
48
+
49
+ self.in_chans = in_chans
50
+ self.out_chans = out_chans
51
+ self.chans = chans
52
+ self.num_pool_layers = num_pool_layers
53
+ self.drop_prob = drop_prob
54
+
55
+ self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
56
+ ch = chans
57
+ for _ in range(num_pool_layers - 1):
58
+ self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
59
+ ch *= 2
60
+ self.conv = ConvBlock(ch, ch * 2, drop_prob)
61
+
62
+ self.up_conv = nn.ModuleList()
63
+ self.up_transpose_conv = nn.ModuleList()
64
+ for _ in range(num_pool_layers - 1):
65
+ self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
66
+ self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
67
+ ch //= 2
68
+
69
+ self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
70
+ self.up_conv.append(
71
+ nn.Sequential(
72
+ ConvBlock(ch * 2, ch, drop_prob),
73
+ nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
74
+ )
75
+ )
76
+
77
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Parameters
80
+ ----------
81
+ image : torch.Tensor
82
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
83
+
84
+ Returns
85
+ -------
86
+ torch.Tensor
87
+ Output tensor of shape `(N, out_chans, H, W)`.
88
+ """
89
+ stack = []
90
+ output = image
91
+
92
+ # apply down-sampling layers
93
+ for layer in self.down_sample_layers:
94
+ output = layer(output)
95
+ stack.append(output)
96
+ output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
97
+
98
+ output = self.conv(output)
99
+
100
+ # apply up-sampling layers
101
+ for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
102
+ downsample_layer = stack.pop()
103
+ output = transpose_conv(output)
104
+
105
+ # reflect pad on the right/botton if needed to handle odd input dimensions
106
+ padding = [0, 0, 0, 0]
107
+ if output.shape[-1] != downsample_layer.shape[-1]:
108
+ padding[1] = 1 # padding right
109
+ if output.shape[-2] != downsample_layer.shape[-2]:
110
+ padding[3] = 1 # padding bottom
111
+ if torch.sum(torch.tensor(padding)) != 0:
112
+ output = F.pad(output, padding, "reflect")
113
+
114
+ output = torch.cat([output, downsample_layer], dim=1)
115
+ output = conv(output)
116
+
117
+ return output
118
+
119
+
120
+ class ConvBlock(nn.Module):
121
+ """
122
+ A Convolutional Block that consists of two convolution layers each followed by
123
+ instance normalization, LeakyReLU activation and dropout.
124
+ """
125
+
126
+ def __init__(self, in_chans: int, out_chans: int, drop_prob: float):
127
+ """
128
+ Parameters
129
+ ----------
130
+ in_chans : int
131
+ Number of channels in the input.
132
+ out_chans : int
133
+ Number of channels in the output.
134
+ drop_prob : float
135
+ Dropout probability.
136
+ """
137
+ super().__init__()
138
+
139
+ self.in_chans = in_chans
140
+ self.out_chans = out_chans
141
+ self.drop_prob = drop_prob
142
+
143
+ self.layers = nn.Sequential(
144
+ nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
145
+ nn.InstanceNorm2d(out_chans),
146
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
147
+ nn.Dropout2d(drop_prob),
148
+ nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
149
+ nn.InstanceNorm2d(out_chans),
150
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
151
+ nn.Dropout2d(drop_prob),
152
+ )
153
+
154
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Parameters
157
+ ----------
158
+ image : ndarray
159
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
160
+
161
+ Returns
162
+ -------
163
+ ndarray
164
+ Output tensor of shape `(N, out_chans, H, W)`.
165
+ """
166
+ return self.layers(image)
167
+
168
+
169
+ class TransposeConvBlock(nn.Module):
170
+ """
171
+ A Transpose Convolutional Block that consists of one convolution transpose
172
+ layers followed by instance normalization and LeakyReLU activation.
173
+ """
174
+
175
+ def __init__(self, in_chans: int, out_chans: int):
176
+ """
177
+ Parameters
178
+ ----------
179
+ in_chans : int
180
+ Number of channels in the input.
181
+ out_chans : int
182
+ Number of channels in the output.
183
+ """
184
+ super().__init__()
185
+
186
+ self.in_chans = in_chans
187
+ self.out_chans = out_chans
188
+
189
+ self.layers = nn.Sequential(
190
+ nn.ConvTranspose2d(
191
+ in_chans, out_chans, kernel_size=2, stride=2, bias=False
192
+ ),
193
+ nn.InstanceNorm2d(out_chans),
194
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
195
+ )
196
+
197
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Parameters
200
+ ----------
201
+ image : torch.Tensor
202
+ Input 4D tensor of shape `(N, in_chans, H, W)`.
203
+
204
+ Returns
205
+ -------
206
+ torch.Tensor
207
+ Output tensor of shape `(N, out_chans, H*2, W*2)`.
208
+ """
209
+ return self.layers(image)
models/varnet.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import math
9
+ import os
10
+ from typing import List, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ import fastmri
17
+ from fastmri import transforms
18
+ from models.unet import Unet
19
+
20
+
21
+ class NormUnet(nn.Module):
22
+ """
23
+ Normalized U-Net model.
24
+
25
+ This is the same as a regular U-Net, but with normalization applied to the
26
+ input before the U-Net. This keeps the values more numerically stable
27
+ during training.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ chans: int,
33
+ num_pools: int,
34
+ in_chans: int = 2,
35
+ out_chans: int = 2,
36
+ drop_prob: float = 0.0,
37
+ ):
38
+ """
39
+
40
+ Initialize the VarNet model.
41
+
42
+ Parameters
43
+ ----------
44
+ chans : int
45
+ Number of output channels of the first convolution layer.
46
+ num_pools : int
47
+ Number of down-sampling and up-sampling layers.
48
+ in_chans : int, optional
49
+ Number of channels in the input to the U-Net model. Default is 2.
50
+ out_chans : int, optional
51
+ Number of channels in the output to the U-Net model. Default is 2.
52
+ drop_prob : float, optional
53
+ Dropout probability. Default is 0.0.
54
+ """
55
+ super().__init__()
56
+
57
+ self.unet = Unet(
58
+ in_chans=in_chans,
59
+ out_chans=out_chans,
60
+ chans=chans,
61
+ num_pool_layers=num_pools,
62
+ drop_prob=drop_prob,
63
+ )
64
+
65
+ def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
66
+ b, c, h, w, two = x.shape
67
+ assert two == 2
68
+ return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
69
+
70
+ def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
71
+ b, c2, h, w = x.shape
72
+ assert c2 % 2 == 0
73
+ c = c2 // 2
74
+ return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
75
+
76
+ def norm(
77
+ self, x: torch.Tensor
78
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
+ # group norm
80
+ b, c, h, w = x.shape
81
+ x = x.view(b, 2, c // 2 * h * w)
82
+
83
+ mean = x.mean(dim=2).view(b, 2, 1, 1)
84
+ std = x.std(dim=2).view(b, 2, 1, 1)
85
+
86
+ x = x.view(b, c, h, w)
87
+
88
+ return (x - mean) / std, mean, std
89
+
90
+ def unnorm(
91
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ return x * std + mean
94
+
95
+ def pad(
96
+ self, x: torch.Tensor
97
+ ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
98
+ _, _, h, w = x.shape
99
+ w_mult = ((w - 1) | 15) + 1
100
+ h_mult = ((h - 1) | 15) + 1
101
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
102
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
103
+ # TODO: fix this type when PyTorch fixes theirs
104
+ # the documentation lies - this actually takes a list
105
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
106
+ # https://github.com/pytorch/pytorch/pull/16949
107
+ x = F.pad(x, w_pad + h_pad)
108
+
109
+ return x, (h_pad, w_pad, h_mult, w_mult)
110
+
111
+ def unpad(
112
+ self,
113
+ x: torch.Tensor,
114
+ h_pad: List[int],
115
+ w_pad: List[int],
116
+ h_mult: int,
117
+ w_mult: int,
118
+ ) -> torch.Tensor:
119
+ return x[
120
+ ..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
121
+ ]
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ if not x.shape[-1] == 2:
125
+ raise ValueError("Last dimension must be 2 for complex.")
126
+
127
+ # get shapes for unet and normalize
128
+ x = self.complex_to_chan_dim(x)
129
+ x, mean, std = self.norm(x)
130
+ x, pad_sizes = self.pad(x)
131
+
132
+ x = self.unet(x)
133
+
134
+ # get shapes back and unnormalize
135
+ x = self.unpad(x, *pad_sizes)
136
+ x = self.unnorm(x, mean, std)
137
+ x = self.chan_complex_to_last_dim(x)
138
+
139
+ return x
140
+
141
+
142
+ class SensitivityModel(nn.Module):
143
+ """
144
+ Model for learning sensitivity estimation from k-space data.
145
+
146
+ This model applies an IFFT to multichannel k-space data and then a U-Net
147
+ to the coil images to estimate coil sensitivities. It can be used with the
148
+ end-to-end variational network.
149
+
150
+ Input: multi-coil k-space data
151
+ Output: multi-coil spatial domain sensitivity maps
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ chans: int,
157
+ num_pools: int,
158
+ in_chans: int = 2,
159
+ out_chans: int = 2,
160
+ drop_prob: float = 0.0,
161
+ mask_center: bool = True,
162
+ ):
163
+ """
164
+ Parameters
165
+ ----------
166
+ chans : int
167
+ Number of output channels of the first convolution layer.
168
+ num_pools : int
169
+ Number of down-sampling and up-sampling layers.
170
+ in_chans : int, optional
171
+ Number of channels in the input to the U-Net model. Default is 2.
172
+ out_chans : int, optional
173
+ Number of channels in the output to the U-Net model. Default is 2.
174
+ drop_prob : float, optional
175
+ Dropout probability. Default is 0.0.
176
+ mask_center : bool, optional
177
+ Whether to mask center of k-space for sensitivity map calculation.
178
+ Default is True.
179
+ """
180
+ super().__init__()
181
+ self.mask_center = mask_center
182
+ self.norm_unet = NormUnet(
183
+ chans,
184
+ num_pools,
185
+ in_chans=in_chans,
186
+ out_chans=out_chans,
187
+ drop_prob=drop_prob,
188
+ )
189
+
190
+ def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
191
+ b, c, h, w, comp = x.shape
192
+
193
+ return x.view(b * c, 1, h, w, comp), b
194
+
195
+ def batch_chans_to_chan_dim(
196
+ self,
197
+ x: torch.Tensor,
198
+ batch_size: int,
199
+ ) -> torch.Tensor:
200
+ bc, _, h, w, comp = x.shape
201
+ c = bc // batch_size
202
+
203
+ return x.view(batch_size, c, h, w, comp)
204
+
205
+ def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
207
+
208
+ def get_pad_and_num_low_freqs(
209
+ self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
210
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
211
+ if num_low_frequencies is None or any(
212
+ torch.any(t == 0) for t in num_low_frequencies
213
+ ):
214
+ # get low frequency line locations and mask them out
215
+ squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
216
+ cent = squeezed_mask.shape[1] // 2
217
+ # running argmin returns the first non-zero
218
+ left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
219
+ right = torch.argmin(squeezed_mask[:, cent:], dim=1)
220
+ num_low_frequencies_tensor = torch.max(
221
+ 2 * torch.min(left, right), torch.ones_like(left)
222
+ ) # force a symmetric center unless 1
223
+ else:
224
+ num_low_frequencies_tensor = num_low_frequencies * torch.ones(
225
+ mask.shape[0], dtype=mask.dtype, device=mask.device
226
+ )
227
+
228
+ pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
229
+
230
+ return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
231
+
232
+ def forward(
233
+ self,
234
+ masked_kspace: torch.Tensor,
235
+ mask: torch.Tensor,
236
+ num_low_frequencies: Optional[int] = None,
237
+ ) -> torch.Tensor:
238
+ if self.mask_center:
239
+ pad, num_low_freqs = self.get_pad_and_num_low_freqs(
240
+ mask, num_low_frequencies
241
+ )
242
+ masked_kspace = transforms.batched_mask_center(
243
+ masked_kspace, pad, pad + num_low_freqs
244
+ )
245
+
246
+ # convert to image space
247
+ images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
248
+
249
+ # estimate sensitivities
250
+ return self.divide_root_sum_of_squares(
251
+ self.batch_chans_to_chan_dim(self.norm_unet(images), batches)
252
+ )
253
+
254
+
255
+ class VarNet(nn.Module):
256
+ """
257
+ A full variational network model.
258
+
259
+ This model applies a combination of soft data consistency with a U-Net
260
+ regularizer. To use non-U-Net regularizers, use VarNetBlock.
261
+
262
+ Input: multi-channel k-space data
263
+ Output: single-channel RSS reconstructed image
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ num_cascades: int = 12,
269
+ sens_chans: int = 8,
270
+ sens_pools: int = 4,
271
+ chans: int = 18,
272
+ pools: int = 4,
273
+ mask_center: bool = True,
274
+ ):
275
+ """
276
+ Parameters
277
+ ----------
278
+ num_cascades : int
279
+ Number of cascades (i.e., layers) for variational network.
280
+ sens_chans : int
281
+ Number of channels for sensitivity map U-Net.
282
+ sens_pools : int
283
+ Number of downsampling and upsampling layers for sensitivity map U-Net.
284
+ chans : int
285
+ Number of channels for cascade U-Net.
286
+ pools : int
287
+ Number of downsampling and upsampling layers for cascade U-Net.
288
+ mask_center : bool
289
+ Whether to mask center of k-space for sensitivity map calculation.
290
+ """
291
+
292
+ super().__init__()
293
+
294
+ self.sens_net = SensitivityModel(
295
+ chans=sens_chans,
296
+ num_pools=sens_pools,
297
+ mask_center=mask_center,
298
+ )
299
+ self.cascades = nn.ModuleList(
300
+ [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)]
301
+ )
302
+
303
+ def forward(
304
+ self,
305
+ masked_kspace: torch.Tensor,
306
+ mask: torch.Tensor,
307
+ num_low_frequencies: Optional[int] = None,
308
+ ) -> torch.Tensor:
309
+ sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
310
+ kspace_pred = masked_kspace.clone()
311
+ for cascade in self.cascades:
312
+ kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
313
+
314
+ spatial_pred = fastmri.ifft2c(kspace_pred)
315
+
316
+ # ---------> FIXME: CHANGE FOR MVUE MODE
317
+ if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]:
318
+ combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1)
319
+ else:
320
+ spatial_pred_abs = fastmri.complex_abs(spatial_pred)
321
+ combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
322
+ return combined_spatial
323
+
324
+
325
+ class VarNetBlock(nn.Module):
326
+ """
327
+ Model block for end-to-end variational network (refinemnt module)
328
+
329
+ This model applies a combination of soft data consistency with the input
330
+ model as a regularizer. A series of these blocks can be stacked to form
331
+ the full variational network.
332
+
333
+ Input: multi-channel k-space data
334
+ Output: multi-channel k-space data
335
+ """
336
+
337
+ def __init__(self, model: nn.Module):
338
+ """
339
+ Parameters
340
+ ----------
341
+ model : nn.Module
342
+ Module for "regularization" component of variational network.
343
+ """
344
+ super().__init__()
345
+
346
+ self.model = model
347
+ self.dc_weight = nn.Parameter(torch.ones(1))
348
+
349
+ def sens_expand(
350
+ self, x: torch.Tensor, sens_maps: torch.Tensor
351
+ ) -> torch.Tensor:
352
+ """
353
+ Calculates F (x sens_maps)
354
+ """
355
+ return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
356
+
357
+ def sens_reduce(
358
+ self, x: torch.Tensor, sens_maps: torch.Tensor
359
+ ) -> torch.Tensor:
360
+ """
361
+ Calculates F^{-1}(x) \overline{sens_maps}
362
+ where \overline{sens_maps} is the element-wise applied complex conjugate
363
+ """
364
+ return fastmri.complex_mul(
365
+ fastmri.ifft2c(x), fastmri.complex_conj(sens_maps)
366
+ ).sum(dim=1, keepdim=True)
367
+
368
+ def forward(
369
+ self,
370
+ current_kspace: torch.Tensor,
371
+ ref_kspace: torch.Tensor,
372
+ mask: torch.Tensor,
373
+ sens_maps: torch.Tensor,
374
+ ) -> torch.Tensor:
375
+ """
376
+ Parameters
377
+ ----------
378
+ current_kspace : torch.Tensor
379
+ The current k-space data (frequency domain data) being processed by the network.
380
+ ref_kspace : torch.Tensor
381
+ The reference k-space data (measured data) used for data consistency.
382
+ mask : torch.Tensor
383
+ A binary mask indicating the locations in k-space where data consistency should be enforced.
384
+ sens_maps : torch.Tensor
385
+ Sensitivity maps for the different coils in parallel imaging.
386
+
387
+ Returns
388
+ -------
389
+ torch.Tensor
390
+ The output k-space data after applying the variational network block.
391
+ """
392
+
393
+ """
394
+ Model term:
395
+ - Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation).
396
+ - Applies the neural network model to the reduced data.
397
+ - Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform).
398
+ """
399
+
400
+ model_term = self.sens_expand(
401
+ self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps
402
+ )
403
+
404
+ """
405
+ Soft data consistency term:
406
+ - Calculates the difference between current k-space and reference k-space where the mask is true.
407
+ - Multiplies this difference by the data consistency weight.
408
+ """
409
+ zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
410
+ soft_dc = (
411
+ torch.where(mask, current_kspace - ref_kspace, zero)
412
+ * self.dc_weight
413
+ )
414
+
415
+ # with data consistency term (removed for single cascade experiments)
416
+ return current_kspace - soft_dc - model_term
pyproject.toml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=64.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "no-med"
7
+ version = "0.0.0"
8
+ description = "Neural operators for medical imaging."
9
+ readme = "README.md"
10
+ # readme-content-type = "text/markdown"
11
+ authors = [{ name = "Armeet Singh Jatyani", email = "[email protected]" }]
12
+ license = { text = "MIT" }
13
+ keywords = ["medical imaging", "neural operators", "AI"]
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ requires-python = ">=3.6"
20
+
21
+ [project.optional-dependencies]
22
+ dev = ["pytest>=6.0", "black", "flake8", "pdoc3"]
23
+
24
+ [tool.setuptools]
25
+ include-package-data = true
26
+ packages = { find = {} } # auto find packages
27
+
28
+ [tool.black]
29
+ line-length = 80 # Default is 88, but you can set it to 100 or 120 if needed
30
+ target-version = ['py38', 'py39', 'py310'] # Set target Python versions
31
+ include = '\.pyi?$' # Format only .py and .pyi files
32
+ skip-string-normalization = true
33
+ exclude = '''
34
+ /(
35
+ \.eggs # Exclude files generated by packaging
36
+ | \.git # Exclude version control files
37
+ | \.mypy_cache # Exclude mypy caches
38
+ | \.tox # Exclude tox environments
39
+ | \.venv # Exclude virtual environments
40
+ )/
41
+ '''
42
+ fast = true
43
+
44
+ [tool.isort]
45
+ profile = "black"
46
+ line_length = 80
pytest.ini ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ train: marks training tests (long runtime) (deselect with '-m "not train"')
4
+ addopts = -m "not train"
setup_config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ config_filename = "fastmri.yaml"
4
+ default_config_content = """brain_path: brain path here
5
+ knee_path: knee path here
6
+ log_path: log path here
7
+ checkpoint_path: checkpoint path here
8
+ """
9
+
10
+
11
+ def check_and_create_config():
12
+ if not os.path.exists(config_filename):
13
+ print(f"{config_filename} not found. Creating with default template...")
14
+ with open(config_filename, "w") as config_file:
15
+ config_file.write(default_config_content)
16
+ print(f"Default configuration file created at {config_filename}.")
17
+ else:
18
+ print(f"{config_filename} already exists. No changes made.")
19
+
20
+
21
+ check_and_create_config()
torch_harmonics_local/__init__.py ADDED
File without changes
torch_harmonics_local/_disco_convolution.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ import math
33
+
34
+ import torch
35
+
36
+ # triton will only be avaiable on cuda installations of pytorch
37
+ import triton
38
+ import triton.language as tl
39
+
40
+ BLOCK_SIZE_BATCH = 4
41
+ BLOCK_SIZE_NZ = 8
42
+ BLOCK_SIZE_POUT = 8
43
+
44
+
45
+ @triton.jit
46
+ def _disco_s2_contraction_kernel(
47
+ inz_ptr,
48
+ vnz_ptr,
49
+ nnz,
50
+ inz_stride_ii,
51
+ inz_stride_nz,
52
+ vnz_stride,
53
+ x_ptr,
54
+ batch_size,
55
+ nlat_in,
56
+ nlon_in,
57
+ x_stride_b,
58
+ x_stride_t,
59
+ x_stride_p,
60
+ y_ptr,
61
+ kernel_size,
62
+ nlat_out,
63
+ nlon_out,
64
+ y_stride_b,
65
+ y_stride_f,
66
+ y_stride_t,
67
+ y_stride_p,
68
+ pscale,
69
+ backward: tl.constexpr,
70
+ BLOCK_SIZE_BATCH: tl.constexpr,
71
+ BLOCK_SIZE_NZ: tl.constexpr,
72
+ BLOCK_SIZE_POUT: tl.constexpr,
73
+ ):
74
+ """
75
+ Kernel for the sparse-dense contraction for the S2 DISCO convolution.
76
+ """
77
+
78
+ pid_batch = tl.program_id(0)
79
+ pid_pout = tl.program_id(2)
80
+
81
+ # pid_nz should always be 0 as we do not account for larger grids in this dimension
82
+ pid_nz = tl.program_id(1) # should be always 0
83
+ tl.device_assert(pid_nz == 0)
84
+
85
+ # create the pointer block for pout
86
+ pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT)
87
+ b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
88
+
89
+ # create pointer blocks for the psi datastructure
90
+ iinz = tl.arange(0, BLOCK_SIZE_NZ)
91
+
92
+ # get the initial pointers
93
+ fout_ptrs = inz_ptr + iinz * inz_stride_nz
94
+ tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii
95
+ tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii
96
+ vals_ptrs = vnz_ptr + iinz * vnz_stride
97
+
98
+ # iterate in a blocked fashion over the non-zero entries
99
+ for offs_nz in range(0, nnz, BLOCK_SIZE_NZ):
100
+ # load input output latitude coordinate pairs
101
+ fout = tl.load(
102
+ fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
103
+ )
104
+ tout = tl.load(
105
+ tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
106
+ )
107
+ tpnz = tl.load(
108
+ tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
109
+ )
110
+
111
+ # load corresponding values
112
+ vals = tl.load(
113
+ vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0
114
+ )
115
+
116
+ # compute the shifted longitude coordinates p+p' to read in a coalesced fashion
117
+ tnz = tpnz // nlon_in
118
+ pnz = tpnz % nlon_in
119
+
120
+ # make sure the value is not out of bounds
121
+ tl.device_assert(fout < kernel_size)
122
+ tl.device_assert(tout < nlat_out)
123
+ tl.device_assert(tnz < nlat_in)
124
+ tl.device_assert(pnz < nlon_in)
125
+
126
+ # load corresponding portion of the input array
127
+ x_ptrs = (
128
+ x_ptr
129
+ + tnz[None, :, None] * x_stride_t
130
+ + ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in)
131
+ * x_stride_p
132
+ + b[:, None, None] * x_stride_b
133
+ )
134
+ y_ptrs = (
135
+ y_ptr
136
+ + fout[None, :, None] * y_stride_f
137
+ + tout[None, :, None] * y_stride_t
138
+ + (pout[None, None, :] % nlon_out) * y_stride_p
139
+ + b[:, None, None] * y_stride_b
140
+ )
141
+
142
+ # precompute the mask
143
+ mask = (
144
+ (b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz)
145
+ ) and (pout[None, None, :] < nlon_out)
146
+
147
+ # do the actual computation. Backward is essentially just the same operation with swapped tensors.
148
+ if not backward:
149
+ x = tl.load(x_ptrs, mask=mask, other=0.0)
150
+ y = vals[None, :, None] * x
151
+
152
+ # store it to the output array
153
+ tl.atomic_add(y_ptrs, y, mask=mask)
154
+ else:
155
+ y = tl.load(y_ptrs, mask=mask, other=0.0)
156
+ x = vals[None, :, None] * y
157
+
158
+ # store it to the output array
159
+ tl.atomic_add(x_ptrs, x, mask=mask)
160
+
161
+
162
+ def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
163
+ """
164
+ Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere.
165
+
166
+ Parameters
167
+ ----------
168
+ x: torch.Tensor
169
+ Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in).
170
+ psi : torch.Tensor
171
+ Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
172
+ nlon_out: int
173
+ Number of longitude points the output should have.
174
+ """
175
+
176
+ # check the shapes of all input tensors
177
+ assert len(psi.shape) == 3
178
+ assert len(x.shape) == 4
179
+ assert psi.is_sparse, "Psi must be a sparse COO tensor"
180
+
181
+ # TODO: check that Psi is also coalesced
182
+
183
+ # get the dimensions of the problem
184
+ kernel_size, nlat_out, n_in = psi.shape
185
+ nnz = psi.indices().shape[-1]
186
+ batch_size, n_chans, nlat_in, nlon_in = x.shape
187
+ assert nlat_in * nlon_in == n_in
188
+
189
+ # TODO: check that Psi index vector is of type long
190
+
191
+ # make sure that the grid-points of the output grid fall onto the grid points of the input grid
192
+ assert nlon_in % nlon_out == 0
193
+ pscale = nlon_in // nlon_out
194
+
195
+ # to simplify things, we merge batch and channel dimensions
196
+ x = x.reshape(batch_size * n_chans, nlat_in, nlon_in)
197
+
198
+ # prepare the output tensor
199
+ y = torch.zeros(
200
+ batch_size * n_chans,
201
+ kernel_size,
202
+ nlat_out,
203
+ nlon_out,
204
+ device=x.device,
205
+ dtype=x.dtype,
206
+ )
207
+
208
+ # determine the grid for the computation
209
+ grid = (
210
+ triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
211
+ 1,
212
+ triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
213
+ )
214
+
215
+ # launch the kernel
216
+ _disco_s2_contraction_kernel[grid](
217
+ psi.indices(),
218
+ psi.values(),
219
+ nnz,
220
+ psi.indices().stride(-2),
221
+ psi.indices().stride(-1),
222
+ psi.values().stride(-1),
223
+ x,
224
+ batch_size * n_chans,
225
+ nlat_in,
226
+ nlon_in,
227
+ x.stride(0),
228
+ x.stride(-2),
229
+ x.stride(-1),
230
+ y,
231
+ kernel_size,
232
+ nlat_out,
233
+ nlon_out,
234
+ y.stride(0),
235
+ y.stride(1),
236
+ y.stride(-2),
237
+ y.stride(-1),
238
+ pscale,
239
+ False,
240
+ BLOCK_SIZE_BATCH,
241
+ BLOCK_SIZE_NZ,
242
+ BLOCK_SIZE_POUT,
243
+ )
244
+
245
+ # reshape y back to expose the correct dimensions
246
+ y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)
247
+
248
+ return y
249
+
250
+
251
+ def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int):
252
+ """
253
+ Backward pass for the triton implementation of the efficient DISCO convolution on the sphere.
254
+
255
+ Parameters
256
+ ----------
257
+ grad_y: torch.Tensor
258
+ Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out.
259
+ psi : torch.Tensor
260
+ Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
261
+ nlon_in: int
262
+ Number of longitude points the input used. Is required to infer the correct dimensions
263
+ """
264
+
265
+ # check the shapes of all input tensors
266
+ assert len(psi.shape) == 3
267
+ assert len(grad_y.shape) == 5
268
+ assert psi.is_sparse, "psi must be a sparse COO tensor"
269
+
270
+ # TODO: check that Psi is also coalesced
271
+
272
+ # get the dimensions of the problem
273
+ kernel_size, nlat_out, n_in = psi.shape
274
+ nnz = psi.indices().shape[-1]
275
+ assert grad_y.shape[-2] == nlat_out
276
+ assert grad_y.shape[-3] == kernel_size
277
+ assert n_in % nlon_in == 0
278
+ nlat_in = n_in // nlon_in
279
+ batch_size, n_chans, _, _, nlon_out = grad_y.shape
280
+
281
+ # make sure that the grid-points of the output grid fall onto the grid points of the input grid
282
+ assert nlon_in % nlon_out == 0
283
+ pscale = nlon_in // nlon_out
284
+
285
+ # to simplify things, we merge batch and channel dimensions
286
+ grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out)
287
+
288
+ # prepare the output tensor
289
+ grad_x = torch.zeros(
290
+ batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype
291
+ )
292
+
293
+ # determine the grid for the computation
294
+ grid = (
295
+ triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
296
+ 1,
297
+ triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
298
+ )
299
+
300
+ # launch the kernel
301
+ _disco_s2_contraction_kernel[grid](
302
+ psi.indices(),
303
+ psi.values(),
304
+ nnz,
305
+ psi.indices().stride(-2),
306
+ psi.indices().stride(-1),
307
+ psi.values().stride(-1),
308
+ grad_x,
309
+ batch_size * n_chans,
310
+ nlat_in,
311
+ nlon_in,
312
+ grad_x.stride(0),
313
+ grad_x.stride(-2),
314
+ grad_x.stride(-1),
315
+ grad_y,
316
+ kernel_size,
317
+ nlat_out,
318
+ nlon_out,
319
+ grad_y.stride(0),
320
+ grad_y.stride(1),
321
+ grad_y.stride(-2),
322
+ grad_y.stride(-1),
323
+ pscale,
324
+ True,
325
+ BLOCK_SIZE_BATCH,
326
+ BLOCK_SIZE_NZ,
327
+ BLOCK_SIZE_POUT,
328
+ )
329
+
330
+ # reshape y back to expose the correct dimensions
331
+ grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in)
332
+
333
+ return grad_x
334
+
335
+
336
+ class _DiscoS2ContractionTriton(torch.autograd.Function):
337
+ """
338
+ Helper function to make the triton implementation work with PyTorch autograd functionality
339
+ """
340
+
341
+ @staticmethod
342
+ def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
343
+ ctx.save_for_backward(psi)
344
+ ctx.nlon_in = x.shape[-1]
345
+
346
+ return _disco_s2_contraction_fwd(x, psi, nlon_out)
347
+
348
+ @staticmethod
349
+ def backward(ctx, grad_output):
350
+ (psi,) = ctx.saved_tensors
351
+ grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in)
352
+ grad_x = grad_psi = None
353
+
354
+ return grad_input, None, None
355
+
356
+
357
+ class _DiscoS2TransposeContractionTriton(torch.autograd.Function):
358
+ """
359
+ Helper function to make the triton implementation work with PyTorch autograd functionality
360
+ """
361
+
362
+ @staticmethod
363
+ def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
364
+ ctx.save_for_backward(psi)
365
+ ctx.nlon_in = x.shape[-1]
366
+
367
+ return _disco_s2_contraction_bwd(x, psi, nlon_out)
368
+
369
+ @staticmethod
370
+ def backward(ctx, grad_output):
371
+ (psi,) = ctx.saved_tensors
372
+ grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in)
373
+ grad_x = grad_psi = None
374
+
375
+ return grad_input, None, None
376
+
377
+
378
+ def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
379
+ return _DiscoS2ContractionTriton.apply(x, psi, nlon_out)
380
+
381
+
382
+ def _disco_s2_transpose_contraction_triton(
383
+ x: torch.Tensor, psi: torch.Tensor, nlon_out: int
384
+ ):
385
+ return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out)
386
+
387
+
388
+ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
389
+ """
390
+ Reference implementation of the custom contraction as described in [1]. This requires repeated
391
+ shifting of the input tensor, which can potentially be costly. For an efficient implementation
392
+ on GPU, make sure to use the custom kernel written in Triton.
393
+ """
394
+ assert len(psi.shape) == 3
395
+ assert len(x.shape) == 4
396
+ psi = psi.to(x.device)
397
+
398
+ batch_size, n_chans, nlat_in, nlon_in = x.shape
399
+ kernel_size, nlat_out, _ = psi.shape
400
+
401
+ assert psi.shape[-1] == nlat_in * nlon_in
402
+ assert nlon_in % nlon_out == 0
403
+ assert nlon_in >= nlat_out
404
+ pscale = nlon_in // nlon_out
405
+
406
+ # add a dummy dimension for nkernel and move the batch and channel dims to the end
407
+ x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
408
+ x = x.expand(kernel_size, -1, -1, -1)
409
+
410
+ y = torch.zeros(
411
+ nlon_out,
412
+ kernel_size,
413
+ nlat_out,
414
+ batch_size * n_chans,
415
+ device=x.device,
416
+ dtype=x.dtype,
417
+ )
418
+
419
+ for pout in range(nlon_out):
420
+ # sparse contraction with psi
421
+ y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
422
+ # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
423
+ x = torch.roll(x, -pscale, dims=2)
424
+
425
+ # reshape y back to expose the correct dimensions
426
+ y = y.permute(3, 1, 2, 0).reshape(
427
+ batch_size, n_chans, kernel_size, nlat_out, nlon_out
428
+ )
429
+
430
+ return y
431
+
432
+
433
+ def _disco_s2_transpose_contraction_torch(
434
+ x: torch.Tensor, psi: torch.Tensor, nlon_out: int
435
+ ):
436
+ """
437
+ Reference implementation of the custom contraction as described in [1]. This requires repeated
438
+ shifting of the input tensor, which can potentially be costly. For an efficient implementation
439
+ on GPU, make sure to use the custom kernel written in Triton.
440
+ """
441
+ assert len(psi.shape) == 3
442
+ assert len(x.shape) == 5
443
+ psi = psi.to(x.device)
444
+
445
+ batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
446
+ kernel_size, _, n_out = psi.shape
447
+
448
+ assert psi.shape[-2] == nlat_in
449
+ assert n_out % nlon_out == 0
450
+ nlat_out = n_out // nlon_out
451
+ assert nlon_out >= nlat_in
452
+ pscale = nlon_out // nlon_in
453
+
454
+ # we do a semi-transposition to faciliate the computation
455
+ inz = psi.indices()
456
+ tout = inz[2] // nlon_out
457
+ pout = inz[2] % nlon_out
458
+ # flip the axis of longitudes
459
+ pout = nlon_out - 1 - pout
460
+ tin = inz[1]
461
+ inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0)
462
+ psi_mod = torch.sparse_coo_tensor(
463
+ inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out)
464
+ )
465
+
466
+ # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
467
+ x_ext = torch.zeros(
468
+ kernel_size,
469
+ nlat_in,
470
+ nlon_out,
471
+ batch_size * n_chans,
472
+ device=x.device,
473
+ dtype=x.dtype,
474
+ )
475
+ x_ext[:, :, ::pscale, :] = x.reshape(
476
+ batch_size * n_chans, kernel_size, nlat_in, nlon_in
477
+ ).permute(1, 2, 3, 0)
478
+ # we need to go backwards through the vector, so we flip the axis
479
+ x_ext = x_ext.contiguous()
480
+
481
+ y = torch.zeros(
482
+ kernel_size,
483
+ nlon_out,
484
+ nlat_out,
485
+ batch_size * n_chans,
486
+ device=x.device,
487
+ dtype=x.dtype,
488
+ )
489
+
490
+ for pout in range(nlon_out):
491
+ # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
492
+ # TODO: double-check why this has to happen first
493
+ x_ext = torch.roll(x_ext, -1, dims=2)
494
+ # sparse contraction with the modified psi
495
+ y[:, pout, :, :] = torch.bmm(
496
+ psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)
497
+ )
498
+
499
+ # sum over the kernel dimension and reshape to the correct output size
500
+ y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
501
+
502
+ return y
torch_harmonics_local/convolution.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ import abc
33
+ import math
34
+ from functools import partial
35
+ from typing import List, Optional, Tuple, Union
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+ from .quadrature import _precompute_grid, _precompute_latitudes
42
+
43
+ if torch.cuda.is_available():
44
+ from ._disco_convolution import (
45
+ _disco_s2_contraction_triton,
46
+ _disco_s2_transpose_contraction_triton,
47
+ )
48
+
49
+
50
+ def _compute_support_vals_isotropic(
51
+ r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"
52
+ ):
53
+ """
54
+ Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
55
+ """
56
+
57
+ # compute the support
58
+ dr = (r_cutoff - 0.0) / nr
59
+ ikernel = torch.arange(nr).reshape(-1, 1, 1)
60
+ ir = ikernel * dr
61
+
62
+ if norm == "none":
63
+ norm_factor = 1.0
64
+ elif norm == "2d":
65
+ norm_factor = (
66
+ math.pi * (r_cutoff * nr / (nr + 1)) ** 2
67
+ + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
68
+ )
69
+ elif norm == "s2":
70
+ norm_factor = (
71
+ 2
72
+ * math.pi
73
+ * (
74
+ 1
75
+ - math.cos(r_cutoff - dr)
76
+ + math.cos(r_cutoff - dr)
77
+ + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
78
+ )
79
+ )
80
+ else:
81
+ raise ValueError(f"Unknown normalization mode {norm}.")
82
+
83
+ # find the indices where the rotated position falls into the support of the kernel
84
+ iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
85
+ vals = (
86
+ 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
87
+ ) / norm_factor
88
+ return iidx, vals
89
+
90
+
91
+ def _compute_support_vals_anisotropic(
92
+ r: torch.Tensor,
93
+ phi: torch.Tensor,
94
+ nr: int,
95
+ nphi: int,
96
+ r_cutoff: float,
97
+ norm: str = "s2",
98
+ ):
99
+ """
100
+ Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
101
+ """
102
+
103
+ # compute the support
104
+ dr = (r_cutoff - 0.0) / nr
105
+ dphi = 2.0 * math.pi / nphi
106
+ kernel_size = (nr - 1) * nphi + 1
107
+ ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
108
+ ir = ((ikernel - 1) // nphi + 1) * dr
109
+ iphi = ((ikernel - 1) % nphi) * dphi
110
+
111
+ if norm == "none":
112
+ norm_factor = 1.0
113
+ elif norm == "2d":
114
+ norm_factor = (
115
+ math.pi * (r_cutoff * nr / (nr + 1)) ** 2
116
+ + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
117
+ )
118
+ elif norm == "s2":
119
+ norm_factor = (
120
+ 2
121
+ * math.pi
122
+ * (
123
+ 1
124
+ - math.cos(r_cutoff - dr)
125
+ + math.cos(r_cutoff - dr)
126
+ + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
127
+ )
128
+ )
129
+ else:
130
+ raise ValueError(f"Unknown normalization mode {norm}.")
131
+
132
+ # find the indices where the rotated position falls into the support of the kernel
133
+ cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
134
+ cond_phi = (
135
+ (ikernel == 0)
136
+ | ((phi - iphi).abs() <= dphi)
137
+ | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
138
+ )
139
+ iidx = torch.argwhere(cond_r & cond_phi)
140
+ vals = (
141
+ 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
142
+ ) / norm_factor
143
+ vals *= torch.where(
144
+ iidx[:, 0] > 0,
145
+ (
146
+ 1
147
+ - torch.minimum(
148
+ (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(),
149
+ (
150
+ 2 * math.pi
151
+ - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
152
+ ),
153
+ )
154
+ / dphi
155
+ ),
156
+ 1.0,
157
+ )
158
+ return iidx, vals
159
+
160
+
161
+ def _precompute_convolution_tensor_s2(
162
+ in_shape,
163
+ out_shape,
164
+ kernel_shape,
165
+ grid_in="equiangular",
166
+ grid_out="equiangular",
167
+ theta_cutoff=0.01 * math.pi,
168
+ ):
169
+ """
170
+ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
171
+ Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
172
+ The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
173
+
174
+ The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
175
+ $$
176
+ Y(\alpha) Z(\beta) Y(\gamma) n =
177
+ {\begin{bmatrix}
178
+ \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
179
+ \sin(\beta)\sin(\gamma) \\
180
+ \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
181
+ \end{bmatrix}}
182
+ $$
183
+ """
184
+
185
+ assert len(in_shape) == 2
186
+ assert len(out_shape) == 2
187
+
188
+ if len(kernel_shape) == 1:
189
+ kernel_handle = partial(
190
+ _compute_support_vals_isotropic,
191
+ nr=kernel_shape[0],
192
+ r_cutoff=theta_cutoff,
193
+ norm="s2",
194
+ )
195
+ elif len(kernel_shape) == 2:
196
+ kernel_handle = partial(
197
+ _compute_support_vals_anisotropic,
198
+ nr=kernel_shape[0],
199
+ nphi=kernel_shape[1],
200
+ r_cutoff=theta_cutoff,
201
+ norm="s2",
202
+ )
203
+ else:
204
+ raise ValueError("kernel_shape should be either one- or two-dimensional.")
205
+
206
+ nlat_in, nlon_in = in_shape
207
+ nlat_out, nlon_out = out_shape
208
+
209
+ lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
210
+ lats_in = torch.from_numpy(lats_in).float()
211
+ lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
212
+ lats_out = torch.from_numpy(lats_out).float()
213
+
214
+ # array for accumulating non-zero indices
215
+ out_idx = torch.empty([3, 0], dtype=torch.long)
216
+ out_vals = torch.empty([0], dtype=torch.long)
217
+
218
+ # compute the phi differences
219
+ # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
220
+ lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
221
+
222
+ for t in range(nlat_out):
223
+ # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
224
+ alpha = -lats_out[t]
225
+ beta = lons_in
226
+ gamma = lats_in.reshape(-1, 1)
227
+
228
+ # compute cartesian coordinates of the rotated position
229
+ # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
230
+ # and therefore applied with a negative sign
231
+ z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(
232
+ alpha
233
+ ) * torch.cos(gamma)
234
+ x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(
235
+ gamma
236
+ ) * torch.sin(alpha)
237
+ y = torch.sin(beta) * torch.sin(gamma)
238
+
239
+ # normalization is emportant to avoid NaNs when arccos and atan are applied
240
+ # this can otherwise lead to spurious artifacts in the solution
241
+ norm = torch.sqrt(x * x + y * y + z * z)
242
+ x = x / norm
243
+ y = y / norm
244
+ z = z / norm
245
+
246
+ # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
247
+ theta = torch.arccos(z)
248
+ phi = torch.arctan2(y, x) + torch.pi
249
+
250
+ # find the indices where the rotated position falls into the support of the kernel
251
+ iidx, vals = kernel_handle(theta, phi)
252
+
253
+ # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
254
+ idx = torch.stack(
255
+ [
256
+ iidx[:, 0],
257
+ t * torch.ones_like(iidx[:, 0]),
258
+ iidx[:, 1] * nlon_in + iidx[:, 2],
259
+ ],
260
+ dim=0,
261
+ )
262
+
263
+ # append indices and values to the COO datastructure
264
+ out_idx = torch.cat([out_idx, idx], dim=-1)
265
+ out_vals = torch.cat([out_vals, vals], dim=-1)
266
+
267
+ return out_idx, out_vals
268
+
269
+
270
+ def _precompute_convolution_tensor_2d(
271
+ grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False
272
+ ):
273
+ """
274
+ Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
275
+ only that it assumes a non-periodic subset of the euclidean plane
276
+ """
277
+
278
+ # check that input arrays are valid point clouds in 2D
279
+ assert len(grid_in) == 2
280
+ assert len(grid_out) == 2
281
+ assert grid_in.shape[0] == 2
282
+ assert grid_out.shape[0] == 2
283
+
284
+ n_in = grid_in.shape[-1]
285
+ n_out = grid_out.shape[-1]
286
+
287
+ if len(kernel_shape) == 1:
288
+ kernel_handle = partial(
289
+ _compute_support_vals_isotropic,
290
+ nr=kernel_shape[0],
291
+ r_cutoff=radius_cutoff,
292
+ norm="2d",
293
+ )
294
+ elif len(kernel_shape) == 2:
295
+ kernel_handle = partial(
296
+ _compute_support_vals_anisotropic,
297
+ nr=kernel_shape[0],
298
+ nphi=kernel_shape[1],
299
+ r_cutoff=radius_cutoff,
300
+ norm="2d",
301
+ )
302
+ else:
303
+ raise ValueError("kernel_shape should be either one- or two-dimensional.")
304
+
305
+ grid_in = grid_in.reshape(2, 1, n_in)
306
+ grid_out = grid_out.reshape(2, n_out, 1)
307
+
308
+ diffs = grid_in - grid_out
309
+ if periodic:
310
+ periodic_diffs = torch.where(diffs > 0.0, diffs - 1, diffs + 1)
311
+ diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)
312
+
313
+ r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
314
+ phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi
315
+
316
+ idx, vals = kernel_handle(r, phi)
317
+ idx = idx.permute(1, 0)
318
+
319
+ return idx, vals
320
+
321
+
322
+ class DiscreteContinuousConv(nn.Module, abc.ABC):
323
+ """
324
+ Abstract base class for DISCO convolutions
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ in_channels: int,
330
+ out_channels: int,
331
+ kernel_shape: Union[int, List[int]],
332
+ groups: Optional[int] = 1,
333
+ bias: Optional[bool] = True,
334
+ ):
335
+ super().__init__()
336
+
337
+ if isinstance(kernel_shape, int):
338
+ self.kernel_shape = [kernel_shape]
339
+ else:
340
+ self.kernel_shape = kernel_shape
341
+
342
+ if len(self.kernel_shape) == 1:
343
+ self.kernel_size = self.kernel_shape[0]
344
+ elif len(self.kernel_shape) == 2:
345
+ self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
346
+ else:
347
+ raise ValueError("kernel_shape should be either one- or two-dimensional.")
348
+
349
+ # groups
350
+ self.groups = groups
351
+
352
+ # weight tensor
353
+ if in_channels % self.groups != 0:
354
+ raise ValueError(
355
+ "Error, the number of input channels has to be an integer multiple of the group size"
356
+ )
357
+ if out_channels % self.groups != 0:
358
+ raise ValueError(
359
+ "Error, the number of output channels has to be an integer multiple of the group size"
360
+ )
361
+ self.groupsize = in_channels // self.groups
362
+ scale = math.sqrt(1.0 / self.groupsize)
363
+ self.weight = nn.Parameter(
364
+ scale * torch.randn(out_channels, self.groupsize, self.kernel_size)
365
+ )
366
+
367
+ if bias:
368
+ self.bias = nn.Parameter(torch.zeros(out_channels))
369
+ else:
370
+ self.bias = None
371
+
372
+ @abc.abstractmethod
373
+ def forward(self, x: torch.Tensor):
374
+ raise NotImplementedError
375
+
376
+
377
+ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
378
+ """
379
+ Reference implementation of the custom contraction as described in [1]. This requires repeated
380
+ shifting of the input tensor, which can potentially be costly. For an efficient implementation
381
+ on GPU, make sure to use the custom kernel written in Triton.
382
+ """
383
+ assert len(psi.shape) == 3
384
+ assert len(x.shape) == 4
385
+ psi = psi.to(x.device)
386
+
387
+ batch_size, n_chans, nlat_in, nlon_in = x.shape
388
+ kernel_size, nlat_out, _ = psi.shape
389
+
390
+ assert psi.shape[-1] == nlat_in * nlon_in
391
+ assert nlon_in % nlon_out == 0
392
+ assert nlon_in >= nlat_out
393
+ pscale = nlon_in // nlon_out
394
+
395
+ # add a dummy dimension for nkernel and move the batch and channel dims to the end
396
+ x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
397
+ x = x.expand(kernel_size, -1, -1, -1)
398
+
399
+ y = torch.zeros(
400
+ nlon_out,
401
+ kernel_size,
402
+ nlat_out,
403
+ batch_size * n_chans,
404
+ device=x.device,
405
+ dtype=x.dtype,
406
+ )
407
+
408
+ for pout in range(nlon_out):
409
+ # sparse contraction with psi
410
+ y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
411
+ # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
412
+ x = torch.roll(x, -pscale, dims=2)
413
+
414
+ # reshape y back to expose the correct dimensions
415
+ y = y.permute(3, 1, 2, 0).reshape(
416
+ batch_size, n_chans, kernel_size, nlat_out, nlon_out
417
+ )
418
+
419
+ return y
420
+
421
+
422
+ def _disco_s2_transpose_contraction_torch(
423
+ x: torch.Tensor, psi: torch.Tensor, nlon_out: int
424
+ ):
425
+ """
426
+ Reference implementation of the custom contraction as described in [1]. This requires repeated
427
+ shifting of the input tensor, which can potentially be costly. For an efficient implementation
428
+ on GPU, make sure to use the custom kernel written in Triton.
429
+ """
430
+ assert len(psi.shape) == 3
431
+ assert len(x.shape) == 5
432
+ psi = psi.to(x.device)
433
+
434
+ batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
435
+ kernel_size, _, n_out = psi.shape
436
+
437
+ assert psi.shape[-2] == nlat_in
438
+ assert n_out % nlon_out == 0
439
+ nlat_out = n_out // nlon_out
440
+ assert nlon_out >= nlat_in
441
+ pscale = nlon_out // nlon_in
442
+
443
+ # we do a semi-transposition to faciliate the computation
444
+ inz = psi.indices()
445
+ tout = inz[2] // nlon_out
446
+ pout = inz[2] % nlon_out
447
+ # flip the axis of longitudes
448
+ pout = nlon_out - 1 - pout
449
+ tin = inz[1]
450
+ inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0)
451
+ psi_mod = torch.sparse_coo_tensor(
452
+ inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out)
453
+ )
454
+
455
+ # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
456
+ x_ext = torch.zeros(
457
+ kernel_size,
458
+ nlat_in,
459
+ nlon_out,
460
+ batch_size * n_chans,
461
+ device=x.device,
462
+ dtype=x.dtype,
463
+ )
464
+ x_ext[:, :, ::pscale, :] = x.reshape(
465
+ batch_size * n_chans, kernel_size, nlat_in, nlon_in
466
+ ).permute(1, 2, 3, 0)
467
+ # we need to go backwards through the vector, so we flip the axis
468
+ x_ext = x_ext.contiguous()
469
+
470
+ y = torch.zeros(
471
+ kernel_size,
472
+ nlon_out,
473
+ nlat_out,
474
+ batch_size * n_chans,
475
+ device=x.device,
476
+ dtype=x.dtype,
477
+ )
478
+
479
+ for pout in range(nlon_out):
480
+ # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
481
+ # TODO: double-check why this has to happen first
482
+ x_ext = torch.roll(x_ext, -1, dims=2)
483
+ # sparse contraction with the modified psi
484
+ y[:, pout, :, :] = torch.bmm(
485
+ psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)
486
+ )
487
+
488
+ # sum over the kernel dimension and reshape to the correct output size
489
+ y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
490
+
491
+ return y
492
+
493
+
494
+ class DiscreteContinuousConvS2(DiscreteContinuousConv):
495
+ """
496
+ Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
497
+
498
+ [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
499
+ """
500
+
501
+ def __init__(
502
+ self,
503
+ in_channels: int,
504
+ out_channels: int,
505
+ in_shape: Tuple[int],
506
+ out_shape: Tuple[int],
507
+ kernel_shape: Union[int, List[int]],
508
+ groups: Optional[int] = 1,
509
+ grid_in: Optional[str] = "equiangular",
510
+ grid_out: Optional[str] = "equiangular",
511
+ bias: Optional[bool] = True,
512
+ theta_cutoff: Optional[float] = None,
513
+ ):
514
+ super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
515
+
516
+ self.nlat_in, self.nlon_in = in_shape
517
+ self.nlat_out, self.nlon_out = out_shape
518
+
519
+ # compute theta cutoff based on the bandlimit of the input field
520
+ if theta_cutoff is None:
521
+ theta_cutoff = (
522
+ (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
523
+ )
524
+
525
+ if theta_cutoff <= 0.0:
526
+ raise ValueError("Error, theta_cutoff has to be positive.")
527
+
528
+ # integration weights
529
+ _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
530
+ quad_weights = (
531
+ 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
532
+ )
533
+ self.register_buffer("quad_weights", quad_weights, persistent=False)
534
+
535
+ idx, vals = _precompute_convolution_tensor_s2(
536
+ in_shape,
537
+ out_shape,
538
+ self.kernel_shape,
539
+ grid_in=grid_in,
540
+ grid_out=grid_out,
541
+ theta_cutoff=theta_cutoff,
542
+ )
543
+
544
+ self.register_buffer("psi_idx", idx, persistent=False)
545
+ self.register_buffer("psi_vals", vals, persistent=False)
546
+
547
+ def get_psi(self):
548
+ psi = torch.sparse_coo_tensor(
549
+ self.psi_idx,
550
+ self.psi_vals,
551
+ size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in),
552
+ ).coalesce()
553
+ return psi
554
+
555
+ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
556
+ # pre-multiply x with the quadrature weights
557
+ x = self.quad_weights * x
558
+
559
+ psi = self.get_psi()
560
+
561
+ if x.is_cuda and use_triton_kernel:
562
+ x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
563
+ else:
564
+ x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
565
+
566
+ # extract shape
567
+ B, C, K, H, W = x.shape
568
+ x = x.reshape(B, self.groups, self.groupsize, K, H, W)
569
+
570
+ # do weight multiplication
571
+ out = torch.einsum(
572
+ "bgckxy,gock->bgoxy",
573
+ x,
574
+ self.weight.reshape(
575
+ self.groups, -1, self.weight.shape[1], self.weight.shape[2]
576
+ ),
577
+ )
578
+ out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
579
+
580
+ if self.bias is not None:
581
+ out = out + self.bias.reshape(1, -1, 1, 1)
582
+
583
+ return out
584
+
585
+
586
+ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
587
+ """
588
+ Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
589
+
590
+ [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ in_channels: int,
596
+ out_channels: int,
597
+ in_shape: Tuple[int],
598
+ out_shape: Tuple[int],
599
+ kernel_shape: Union[int, List[int]],
600
+ groups: Optional[int] = 1,
601
+ grid_in: Optional[str] = "equiangular",
602
+ grid_out: Optional[str] = "equiangular",
603
+ bias: Optional[bool] = True,
604
+ theta_cutoff: Optional[float] = None,
605
+ ):
606
+ super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
607
+
608
+ self.nlat_in, self.nlon_in = in_shape
609
+ self.nlat_out, self.nlon_out = out_shape
610
+
611
+ # bandlimit
612
+ if theta_cutoff is None:
613
+ theta_cutoff = (
614
+ (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
615
+ )
616
+
617
+ if theta_cutoff <= 0.0:
618
+ raise ValueError("Error, theta_cutoff has to be positive.")
619
+
620
+ # integration weights
621
+ _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
622
+ quad_weights = (
623
+ 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
624
+ )
625
+ self.register_buffer("quad_weights", quad_weights, persistent=False)
626
+
627
+ # switch in_shape and out_shape since we want transpose conv
628
+ idx, vals = _precompute_convolution_tensor_s2(
629
+ out_shape,
630
+ in_shape,
631
+ self.kernel_shape,
632
+ grid_in=grid_out,
633
+ grid_out=grid_in,
634
+ theta_cutoff=theta_cutoff,
635
+ )
636
+
637
+ self.register_buffer("psi_idx", idx, persistent=False)
638
+ self.register_buffer("psi_vals", vals, persistent=False)
639
+
640
+ def get_psi(self):
641
+ psi = torch.sparse_coo_tensor(
642
+ self.psi_idx,
643
+ self.psi_vals,
644
+ size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out),
645
+ ).coalesce()
646
+ return psi
647
+
648
+ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
649
+ # extract shape
650
+ B, C, H, W = x.shape
651
+ x = x.reshape(B, self.groups, self.groupsize, H, W)
652
+
653
+ # do weight multiplication
654
+ x = torch.einsum(
655
+ "bgcxy,gock->bgokxy",
656
+ x,
657
+ self.weight.reshape(
658
+ self.groups, -1, self.weight.shape[1], self.weight.shape[2]
659
+ ),
660
+ )
661
+ x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
662
+
663
+ # pre-multiply x with the quadrature weights
664
+ x = self.quad_weights * x
665
+
666
+ psi = self.get_psi()
667
+
668
+ if x.is_cuda and use_triton_kernel:
669
+ out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
670
+ else:
671
+ out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
672
+
673
+ if self.bias is not None:
674
+ out = out + self.bias.reshape(1, -1, 1, 1)
675
+
676
+ return out
677
+
678
+
679
+ class DiscreteContinuousConv2d(DiscreteContinuousConv):
680
+ """
681
+ Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
682
+
683
+ [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
684
+ """
685
+
686
+ def __init__(
687
+ self,
688
+ in_channels: int,
689
+ out_channels: int,
690
+ grid_in: torch.Tensor,
691
+ grid_out: torch.Tensor,
692
+ kernel_shape: Union[int, List[int]],
693
+ n_in: Optional[Tuple[int]] = None,
694
+ n_out: Optional[Tuple[int]] = None,
695
+ quad_weights: Optional[torch.Tensor] = None,
696
+ periodic: Optional[bool] = False,
697
+ groups: Optional[int] = 1,
698
+ bias: Optional[bool] = True,
699
+ radius_cutoff: Optional[float] = None,
700
+ ):
701
+ super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
702
+
703
+ # the instantiator supports convenience constructors for the input and output grids
704
+ if isinstance(grid_in, torch.Tensor):
705
+ assert isinstance(quad_weights, torch.Tensor)
706
+ assert not periodic
707
+ elif isinstance(grid_in, str):
708
+ assert n_in is not None
709
+ assert len(n_in) == 2
710
+ x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
711
+ y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
712
+ x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
713
+ wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
714
+ grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
715
+ quad_weights = (wx * wy).reshape(-1)
716
+ else:
717
+ raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
718
+
719
+ if isinstance(grid_out, torch.Tensor):
720
+ pass
721
+ elif isinstance(grid_out, str):
722
+ assert n_out is not None
723
+ assert len(n_out) == 2
724
+ x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
725
+ y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
726
+ x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
727
+ grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
728
+ else:
729
+ raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
730
+
731
+ # check that input arrays are valid point clouds in 2D
732
+ assert len(grid_in.shape) == 2
733
+ assert len(grid_out.shape) == 2
734
+ assert len(quad_weights.shape) == 1
735
+ assert grid_in.shape[0] == 2
736
+ assert grid_out.shape[0] == 2
737
+
738
+ self.n_in = grid_in.shape[-1]
739
+ self.n_out = grid_out.shape[-1]
740
+
741
+ # compute the cutoff radius based on the bandlimit of the input field
742
+ # TODO: this heuristic is ad-hoc! Verify that we do the right one
743
+ if radius_cutoff is None:
744
+ radius_cutoff = (
745
+ 2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
746
+ )
747
+
748
+ if radius_cutoff <= 0.0:
749
+ raise ValueError("Error, radius_cutoff has to be positive.")
750
+
751
+ # integration weights
752
+ self.register_buffer("quad_weights", quad_weights, persistent=False)
753
+
754
+ idx, vals = _precompute_convolution_tensor_2d(
755
+ grid_in,
756
+ grid_out,
757
+ self.kernel_shape,
758
+ radius_cutoff=radius_cutoff,
759
+ periodic=periodic,
760
+ )
761
+
762
+ # to improve performance, we make psi a matrix by merging the first two dimensions
763
+ # This has to be accounted for in the forward pass
764
+ idx = torch.stack([idx[0] * self.n_out + idx[1], idx[2]], dim=0)
765
+
766
+ self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
767
+ self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
768
+
769
+ def get_psi(self):
770
+ psi = torch.sparse_coo_tensor(
771
+ self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
772
+ )
773
+ return psi
774
+
775
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
776
+ # pre-multiply x with the quadrature weights
777
+ x = self.quad_weights * x
778
+
779
+ psi = self.get_psi()
780
+
781
+ # extract shape
782
+ B, C, _ = x.shape
783
+
784
+ # bring into the right shape for the bmm and perform it
785
+ x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
786
+ x = torch.mm(psi, x)
787
+ x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
788
+ x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
789
+
790
+ # do weight multiplication
791
+ out = torch.einsum(
792
+ "bgckx,gock->bgox",
793
+ x,
794
+ self.weight.reshape(
795
+ self.groups, -1, self.weight.shape[1], self.weight.shape[2]
796
+ ),
797
+ )
798
+ out = out.reshape(out.shape[0], -1, out.shape[-1])
799
+
800
+ if self.bias is not None:
801
+ out = out + self.bias.reshape(1, -1, 1)
802
+
803
+ return out
804
+
805
+
806
+ class DiscreteContinuousConvTranspose2d(DiscreteContinuousConv):
807
+ """
808
+ Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
809
+
810
+ [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
811
+ """
812
+
813
+ def __init__(
814
+ self,
815
+ in_channels: int,
816
+ out_channels: int,
817
+ grid_in: torch.Tensor,
818
+ grid_out: torch.Tensor,
819
+ kernel_shape: Union[int, List[int]],
820
+ n_in: Optional[Tuple[int]] = None,
821
+ n_out: Optional[Tuple[int]] = None,
822
+ quad_weights: Optional[torch.Tensor] = None,
823
+ periodic: Optional[bool] = False,
824
+ groups: Optional[int] = 1,
825
+ bias: Optional[bool] = True,
826
+ radius_cutoff: Optional[float] = None,
827
+ ):
828
+ super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
829
+
830
+ # the instantiator supports convenience constructors for the input and output grids
831
+ if isinstance(grid_in, torch.Tensor):
832
+ assert isinstance(quad_weights, torch.Tensor)
833
+ assert not periodic
834
+ elif isinstance(grid_in, str):
835
+ assert n_in is not None
836
+ assert len(n_in) == 2
837
+ x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
838
+ y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
839
+ x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
840
+ wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
841
+ grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
842
+ quad_weights = (wx * wy).reshape(-1)
843
+ else:
844
+ raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
845
+
846
+ if isinstance(grid_out, torch.Tensor):
847
+ pass
848
+ elif isinstance(grid_out, str):
849
+ assert n_out is not None
850
+ assert len(n_out) == 2
851
+ x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
852
+ y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
853
+ x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
854
+ grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
855
+ else:
856
+ raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
857
+
858
+ # check that input arrays are valid point clouds in 2D
859
+ assert len(grid_in.shape) == 2
860
+ assert len(grid_out.shape) == 2
861
+ assert len(quad_weights.shape) == 1
862
+ assert grid_in.shape[0] == 2
863
+ assert grid_out.shape[0] == 2
864
+
865
+ self.n_in = grid_in.shape[-1]
866
+ self.n_out = grid_out.shape[-1]
867
+
868
+ # compute the cutoff radius based on the bandlimit of the input field
869
+ # TODO: this heuristic is ad-hoc! Verify that we do the right one
870
+ if radius_cutoff is None:
871
+ radius_cutoff = (
872
+ 2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
873
+ )
874
+
875
+ if radius_cutoff <= 0.0:
876
+ raise ValueError("Error, radius_cutoff has to be positive.")
877
+
878
+ # integration weights
879
+ self.register_buffer("quad_weights", quad_weights, persistent=False)
880
+
881
+ # precompute the transposed tensor
882
+ idx, vals = _precompute_convolution_tensor_2d(
883
+ grid_out,
884
+ grid_in,
885
+ self.kernel_shape,
886
+ radius_cutoff=radius_cutoff,
887
+ periodic=periodic,
888
+ )
889
+
890
+ # to improve performance, we make psi a matrix by merging the first two dimensions
891
+ # This has to be accounted for in the forward pass
892
+ idx = torch.stack([idx[0] * self.n_out + idx[2], idx[1]], dim=0)
893
+
894
+ self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
895
+ self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
896
+
897
+ def get_psi(self):
898
+ psi = torch.sparse_coo_tensor(
899
+ self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
900
+ )
901
+ return psi
902
+
903
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
904
+ # pre-multiply x with the quadrature weights
905
+ x = self.quad_weights * x
906
+
907
+ psi = self.get_psi()
908
+
909
+ # extract shape
910
+ B, C, _ = x.shape
911
+
912
+ # bring into the right shape for the bmm and perform it
913
+ x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
914
+ x = torch.mm(psi, x)
915
+ x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
916
+ x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
917
+
918
+ # do weight multiplication
919
+ out = torch.einsum(
920
+ "bgckx,gock->bgox",
921
+ x,
922
+ self.weight.reshape(
923
+ self.groups, -1, self.weight.shape[1], self.weight.shape[2]
924
+ ),
925
+ )
926
+ out = out.reshape(out.shape[0], -1, out.shape[-1])
927
+
928
+ if self.bias is not None:
929
+ out = out + self.bias.reshape(1, -1, 1)
930
+
931
+ return out
932
+
933
+
934
+ class EquidistantDiscreteContinuousConv2d(DiscreteContinuousConv):
935
+ """
936
+ Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
937
+
938
+ [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
939
+ """
940
+
941
+ def __init__(
942
+ self,
943
+ in_channels: int,
944
+ out_channels: int,
945
+ kernel_shape: Union[int, List[int]],
946
+ in_shape: Tuple[int],
947
+ groups: Optional[int] = 1,
948
+ bias: Optional[bool] = True,
949
+ radius_cutoff: Optional[float] = None,
950
+ padding_mode: str = "circular",
951
+ use_min_dim: bool = True,
952
+ **kwargs,
953
+ ):
954
+ """
955
+ use_min_dim (bool, optional): Use the minimum dimension of the input
956
+ shape to compute the cutoff radius. Otherwise use the maximum
957
+ dimension. Defaults to True.
958
+ """
959
+ super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
960
+
961
+ self.padding_mode = padding_mode
962
+
963
+ # compute the cutoff radius based on the assumption that the grid is [-1, 1]^2
964
+ # this still assumes a quadratic domain
965
+ f = min if use_min_dim else max
966
+ if radius_cutoff is None:
967
+ radius_cutoff = 2 * (self.kernel_shape[0]) / float(f(*in_shape))
968
+ # 2 * 0.02 * 7 / 2 + 1 = 1.14
969
+ self.psi_local_size = math.floor(2 * radius_cutoff * f(*in_shape) / 2) + 1
970
+
971
+ # psi_local is essentially the support of the hat functions evaluated locally
972
+ x = torch.linspace(-radius_cutoff, radius_cutoff, self.psi_local_size)
973
+ x, y = torch.meshgrid(x, x)
974
+ grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
975
+ grid_out = torch.Tensor([[0.0], [0.0]])
976
+
977
+ idx, vals = _precompute_convolution_tensor_2d(
978
+ grid_in,
979
+ grid_out,
980
+ self.kernel_shape,
981
+ radius_cutoff=radius_cutoff,
982
+ periodic=False,
983
+ )
984
+
985
+ psi_loc = torch.zeros(
986
+ self.kernel_size, self.psi_local_size * self.psi_local_size
987
+ )
988
+ for ie in range(len(vals)):
989
+ f = idx[0, ie]
990
+ j = idx[2, ie]
991
+ v = vals[ie]
992
+ psi_loc[f, j] = v
993
+
994
+ # compute local version of the filter matrix
995
+ psi_loc = psi_loc.reshape(
996
+ self.kernel_size, self.psi_local_size, self.psi_local_size
997
+ )
998
+ # normalization by the quadrature weights
999
+ psi_loc = 4.0 * psi_loc / float(in_shape[0] * in_shape[1])
1000
+
1001
+ self.register_buffer("psi_loc", psi_loc, persistent=False)
1002
+
1003
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1004
+
1005
+ kernel = torch.einsum("kxy,ogk->ogxy", self.psi_loc, self.weight)
1006
+
1007
+ left_pad = self.psi_local_size // 2
1008
+ right_pad = (self.psi_local_size + 1) // 2 - 1
1009
+ x = F.pad(x, (left_pad, right_pad, left_pad, right_pad), mode=self.padding_mode)
1010
+ out = F.conv2d(
1011
+ x, kernel, self.bias, stride=1, dilation=1, padding=0, groups=self.groups
1012
+ )
1013
+
1014
+ return out
torch_harmonics_local/quadrature.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ import numpy as np
33
+
34
+
35
+ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
36
+
37
+ if (grid != "equidistant") and periodic:
38
+ raise ValueError(f"Periodic grid is only supported on equidistant grids.")
39
+
40
+ # compute coordinates
41
+ if grid == "equidistant":
42
+ xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic)
43
+ elif grid == "legendre-gauss":
44
+ xlg, wlg = legendre_gauss_weights(n, a=a, b=b)
45
+ elif grid == "lobatto":
46
+ xlg, wlg = lobatto_weights(n, a=a, b=b)
47
+ elif grid == "equiangular":
48
+ xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b)
49
+ else:
50
+ raise ValueError(f"Unknown grid type {grid}")
51
+
52
+ return xlg, wlg
53
+
54
+
55
+ def _precompute_latitudes(nlat, grid="equiangular"):
56
+ r"""
57
+ Convenience routine to precompute latitudes
58
+ """
59
+
60
+ # compute coordinates
61
+ xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
62
+
63
+ lats = np.flip(np.arccos(xlg)).copy()
64
+ wlg = np.flip(wlg).copy()
65
+
66
+ return lats, wlg
67
+
68
+
69
+ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
70
+ r"""
71
+ Helper routine which returns equidistant nodes with trapezoidal weights
72
+ on the interval [a, b]
73
+ """
74
+
75
+ xlg = np.linspace(a, b, n)
76
+ wlg = (b - a) / (n - 1) * np.ones(n)
77
+
78
+ if not periodic:
79
+ wlg[0] *= 0.5
80
+ wlg[-1] *= 0.5
81
+
82
+ return xlg, wlg
83
+
84
+
85
+ def legendre_gauss_weights(n, a=-1.0, b=1.0):
86
+ r"""
87
+ Helper routine which returns the Legendre-Gauss nodes and weights
88
+ on the interval [a, b]
89
+ """
90
+
91
+ xlg, wlg = np.polynomial.legendre.leggauss(n)
92
+ xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
93
+ wlg = wlg * (b - a) * 0.5
94
+
95
+ return xlg, wlg
96
+
97
+
98
+ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
99
+ r"""
100
+ Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
101
+ on the interval [a, b]
102
+ """
103
+
104
+ wlg = np.zeros((n,))
105
+ tlg = np.zeros((n,))
106
+ tmp = np.zeros((n,))
107
+
108
+ # Vandermonde Matrix
109
+ vdm = np.zeros((n, n))
110
+
111
+ # initialize Chebyshev nodes as first guess
112
+ for i in range(n):
113
+ tlg[i] = -np.cos(np.pi * i / (n - 1))
114
+
115
+ tmp = 2.0
116
+
117
+ for i in range(maxiter):
118
+ tmp = tlg
119
+
120
+ vdm[:, 0] = 1.0
121
+ vdm[:, 1] = tlg
122
+
123
+ for k in range(2, n):
124
+ vdm[:, k] = (
125
+ (2 * k - 1) * tlg * vdm[:, k - 1] - (k - 1) * vdm[:, k - 2]
126
+ ) / k
127
+
128
+ tlg = tmp - (tlg * vdm[:, n - 1] - vdm[:, n - 2]) / (n * vdm[:, n - 1])
129
+
130
+ if max(abs(tlg - tmp).flatten()) < tol:
131
+ break
132
+
133
+ wlg = 2.0 / ((n * (n - 1)) * (vdm[:, n - 1] ** 2))
134
+
135
+ # rescale
136
+ tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5
137
+ wlg = wlg * (b - a) * 0.5
138
+
139
+ return tlg, wlg
140
+
141
+
142
+ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
143
+ r"""
144
+ Computation of the Clenshaw-Curtis quadrature nodes and weights.
145
+ This implementation follows
146
+
147
+ [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
148
+ """
149
+
150
+ assert n > 1
151
+
152
+ tcc = np.cos(np.linspace(np.pi, 0, n))
153
+
154
+ if n == 2:
155
+ wcc = np.array([1.0, 1.0])
156
+ else:
157
+
158
+ n1 = n - 1
159
+ N = np.arange(1, n1, 2)
160
+ l = len(N)
161
+ m = n1 - l
162
+
163
+ v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
164
+ v = 0 - v[:-1] - v[-1:0:-1]
165
+
166
+ g0 = -np.ones(n1)
167
+ g0[l] = g0[l] + n1
168
+ g0[m] = g0[m] + n1
169
+ g = g0 / (n1**2 - 1 + (n1 % 2))
170
+ wcc = np.fft.ifft(v + g).real
171
+ wcc = np.concatenate((wcc, wcc[:1]))
172
+
173
+ # rescale
174
+ tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
175
+ wcc = wcc * (b - a) * 0.5
176
+
177
+ return tcc, wcc
178
+
179
+
180
+ def fejer2_weights(n, a=-1.0, b=1.0):
181
+ r"""
182
+ Computation of the Fejer quadrature nodes and weights.
183
+ This implementation follows
184
+
185
+ [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
186
+ """
187
+
188
+ assert n > 2
189
+
190
+ tcc = np.cos(np.linspace(np.pi, 0, n))
191
+
192
+ n1 = n - 1
193
+ N = np.arange(1, n1, 2)
194
+ l = len(N)
195
+ m = n1 - l
196
+
197
+ v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
198
+ v = 0 - v[:-1] - v[-1:0:-1]
199
+
200
+ wcc = np.fft.ifft(v).real
201
+ wcc = np.concatenate((wcc, wcc[:1]))
202
+
203
+ # rescale
204
+ tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
205
+ wcc = wcc * (b - a) * 0.5
206
+
207
+ return tcc, wcc
type_utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def tuple_type(strings):
2
+ strings = strings.replace("(", "").replace(")", "").replace(" ", "")
3
+ mapped_int = map(int, strings.split(","))
4
+ return tuple(mapped_int)