Spaces:
Runtime error
Runtime error
samaonline
commited on
Commit
·
1b34a12
0
Parent(s):
init
Browse files- .gitattributes +61 -0
- .gitignore +124 -0
- .gradio/certificate.pem +31 -0
- README.md +33 -0
- app.py +264 -0
- environment.yml +158 -0
- fastmri/__init__.py +20 -0
- fastmri/coil_combine.py +67 -0
- fastmri/datasets.py +583 -0
- fastmri/evaluate.py +174 -0
- fastmri/fftc.py +203 -0
- fastmri/losses.py +91 -0
- fastmri/math_utils.py +121 -0
- fastmri/poisson_cache/poisson_16x.npy +3 -0
- fastmri/poisson_cache/poisson_2x.npy +3 -0
- fastmri/poisson_cache/poisson_32x.npy +3 -0
- fastmri/poisson_cache/poisson_4x.npy +3 -0
- fastmri/poisson_cache/poisson_6x.npy +3 -0
- fastmri/poisson_cache/poisson_8x.npy +3 -0
- fastmri/subsample.py +818 -0
- fastmri/transforms.py +974 -0
- models/lightning/mri_module.py +402 -0
- models/lightning/no_shared_module.py +274 -0
- models/lightning/no_varnet_module.py +299 -0
- models/lightning/no_varnet_nokno_module.py +294 -0
- models/lightning/varnet_module.py +224 -0
- models/no_shared.py +467 -0
- models/no_varnet.py +598 -0
- models/no_varnet_nokno.py +581 -0
- models/temp/no_repeatk.py +562 -0
- models/temp/no_repeatk_module.py +303 -0
- models/udno.py +369 -0
- models/unet.py +209 -0
- models/varnet.py +416 -0
- pyproject.toml +46 -0
- pytest.ini +4 -0
- setup_config.py +21 -0
- torch_harmonics_local/__init__.py +0 -0
- torch_harmonics_local/_disco_convolution.py +502 -0
- torch_harmonics_local/convolution.py +1014 -0
- torch_harmonics_local/quadrature.py +207 -0
- type_utils.py +4 -0
.gitattributes
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mds filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
# Audio files - uncompressed
|
| 39 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
# Audio files - compressed
|
| 43 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
# Image files - uncompressed
|
| 49 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
# Image files - compressed
|
| 54 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
# Video files - compressed
|
| 58 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
dataset/kspace/data.mdb filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
dataset/rss/data.mdb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Config files
|
| 2 |
+
fastmri.yaml
|
| 3 |
+
|
| 4 |
+
# Python specific
|
| 5 |
+
__pycache__/
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*.so
|
| 9 |
+
*.egg-info/
|
| 10 |
+
*.pyo
|
| 11 |
+
*.pyd
|
| 12 |
+
|
| 13 |
+
# Virtual environments
|
| 14 |
+
.env
|
| 15 |
+
.venv
|
| 16 |
+
env/
|
| 17 |
+
venv/
|
| 18 |
+
ENV/
|
| 19 |
+
env.bak/
|
| 20 |
+
venv.bak/
|
| 21 |
+
|
| 22 |
+
# Code in development
|
| 23 |
+
ignore_**.py
|
| 24 |
+
|
| 25 |
+
# Hidden / ignore folders
|
| 26 |
+
hidden/
|
| 27 |
+
ignore/
|
| 28 |
+
hidden_**
|
| 29 |
+
|
| 30 |
+
# Jupyter Notebook Checkpoints
|
| 31 |
+
.ipynb_checkpoints
|
| 32 |
+
|
| 33 |
+
# Data files
|
| 34 |
+
data/
|
| 35 |
+
datasets/
|
| 36 |
+
dataset
|
| 37 |
+
*.csv
|
| 38 |
+
*.tsv
|
| 39 |
+
*.h5
|
| 40 |
+
*.json
|
| 41 |
+
*.xml
|
| 42 |
+
*.parquet
|
| 43 |
+
*.pkl
|
| 44 |
+
|
| 45 |
+
# Model files
|
| 46 |
+
*.ckpt
|
| 47 |
+
*.h5
|
| 48 |
+
*.tflite
|
| 49 |
+
*.onnx
|
| 50 |
+
*.pb
|
| 51 |
+
*.pth
|
| 52 |
+
*.pt
|
| 53 |
+
*.joblib
|
| 54 |
+
*.pkl
|
| 55 |
+
|
| 56 |
+
# Logs and outputs
|
| 57 |
+
logs/
|
| 58 |
+
wandb/
|
| 59 |
+
*.log
|
| 60 |
+
*.out
|
| 61 |
+
*.txt
|
| 62 |
+
*.csv
|
| 63 |
+
|
| 64 |
+
# Test dir
|
| 65 |
+
!tests/**/*.txt
|
| 66 |
+
!tests/datasets
|
| 67 |
+
|
| 68 |
+
# Results
|
| 69 |
+
results/
|
| 70 |
+
output/
|
| 71 |
+
runs/
|
| 72 |
+
outfig/
|
| 73 |
+
figs/*.png
|
| 74 |
+
|
| 75 |
+
# SLURM
|
| 76 |
+
slurm/
|
| 77 |
+
|
| 78 |
+
# Ignore files related to experiments
|
| 79 |
+
experiments/
|
| 80 |
+
|
| 81 |
+
# Temporary files
|
| 82 |
+
*.tmp
|
| 83 |
+
*.temp
|
| 84 |
+
*.swp
|
| 85 |
+
*.swo
|
| 86 |
+
|
| 87 |
+
# VS Code specific
|
| 88 |
+
.vscode/
|
| 89 |
+
*.code-workspace
|
| 90 |
+
|
| 91 |
+
# System files
|
| 92 |
+
.DS_Store
|
| 93 |
+
Thumbs.db
|
| 94 |
+
|
| 95 |
+
# Environment files
|
| 96 |
+
*.env
|
| 97 |
+
|
| 98 |
+
# Ignore files from data processing tools
|
| 99 |
+
*.dvc
|
| 100 |
+
.dvc/
|
| 101 |
+
|
| 102 |
+
# PyTorch Lightning Logs
|
| 103 |
+
lightning_logs/
|
| 104 |
+
|
| 105 |
+
# Ignore files generated by package managers
|
| 106 |
+
Pipfile
|
| 107 |
+
Pipfile.lock
|
| 108 |
+
poetry.lock
|
| 109 |
+
|
| 110 |
+
# TensorBoard logs
|
| 111 |
+
logs/
|
| 112 |
+
events.out.tfevents.*
|
| 113 |
+
|
| 114 |
+
# Checkpoints and weights
|
| 115 |
+
checkpoints/
|
| 116 |
+
weights/
|
| 117 |
+
|
| 118 |
+
# Large file extensions
|
| 119 |
+
*.tar.gz
|
| 120 |
+
*.zip
|
| 121 |
+
*.tar
|
| 122 |
+
*.gz
|
| 123 |
+
|
| 124 |
+
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- medical
|
| 4 |
+
- mri
|
| 5 |
+
- neuraloperator
|
| 6 |
+
- fastmri
|
| 7 |
+
pretty_name: fastMRI Tiny
|
| 8 |
+
---
|
| 9 |
+
# A Unified Model for Compressed Sensing MRI Across Undersampling Patterns
|
| 10 |
+
|
| 11 |
+
> [**A Unified Model for Compressed Sensing MRI Across Undersampling Patterns**](https://arxiv.org/abs/2410.16290)
|
| 12 |
+
> Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar
|
| 13 |
+
> *Paper at [CVPR 2025](https://cvpr.thecvf.com/Conferences/2025/AcceptedPapers)*
|
| 14 |
+
|
| 15 |
+
This is a tiny subset of 230 fastMRI samples, used in the demo for the above [paper](https://huggingface.co/armeet/nomri) at CVPR 2025!
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
## Citation
|
| 19 |
+
|
| 20 |
+
If you found our work helpful or used any of our models (UDNO), please cite the following:
|
| 21 |
+
```bibtex
|
| 22 |
+
@inproceedings{jatyani2025nomri,
|
| 23 |
+
author = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar},
|
| 24 |
+
title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns},
|
| 25 |
+
booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings},
|
| 26 |
+
abbr = {CVPR},
|
| 27 |
+
year = {2025}
|
| 28 |
+
}
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
https://arxiv.org/abs/2410.16290
|
app.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# import spaces
|
| 9 |
+
# from huggingface_hub import hf_hub_download
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
+
|
| 13 |
+
# Set the working directory to the root directory
|
| 14 |
+
# root_dir = os.path.abspath("..")
|
| 15 |
+
# os.chdir(root_dir)
|
| 16 |
+
# sys.path.insert(0, root_dir)
|
| 17 |
+
|
| 18 |
+
# download dataset & weights
|
| 19 |
+
snapshot_download(repo_id="armeet/fastmri-tiny", repo_type="dataset", local_dir=".")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
device = "cuda"
|
| 23 |
+
# dataset_path = "/global/homes/p/peterwg/pscratch/datasets/mri_knee_dummy"
|
| 24 |
+
dataset_path = "dataset"
|
| 25 |
+
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
from torch.nn import functional as F
|
| 31 |
+
|
| 32 |
+
import fastmri
|
| 33 |
+
from fastmri.datasets import SliceDatasetLMDB, SliceSample
|
| 34 |
+
from fastmri.subsample import create_mask_for_mask_type
|
| 35 |
+
from models.lightning.no_varnet_module import NOVarnetModule
|
| 36 |
+
from models.lightning.varnet_module import VarNetModule
|
| 37 |
+
|
| 38 |
+
acceleration_to_fractions = {
|
| 39 |
+
1: 1,
|
| 40 |
+
2: 0.16,
|
| 41 |
+
4: 0.08,
|
| 42 |
+
6: 0.06,
|
| 43 |
+
8: 0.04,
|
| 44 |
+
16: 0.02,
|
| 45 |
+
32: 0.01,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_mask_fn(center_fraction, acceleration):
|
| 50 |
+
mask_fn = create_mask_for_mask_type(
|
| 51 |
+
"equispaced_fraction",
|
| 52 |
+
[center_fraction],
|
| 53 |
+
[acceleration],
|
| 54 |
+
)
|
| 55 |
+
return mask_fn
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
mask_4x = create_mask_fn(acceleration_to_fractions[4], 4)
|
| 59 |
+
mask_6x = create_mask_fn(acceleration_to_fractions[6], 6)
|
| 60 |
+
mask_8x = create_mask_fn(acceleration_to_fractions[8], 8)
|
| 61 |
+
mask_16x = create_mask_fn(acceleration_to_fractions[16], 16)
|
| 62 |
+
|
| 63 |
+
val_dataset_4x = SliceDatasetLMDB(
|
| 64 |
+
"knee",
|
| 65 |
+
partition="val",
|
| 66 |
+
mask_fns=[mask_4x],
|
| 67 |
+
complex=False,
|
| 68 |
+
root=dataset_path,
|
| 69 |
+
crop_shape=(320, 320),
|
| 70 |
+
coils=15,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
val_dataset_6x = SliceDatasetLMDB(
|
| 74 |
+
"knee",
|
| 75 |
+
partition="val",
|
| 76 |
+
mask_fns=[mask_6x],
|
| 77 |
+
complex=False,
|
| 78 |
+
root=dataset_path,
|
| 79 |
+
crop_shape=(320, 320),
|
| 80 |
+
coils=15,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
val_dataset_8x = SliceDatasetLMDB(
|
| 84 |
+
"knee",
|
| 85 |
+
partition="val",
|
| 86 |
+
mask_fns=[mask_8x],
|
| 87 |
+
complex=False,
|
| 88 |
+
root=dataset_path,
|
| 89 |
+
crop_shape=(320, 320),
|
| 90 |
+
coils=15,
|
| 91 |
+
)
|
| 92 |
+
val_dataset_16x = SliceDatasetLMDB(
|
| 93 |
+
"knee",
|
| 94 |
+
partition="val",
|
| 95 |
+
mask_fns=[mask_16x],
|
| 96 |
+
complex=False,
|
| 97 |
+
root=dataset_path,
|
| 98 |
+
crop_shape=(320, 320),
|
| 99 |
+
coils=15,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
vn = VarNetModule.load_from_checkpoint(
|
| 103 |
+
"vn.ckpt",
|
| 104 |
+
)
|
| 105 |
+
no = NOVarnetModule.load_from_checkpoint(
|
| 106 |
+
"no.ckpt",
|
| 107 |
+
)
|
| 108 |
+
no.eval()
|
| 109 |
+
vn.eval()
|
| 110 |
+
|
| 111 |
+
bright_samples = [42, 69, 80, 137, 139, 226, 229]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def v(x):
|
| 115 |
+
return x.detach().cpu().numpy().squeeze()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def viz(x, cmap="gray", vmin=0, vmax=1):
|
| 119 |
+
processed_data = v(x)
|
| 120 |
+
fig, ax = plt.subplots()
|
| 121 |
+
ax.imshow(processed_data, cmap=cmap, vmin=vmin, vmax=vmax)
|
| 122 |
+
ax.axis("off") # Turn off axes
|
| 123 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # Adjust margins
|
| 124 |
+
buf = io.BytesIO()
|
| 125 |
+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 126 |
+
buf.seek(0) # Rewind the buffer to the beginning
|
| 127 |
+
plt.show()
|
| 128 |
+
try:
|
| 129 |
+
img = Image.open(buf)
|
| 130 |
+
img_array = np.array(img)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"Error converting image buffer to NumPy array: {e}")
|
| 133 |
+
img_array = None
|
| 134 |
+
finally:
|
| 135 |
+
plt.close(fig)
|
| 136 |
+
buf.close()
|
| 137 |
+
return img_array
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def forward(model, idx, rate):
|
| 141 |
+
if rate == 4:
|
| 142 |
+
dataset = val_dataset_4x
|
| 143 |
+
elif rate == 6:
|
| 144 |
+
dataset = val_dataset_6x
|
| 145 |
+
elif rate == 8:
|
| 146 |
+
dataset = val_dataset_8x
|
| 147 |
+
elif rate == 16:
|
| 148 |
+
dataset = val_dataset_16x
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError("Invalid rate")
|
| 151 |
+
|
| 152 |
+
sample = dataset[idx]
|
| 153 |
+
mask, k, target = (
|
| 154 |
+
sample.mask.to(device),
|
| 155 |
+
sample.masked_kspace.to(device),
|
| 156 |
+
sample.target.to(device),
|
| 157 |
+
)
|
| 158 |
+
pred = model(k.unsqueeze(0), mask.unsqueeze(0), None)
|
| 159 |
+
|
| 160 |
+
return mask, k, target, pred[0]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def update_interface(sample_id, sample_rate):
|
| 164 |
+
n = [None] * 6
|
| 165 |
+
if sample_id is None or sample_rate is None or sample_id not in bright_samples:
|
| 166 |
+
return n
|
| 167 |
+
|
| 168 |
+
mask, k, target, pred_vn = forward(vn, sample_id, sample_rate)
|
| 169 |
+
_, _, _, pred_no = forward(no, sample_id, sample_rate)
|
| 170 |
+
|
| 171 |
+
k = viz(mask[0, :, :, 0], cmap="gray", vmin=0, vmax=1)
|
| 172 |
+
target_res = viz(target, cmap="gray", vmin=None, vmax=None)
|
| 173 |
+
|
| 174 |
+
pred_no_res = viz(pred_no, cmap="gray", vmin=None, vmax=None)
|
| 175 |
+
pred_vn_res = viz(pred_vn, cmap="gray", vmin=None, vmax=None)
|
| 176 |
+
|
| 177 |
+
diff_no_res = viz(torch.abs(pred_no - target), cmap=None, vmin=None, vmax=None)
|
| 178 |
+
diff_vn_res = viz(torch.abs(pred_vn - target), cmap=None, vmin=None, vmax=None)
|
| 179 |
+
|
| 180 |
+
return k, target_res, pred_no_res, pred_vn_res, diff_no_res, diff_vn_res
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
with gr.Blocks(theme=gr.themes.Monochrome(), fill_width=True) as demo:
|
| 184 |
+
gr.Markdown(
|
| 185 |
+
"# A Unified Model for Compressed Sensing MRI Across Undersampling Patterns [CPVR 2025]"
|
| 186 |
+
)
|
| 187 |
+
gr.Markdown("""
|
| 188 |
+
> Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar
|
| 189 |
+
""")
|
| 190 |
+
gr.Markdown(
|
| 191 |
+
"[](https://arxiv.org/abs/2410.16290)"
|
| 192 |
+
)
|
| 193 |
+
gr.Markdown(
|
| 194 |
+
"[](https://armeet.ca/nomri)"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
gr.Markdown(
|
| 198 |
+
"This demo showcases the performance of our unified model for compressed sensing MRI across different acceleration rates."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
with gr.Row():
|
| 202 |
+
dropdown_sample = gr.Dropdown(
|
| 203 |
+
choices=bright_samples,
|
| 204 |
+
label="Select a Sample",
|
| 205 |
+
info="Choose one of the available samples.",
|
| 206 |
+
filterable=False,
|
| 207 |
+
value=229,
|
| 208 |
+
)
|
| 209 |
+
with gr.Row():
|
| 210 |
+
dropdown_rate = gr.Radio(
|
| 211 |
+
choices=[16, 8, 6, 4],
|
| 212 |
+
value=16,
|
| 213 |
+
label="Select an Acceleration Rate",
|
| 214 |
+
info="Ex: 4x means the model is trained to reconstruct from 4x undersampled k-space data",
|
| 215 |
+
# filterable=False,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
with gr.Row():
|
| 219 |
+
with gr.Column():
|
| 220 |
+
gr.Label("Undersampling Mask")
|
| 221 |
+
k = gr.Image(label=None, interactive=False)
|
| 222 |
+
with gr.Column():
|
| 223 |
+
gr.Label("Ground Truth")
|
| 224 |
+
target = gr.Image(label=None, interactive=False)
|
| 225 |
+
with gr.Column():
|
| 226 |
+
gr.Label("NO (ours)")
|
| 227 |
+
pred_no = gr.Image(label="Reconstruction", interactive=False)
|
| 228 |
+
with gr.Column():
|
| 229 |
+
gr.Label("VN (existing)")
|
| 230 |
+
pred_vn = gr.Image(label="Reconstruction", interactive=False)
|
| 231 |
+
with gr.Row():
|
| 232 |
+
with gr.Column():
|
| 233 |
+
pass
|
| 234 |
+
with gr.Column():
|
| 235 |
+
pass
|
| 236 |
+
with gr.Column():
|
| 237 |
+
diff_no = gr.Image(label="| Recon - GT |", interactive=False)
|
| 238 |
+
with gr.Column():
|
| 239 |
+
diff_vn = gr.Image(label="| Recon - GT |", interactive=False)
|
| 240 |
+
|
| 241 |
+
gr.Markdown("""
|
| 242 |
+
```
|
| 243 |
+
@inproceedings{jatyani2025nomri,
|
| 244 |
+
author = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar},
|
| 245 |
+
title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns},
|
| 246 |
+
booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings},
|
| 247 |
+
abbr = {CVPR},
|
| 248 |
+
year = {2025}
|
| 249 |
+
}
|
| 250 |
+
```
|
| 251 |
+
""")
|
| 252 |
+
|
| 253 |
+
update_inputs = [dropdown_sample, dropdown_rate]
|
| 254 |
+
update_outputs = [k, target, pred_no, pred_vn, diff_no, diff_vn]
|
| 255 |
+
|
| 256 |
+
dropdown_sample.change(
|
| 257 |
+
fn=update_interface, inputs=update_inputs, outputs=update_outputs
|
| 258 |
+
)
|
| 259 |
+
dropdown_rate.change(
|
| 260 |
+
fn=update_interface, inputs=update_inputs, outputs=update_outputs
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
demo.launch(share=True)
|
environment.yml
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: no-med
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 7 |
+
- _openmp_mutex=4.5=2_gnu
|
| 8 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
| 9 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 10 |
+
- ca-certificates=2024.8.30=hbcca054_0
|
| 11 |
+
- comm=0.2.2=pyhd8ed1ab_0
|
| 12 |
+
- debugpy=1.6.7=py312h6a678d5_0
|
| 13 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
| 14 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_0
|
| 15 |
+
- executing=2.1.0=pyhd8ed1ab_0
|
| 16 |
+
- expat=2.6.3=h6a678d5_0
|
| 17 |
+
- importlib-metadata=8.5.0=pyha770c72_0
|
| 18 |
+
- ipykernel=6.29.5=pyh3099207_0
|
| 19 |
+
- ipython=8.27.0=pyh707e725_0
|
| 20 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
| 21 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_0
|
| 22 |
+
- jupyter_core=5.7.2=py312h06a4308_0
|
| 23 |
+
- krb5=1.21.3=h143b758_0
|
| 24 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 25 |
+
- libedit=3.1.20230828=h5eee18b_0
|
| 26 |
+
- libffi=3.4.4=h6a678d5_1
|
| 27 |
+
- libgcc=14.1.0=h77fa898_1
|
| 28 |
+
- libgcc-ng=14.1.0=h69a702a_1
|
| 29 |
+
- libgomp=14.1.0=h77fa898_1
|
| 30 |
+
- libsodium=1.0.20=h4ab18f5_0
|
| 31 |
+
- libstdcxx=14.1.0=hc0a3c3a_1
|
| 32 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 33 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 34 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
| 35 |
+
- ncurses=6.4=h6a678d5_0
|
| 36 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
| 37 |
+
- openssl=3.3.2=hb9d3cd8_0
|
| 38 |
+
- packaging=24.1=pyhd8ed1ab_0
|
| 39 |
+
- parso=0.8.4=pyhd8ed1ab_0
|
| 40 |
+
- pexpect=4.9.0=pyhd8ed1ab_0
|
| 41 |
+
- pickleshare=0.7.5=py_1003
|
| 42 |
+
- pip=24.2=py312h06a4308_0
|
| 43 |
+
- prompt-toolkit=3.0.47=pyha770c72_0
|
| 44 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
| 45 |
+
- pure_eval=0.2.3=pyhd8ed1ab_0
|
| 46 |
+
- pygments=2.18.0=pyhd8ed1ab_0
|
| 47 |
+
- python=3.12.4=h5148396_1
|
| 48 |
+
- pyzmq=25.1.2=py312h6a678d5_0
|
| 49 |
+
- readline=8.2=h5eee18b_0
|
| 50 |
+
- setuptools=72.1.0=py312h06a4308_0
|
| 51 |
+
- six=1.16.0=pyh6c4a22f_0
|
| 52 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 53 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
| 54 |
+
- tk=8.6.14=h39e8969_0
|
| 55 |
+
- tornado=6.4.1=py312h5eee18b_0
|
| 56 |
+
- traitlets=5.14.3=pyhd8ed1ab_0
|
| 57 |
+
- typing_extensions=4.12.2=pyha770c72_0
|
| 58 |
+
- wcwidth=0.2.13=pyhd8ed1ab_0
|
| 59 |
+
- wheel=0.43.0=py312h06a4308_0
|
| 60 |
+
- xz=5.4.6=h5eee18b_1
|
| 61 |
+
- zeromq=4.3.5=ha4adb4c_5
|
| 62 |
+
- zipp=3.20.2=pyhd8ed1ab_0
|
| 63 |
+
- zlib=1.2.13=h5eee18b_1
|
| 64 |
+
- pip:
|
| 65 |
+
- aiohappyeyeballs==2.4.0
|
| 66 |
+
- aiohttp==3.10.5
|
| 67 |
+
- aiosignal==1.3.1
|
| 68 |
+
- antlr4-python3-runtime==4.9.3
|
| 69 |
+
- attrs==24.2.0
|
| 70 |
+
- black==24.10.0
|
| 71 |
+
- certifi==2024.8.30
|
| 72 |
+
- charset-normalizer==3.3.2
|
| 73 |
+
- click==8.1.7
|
| 74 |
+
- cloudpickle==3.0.0
|
| 75 |
+
- contourpy==1.3.0
|
| 76 |
+
- cycler==0.12.1
|
| 77 |
+
- docker-pycreds==0.4.0
|
| 78 |
+
- filelock==3.16.0
|
| 79 |
+
- fonttools==4.53.1
|
| 80 |
+
- frozenlist==1.4.1
|
| 81 |
+
- fsspec==2024.9.0
|
| 82 |
+
- gitdb==4.0.11
|
| 83 |
+
- gitpython==3.1.43
|
| 84 |
+
- h5py==3.11.0
|
| 85 |
+
- hydra-core==1.3.2
|
| 86 |
+
- hydra-submitit-launcher==1.2.0
|
| 87 |
+
- idna==3.8
|
| 88 |
+
- imageio==2.35.1
|
| 89 |
+
- iniconfig==2.0.0
|
| 90 |
+
- isort==5.13.2
|
| 91 |
+
- jinja2==3.1.4
|
| 92 |
+
- joblib==1.4.2
|
| 93 |
+
- kiwisolver==1.4.7
|
| 94 |
+
- lazy-loader==0.4
|
| 95 |
+
- lightning==2.4.0
|
| 96 |
+
- lightning-utilities==0.11.7
|
| 97 |
+
- llvmlite==0.43.0
|
| 98 |
+
- lmdb==1.5.1
|
| 99 |
+
- markupsafe==2.1.5
|
| 100 |
+
- matplotlib==3.9.2
|
| 101 |
+
- mpmath==1.3.0
|
| 102 |
+
- multidict==6.0.5
|
| 103 |
+
- mypy-extensions==1.0.0
|
| 104 |
+
- networkx==3.3
|
| 105 |
+
- no-med==0.0.0
|
| 106 |
+
- numba==0.60.0
|
| 107 |
+
- numpy==2.0.2
|
| 108 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 109 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 110 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 111 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 112 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 113 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 114 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 115 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 116 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 117 |
+
- nvidia-nccl-cu12==2.20.5
|
| 118 |
+
- nvidia-nvjitlink-cu12==12.6.68
|
| 119 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 120 |
+
- omegaconf==2.3.0
|
| 121 |
+
- opencv-python==4.10.0.84
|
| 122 |
+
- pandas==2.2.2
|
| 123 |
+
- pathspec==0.12.1
|
| 124 |
+
- pillow==10.4.0
|
| 125 |
+
- platformdirs==4.3.2
|
| 126 |
+
- pluggy==1.5.0
|
| 127 |
+
- protobuf==5.28.0
|
| 128 |
+
- psutil==6.0.0
|
| 129 |
+
- pyparsing==3.1.4
|
| 130 |
+
- pytest==8.3.3
|
| 131 |
+
- python-dateutil==2.9.0.post0
|
| 132 |
+
- pytorch-lightning==2.4.0
|
| 133 |
+
- pytz==2024.1
|
| 134 |
+
- pywavelets==1.7.0
|
| 135 |
+
- pyyaml==6.0.2
|
| 136 |
+
- requests==2.32.3
|
| 137 |
+
- runstats==2.0.0
|
| 138 |
+
- scikit-image==0.24.0
|
| 139 |
+
- scipy==1.14.1
|
| 140 |
+
- sentry-sdk==2.13.0
|
| 141 |
+
- setproctitle==1.3.3
|
| 142 |
+
- sigpy==0.1.26
|
| 143 |
+
- smmap==5.0.1
|
| 144 |
+
- submitit==1.5.1
|
| 145 |
+
- sympy==1.13.2
|
| 146 |
+
- tabulate==0.9.0
|
| 147 |
+
- tifffile==2024.8.30
|
| 148 |
+
- toolz==1.0.0
|
| 149 |
+
- torch==2.4.1
|
| 150 |
+
- torchmetrics==1.4.1
|
| 151 |
+
- torchvision==0.19.1
|
| 152 |
+
- tqdm==4.66.5
|
| 153 |
+
- triton==3.0.0
|
| 154 |
+
- tzdata==2024.1
|
| 155 |
+
- urllib3==2.2.2
|
| 156 |
+
- wandb==0.17.9
|
| 157 |
+
- yarl==1.11.0
|
| 158 |
+
prefix: /global/homes/p/peterwg/local/miniconda3/envs/no-med
|
fastmri/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .coil_combine import rss, rss_complex, mvue
|
| 9 |
+
from .fftc import fft2c_new as fft2c
|
| 10 |
+
from .fftc import fftshift
|
| 11 |
+
from .fftc import ifft2c_new as ifft2c
|
| 12 |
+
from .fftc import ifftshift, roll
|
| 13 |
+
from .losses import SSIMLoss
|
| 14 |
+
from .math_utils import (
|
| 15 |
+
complex_abs,
|
| 16 |
+
complex_abs_sq,
|
| 17 |
+
complex_conj,
|
| 18 |
+
complex_mul,
|
| 19 |
+
tensor_to_complex_np,
|
| 20 |
+
)
|
fastmri/coil_combine.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import fastmri
|
| 11 |
+
import sigpy as sp
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Compute the Root Sum of Squares (RSS).
|
| 18 |
+
|
| 19 |
+
The RSS is computed assuming that `dim` is the coil dimension.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
data : torch.Tensor
|
| 24 |
+
The input tensor.
|
| 25 |
+
dim : int, optional
|
| 26 |
+
The dimension along which to apply the RSS transform (default is 0).
|
| 27 |
+
|
| 28 |
+
Returns
|
| 29 |
+
-------
|
| 30 |
+
torch.Tensor
|
| 31 |
+
The computed RSS value.
|
| 32 |
+
"""
|
| 33 |
+
return torch.sqrt((data**2).sum(dim))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def mvue(spatial_pred, sens_maps, dim: int = 0) -> torch.Tensor:
|
| 37 |
+
spatial_pred = torch.view_as_complex(spatial_pred)
|
| 38 |
+
sens_maps = torch.view_as_complex(sens_maps)
|
| 39 |
+
|
| 40 |
+
numerator = torch.sum(spatial_pred * torch.conj(sens_maps), dim=dim)
|
| 41 |
+
denominator = torch.sqrt(
|
| 42 |
+
torch.sum(torch.square(torch.abs(sens_maps)), dim=dim)
|
| 43 |
+
)
|
| 44 |
+
res = numerator / denominator
|
| 45 |
+
res = torch.abs(res)
|
| 46 |
+
return res
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Compute the Root Sum of Squares (RSS) for complex inputs.
|
| 52 |
+
|
| 53 |
+
The RSS is computed assuming that `dim` is the coil dimension.
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
data : torch.Tensor
|
| 58 |
+
The input tensor containing complex values.
|
| 59 |
+
dim : int, optional
|
| 60 |
+
The dimension along which to apply the RSS transform (default is 0).
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
torch.Tensor
|
| 65 |
+
The computed RSS value for complex inputs.
|
| 66 |
+
"""
|
| 67 |
+
return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim))
|
fastmri/datasets.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import xml.etree.ElementTree as etree
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import (
|
| 5 |
+
Any,
|
| 6 |
+
Callable,
|
| 7 |
+
Dict,
|
| 8 |
+
List,
|
| 9 |
+
Literal,
|
| 10 |
+
NamedTuple,
|
| 11 |
+
Optional,
|
| 12 |
+
Sequence,
|
| 13 |
+
Tuple,
|
| 14 |
+
Union,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
import h5py
|
| 18 |
+
import lmdb
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import yaml
|
| 22 |
+
import sigpy as sp
|
| 23 |
+
import pandas as pd
|
| 24 |
+
|
| 25 |
+
import fastmri
|
| 26 |
+
import fastmri.transforms as T
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RawSample(NamedTuple):
|
| 30 |
+
fname: Path
|
| 31 |
+
slice_num: int
|
| 32 |
+
metadata: Dict[str, Any]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SliceSample(NamedTuple):
|
| 36 |
+
masked_kspace: torch.Tensor
|
| 37 |
+
mask: torch.Tensor
|
| 38 |
+
num_low_frequencies: int
|
| 39 |
+
target: torch.Tensor
|
| 40 |
+
max_value: float
|
| 41 |
+
# attrs: Dict[str, Any]
|
| 42 |
+
fname: str
|
| 43 |
+
slice_num: int
|
| 44 |
+
|
| 45 |
+
class SliceSampleMVUE(NamedTuple):
|
| 46 |
+
masked_kspace: torch.Tensor
|
| 47 |
+
mask: torch.Tensor
|
| 48 |
+
num_low_frequencies: int
|
| 49 |
+
target: torch.Tensor
|
| 50 |
+
rss: torch.Tensor
|
| 51 |
+
max_value: float
|
| 52 |
+
# attrs: Dict[str, Any]
|
| 53 |
+
fname: str
|
| 54 |
+
slice_num: int
|
| 55 |
+
|
| 56 |
+
def et_query(
|
| 57 |
+
root: etree.Element,
|
| 58 |
+
qlist: Sequence[str],
|
| 59 |
+
namespace: str = "http://www.ismrm.org/ISMRMRD",
|
| 60 |
+
) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Query an XML document using ElementTree.
|
| 63 |
+
|
| 64 |
+
This function allows querying an XML document by specifying a root and a list of nested queries.
|
| 65 |
+
It supports optional XML namespaces.
|
| 66 |
+
|
| 67 |
+
Parameters
|
| 68 |
+
----------
|
| 69 |
+
root : ElementTree.Element
|
| 70 |
+
The root element of the XML to search through.
|
| 71 |
+
qlist : list of str
|
| 72 |
+
A list of strings for nested searches, e.g., ["Encoding", "matrixSize"].
|
| 73 |
+
namespace : str, optional
|
| 74 |
+
An optional XML namespace to prepend to the query (default is None).
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
str
|
| 79 |
+
The retrieved data as a string.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
s = "."
|
| 83 |
+
prefix = "ismrmrd_namespace"
|
| 84 |
+
|
| 85 |
+
ns = {prefix: namespace}
|
| 86 |
+
|
| 87 |
+
for el in qlist:
|
| 88 |
+
s = s + f"//{prefix}:{el}"
|
| 89 |
+
|
| 90 |
+
value = root.find(s, ns)
|
| 91 |
+
if value is None:
|
| 92 |
+
raise RuntimeError("Element not found")
|
| 93 |
+
|
| 94 |
+
return str(value.text)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SliceDataset(torch.utils.data.Dataset):
|
| 98 |
+
"""
|
| 99 |
+
A simplified PyTorch Dataset that provides access to multicoil MR image
|
| 100 |
+
slices from the fastMRI dataset.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
# root: Optional[Path | str],
|
| 106 |
+
body_part: Literal["knee", "brain"],
|
| 107 |
+
partition: Literal["train", "val", "test"],
|
| 108 |
+
mask_fns: Optional[List[Callable]] = None,
|
| 109 |
+
sample_rate: float = 1.0,
|
| 110 |
+
complex: bool = False,
|
| 111 |
+
crop_shape: Tuple[int, int] = (320, 320),
|
| 112 |
+
slug: str = "",
|
| 113 |
+
contrast: Optional[Literal["T1", "T2"]] = None,
|
| 114 |
+
coils: Optional[int] = None,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Initializes the fastMRI multi-coil challenge dataset.
|
| 118 |
+
|
| 119 |
+
Samples are individual 2D slices taken from k-space volume data.
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
body_part : {'knee', 'brain'}
|
| 124 |
+
The body part to analyze.
|
| 125 |
+
partition : {'train', 'val', 'test'}
|
| 126 |
+
The data partition type.
|
| 127 |
+
mask_fns : list of callable, optional
|
| 128 |
+
A list of masking functions to apply to samples.
|
| 129 |
+
If multiple are given, a mask is randomly chosen for each sample.
|
| 130 |
+
sample_rate : float, optional
|
| 131 |
+
Fraction of data to sample, by default 1.0.
|
| 132 |
+
complex : bool, optional
|
| 133 |
+
Whether the $k$-space data should return complex-valued, by default False.
|
| 134 |
+
If True, kspace values will be complex.
|
| 135 |
+
If False, kspace values will be real (shape, 2).
|
| 136 |
+
crop_shape : tuple of two ints, optional
|
| 137 |
+
The shape to center crop the k-space data, by default (320, 320).
|
| 138 |
+
slug : string
|
| 139 |
+
dataset slug name
|
| 140 |
+
contrast : {'T1', 'T2'}
|
| 141 |
+
If partition is brain, the contrast of images to use.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
with open("fastmri.yaml", "r") as file:
|
| 145 |
+
config = yaml.safe_load(file)
|
| 146 |
+
self.contrast = contrast
|
| 147 |
+
self.slug = slug
|
| 148 |
+
self.partition = partition
|
| 149 |
+
self.body_part = body_part
|
| 150 |
+
self.root = (
|
| 151 |
+
Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}"
|
| 152 |
+
)
|
| 153 |
+
self.mask_fns = mask_fns
|
| 154 |
+
self.sample_rate = sample_rate
|
| 155 |
+
self.raw_samples: List[RawSample] = self._load_samples()
|
| 156 |
+
self.complex = complex
|
| 157 |
+
self.crop_shape = crop_shape
|
| 158 |
+
self.coils = coils
|
| 159 |
+
|
| 160 |
+
def _load_samples(self):
|
| 161 |
+
# Gather all files in the root directory
|
| 162 |
+
if self.body_part == "brain" and self.contrast:
|
| 163 |
+
files = list(self.root.glob(f"*{self.contrast}*.h5"))
|
| 164 |
+
else:
|
| 165 |
+
files = list(self.root.glob("*.h5"))
|
| 166 |
+
raw_samples = []
|
| 167 |
+
|
| 168 |
+
# Load and process metadata from each file
|
| 169 |
+
for fname in sorted(files):
|
| 170 |
+
with h5py.File(fname, "r") as hf:
|
| 171 |
+
metadata, num_slices = self._retrieve_metadata(fname)
|
| 172 |
+
|
| 173 |
+
# Collect samples for each slice, discard first c slices, and last c slices
|
| 174 |
+
c = 6
|
| 175 |
+
for slice_num in range(num_slices):
|
| 176 |
+
if c <= slice_num <= num_slices - c - 1:
|
| 177 |
+
raw_samples.append(
|
| 178 |
+
RawSample(fname, slice_num, metadata)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Subsample if desired
|
| 182 |
+
if self.sample_rate < 1.0:
|
| 183 |
+
raw_samples = random.sample(
|
| 184 |
+
raw_samples, int(len(raw_samples) * self.sample_rate)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return raw_samples
|
| 188 |
+
|
| 189 |
+
def _retrieve_metadata(self, fname):
|
| 190 |
+
with h5py.File(fname, "r") as hf:
|
| 191 |
+
et_root = etree.fromstring(hf["ismrmrd_header"][()])
|
| 192 |
+
|
| 193 |
+
enc = ["encoding", "encodedSpace", "matrixSize"]
|
| 194 |
+
enc_size = (
|
| 195 |
+
int(et_query(et_root, enc + ["x"])),
|
| 196 |
+
int(et_query(et_root, enc + ["y"])),
|
| 197 |
+
int(et_query(et_root, enc + ["z"])),
|
| 198 |
+
)
|
| 199 |
+
rec = ["encoding", "reconSpace", "matrixSize"]
|
| 200 |
+
recon_size = (
|
| 201 |
+
int(et_query(et_root, rec + ["x"])),
|
| 202 |
+
int(et_query(et_root, rec + ["y"])),
|
| 203 |
+
int(et_query(et_root, rec + ["z"])),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
|
| 207 |
+
enc_limits_center = int(et_query(et_root, lims + ["center"]))
|
| 208 |
+
enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1
|
| 209 |
+
|
| 210 |
+
padding_left = enc_size[1] // 2 - enc_limits_center
|
| 211 |
+
padding_right = padding_left + enc_limits_max
|
| 212 |
+
|
| 213 |
+
num_slices = hf["kspace"].shape[0]
|
| 214 |
+
|
| 215 |
+
metadata = {
|
| 216 |
+
"padding_left": padding_left,
|
| 217 |
+
"padding_right": padding_right,
|
| 218 |
+
"encoding_size": enc_size,
|
| 219 |
+
"recon_size": recon_size,
|
| 220 |
+
**hf.attrs,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return metadata, num_slices
|
| 224 |
+
|
| 225 |
+
def __len__(self):
|
| 226 |
+
return len(self.raw_samples)
|
| 227 |
+
|
| 228 |
+
def __getitem__(self, idx) -> SliceSample:
|
| 229 |
+
try:
|
| 230 |
+
raw_sample: RawSample = self.raw_samples[idx]
|
| 231 |
+
fname, slice_num, metadata = raw_sample
|
| 232 |
+
|
| 233 |
+
# load kspace and target
|
| 234 |
+
with h5py.File(fname, "r") as hf:
|
| 235 |
+
kspace = torch.tensor(hf["kspace"][()][slice_num])
|
| 236 |
+
if not self.complex:
|
| 237 |
+
kspace = torch.view_as_real(kspace)
|
| 238 |
+
if self.coils:
|
| 239 |
+
if kspace.shape[0] < self.coils:
|
| 240 |
+
return None
|
| 241 |
+
kspace = kspace[: self.coils, :, :, :]
|
| 242 |
+
target_key = (
|
| 243 |
+
"reconstruction_rss"
|
| 244 |
+
if self.partition in ["train", "val"]
|
| 245 |
+
else "reconstruction_esc"
|
| 246 |
+
)
|
| 247 |
+
target = hf.get(target_key, None)
|
| 248 |
+
if target is not None:
|
| 249 |
+
target = torch.tensor(target[()][slice_num])
|
| 250 |
+
if self.body_part == "brain":
|
| 251 |
+
target = T.center_crop(target, self.crop_shape)
|
| 252 |
+
|
| 253 |
+
# center crop to enable collating for batching
|
| 254 |
+
if self.complex:
|
| 255 |
+
# if complex, crop across dims: -2 and -1 (last 2)
|
| 256 |
+
raise NotImplementedError("Not implemented for complex native")
|
| 257 |
+
else:
|
| 258 |
+
# crop in image space, to not lose high-frequency information
|
| 259 |
+
image = fastmri.ifft2c(kspace)
|
| 260 |
+
image_cropped = T.complex_center_crop(image, self.crop_shape)
|
| 261 |
+
kspace = fastmri.fft2c(image_cropped)
|
| 262 |
+
|
| 263 |
+
# apply transform mask if there is one
|
| 264 |
+
if self.mask_fns:
|
| 265 |
+
# choose a random mask
|
| 266 |
+
mask_fn = random.choice(self.mask_fns)
|
| 267 |
+
kspace, mask, num_low_frequencies = T.apply_mask(
|
| 268 |
+
kspace,
|
| 269 |
+
mask_fn,
|
| 270 |
+
# seed=seed,
|
| 271 |
+
)
|
| 272 |
+
mask = mask.bool()
|
| 273 |
+
else:
|
| 274 |
+
mask = torch.ones_like(kspace, dtype=torch.bool)
|
| 275 |
+
num_low_frequencies = 0
|
| 276 |
+
sample = SliceSample(
|
| 277 |
+
kspace,
|
| 278 |
+
mask,
|
| 279 |
+
num_low_frequencies,
|
| 280 |
+
target,
|
| 281 |
+
metadata["max"],
|
| 282 |
+
fname.name,
|
| 283 |
+
slice_num,
|
| 284 |
+
)
|
| 285 |
+
return sample
|
| 286 |
+
except:
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class SliceDatasetLMDB(torch.utils.data.Dataset):
|
| 291 |
+
"""
|
| 292 |
+
A simplified PyTorch Dataset that provides access to multicoil MR image
|
| 293 |
+
slices from the fastMRI dataset. Loads from LMDB saved samples.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
body_part: Literal["knee", "brain"],
|
| 299 |
+
partition: Literal["train", "val", "test"],
|
| 300 |
+
root: Optional[Path | str] = None,
|
| 301 |
+
mask_fns: Optional[List[Callable]] = None,
|
| 302 |
+
sample_rate: float = 1.0,
|
| 303 |
+
complex: bool = False,
|
| 304 |
+
crop_shape: Tuple[int, int] = (320, 320),
|
| 305 |
+
slug: str = "",
|
| 306 |
+
coils: int = 15,
|
| 307 |
+
):
|
| 308 |
+
"""
|
| 309 |
+
Initializes the fastMRI multi-coil challenge dataset.
|
| 310 |
+
|
| 311 |
+
Samples are individual 2D slices taken from k-space volume data.
|
| 312 |
+
|
| 313 |
+
Parameters
|
| 314 |
+
----------
|
| 315 |
+
body_part : {'knee', 'brain'}
|
| 316 |
+
The body part to analyze.
|
| 317 |
+
root : Path or str, optional
|
| 318 |
+
Root to lmdb dataset. If not provided, the root is automatically
|
| 319 |
+
loaded directly from fastmri.yaml config
|
| 320 |
+
partition : {'train', 'val', 'test'}
|
| 321 |
+
The data partition type.
|
| 322 |
+
mask_fns : list of callable, optional
|
| 323 |
+
A list of masking functions to apply to samples.
|
| 324 |
+
If multiple are given, a mask is randomly chosen for each sample.
|
| 325 |
+
sample_rate : float, optional
|
| 326 |
+
Fraction of data to sample, by default 1.0.
|
| 327 |
+
complex : bool, optional
|
| 328 |
+
Whether the $k$-space data should return complex-valued, by default False.
|
| 329 |
+
If True, kspace values will be complex.
|
| 330 |
+
If False, kspace values will be real (shape, 2).
|
| 331 |
+
crop_shape : tuple of two ints, optional
|
| 332 |
+
The shape to center crop the k-space data, by default (320, 320).
|
| 333 |
+
slug : string
|
| 334 |
+
dataset slug name
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
# set attrs
|
| 338 |
+
self.coils = coils
|
| 339 |
+
self.slug = slug
|
| 340 |
+
self.partition = partition
|
| 341 |
+
self.mask_fns = mask_fns
|
| 342 |
+
self.sample_rate = sample_rate
|
| 343 |
+
self.complex = complex
|
| 344 |
+
self.crop_shape = crop_shape
|
| 345 |
+
|
| 346 |
+
# load lmdb info
|
| 347 |
+
if root:
|
| 348 |
+
if isinstance(root, str):
|
| 349 |
+
root = Path(root)
|
| 350 |
+
assert root.exists(), "Provided root doesn't exist."
|
| 351 |
+
self.root = root
|
| 352 |
+
else:
|
| 353 |
+
with open("fastmri.yaml", "r") as file:
|
| 354 |
+
config = yaml.safe_load(file)
|
| 355 |
+
self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"])
|
| 356 |
+
self.meta = np.load(self.root / "meta.npy")
|
| 357 |
+
self.kspace_env = lmdb.open(
|
| 358 |
+
str(self.root / "kspace"),
|
| 359 |
+
readonly=True,
|
| 360 |
+
lock=False,
|
| 361 |
+
create=False,
|
| 362 |
+
)
|
| 363 |
+
self.kspace_txn = self.kspace_env.begin(write=False)
|
| 364 |
+
self.rss_env = lmdb.open(
|
| 365 |
+
str(self.root / "rss"),
|
| 366 |
+
readonly=True,
|
| 367 |
+
lock=False,
|
| 368 |
+
create=False,
|
| 369 |
+
)
|
| 370 |
+
self.rss_txn = self.rss_env.begin(write=False)
|
| 371 |
+
self.length = self.kspace_txn.stat()["entries"]
|
| 372 |
+
|
| 373 |
+
def __len__(self):
|
| 374 |
+
return int(self.sample_rate * self.length)
|
| 375 |
+
|
| 376 |
+
def __getitem__(self, idx) -> SliceSample:
|
| 377 |
+
idx_key = str(idx).encode("utf-8")
|
| 378 |
+
|
| 379 |
+
# load sample data
|
| 380 |
+
kspace = torch.from_numpy(
|
| 381 |
+
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
|
| 382 |
+
.reshape(self.coils, 320, 320, 2)
|
| 383 |
+
.copy()
|
| 384 |
+
)
|
| 385 |
+
rss = torch.from_numpy(
|
| 386 |
+
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
|
| 387 |
+
.reshape(320, 320)
|
| 388 |
+
.copy()
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# crop in image space, to not lose high-frequency information
|
| 392 |
+
if self.crop_shape and self.crop_shape != (320, 320):
|
| 393 |
+
image = fastmri.ifft2c(kspace)
|
| 394 |
+
image_cropped = T.complex_center_crop(image, self.crop_shape)
|
| 395 |
+
kspace = fastmri.fft2c(image_cropped)
|
| 396 |
+
rss = T.center_crop(rss, self.crop_shape)
|
| 397 |
+
|
| 398 |
+
# load and apply mask
|
| 399 |
+
if self.mask_fns:
|
| 400 |
+
# choose a random mask
|
| 401 |
+
mask_fn = random.choice(self.mask_fns)
|
| 402 |
+
kspace, mask, num_low_frequencies = T.apply_mask(
|
| 403 |
+
kspace,
|
| 404 |
+
mask_fn, # type: ignore
|
| 405 |
+
)
|
| 406 |
+
mask = mask.bool()
|
| 407 |
+
else:
|
| 408 |
+
mask = torch.ones_like(kspace, dtype=torch.bool)
|
| 409 |
+
num_low_frequencies = 0
|
| 410 |
+
|
| 411 |
+
# load metadata
|
| 412 |
+
fname, slice_num, max_value = self.meta[idx]
|
| 413 |
+
fname = str(fname)
|
| 414 |
+
slice_num = int(slice_num)
|
| 415 |
+
max_value = float(max_value)
|
| 416 |
+
|
| 417 |
+
return SliceSample(
|
| 418 |
+
kspace,
|
| 419 |
+
mask,
|
| 420 |
+
num_low_frequencies,
|
| 421 |
+
rss,
|
| 422 |
+
max_value,
|
| 423 |
+
fname,
|
| 424 |
+
slice_num,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class SliceDatasetLMDB_MVUE(torch.utils.data.Dataset):
|
| 429 |
+
"""
|
| 430 |
+
Loads from LMDB brain saved samples.
|
| 431 |
+
|
| 432 |
+
Modified to have MVUE targets
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
root: Path | str,
|
| 438 |
+
mask_fns: Optional[List[Callable]] = None,
|
| 439 |
+
sample_rate: float = 1.0,
|
| 440 |
+
crop_shape: Tuple[int, int] = (320, 320),
|
| 441 |
+
slug: str = "",
|
| 442 |
+
coils: int = 15,
|
| 443 |
+
):
|
| 444 |
+
|
| 445 |
+
# set attrs
|
| 446 |
+
self.coils = coils
|
| 447 |
+
self.slug = slug
|
| 448 |
+
self.mask_fns = mask_fns
|
| 449 |
+
self.sample_rate = sample_rate
|
| 450 |
+
self.complex = complex
|
| 451 |
+
self.crop_shape = crop_shape
|
| 452 |
+
|
| 453 |
+
# load lmdb info
|
| 454 |
+
if isinstance(root, str):
|
| 455 |
+
root = Path(root)
|
| 456 |
+
assert root.exists(), "Provided root doesn't exist."
|
| 457 |
+
self.root = root
|
| 458 |
+
self.meta = np.load(self.root / "meta.npy")
|
| 459 |
+
self.mapping = pd.read_csv("brain_mvue_map.csv")
|
| 460 |
+
self.kspace_env = lmdb.open(
|
| 461 |
+
str(self.root / "kspace"),
|
| 462 |
+
readonly=True,
|
| 463 |
+
lock=False,
|
| 464 |
+
create=False,
|
| 465 |
+
)
|
| 466 |
+
self.kspace_txn = self.kspace_env.begin(write=False)
|
| 467 |
+
self.rss_env = lmdb.open(
|
| 468 |
+
str(self.root / "rss"),
|
| 469 |
+
readonly=True,
|
| 470 |
+
lock=False,
|
| 471 |
+
create=False,
|
| 472 |
+
)
|
| 473 |
+
self.rss_txn = self.rss_env.begin(write=False)
|
| 474 |
+
|
| 475 |
+
# ray mvue dataset
|
| 476 |
+
self.mvue_env = lmdb.open(
|
| 477 |
+
str("/pscratch/sd/p/peterwg/datasets/raytemp"),
|
| 478 |
+
readonly=True,
|
| 479 |
+
lock=False,
|
| 480 |
+
create=False,
|
| 481 |
+
)
|
| 482 |
+
self.mvue_txn = self.mvue_env.begin(write=False)
|
| 483 |
+
|
| 484 |
+
self.length = len(self.mapping)
|
| 485 |
+
# self.length = self.kspace_txn.stat()["entries"]
|
| 486 |
+
|
| 487 |
+
def __len__(self):
|
| 488 |
+
return int(self.sample_rate * self.length)
|
| 489 |
+
|
| 490 |
+
def __getitem__(self, idx) -> SliceSampleMVUE:
|
| 491 |
+
# ray's index: 0-n
|
| 492 |
+
ray_idx = idx
|
| 493 |
+
# my index: lookup(ray index)
|
| 494 |
+
idx = int(self.mapping.iloc[ray_idx].my_index)
|
| 495 |
+
ray_idx_key = str(ray_idx).encode("utf-8")
|
| 496 |
+
idx_key = str(idx).encode("utf-8")
|
| 497 |
+
|
| 498 |
+
# load sample data
|
| 499 |
+
kspace = torch.from_numpy(
|
| 500 |
+
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
|
| 501 |
+
.reshape(self.coils, 320, 320, 2)
|
| 502 |
+
.copy()
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# mvue_target = np.sum(
|
| 506 |
+
# sp.ifft(kspace, axes=(-1, -2)) * np.conj(s_maps), axis=1
|
| 507 |
+
# ) / np.sqrt(np.sum(np.square(np.abs(s_maps)), axis=1))
|
| 508 |
+
rss = torch.from_numpy(
|
| 509 |
+
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
|
| 510 |
+
.reshape(320, 320)
|
| 511 |
+
.copy()
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# load mvue from ray dataset
|
| 515 |
+
mvue = torch.from_numpy(
|
| 516 |
+
np.frombuffer(self.mvue_txn.get(ray_idx_key), dtype=np.complex64)
|
| 517 |
+
.reshape(320, 320)
|
| 518 |
+
.copy()
|
| 519 |
+
)
|
| 520 |
+
mvue = torch.abs(mvue)
|
| 521 |
+
|
| 522 |
+
# crop in image space, to not lose high-frequency information
|
| 523 |
+
if self.crop_shape and self.crop_shape != (320, 320):
|
| 524 |
+
image = fastmri.ifft2c(kspace)
|
| 525 |
+
image_cropped = T.complex_center_crop(image, self.crop_shape)
|
| 526 |
+
kspace = fastmri.fft2c(image_cropped)
|
| 527 |
+
rss = T.center_crop(rss, self.crop_shape)
|
| 528 |
+
|
| 529 |
+
# load and apply mask
|
| 530 |
+
if self.mask_fns:
|
| 531 |
+
# choose a random mask
|
| 532 |
+
mask_fn = random.choice(self.mask_fns)
|
| 533 |
+
kspace, mask, num_low_frequencies = T.apply_mask(
|
| 534 |
+
kspace,
|
| 535 |
+
mask_fn, # type: ignore
|
| 536 |
+
)
|
| 537 |
+
mask = mask.bool()
|
| 538 |
+
else:
|
| 539 |
+
mask = torch.ones_like(kspace, dtype=torch.bool)
|
| 540 |
+
num_low_frequencies = 0
|
| 541 |
+
|
| 542 |
+
# load metadata
|
| 543 |
+
fname, slice_num, max_value = self.meta[idx]
|
| 544 |
+
fname = str(fname)
|
| 545 |
+
slice_num = int(slice_num)
|
| 546 |
+
max_value = float(max_value)
|
| 547 |
+
|
| 548 |
+
return SliceSampleMVUE(
|
| 549 |
+
kspace,
|
| 550 |
+
mask,
|
| 551 |
+
num_low_frequencies,
|
| 552 |
+
mvue,
|
| 553 |
+
rss,
|
| 554 |
+
max_value,
|
| 555 |
+
fname,
|
| 556 |
+
slice_num,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
# d = SliceDatasetLMDB("knee", "val", None, 1, True, (320, 320), "testdataset")
|
| 561 |
+
# print(len(d))
|
| 562 |
+
# breakpoint()
|
| 563 |
+
|
| 564 |
+
# ds = SuperSliceDatasetLMDB(
|
| 565 |
+
# "brain", # body_part
|
| 566 |
+
# "val", # partition
|
| 567 |
+
# None, # root
|
| 568 |
+
# None, # mask_fns
|
| 569 |
+
# 1.0, # sample_rate
|
| 570 |
+
# True, # complex
|
| 571 |
+
# (320, 320), # crop_shape
|
| 572 |
+
# "test-superres", # slug
|
| 573 |
+
# coils=16, # coils
|
| 574 |
+
# )
|
| 575 |
+
# breakpoint()
|
| 576 |
+
|
| 577 |
+
# d = SliceDataset("brain", "train", None, contrast="T2")
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# # TESTING MVUE
|
| 581 |
+
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/mri_brain_train_lmdb", coils=16)
|
| 582 |
+
# x = d[0]
|
| 583 |
+
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/raytemp/", coils=16)
|
fastmri/evaluate.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import pathlib
|
| 10 |
+
from argparse import ArgumentParser
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import h5py
|
| 14 |
+
import numpy as np
|
| 15 |
+
from runstats import Statistics
|
| 16 |
+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
| 17 |
+
|
| 18 |
+
from fastmri import transforms
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray:
|
| 22 |
+
"""Compute Mean Squared Error (MSE)"""
|
| 23 |
+
return np.mean((gt - pred) ** 2)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray:
|
| 27 |
+
"""Compute Normalized Mean Squared Error (NMSE)"""
|
| 28 |
+
return np.array(np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def psnr(
|
| 32 |
+
gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None
|
| 33 |
+
) -> np.ndarray:
|
| 34 |
+
"""Compute Peak Signal to Noise Ratio metric (PSNR)"""
|
| 35 |
+
if maxval is None:
|
| 36 |
+
maxval = gt.max()
|
| 37 |
+
return peak_signal_noise_ratio(gt, pred, data_range=maxval)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def ssim(
|
| 41 |
+
gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None
|
| 42 |
+
) -> np.ndarray:
|
| 43 |
+
"""Compute Structural Similarity Index Metric (SSIM)"""
|
| 44 |
+
if not gt.ndim == 3:
|
| 45 |
+
raise ValueError("Unexpected number of dimensions in ground truth.")
|
| 46 |
+
if not gt.ndim == pred.ndim:
|
| 47 |
+
raise ValueError("Ground truth dimensions does not match pred.")
|
| 48 |
+
|
| 49 |
+
maxval = gt.max() if maxval is None else maxval
|
| 50 |
+
|
| 51 |
+
ssim = np.array([0])
|
| 52 |
+
for slice_num in range(gt.shape[0]):
|
| 53 |
+
ssim = ssim + structural_similarity(
|
| 54 |
+
gt[slice_num], pred[slice_num], data_range=maxval
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return ssim / gt.shape[0]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
METRIC_FUNCS = dict(
|
| 61 |
+
MSE=mse,
|
| 62 |
+
NMSE=nmse,
|
| 63 |
+
PSNR=psnr,
|
| 64 |
+
SSIM=ssim,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Metrics:
|
| 69 |
+
"""
|
| 70 |
+
Maintains running statistics for a given collection of metrics.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, metric_funcs):
|
| 74 |
+
"""
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
metric_funcs : dict
|
| 78 |
+
A dictionary where the keys are metric names (as strings) and the values
|
| 79 |
+
are Python functions for evaluating the corresponding metrics.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
self.metrics = {metric: Statistics() for metric in metric_funcs}
|
| 83 |
+
|
| 84 |
+
def push(self, target, recons):
|
| 85 |
+
for metric, func in METRIC_FUNCS.items():
|
| 86 |
+
self.metrics[metric].push(func(target, recons))
|
| 87 |
+
|
| 88 |
+
def means(self):
|
| 89 |
+
return {metric: stat.mean() for metric, stat in self.metrics.items()}
|
| 90 |
+
|
| 91 |
+
def stddevs(self):
|
| 92 |
+
return {metric: stat.stddev() for metric, stat in self.metrics.items()}
|
| 93 |
+
|
| 94 |
+
def __repr__(self):
|
| 95 |
+
means = self.means()
|
| 96 |
+
stddevs = self.stddevs()
|
| 97 |
+
metric_names = sorted(list(means))
|
| 98 |
+
return " ".join(
|
| 99 |
+
f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}"
|
| 100 |
+
for name in metric_names
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def evaluate(args, recons_key):
|
| 105 |
+
metrics = Metrics(METRIC_FUNCS)
|
| 106 |
+
|
| 107 |
+
for tgt_file in args.target_path.iterdir():
|
| 108 |
+
with h5py.File(tgt_file, "r") as target, h5py.File(
|
| 109 |
+
args.predictions_path / tgt_file.name, "r"
|
| 110 |
+
) as recons:
|
| 111 |
+
if args.acquisition and args.acquisition != target.attrs["acquisition"]:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
if args.acceleration and target.attrs["acceleration"] != args.acceleration:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
target = target[recons_key][()]
|
| 118 |
+
recons = recons["reconstruction"][()]
|
| 119 |
+
target = transforms.center_crop(
|
| 120 |
+
target, (target.shape[-1], target.shape[-1])
|
| 121 |
+
)
|
| 122 |
+
recons = transforms.center_crop(
|
| 123 |
+
recons, (target.shape[-1], target.shape[-1])
|
| 124 |
+
)
|
| 125 |
+
metrics.push(target, recons)
|
| 126 |
+
|
| 127 |
+
return metrics
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--target-path",
|
| 134 |
+
type=pathlib.Path,
|
| 135 |
+
required=True,
|
| 136 |
+
help="Path to the ground truth data",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--predictions-path",
|
| 140 |
+
type=pathlib.Path,
|
| 141 |
+
required=True,
|
| 142 |
+
help="Path to reconstructions",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--challenge",
|
| 146 |
+
choices=["singlecoil", "multicoil"],
|
| 147 |
+
required=True,
|
| 148 |
+
help="Which challenge",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument("--acceleration", type=int, default=None)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--acquisition",
|
| 153 |
+
choices=[
|
| 154 |
+
"CORPD_FBK",
|
| 155 |
+
"CORPDFS_FBK",
|
| 156 |
+
"AXT1",
|
| 157 |
+
"AXT1PRE",
|
| 158 |
+
"AXT1POST",
|
| 159 |
+
"AXT2",
|
| 160 |
+
"AXFLAIR",
|
| 161 |
+
],
|
| 162 |
+
default=None,
|
| 163 |
+
help=(
|
| 164 |
+
"If set, only volumes of the specified acquisition type are used "
|
| 165 |
+
"for evaluation. By default, all volumes are included."
|
| 166 |
+
),
|
| 167 |
+
)
|
| 168 |
+
args = parser.parse_args()
|
| 169 |
+
|
| 170 |
+
recons_key = (
|
| 171 |
+
"reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
|
| 172 |
+
)
|
| 173 |
+
metrics = evaluate(args, recons_key)
|
| 174 |
+
print(metrics)
|
fastmri/fftc.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.fft
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Apply a centered 2-dimensional Fast Fourier Transform (FFT).
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
data : torch.Tensor
|
| 21 |
+
Complex-valued input data containing at least 3 dimensions.
|
| 22 |
+
Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2.
|
| 23 |
+
All other dimensions are assumed to be batch dimensions.
|
| 24 |
+
norm : str
|
| 25 |
+
Normalization mode. Refer to `torch.fft.fft` for details on normalization options.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
torch.Tensor
|
| 30 |
+
The FFT of the input data.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
if not data.shape[-1] == 2:
|
| 34 |
+
raise ValueError("Tensor does not have separate complex dim.")
|
| 35 |
+
|
| 36 |
+
data = ifftshift(data, dim=[-3, -2])
|
| 37 |
+
data = torch.view_as_real(
|
| 38 |
+
torch.fft.fftn( # type: ignore
|
| 39 |
+
torch.view_as_complex(data), dim=(-2, -1), norm=norm
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
data = fftshift(data, dim=[-3, -2])
|
| 43 |
+
|
| 44 |
+
return data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Apply a centered 2-dimensional Inverse Fast Fourier Transform (IFFT).
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
data : torch.Tensor
|
| 54 |
+
Complex-valued input data containing at least 3 dimensions.
|
| 55 |
+
Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2.
|
| 56 |
+
All other dimensions are assumed to be batch dimensions.
|
| 57 |
+
norm : str
|
| 58 |
+
Normalization mode. Refer to `torch.fft.ifft` for details on normalization options.
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
torch.Tensor
|
| 63 |
+
The IFFT of the input data.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
if not data.shape[-1] == 2:
|
| 67 |
+
raise ValueError("Tensor does not have separate complex dim.")
|
| 68 |
+
|
| 69 |
+
data = ifftshift(data, dim=[-3, -2])
|
| 70 |
+
data = torch.view_as_real(
|
| 71 |
+
torch.fft.ifftn( # type: ignore
|
| 72 |
+
torch.view_as_complex(data), dim=(-2, -1), norm=norm
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
data = fftshift(data, dim=[-3, -2])
|
| 76 |
+
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Helper functions
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Roll a PyTorch tensor along a specified dimension.
|
| 86 |
+
|
| 87 |
+
This function is similar to `torch.roll` but operates on a single dimension.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
x : torch.Tensor
|
| 92 |
+
The input tensor to be rolled.
|
| 93 |
+
shift : int
|
| 94 |
+
Amount to roll.
|
| 95 |
+
dim : int
|
| 96 |
+
The dimension along which to roll the tensor.
|
| 97 |
+
|
| 98 |
+
Returns
|
| 99 |
+
-------
|
| 100 |
+
torch.Tensor
|
| 101 |
+
A tensor with the same shape as `x`, but rolled along the specified dimension.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
shift = shift % x.size(dim)
|
| 105 |
+
if shift == 0:
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
left = x.narrow(dim, 0, x.size(dim) - shift)
|
| 109 |
+
right = x.narrow(dim, x.size(dim) - shift, shift)
|
| 110 |
+
|
| 111 |
+
return torch.cat((right, left), dim=dim)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def roll(
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
shift: List[int],
|
| 117 |
+
dim: List[int],
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
"""
|
| 120 |
+
Similar to np.roll but applies to PyTorch Tensors.
|
| 121 |
+
|
| 122 |
+
Parameters
|
| 123 |
+
----------
|
| 124 |
+
x : torch.Tensor
|
| 125 |
+
A PyTorch tensor.
|
| 126 |
+
shift : int
|
| 127 |
+
Amount to roll.
|
| 128 |
+
dim : int
|
| 129 |
+
Which dimension to roll.
|
| 130 |
+
|
| 131 |
+
Returns
|
| 132 |
+
-------
|
| 133 |
+
torch.Tensor
|
| 134 |
+
Rolled version of x.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
if len(shift) != len(dim):
|
| 138 |
+
raise ValueError("len(shift) must match len(dim)")
|
| 139 |
+
|
| 140 |
+
for s, d in zip(shift, dim):
|
| 141 |
+
x = roll_one_dim(x, s, d)
|
| 142 |
+
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Similar to np.fft.fftshift but applies to PyTorch Tensors.
|
| 149 |
+
|
| 150 |
+
Parameters
|
| 151 |
+
----------
|
| 152 |
+
x : torch.Tensor
|
| 153 |
+
A PyTorch tensor.
|
| 154 |
+
dim : list of int, optional
|
| 155 |
+
Which dimension to apply fftshift. If None, the shift is applied to all dimensions (default is None).
|
| 156 |
+
|
| 157 |
+
Returns
|
| 158 |
+
-------
|
| 159 |
+
torch.Tensor
|
| 160 |
+
fftshifted version of x.
|
| 161 |
+
"""
|
| 162 |
+
if dim is None:
|
| 163 |
+
# this weird code is necessary for torch.jit.script typing
|
| 164 |
+
dim = [0] * (x.dim())
|
| 165 |
+
for i in range(1, x.dim()):
|
| 166 |
+
dim[i] = i
|
| 167 |
+
|
| 168 |
+
# also necessary for torch.jit.script
|
| 169 |
+
shift = [0] * len(dim)
|
| 170 |
+
for i, dim_num in enumerate(dim):
|
| 171 |
+
shift[i] = x.shape[dim_num] // 2
|
| 172 |
+
|
| 173 |
+
return roll(x, shift, dim)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
|
| 177 |
+
"""
|
| 178 |
+
Similar to np.fft.ifftshift but applies to PyTorch Tensors.
|
| 179 |
+
|
| 180 |
+
Parameters
|
| 181 |
+
----------
|
| 182 |
+
x : torch.Tensor
|
| 183 |
+
A PyTorch tensor.
|
| 184 |
+
dim : list of int, optional
|
| 185 |
+
Which dimension to apply ifftshift. If None, the shift is applied to all dimensions (default is None).
|
| 186 |
+
|
| 187 |
+
Returns
|
| 188 |
+
-------
|
| 189 |
+
torch.Tensor
|
| 190 |
+
ifftshifted version of x.
|
| 191 |
+
"""
|
| 192 |
+
if dim is None:
|
| 193 |
+
# this weird code is necessary for torch.jit.script typing
|
| 194 |
+
dim = [0] * (x.dim())
|
| 195 |
+
for i in range(1, x.dim()):
|
| 196 |
+
dim[i] = i
|
| 197 |
+
|
| 198 |
+
# also necessary for torch.jit.script
|
| 199 |
+
shift = [0] * len(dim)
|
| 200 |
+
for i, dim_num in enumerate(dim):
|
| 201 |
+
shift[i] = (x.shape[dim_num] + 1) // 2
|
| 202 |
+
|
| 203 |
+
return roll(x, shift, dim)
|
fastmri/losses.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SSIMLoss(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
SSIM loss module.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03):
|
| 19 |
+
"""
|
| 20 |
+
Initialize the Losses class.
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
win_size : int, optional
|
| 25 |
+
Window size for SSIM calculation.
|
| 26 |
+
k1 : float, optional
|
| 27 |
+
k1 parameter for SSIM calculation.
|
| 28 |
+
k2 : float, optional
|
| 29 |
+
k2 parameter for SSIM calculation.
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.win_size = win_size
|
| 33 |
+
self.k1, self.k2 = k1, k2
|
| 34 |
+
self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2)
|
| 35 |
+
NP = win_size**2
|
| 36 |
+
self.cov_norm = NP / (NP - 1)
|
| 37 |
+
|
| 38 |
+
def forward(
|
| 39 |
+
self,
|
| 40 |
+
X: torch.Tensor,
|
| 41 |
+
Y: torch.Tensor,
|
| 42 |
+
data_range: torch.Tensor,
|
| 43 |
+
reduced: bool = True,
|
| 44 |
+
):
|
| 45 |
+
assert isinstance(self.w, torch.Tensor)
|
| 46 |
+
|
| 47 |
+
data_range = data_range[:, None, None, None].to(X.device)
|
| 48 |
+
C1 = (self.k1 * data_range) ** 2
|
| 49 |
+
C2 = (self.k2 * data_range) ** 2
|
| 50 |
+
|
| 51 |
+
# Compute means
|
| 52 |
+
ux = F.conv2d(X, self.w)
|
| 53 |
+
uy = F.conv2d(Y, self.w)
|
| 54 |
+
|
| 55 |
+
# Compute variances
|
| 56 |
+
uxx = F.conv2d(X * X, self.w)
|
| 57 |
+
uyy = F.conv2d(Y * Y, self.w)
|
| 58 |
+
uxy = F.conv2d(X * Y, self.w)
|
| 59 |
+
|
| 60 |
+
# Compute covariances
|
| 61 |
+
vx = self.cov_norm * (uxx - ux * ux)
|
| 62 |
+
vy = self.cov_norm * (uyy - uy * uy)
|
| 63 |
+
vxy = self.cov_norm * (uxy - ux * uy)
|
| 64 |
+
|
| 65 |
+
# Compute SSIM components
|
| 66 |
+
A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2
|
| 67 |
+
B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2
|
| 68 |
+
D = B1 * B2
|
| 69 |
+
S = (A1 * A2) / D
|
| 70 |
+
|
| 71 |
+
if reduced:
|
| 72 |
+
return 1 - S.mean()
|
| 73 |
+
else:
|
| 74 |
+
return 1 - S
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
# Example usage
|
| 79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
|
| 81 |
+
# Create the SSIMLoss module and move it to the GPU
|
| 82 |
+
ssim_loss = SSIMLoss().to(device)
|
| 83 |
+
|
| 84 |
+
# Create example tensors and move them to the GPU
|
| 85 |
+
X = torch.randn(4, 1, 256, 256).to(device)
|
| 86 |
+
Y = torch.randn(4, 1, 256, 256).to(device)
|
| 87 |
+
data_range = torch.rand(4).to(device)
|
| 88 |
+
|
| 89 |
+
# Compute the loss
|
| 90 |
+
loss = ssim_loss(X, Y, data_range)
|
| 91 |
+
print(loss)
|
fastmri/math_utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
"""
|
| 14 |
+
Complex multiplication.
|
| 15 |
+
|
| 16 |
+
Multiplies two complex tensors assuming that they are both stored as
|
| 17 |
+
real arrays with the last dimension being the complex dimension.
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
x : torch.Tensor
|
| 22 |
+
A PyTorch tensor with the last dimension of size 2.
|
| 23 |
+
y : torch.Tensor
|
| 24 |
+
A PyTorch tensor with the last dimension of size 2.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
torch.Tensor
|
| 29 |
+
A PyTorch tensor with the last dimension of size 2, representing
|
| 30 |
+
the result of the complex multiplication.
|
| 31 |
+
"""
|
| 32 |
+
if not x.shape[-1] == y.shape[-1] == 2:
|
| 33 |
+
raise ValueError("Tensors do not have separate complex dim.")
|
| 34 |
+
|
| 35 |
+
re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
|
| 36 |
+
im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
|
| 37 |
+
|
| 38 |
+
return torch.stack((re, im), dim=-1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def complex_conj(x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Complex conjugate.
|
| 44 |
+
|
| 45 |
+
Applies the complex conjugate assuming that the input array has the
|
| 46 |
+
last dimension as the complex dimension.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
x : torch.Tensor
|
| 51 |
+
A PyTorch tensor with the last dimension of size 2.
|
| 52 |
+
|
| 53 |
+
Returns
|
| 54 |
+
-------
|
| 55 |
+
torch.Tensor
|
| 56 |
+
A PyTorch tensor with the last dimension of size 2, representing
|
| 57 |
+
the complex conjugate of the input tensor.
|
| 58 |
+
"""
|
| 59 |
+
if not x.shape[-1] == 2:
|
| 60 |
+
raise ValueError("Tensor does not have separate complex dim.")
|
| 61 |
+
|
| 62 |
+
return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def complex_abs(data: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Compute the absolute value of a complex-valued input tensor.
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
data : torch.Tensor
|
| 72 |
+
A complex-valued tensor, where the size of the final dimension
|
| 73 |
+
should be 2.
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
torch.Tensor
|
| 78 |
+
Absolute value of the input tensor.
|
| 79 |
+
"""
|
| 80 |
+
if not data.shape[-1] == 2:
|
| 81 |
+
raise ValueError("Tensor does not have separate complex dim.")
|
| 82 |
+
|
| 83 |
+
return (data**2).sum(dim=-1).sqrt()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def complex_abs_sq(data: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""
|
| 88 |
+
Compute the squared absolute value of a complex tensor.
|
| 89 |
+
|
| 90 |
+
Parameters
|
| 91 |
+
----------
|
| 92 |
+
data : torch.Tensor
|
| 93 |
+
A complex-valued tensor, where the size of the final dimension
|
| 94 |
+
should be 2.
|
| 95 |
+
|
| 96 |
+
Returns
|
| 97 |
+
-------
|
| 98 |
+
torch.Tensor
|
| 99 |
+
Squared absolute value of the input tensor.
|
| 100 |
+
"""
|
| 101 |
+
if not data.shape[-1] == 2:
|
| 102 |
+
raise ValueError("Tensor does not have separate complex dim.")
|
| 103 |
+
|
| 104 |
+
return (data**2).sum(dim=-1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
|
| 108 |
+
"""
|
| 109 |
+
Convert a complex PyTorch tensor to a NumPy array.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
data : torch.Tensor
|
| 114 |
+
Input data to be converted to a NumPy array.
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
np.ndarray
|
| 119 |
+
A complex NumPy array version of the input tensor.
|
| 120 |
+
"""
|
| 121 |
+
return torch.view_as_complex(data).numpy()
|
fastmri/poisson_cache/poisson_16x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22199ef1c9b045b6f747e57c4effa9a6667cfce864773ee035df8c3d2a28138f
|
| 3 |
+
size 819328
|
fastmri/poisson_cache/poisson_2x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d6908dbd90fda83085cc9e7d3ab35f7ed215ab3f735fd39023a091a9f1632df
|
| 3 |
+
size 819328
|
fastmri/poisson_cache/poisson_32x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e316852431f100b1c5c6749b672dfbf2dffc4f23ba4d118eddc64956b8c22f4
|
| 3 |
+
size 819328
|
fastmri/poisson_cache/poisson_4x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c9f0c9b2c3be534b7c94b8398c2aa66c8e634d343e9b45b842450137266cbc8
|
| 3 |
+
size 819328
|
fastmri/poisson_cache/poisson_6x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7eeed6d470af6ef2b7594da388ea952dad44e74a33080a1bf59faf9e8973ca8
|
| 3 |
+
size 819328
|
fastmri/poisson_cache/poisson_8x.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28ce774dc182798a3fd4358c70cb86c81f936f12336fee104c93e89765727462
|
| 3 |
+
size 819328
|
fastmri/subsample.py
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributions as D
|
| 14 |
+
from networkx import center
|
| 15 |
+
from sigpy.mri import poisson, radial, spiral
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MaskFunc:
|
| 19 |
+
"""
|
| 20 |
+
An object for GRAPPA-style sampling masks.
|
| 21 |
+
|
| 22 |
+
This crates a sampling mask that densely samples the center while
|
| 23 |
+
subsampling outer k-space regions based on the undersampling factor.
|
| 24 |
+
|
| 25 |
+
When called, ``MaskFunc`` uses internal functions create mask by 1)
|
| 26 |
+
creating a mask for the k-space center, 2) create a mask outside of the
|
| 27 |
+
k-space center, and 3) combining them into a total mask. The internals are
|
| 28 |
+
handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1)
|
| 29 |
+
and ``calculate_acceleration_mask`` for (2). The combination is executed
|
| 30 |
+
in the ``MaskFunc`` ``__call__`` function.
|
| 31 |
+
|
| 32 |
+
If you would like to implement a new mask, simply subclass ``MaskFunc``
|
| 33 |
+
and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc``
|
| 34 |
+
and ``EquispacedMaskFunc``.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
center_fractions: Sequence[float],
|
| 40 |
+
accelerations: Sequence[int],
|
| 41 |
+
allow_any_combination: bool = False,
|
| 42 |
+
seed: Optional[int] = None,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Args:
|
| 46 |
+
center_fractions: Fraction of low-frequency columns to be retained.
|
| 47 |
+
If multiple values are provided, then one of these numbers is
|
| 48 |
+
chosen uniformly each time.
|
| 49 |
+
accelerations: Amount of under-sampling. This should have the same
|
| 50 |
+
length as center_fractions. If multiple values are provided,
|
| 51 |
+
then one of these is chosen uniformly each time.
|
| 52 |
+
allow_any_combination: Whether to allow cross combinations of
|
| 53 |
+
elements from ``center_fractions`` and ``accelerations``.
|
| 54 |
+
seed: Seed for starting the internal random number generator of the
|
| 55 |
+
``MaskFunc``.
|
| 56 |
+
"""
|
| 57 |
+
if (
|
| 58 |
+
len(center_fractions) != len(accelerations)
|
| 59 |
+
and not allow_any_combination
|
| 60 |
+
):
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"Number of center fractions should match number of"
|
| 63 |
+
" accelerations if allow_any_combination is False."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.center_fractions = center_fractions
|
| 67 |
+
self.accelerations = accelerations
|
| 68 |
+
self.allow_any_combination = allow_any_combination
|
| 69 |
+
self.rng = np.random.RandomState(seed)
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self,
|
| 73 |
+
shape: Sequence[int],
|
| 74 |
+
offset: Optional[int] = None,
|
| 75 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 76 |
+
) -> Tuple[torch.Tensor, int]:
|
| 77 |
+
"""
|
| 78 |
+
Sample and return a k-space mask.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
shape: Shape of k-space.
|
| 82 |
+
offset: Offset from 0 to begin mask (for equispaced masks). If no
|
| 83 |
+
offset is given, then one is selected randomly.
|
| 84 |
+
seed: Seed for random number generator for reproducibility.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A 2-tuple containing 1) the k-space mask and 2) the number of
|
| 88 |
+
center frequency lines.
|
| 89 |
+
"""
|
| 90 |
+
if len(shape) < 3:
|
| 91 |
+
raise ValueError("Shape should have 3 or more dimensions")
|
| 92 |
+
|
| 93 |
+
center_mask, accel_mask, num_low_frequencies = self.sample_mask(
|
| 94 |
+
shape, offset
|
| 95 |
+
)
|
| 96 |
+
# combine masks together
|
| 97 |
+
return torch.max(center_mask, accel_mask), num_low_frequencies
|
| 98 |
+
|
| 99 |
+
def sample_mask(
|
| 100 |
+
self,
|
| 101 |
+
shape: Sequence[int],
|
| 102 |
+
offset: Optional[int],
|
| 103 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 104 |
+
"""
|
| 105 |
+
Sample a new k-space mask.
|
| 106 |
+
|
| 107 |
+
This function samples and returns two components of a k-space mask: 1)
|
| 108 |
+
the center mask (e.g., for sensitivity map calculation) and 2) the
|
| 109 |
+
acceleration mask (for the edge of k-space). Both of these masks, as
|
| 110 |
+
well as the integer of low frequency samples, are returned.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
shape: Shape of the k-space to subsample.
|
| 114 |
+
offset: Offset from 0 to begin mask (for equispaced masks).
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
A 3-tuple contaiing 1) the mask for the center of k-space, 2) the
|
| 118 |
+
mask for the high frequencies of k-space, and 3) the integer count
|
| 119 |
+
of low frequency samples.
|
| 120 |
+
"""
|
| 121 |
+
num_cols = shape[-2]
|
| 122 |
+
center_fraction, acceleration = self.choose_acceleration()
|
| 123 |
+
num_low_frequencies = round(num_cols * center_fraction)
|
| 124 |
+
center_mask = self.reshape_mask(
|
| 125 |
+
self.calculate_center_mask(shape, num_low_frequencies), shape
|
| 126 |
+
)
|
| 127 |
+
acceleration_mask = self.reshape_mask(
|
| 128 |
+
self.calculate_acceleration_mask(
|
| 129 |
+
num_cols, acceleration, offset, num_low_frequencies
|
| 130 |
+
),
|
| 131 |
+
shape,
|
| 132 |
+
)
|
| 133 |
+
return center_mask, acceleration_mask, num_low_frequencies
|
| 134 |
+
|
| 135 |
+
def reshape_mask(
|
| 136 |
+
self, mask: torch.Tensor, shape: Sequence[int]
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""Reshape mask to desired output shape."""
|
| 139 |
+
if len(mask.shape) == 1:
|
| 140 |
+
mask = torch.tensor(mask)
|
| 141 |
+
mask_num_freqs = len(mask)
|
| 142 |
+
mask = mask.reshape(1, 1, mask_num_freqs, 1)
|
| 143 |
+
mask = mask.expand(shape)
|
| 144 |
+
return mask.expand(shape)
|
| 145 |
+
|
| 146 |
+
def reshape_mask_old(
|
| 147 |
+
self, mask: np.ndarray, shape: Sequence[int]
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
"""Reshape mask to desired output shape."""
|
| 150 |
+
num_cols = shape[-2]
|
| 151 |
+
mask_shape = [1 for s in shape]
|
| 152 |
+
mask_shape[-2] = num_cols
|
| 153 |
+
|
| 154 |
+
return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
|
| 155 |
+
|
| 156 |
+
def calculate_acceleration_mask(
|
| 157 |
+
self,
|
| 158 |
+
num_cols: int,
|
| 159 |
+
acceleration: int,
|
| 160 |
+
offset: Optional[int],
|
| 161 |
+
num_low_frequencies: int,
|
| 162 |
+
) -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Produce mask for non-central acceleration lines.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
num_cols: Number of columns of k-space (2D subsampling).
|
| 168 |
+
acceleration: Desired acceleration rate.
|
| 169 |
+
offset: Offset from 0 to begin masking (for equispaced masks).
|
| 170 |
+
num_low_frequencies: Integer count of low-frequency lines sampled.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A mask for the high spatial frequencies of k-space.
|
| 174 |
+
"""
|
| 175 |
+
raise NotImplementedError
|
| 176 |
+
|
| 177 |
+
def calculate_center_mask(
|
| 178 |
+
self, shape: Sequence[int], num_low_freqs: int
|
| 179 |
+
) -> np.ndarray:
|
| 180 |
+
"""
|
| 181 |
+
Build center mask based on number of low frequencies.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
shape: Shape of k-space to mask.
|
| 185 |
+
num_low_freqs: Number of low-frequency lines to sample.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
A mask for hte low spatial frequencies of k-space.
|
| 189 |
+
"""
|
| 190 |
+
num_cols = shape[-2]
|
| 191 |
+
mask = np.zeros(num_cols, dtype=np.float32)
|
| 192 |
+
pad = (num_cols - num_low_freqs + 1) // 2
|
| 193 |
+
mask[pad : pad + num_low_freqs] = 1
|
| 194 |
+
assert mask.sum() == num_low_freqs
|
| 195 |
+
|
| 196 |
+
return mask
|
| 197 |
+
|
| 198 |
+
def choose_acceleration(self):
|
| 199 |
+
"""Choose acceleration based on class parameters."""
|
| 200 |
+
if self.allow_any_combination:
|
| 201 |
+
return self.rng.choice(self.center_fractions), self.rng.choice(
|
| 202 |
+
self.accelerations
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
choice = self.rng.randint(len(self.center_fractions))
|
| 206 |
+
return self.center_fractions[choice], self.accelerations[choice]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class RandomMaskFunc(MaskFunc):
|
| 210 |
+
"""
|
| 211 |
+
Creates a random sub-sampling mask of a given shape.
|
| 212 |
+
|
| 213 |
+
The mask selects a subset of columns from the input k-space data. If the
|
| 214 |
+
k-space data has N columns, the mask picks out:
|
| 215 |
+
1. N_low_freqs = (N * center_fraction) columns in the center
|
| 216 |
+
corresponding to low-frequencies.
|
| 217 |
+
2. The other columns are selected uniformly at random with a
|
| 218 |
+
probability equal to: prob = (N / acceleration - N_low_freqs) /
|
| 219 |
+
(N - N_low_freqs). This ensures that the expected number of columns
|
| 220 |
+
selected is equal to (N / acceleration).
|
| 221 |
+
|
| 222 |
+
It is possible to use multiple center_fractions and accelerations, in which
|
| 223 |
+
case one possible (center_fraction, acceleration) is chosen uniformly at
|
| 224 |
+
random each time the ``RandomMaskFunc`` object is called.
|
| 225 |
+
|
| 226 |
+
For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04],
|
| 227 |
+
then there is a 50% probability that 4-fold acceleration with 8% center
|
| 228 |
+
fraction is selected and a 50% probability that 8-fold acceleration with 4%
|
| 229 |
+
center fraction is selected.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def calculate_acceleration_mask(
|
| 233 |
+
self,
|
| 234 |
+
num_cols: int,
|
| 235 |
+
acceleration: int,
|
| 236 |
+
offset: Optional[int],
|
| 237 |
+
num_low_frequencies: int,
|
| 238 |
+
) -> np.ndarray:
|
| 239 |
+
prob = (num_cols / acceleration - num_low_frequencies) / (
|
| 240 |
+
num_cols - num_low_frequencies
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return self.rng.uniform(size=num_cols) < prob
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class EquiSpacedMaskFunc(MaskFunc):
|
| 247 |
+
"""
|
| 248 |
+
Sample data with equally-spaced k-space lines.
|
| 249 |
+
|
| 250 |
+
The lines are spaced exactly evenly, as is done in standard GRAPPA-style
|
| 251 |
+
acquisitions. This means that with a densely-sampled center,
|
| 252 |
+
``acceleration`` will be greater than the true acceleration rate.
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def calculate_acceleration_mask(
|
| 256 |
+
self,
|
| 257 |
+
num_cols: int,
|
| 258 |
+
acceleration: int,
|
| 259 |
+
offset: Optional[int],
|
| 260 |
+
num_low_frequencies: int,
|
| 261 |
+
) -> np.ndarray:
|
| 262 |
+
"""
|
| 263 |
+
Produce mask for non-central acceleration lines.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
num_cols: Number of columns of k-space (2D subsampling).
|
| 267 |
+
acceleration: Desired acceleration rate.
|
| 268 |
+
offset: Offset from 0 to begin masking. If no offset is specified,
|
| 269 |
+
then one is selected randomly.
|
| 270 |
+
num_low_frequencies: Not used.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
A mask for the high spatial frequencies of k-space.
|
| 274 |
+
"""
|
| 275 |
+
if offset is None:
|
| 276 |
+
offset = self.rng.randint(0, high=round(acceleration))
|
| 277 |
+
|
| 278 |
+
mask = np.zeros(num_cols, dtype=np.float32)
|
| 279 |
+
mask[offset::acceleration] = 1
|
| 280 |
+
|
| 281 |
+
return mask
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class EquispacedMaskFractionFunc(MaskFunc):
|
| 285 |
+
"""
|
| 286 |
+
Equispaced mask with approximate acceleration matching.
|
| 287 |
+
|
| 288 |
+
The mask selects a subset of columns from the input k-space data. If the
|
| 289 |
+
k-space data has N columns, the mask picks out:
|
| 290 |
+
1. N_low_freqs = (N * center_fraction) columns in the center
|
| 291 |
+
corresponding to low-frequencies.
|
| 292 |
+
2. The other columns are selected with equal spacing at a proportion
|
| 293 |
+
that reaches the desired acceleration rate taking into consideration
|
| 294 |
+
the number of low frequencies. This ensures that the expected number
|
| 295 |
+
of columns selected is equal to (N / acceleration)
|
| 296 |
+
|
| 297 |
+
It is possible to use multiple center_fractions and accelerations, in which
|
| 298 |
+
case one possible (center_fraction, acceleration) is chosen uniformly at
|
| 299 |
+
random each time the EquispacedMaskFunc object is called.
|
| 300 |
+
|
| 301 |
+
Note that this function may not give equispaced samples (documented in
|
| 302 |
+
https://github.com/facebookresearch/fastMRI/issues/54), which will require
|
| 303 |
+
modifications to standard GRAPPA approaches. Nonetheless, this aspect of
|
| 304 |
+
the function has been preserved to match the public multicoil data.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def calculate_acceleration_mask(
|
| 308 |
+
self,
|
| 309 |
+
num_cols: int,
|
| 310 |
+
acceleration: int,
|
| 311 |
+
offset: Optional[int],
|
| 312 |
+
num_low_frequencies: int,
|
| 313 |
+
) -> np.ndarray:
|
| 314 |
+
"""
|
| 315 |
+
Produce mask for non-central acceleration lines.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
num_cols: Number of columns of k-space (2D subsampling).
|
| 319 |
+
acceleration: Desired acceleration rate.
|
| 320 |
+
offset: Offset from 0 to begin masking. If no offset is specified,
|
| 321 |
+
then one is selected randomly.
|
| 322 |
+
num_low_frequencies: Number of low frequencies. Used to adjust mask
|
| 323 |
+
to exactly match the target acceleration.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
A mask for the high spatial frequencies of k-space.
|
| 327 |
+
"""
|
| 328 |
+
# determine acceleration rate by adjusting for the number of low frequencies
|
| 329 |
+
adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / (
|
| 330 |
+
num_low_frequencies * acceleration - num_cols
|
| 331 |
+
)
|
| 332 |
+
if offset is None:
|
| 333 |
+
offset = self.rng.randint(0, high=round(adjusted_accel))
|
| 334 |
+
|
| 335 |
+
mask = np.zeros(num_cols, dtype=np.float32)
|
| 336 |
+
accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)
|
| 337 |
+
accel_samples = np.around(accel_samples).astype(np.uint)
|
| 338 |
+
mask[accel_samples] = 1.0
|
| 339 |
+
|
| 340 |
+
return mask
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class MagicMaskFunc(MaskFunc):
|
| 344 |
+
"""
|
| 345 |
+
Masking function for exploiting conjugate symmetry via offset-sampling.
|
| 346 |
+
|
| 347 |
+
This function applies the mask described in the following paper:
|
| 348 |
+
|
| 349 |
+
Defazio, A. (2019). Offset Sampling Improves Deep Learning based
|
| 350 |
+
Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint,
|
| 351 |
+
arXiv:1912.01101.
|
| 352 |
+
|
| 353 |
+
It is essentially an equispaced mask with an offset for the opposite site
|
| 354 |
+
of k-space. Since MRI images often exhibit approximate conjugate k-space
|
| 355 |
+
symmetry, this mask is generally more efficient than a standard equispaced
|
| 356 |
+
mask.
|
| 357 |
+
|
| 358 |
+
Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the
|
| 359 |
+
target acceleration rate.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def calculate_acceleration_mask(
|
| 363 |
+
self,
|
| 364 |
+
num_cols: int,
|
| 365 |
+
acceleration: int,
|
| 366 |
+
offset: Optional[int],
|
| 367 |
+
num_low_frequencies: int,
|
| 368 |
+
) -> np.ndarray:
|
| 369 |
+
"""
|
| 370 |
+
Produce mask for non-central acceleration lines.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
num_cols: Number of columns of k-space (2D subsampling).
|
| 374 |
+
acceleration: Desired acceleration rate.
|
| 375 |
+
offset: Offset from 0 to begin masking. If no offset is specified,
|
| 376 |
+
then one is selected randomly.
|
| 377 |
+
num_low_frequencies: Not used.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
A mask for the high spatial frequencies of k-space.
|
| 381 |
+
"""
|
| 382 |
+
if offset is None:
|
| 383 |
+
offset = self.rng.randint(0, high=acceleration)
|
| 384 |
+
|
| 385 |
+
if offset % 2 == 0:
|
| 386 |
+
offset_pos = offset + 1
|
| 387 |
+
offset_neg = offset + 2
|
| 388 |
+
else:
|
| 389 |
+
offset_pos = offset - 1 + 3
|
| 390 |
+
offset_neg = offset - 1 + 0
|
| 391 |
+
|
| 392 |
+
poslen = (num_cols + 1) // 2
|
| 393 |
+
neglen = num_cols - (num_cols + 1) // 2
|
| 394 |
+
mask_positive = np.zeros(poslen, dtype=np.float32)
|
| 395 |
+
mask_negative = np.zeros(neglen, dtype=np.float32)
|
| 396 |
+
|
| 397 |
+
mask_positive[offset_pos::acceleration] = 1
|
| 398 |
+
mask_negative[offset_neg::acceleration] = 1
|
| 399 |
+
mask_negative = np.flip(mask_negative)
|
| 400 |
+
|
| 401 |
+
mask = np.concatenate((mask_positive, mask_negative))
|
| 402 |
+
|
| 403 |
+
return np.fft.fftshift(mask) # shift mask and return
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class MagicMaskFractionFunc(MagicMaskFunc):
|
| 407 |
+
"""
|
| 408 |
+
Masking function for exploiting conjugate symmetry via offset-sampling.
|
| 409 |
+
|
| 410 |
+
This function applies the mask described in the following paper:
|
| 411 |
+
|
| 412 |
+
Defazio, A. (2019). Offset Sampling Improves Deep Learning based
|
| 413 |
+
Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint,
|
| 414 |
+
arXiv:1912.01101.
|
| 415 |
+
|
| 416 |
+
It is essentially an equispaced mask with an offset for the opposite site
|
| 417 |
+
of k-space. Since MRI images often exhibit approximate conjugate k-space
|
| 418 |
+
symmetry, this mask is generally more efficient than a standard equispaced
|
| 419 |
+
mask.
|
| 420 |
+
|
| 421 |
+
Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches
|
| 422 |
+
the target acceleration by adjusting the offsets.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def sample_mask(
|
| 426 |
+
self,
|
| 427 |
+
shape: Sequence[int],
|
| 428 |
+
offset: Optional[int],
|
| 429 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 430 |
+
"""
|
| 431 |
+
Sample a new k-space mask.
|
| 432 |
+
|
| 433 |
+
This function samples and returns two components of a k-space mask: 1)
|
| 434 |
+
the center mask (e.g., for sensitivity map calculation) and 2) the
|
| 435 |
+
acceleration mask (for the edge of k-space). Both of these masks, as
|
| 436 |
+
well as the integer of low frequency samples, are returned.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
shape: Shape of the k-space to subsample.
|
| 440 |
+
offset: Offset from 0 to begin mask (for equispaced masks).
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
A 3-tuple contaiing 1) the mask for the center of k-space, 2) the
|
| 444 |
+
mask for the high frequencies of k-space, and 3) the integer count
|
| 445 |
+
of low frequency samples.
|
| 446 |
+
"""
|
| 447 |
+
num_cols = shape[-2]
|
| 448 |
+
fraction_low_freqs, acceleration = self.choose_acceleration()
|
| 449 |
+
num_cols = shape[-2]
|
| 450 |
+
num_low_frequencies = round(num_cols * fraction_low_freqs)
|
| 451 |
+
|
| 452 |
+
# bound the number of low frequencies between 1 and target columns
|
| 453 |
+
target_columns_to_sample = round(num_cols / acceleration)
|
| 454 |
+
num_low_frequencies = max(
|
| 455 |
+
min(num_low_frequencies, target_columns_to_sample), 1
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# adjust acceleration rate based on target acceleration.
|
| 459 |
+
adjusted_target_columns_to_sample = (
|
| 460 |
+
target_columns_to_sample - num_low_frequencies
|
| 461 |
+
)
|
| 462 |
+
adjusted_acceleration = 0
|
| 463 |
+
if adjusted_target_columns_to_sample > 0:
|
| 464 |
+
adjusted_acceleration = round(
|
| 465 |
+
num_cols / adjusted_target_columns_to_sample
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
center_mask = self.reshape_mask(
|
| 469 |
+
self.calculate_center_mask(shape, num_low_frequencies), shape
|
| 470 |
+
)
|
| 471 |
+
accel_mask = self.reshape_mask(
|
| 472 |
+
self.calculate_acceleration_mask(
|
| 473 |
+
num_cols, adjusted_acceleration, offset, num_low_frequencies
|
| 474 |
+
),
|
| 475 |
+
shape,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
return center_mask, accel_mask, num_low_frequencies
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class Gaussian2DMaskFunc(MaskFunc):
|
| 482 |
+
"""Gaussian 2D Masking
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
MaskFunc (_type_): _description_
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
def __init__(
|
| 489 |
+
self,
|
| 490 |
+
accelerations: Sequence[int],
|
| 491 |
+
stds: Sequence[float],
|
| 492 |
+
seed: Optional[int] = None,
|
| 493 |
+
):
|
| 494 |
+
"""initialize Gaussian 2D Mask
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
accelerations (Sequence[int]): list of acceleration factors, when
|
| 498 |
+
generating a mask, an acceleration factor from this list will be chosen
|
| 499 |
+
stds (Sequence[float]): list of torch.Normal scale (~std) to choose from
|
| 500 |
+
seed (Optional[int], optional): Seed for selecting mask parameters. Defaults to None.
|
| 501 |
+
"""
|
| 502 |
+
self.rng = np.random.RandomState(seed)
|
| 503 |
+
self.accelerations = accelerations
|
| 504 |
+
self.stds = stds
|
| 505 |
+
|
| 506 |
+
def __call__(
|
| 507 |
+
self,
|
| 508 |
+
shape: Sequence[int],
|
| 509 |
+
offset: Optional[int] = None,
|
| 510 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 511 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 512 |
+
if len(shape) < 3:
|
| 513 |
+
raise ValueError("Shape should have 3 or more dimensions")
|
| 514 |
+
|
| 515 |
+
acceleration = self.rng.choice(self.accelerations)
|
| 516 |
+
std = self.rng.choice(self.stds)
|
| 517 |
+
|
| 518 |
+
x, y = shape[-3], shape[-2]
|
| 519 |
+
mean_x = x // 2
|
| 520 |
+
mean_y = y // 2
|
| 521 |
+
num_samples_collected = 0
|
| 522 |
+
|
| 523 |
+
dist = D.Normal(
|
| 524 |
+
loc=torch.tensor([mean_x, mean_y], dtype=torch.float32),
|
| 525 |
+
scale=std,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
N = (
|
| 529 |
+
int(1 / acceleration * x * y) + 10000
|
| 530 |
+
) # add constant or won't reach desired subsampling rate
|
| 531 |
+
sample_x, sample_y = (
|
| 532 |
+
torch.zeros(N, dtype=torch.int),
|
| 533 |
+
torch.zeros(N, dtype=torch.int),
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
while num_samples_collected < N:
|
| 537 |
+
samples = dist.sample((N,)) # type: ignore
|
| 538 |
+
valid_samples = (
|
| 539 |
+
(samples[:, 0] >= 0)
|
| 540 |
+
& (samples[:, 0] < x)
|
| 541 |
+
& (samples[:, 1] >= 0)
|
| 542 |
+
& (samples[:, 1] < y)
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
valid_x = samples[valid_samples, 0].int()
|
| 546 |
+
valid_y = samples[valid_samples, 1].int()
|
| 547 |
+
|
| 548 |
+
num_to_take = min(N - num_samples_collected, valid_x.size(0))
|
| 549 |
+
sample_x[
|
| 550 |
+
num_samples_collected : num_samples_collected + num_to_take
|
| 551 |
+
] = valid_x[:num_to_take]
|
| 552 |
+
sample_y[
|
| 553 |
+
num_samples_collected : num_samples_collected + num_to_take
|
| 554 |
+
] = valid_y[:num_to_take]
|
| 555 |
+
num_samples_collected += num_to_take
|
| 556 |
+
|
| 557 |
+
mask = torch.zeros((x, y))
|
| 558 |
+
mask[sample_x, sample_y] = 1.0
|
| 559 |
+
|
| 560 |
+
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
|
| 561 |
+
mask = mask.unsqueeze(-1) # (x, y, 1)
|
| 562 |
+
mask = mask.unsqueeze(0) # (1, x, y, 1)
|
| 563 |
+
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
|
| 564 |
+
|
| 565 |
+
# num_low_freqs doesn't make sense so just return std (a number)
|
| 566 |
+
# returning None doesn't work since we can't stack for multiple batches
|
| 567 |
+
return mask, std
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class Poisson2DMaskFunc(MaskFunc):
|
| 571 |
+
"""
|
| 572 |
+
Variable Density Poisson Disk Sampling
|
| 573 |
+
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.poisson.html#sigpy.mri.poisson
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
def __init__(
|
| 577 |
+
self,
|
| 578 |
+
accelerations: Sequence[int],
|
| 579 |
+
stds: None,
|
| 580 |
+
seed: Optional[int] = None,
|
| 581 |
+
use_cache: bool = True,
|
| 582 |
+
):
|
| 583 |
+
"""initialize VDPD (Poisson) mask
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
accelerations (Sequence[int]): list of acceleration factors to
|
| 587 |
+
choose from
|
| 588 |
+
stds: Dummy param. Do not pass value. Defaults to None.
|
| 589 |
+
seed (Optional[int], optional): Seed for selecting mask params.
|
| 590 |
+
Defaults to None.
|
| 591 |
+
"""
|
| 592 |
+
self.rng = np.random.RandomState(seed)
|
| 593 |
+
self.accelerations = accelerations
|
| 594 |
+
self.use_cache = use_cache
|
| 595 |
+
if use_cache:
|
| 596 |
+
self.cache: Dict[int, np.ndarray] = dict()
|
| 597 |
+
for acc in accelerations:
|
| 598 |
+
# assert os.path.exists(
|
| 599 |
+
# f"fastmri/poisson_cache/poisson_{acc}x.npy"
|
| 600 |
+
# )
|
| 601 |
+
# self.cache[acc] = np.load(
|
| 602 |
+
# f"fastmri/poisson_cache/poisson_{acc}x.npy"
|
| 603 |
+
# )
|
| 604 |
+
self.cache[acc] = np.load(
|
| 605 |
+
f"/global/homes/p/peterwg/more/medical-imaging/fastmri/poisson_cache/poisson_{acc}x.npy"
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
def __call__(
|
| 609 |
+
self,
|
| 610 |
+
shape: Sequence[int],
|
| 611 |
+
offset: Optional[int] = None,
|
| 612 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 613 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 614 |
+
if self.use_cache:
|
| 615 |
+
acceleration = self.rng.choice(self.accelerations)
|
| 616 |
+
return torch.from_numpy(self.cache[acceleration]), 1.0 # type: ignore
|
| 617 |
+
if len(shape) < 3:
|
| 618 |
+
raise ValueError("Shape should have 3 or more dimensions")
|
| 619 |
+
|
| 620 |
+
acceleration = self.rng.choice(self.accelerations)
|
| 621 |
+
x, y = shape[-3], shape[-2]
|
| 622 |
+
|
| 623 |
+
mask = poisson(img_shape=(x, y), accel=acceleration, dtype=np.float32)
|
| 624 |
+
mask = torch.from_numpy(mask)
|
| 625 |
+
|
| 626 |
+
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
|
| 627 |
+
mask = mask.unsqueeze(-1) # (x, y, 1e
|
| 628 |
+
mask = mask.unsqueeze(0) # (1, x, y, 1)
|
| 629 |
+
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
|
| 630 |
+
|
| 631 |
+
# num low freqs doesn't make sense here, so we return arbitrary value 1.0
|
| 632 |
+
return mask, 100.0
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class Radial2DMaskFunc(MaskFunc):
|
| 636 |
+
"""
|
| 637 |
+
Radial trajectory MRI masking method.
|
| 638 |
+
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.radial.html#sigpy.mri.radial
|
| 639 |
+
"""
|
| 640 |
+
|
| 641 |
+
def __init__(
|
| 642 |
+
self,
|
| 643 |
+
accelerations: Sequence[int],
|
| 644 |
+
arms: Optional[Sequence[int]],
|
| 645 |
+
seed: Optional[int] = None,
|
| 646 |
+
):
|
| 647 |
+
"""
|
| 648 |
+
initialize Radial mask
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
accelerations (Sequence[int]): list of acceleration factors to
|
| 652 |
+
choose from
|
| 653 |
+
arms: Number of radial arms.
|
| 654 |
+
seed (Optional[int], optional): Seed for selecting mask params.
|
| 655 |
+
Defaults to None.
|
| 656 |
+
"""
|
| 657 |
+
self.rng = np.random.RandomState(seed)
|
| 658 |
+
self.accelerations = accelerations
|
| 659 |
+
self.arms = arms
|
| 660 |
+
|
| 661 |
+
def __call__(
|
| 662 |
+
self,
|
| 663 |
+
shape: Sequence[int],
|
| 664 |
+
offset: Optional[int] = None,
|
| 665 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 666 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 667 |
+
if len(shape) < 3:
|
| 668 |
+
raise ValueError("Shape should have 3 or more dimensions")
|
| 669 |
+
|
| 670 |
+
acceleration = self.rng.choice(self.accelerations)
|
| 671 |
+
x, y = shape[-3], shape[-2]
|
| 672 |
+
npoints = int(x * y * (1 / acceleration))
|
| 673 |
+
if self.arms:
|
| 674 |
+
arms = self.rng.choice(self.arms)
|
| 675 |
+
else:
|
| 676 |
+
points_per_arm = x // 3
|
| 677 |
+
arms = npoints // points_per_arm
|
| 678 |
+
|
| 679 |
+
# calculate radial parameters to satisfy acceleration factor
|
| 680 |
+
ntr = arms # num radial lines
|
| 681 |
+
nro = npoints // arms # num points on each radial line
|
| 682 |
+
ndim = 2 # 2D
|
| 683 |
+
|
| 684 |
+
# gen trajectory w/ shape (ntr, nro, ndim)
|
| 685 |
+
traj = radial(
|
| 686 |
+
coord_shape=[ntr, nro, ndim],
|
| 687 |
+
img_shape=(x, y),
|
| 688 |
+
golden=True,
|
| 689 |
+
dtype=int,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
mask = torch.zeros(x, y, dtype=torch.float32)
|
| 693 |
+
x_coords = traj[..., 0].flatten() + (x // 2)
|
| 694 |
+
y_coords = traj[..., 1].flatten() + (y // 2)
|
| 695 |
+
mask[x_coords, y_coords] = 1.0
|
| 696 |
+
|
| 697 |
+
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
|
| 698 |
+
mask = mask.unsqueeze(-1) # (x, y, 1)
|
| 699 |
+
mask = mask.unsqueeze(0) # (1, x, y, 1)
|
| 700 |
+
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
|
| 701 |
+
|
| 702 |
+
# num low freqs doesn't make sense here, so we return arbitrary value 1.0
|
| 703 |
+
return mask, 100.0
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class Spiral2DMaskFunc(MaskFunc):
|
| 707 |
+
"""
|
| 708 |
+
Spiral trajectory MRI masking method.
|
| 709 |
+
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.spiral.html#sigpy.mri.spiral
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def __init__(
|
| 713 |
+
self,
|
| 714 |
+
accelerations: Sequence[int],
|
| 715 |
+
arms: Sequence[int],
|
| 716 |
+
seed: Optional[int] = None,
|
| 717 |
+
):
|
| 718 |
+
"""
|
| 719 |
+
initialize Radial mask
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
accelerations (Sequence[int]): list of acceleration factors to
|
| 723 |
+
choose from
|
| 724 |
+
arms: Number of radial arms.
|
| 725 |
+
seed (Optional[int], optional): Seed for selecting mask params.
|
| 726 |
+
Defaults to None.
|
| 727 |
+
"""
|
| 728 |
+
self.rng = np.random.RandomState(seed)
|
| 729 |
+
self.accelerations = accelerations
|
| 730 |
+
self.arms = arms
|
| 731 |
+
|
| 732 |
+
def __call__(
|
| 733 |
+
self,
|
| 734 |
+
shape: Sequence[int],
|
| 735 |
+
offset: Optional[int] = None,
|
| 736 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 737 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 738 |
+
# TODO: implement
|
| 739 |
+
raise (NotImplementedError("Spiral2D not implemented"))
|
| 740 |
+
if len(shape) < 3:
|
| 741 |
+
raise ValueError("Shape should have 3 or more dimensions")
|
| 742 |
+
acceleration = self.rng.choice(self.accelerations)
|
| 743 |
+
arms = self.rng.choice(self.arms)
|
| 744 |
+
x, y = shape[-3], shape[-2]
|
| 745 |
+
|
| 746 |
+
# calculate radial parameters to satisfy acceleration factor
|
| 747 |
+
npoints = int(x * y * (1 / acceleration))
|
| 748 |
+
|
| 749 |
+
# gen trajectory w/ shape (ntr, nro, ndim)
|
| 750 |
+
traj = spiral(
|
| 751 |
+
N=npoints,
|
| 752 |
+
img_shape=(x, y),
|
| 753 |
+
golden=True,
|
| 754 |
+
dtype=int,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
mask = torch.zeros(x, y, dtype=float)
|
| 758 |
+
x_coords = traj[..., 0].flatten() + (x // 2)
|
| 759 |
+
y_coords = traj[..., 1].flatten() + (y // 2)
|
| 760 |
+
mask[x_coords, y_coords] = 1.0
|
| 761 |
+
|
| 762 |
+
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size
|
| 763 |
+
mask = mask.unsqueeze(-1) # (x, y, 1)
|
| 764 |
+
mask = mask.unsqueeze(0) # (1, x, y, 1)
|
| 765 |
+
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone()
|
| 766 |
+
|
| 767 |
+
# num low freqs doesn't make sense here, so we return arbitrary value 1.0
|
| 768 |
+
return mask, 100.0
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def create_mask_for_mask_type(
|
| 772 |
+
mask_type_str: str,
|
| 773 |
+
center_fractions: Optional[Sequence],
|
| 774 |
+
accelerations: Sequence[int],
|
| 775 |
+
) -> MaskFunc:
|
| 776 |
+
"""
|
| 777 |
+
Creates a mask of the specified type.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
center_fractions: What fraction of the center of k-space to include.
|
| 781 |
+
accelerations: What accelerations to apply.
|
| 782 |
+
|
| 783 |
+
Returns:
|
| 784 |
+
A mask func for the target mask type.
|
| 785 |
+
"""
|
| 786 |
+
if mask_type_str == "random":
|
| 787 |
+
return RandomMaskFunc(center_fractions, accelerations)
|
| 788 |
+
elif mask_type_str == "equispaced":
|
| 789 |
+
return EquiSpacedMaskFunc(center_fractions, accelerations)
|
| 790 |
+
elif mask_type_str == "equispaced_fraction":
|
| 791 |
+
return EquispacedMaskFractionFunc(center_fractions, accelerations)
|
| 792 |
+
elif mask_type_str == "magic":
|
| 793 |
+
return MagicMaskFunc(center_fractions, accelerations)
|
| 794 |
+
elif mask_type_str == "magic_fraction":
|
| 795 |
+
return MagicMaskFractionFunc(center_fractions, accelerations)
|
| 796 |
+
elif mask_type_str == "gaussian_2d":
|
| 797 |
+
return Gaussian2DMaskFunc(
|
| 798 |
+
stds=center_fractions,
|
| 799 |
+
accelerations=accelerations,
|
| 800 |
+
)
|
| 801 |
+
elif mask_type_str == "poisson_2d":
|
| 802 |
+
return Poisson2DMaskFunc(
|
| 803 |
+
accelerations=accelerations,
|
| 804 |
+
stds=None,
|
| 805 |
+
)
|
| 806 |
+
elif mask_type_str == "radial_2d":
|
| 807 |
+
return Radial2DMaskFunc(
|
| 808 |
+
accelerations=accelerations,
|
| 809 |
+
arms=(
|
| 810 |
+
[int(arm) for arm in center_fractions]
|
| 811 |
+
if center_fractions
|
| 812 |
+
else None
|
| 813 |
+
),
|
| 814 |
+
)
|
| 815 |
+
elif mask_type_str == "spiral_2d":
|
| 816 |
+
raise NotImplementedError("spiral_2d not implemented")
|
| 817 |
+
else:
|
| 818 |
+
raise ValueError(f"{mask_type_str} not supported")
|
fastmri/transforms.py
ADDED
|
@@ -0,0 +1,974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
import fastmri
|
| 16 |
+
|
| 17 |
+
from .subsample import MaskFunc
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def to_tensor(data: np.ndarray) -> torch.Tensor:
|
| 21 |
+
"""
|
| 22 |
+
Convert numpy array to PyTorch tensor.
|
| 23 |
+
|
| 24 |
+
For complex arrays, the real and imaginary parts are stacked along the last
|
| 25 |
+
dimension.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
data: Input numpy array.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
PyTorch version of data.
|
| 32 |
+
"""
|
| 33 |
+
if np.iscomplexobj(data):
|
| 34 |
+
data = np.stack((data.real, data.imag), axis=-1)
|
| 35 |
+
|
| 36 |
+
return torch.from_numpy(data)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
|
| 40 |
+
"""
|
| 41 |
+
Converts a complex torch tensor to numpy array.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
data: Input data to be converted to numpy.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Complex numpy version of data.
|
| 48 |
+
"""
|
| 49 |
+
return torch.view_as_complex(data).numpy()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_mask(
|
| 53 |
+
data: torch.Tensor,
|
| 54 |
+
mask_func: MaskFunc,
|
| 55 |
+
offset: Optional[int] = None,
|
| 56 |
+
seed: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 57 |
+
padding: Optional[Sequence[int]] = None,
|
| 58 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 59 |
+
"""
|
| 60 |
+
Subsample given k-space by multiplying with a mask.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
data: The input k-space data. This should have at least 3 dimensions,
|
| 64 |
+
where dimensions -3 and -2 are the spatial dimensions, and the
|
| 65 |
+
final dimension has size 2 (for complex values).
|
| 66 |
+
mask_func: A function that takes a shape (tuple of ints) and a random
|
| 67 |
+
number seed and returns a mask.
|
| 68 |
+
seed: Seed for the random number generator.
|
| 69 |
+
padding: Padding value to apply for mask.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
tuple containing:
|
| 73 |
+
masked data: Subsampled k-space data.
|
| 74 |
+
mask: The generated mask.
|
| 75 |
+
num_low_frequencies: The number of low-resolution frequency samples
|
| 76 |
+
in the mask.
|
| 77 |
+
"""
|
| 78 |
+
shape = (1,) * len(data.shape[:-3]) + tuple(data.shape[-3:])
|
| 79 |
+
mask, num_low_frequencies = mask_func(shape, offset, seed)
|
| 80 |
+
if padding is not None:
|
| 81 |
+
mask[..., : padding[0], :] = 0
|
| 82 |
+
mask[..., padding[1] :, :] = (
|
| 83 |
+
0 # padding value inclusive on right of zeros
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros
|
| 87 |
+
|
| 88 |
+
return masked_data, mask, num_low_frequencies
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def mask_center(x: torch.Tensor, mask_from: int, mask_to: int) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
Initializes a mask with the center filled in.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
mask_from: Part of center to start filling.
|
| 97 |
+
mask_to: Part of center to end filling.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
A mask with the center filled.
|
| 101 |
+
"""
|
| 102 |
+
mask = torch.zeros_like(x)
|
| 103 |
+
mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to]
|
| 104 |
+
|
| 105 |
+
return mask
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def batched_mask_center(
|
| 109 |
+
x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
"""
|
| 112 |
+
Initializes a mask with the center filled in.
|
| 113 |
+
|
| 114 |
+
Can operate with different masks for each batch element.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
mask_from: Part of center to start filling.
|
| 118 |
+
mask_to: Part of center to end filling.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
A mask with the center filled.
|
| 122 |
+
"""
|
| 123 |
+
if not mask_from.shape == mask_to.shape:
|
| 124 |
+
raise ValueError("mask_from and mask_to must match shapes.")
|
| 125 |
+
if not mask_from.ndim == 1:
|
| 126 |
+
raise ValueError("mask_from and mask_to must have 1 dimension.")
|
| 127 |
+
if not mask_from.shape[0] == 1:
|
| 128 |
+
if (not x.shape[0] == mask_from.shape[0]) or (
|
| 129 |
+
not x.shape[0] == mask_to.shape[0]
|
| 130 |
+
):
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"mask_from and mask_to must have batch_size length."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if mask_from.shape[0] == 1:
|
| 136 |
+
mask = mask_center(x, int(mask_from), int(mask_to))
|
| 137 |
+
else:
|
| 138 |
+
mask = torch.zeros_like(x)
|
| 139 |
+
for i, (start, end) in enumerate(zip(mask_from, mask_to)):
|
| 140 |
+
mask[i, :, :, start:end] = x[i, :, :, start:end]
|
| 141 |
+
|
| 142 |
+
return mask
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Apply a center crop to the input real image or batch of real images.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
data: The input tensor to be center cropped. It should
|
| 151 |
+
have at least 2 dimensions and the cropping is applied along the
|
| 152 |
+
last two dimensions.
|
| 153 |
+
shape: The output shape. The shape should be smaller
|
| 154 |
+
than the corresponding dimensions of data.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
The center cropped image.
|
| 158 |
+
"""
|
| 159 |
+
if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]):
|
| 160 |
+
raise ValueError("Invalid shapes.")
|
| 161 |
+
|
| 162 |
+
w_from = (data.shape[-2] - shape[0]) // 2
|
| 163 |
+
h_from = (data.shape[-1] - shape[1]) // 2
|
| 164 |
+
w_to = w_from + shape[0]
|
| 165 |
+
h_to = h_from + shape[1]
|
| 166 |
+
|
| 167 |
+
return data[..., w_from:w_to, h_from:h_to]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def complex_center_crop(
|
| 171 |
+
data: torch.Tensor, shape: Tuple[int, int]
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Apply a center crop to the input image or batch of complex images.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
data: The complex input tensor to be center cropped. It should have at
|
| 178 |
+
least 3 dimensions and the cropping is applied along dimensions -3
|
| 179 |
+
and -2 and the last dimensions should have a size of 2.
|
| 180 |
+
shape: The output shape. The shape should be smaller than the
|
| 181 |
+
corresponding dimensions of data.
|
| 182 |
+
Returns:
|
| 183 |
+
The center cropped image
|
| 184 |
+
"""
|
| 185 |
+
if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]):
|
| 186 |
+
raise ValueError("Invalid shapes.")
|
| 187 |
+
|
| 188 |
+
w_from = (data.shape[-3] - shape[0]) // 2
|
| 189 |
+
h_from = (data.shape[-2] - shape[1]) // 2
|
| 190 |
+
w_to = w_from + shape[0]
|
| 191 |
+
h_to = h_from + shape[1]
|
| 192 |
+
|
| 193 |
+
return data[..., w_from:w_to, h_from:h_to, :]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def center_crop_to_smallest(
|
| 197 |
+
x: torch.Tensor, y: torch.Tensor
|
| 198 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 199 |
+
"""
|
| 200 |
+
Apply a center crop on the larger image to the size of the smaller.
|
| 201 |
+
|
| 202 |
+
The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at
|
| 203 |
+
dim=-1 and y is smaller than x at dim=-2, then the returned dimension will
|
| 204 |
+
be a mixture of the two.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
x: The first image.
|
| 208 |
+
y: The second image.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
tuple of tensors x and y, each cropped to the minimim size.
|
| 212 |
+
"""
|
| 213 |
+
smallest_width = min(x.shape[-1], y.shape[-1])
|
| 214 |
+
smallest_height = min(x.shape[-2], y.shape[-2])
|
| 215 |
+
x = center_crop(x, (smallest_height, smallest_width))
|
| 216 |
+
y = center_crop(y, (smallest_height, smallest_width))
|
| 217 |
+
|
| 218 |
+
return x, y
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def normalize(
|
| 222 |
+
data: torch.Tensor,
|
| 223 |
+
mean: Union[float, torch.Tensor],
|
| 224 |
+
stddev: Union[float, torch.Tensor],
|
| 225 |
+
eps: Union[float, torch.Tensor] = 0.0,
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Normalize the given tensor.
|
| 229 |
+
|
| 230 |
+
Applies the formula (data - mean) / (stddev + eps).
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
data: Input data to be normalized.
|
| 234 |
+
mean: Mean value.
|
| 235 |
+
stddev: Standard deviation.
|
| 236 |
+
eps: Added to stddev to prevent dividing by zero.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Normalized tensor.
|
| 240 |
+
"""
|
| 241 |
+
return (data - mean) / (stddev + eps)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def normalize_instance(
|
| 245 |
+
data: torch.Tensor, eps: Union[float, torch.Tensor] = 0.0
|
| 246 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 247 |
+
"""
|
| 248 |
+
Normalize the given tensor with instance norm/
|
| 249 |
+
|
| 250 |
+
Applies the formula (data - mean) / (stddev + eps), where mean and stddev
|
| 251 |
+
are computed from the data itself.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
data: Input data to be normalized
|
| 255 |
+
eps: Added to stddev to prevent dividing by zero.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
torch.Tensor: Normalized tensor
|
| 259 |
+
"""
|
| 260 |
+
mean = data.mean()
|
| 261 |
+
std = data.std()
|
| 262 |
+
|
| 263 |
+
return normalize(data, mean, std, eps), mean, std
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class UnetSample(NamedTuple):
|
| 267 |
+
"""
|
| 268 |
+
A subsampled image for U-Net reconstruction.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
image: Subsampled image after inverse FFT.
|
| 272 |
+
target: The target image (if applicable).
|
| 273 |
+
mean: Per-channel mean values used for normalization.
|
| 274 |
+
std: Per-channel standard deviations used for normalization.
|
| 275 |
+
fname: File name.
|
| 276 |
+
slice_num: The slice index.
|
| 277 |
+
max_value: Maximum image value.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
image: torch.Tensor
|
| 281 |
+
target: torch.Tensor
|
| 282 |
+
mean: torch.Tensor
|
| 283 |
+
std: torch.Tensor
|
| 284 |
+
fname: str
|
| 285 |
+
slice_num: int
|
| 286 |
+
max_value: float
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class UnetDataTransform:
|
| 290 |
+
"""
|
| 291 |
+
Data Transformer for training U-Net models.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def __init__(
|
| 295 |
+
self,
|
| 296 |
+
which_challenge: str,
|
| 297 |
+
mask_func: Optional[MaskFunc] = None,
|
| 298 |
+
use_seed: bool = True,
|
| 299 |
+
):
|
| 300 |
+
"""
|
| 301 |
+
Args:
|
| 302 |
+
which_challenge: Challenge from ("singlecoil", "multicoil").
|
| 303 |
+
mask_func: Optional; A function that can create a mask of
|
| 304 |
+
appropriate shape.
|
| 305 |
+
use_seed: If true, this class computes a pseudo random number
|
| 306 |
+
generator seed from the filename. This ensures that the same
|
| 307 |
+
mask is used for all the slices of a given volume every time.
|
| 308 |
+
"""
|
| 309 |
+
if which_challenge not in ("singlecoil", "multicoil"):
|
| 310 |
+
raise ValueError(
|
| 311 |
+
"Challenge should either be 'singlecoil' or 'multicoil'"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
self.mask_func = mask_func
|
| 315 |
+
self.which_challenge = which_challenge
|
| 316 |
+
self.use_seed = use_seed
|
| 317 |
+
|
| 318 |
+
def __call__(
|
| 319 |
+
self,
|
| 320 |
+
kspace: np.ndarray,
|
| 321 |
+
mask: np.ndarray,
|
| 322 |
+
target: np.ndarray,
|
| 323 |
+
attrs: Dict,
|
| 324 |
+
fname: str,
|
| 325 |
+
slice_num: int,
|
| 326 |
+
) -> Tuple[
|
| 327 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float
|
| 328 |
+
]:
|
| 329 |
+
"""
|
| 330 |
+
Args:
|
| 331 |
+
kspace: Input k-space of shape (num_coils, rows, cols) for
|
| 332 |
+
multi-coil data or (rows, cols) for single coil data.
|
| 333 |
+
mask: Mask from the test dataset.
|
| 334 |
+
target: Target image.
|
| 335 |
+
attrs: Acquisition related information stored in the HDF5 object.
|
| 336 |
+
fname: File name.
|
| 337 |
+
slice_num: Serial number of the slice.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
A tuple containing, zero-filled input image, the reconstruction
|
| 341 |
+
target, the mean used for normalization, the standard deviations
|
| 342 |
+
used for normalization, the filename, and the slice number.
|
| 343 |
+
"""
|
| 344 |
+
kspace_torch = to_tensor(kspace)
|
| 345 |
+
|
| 346 |
+
# check for max value
|
| 347 |
+
max_value = attrs["max"] if "max" in attrs.keys() else 0.0
|
| 348 |
+
|
| 349 |
+
# apply mask
|
| 350 |
+
if self.mask_func:
|
| 351 |
+
seed = None if not self.use_seed else tuple(map(ord, fname))
|
| 352 |
+
# we only need first element, which is k-space after masking
|
| 353 |
+
masked_kspace = apply_mask(kspace_torch, self.mask_func, seed=seed)[
|
| 354 |
+
0
|
| 355 |
+
]
|
| 356 |
+
else:
|
| 357 |
+
masked_kspace = kspace_torch
|
| 358 |
+
|
| 359 |
+
# inverse Fourier transform to get zero filled solution
|
| 360 |
+
image = fastmri.ifft2c(masked_kspace)
|
| 361 |
+
|
| 362 |
+
# crop input to correct size
|
| 363 |
+
if target is not None:
|
| 364 |
+
crop_size = (target.shape[-2], target.shape[-1])
|
| 365 |
+
else:
|
| 366 |
+
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 367 |
+
|
| 368 |
+
# check for FLAIR 203
|
| 369 |
+
if image.shape[-2] < crop_size[1]:
|
| 370 |
+
crop_size = (image.shape[-2], image.shape[-2])
|
| 371 |
+
|
| 372 |
+
image = complex_center_crop(image, crop_size)
|
| 373 |
+
|
| 374 |
+
# absolute value
|
| 375 |
+
image = fastmri.complex_abs(image)
|
| 376 |
+
|
| 377 |
+
# apply Root-Sum-of-Squares if multicoil data
|
| 378 |
+
if self.which_challenge == "multicoil":
|
| 379 |
+
image = fastmri.rss(image)
|
| 380 |
+
|
| 381 |
+
# normalize input
|
| 382 |
+
image, mean, std = normalize_instance(image, eps=1e-11)
|
| 383 |
+
image = image.clamp(-6, 6)
|
| 384 |
+
|
| 385 |
+
# normalize target
|
| 386 |
+
if target is not None:
|
| 387 |
+
target_torch = to_tensor(target)
|
| 388 |
+
target_torch = center_crop(target_torch, crop_size)
|
| 389 |
+
target_torch = normalize(target_torch, mean, std, eps=1e-11)
|
| 390 |
+
target_torch = target_torch.clamp(-6, 6)
|
| 391 |
+
else:
|
| 392 |
+
target_torch = torch.Tensor([0])
|
| 393 |
+
|
| 394 |
+
return UnetSample(
|
| 395 |
+
image=image,
|
| 396 |
+
target=target_torch,
|
| 397 |
+
mean=mean,
|
| 398 |
+
std=std,
|
| 399 |
+
fname=fname,
|
| 400 |
+
slice_num=slice_num,
|
| 401 |
+
max_value=max_value,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class VarNetSample(NamedTuple):
|
| 406 |
+
"""
|
| 407 |
+
A sample of masked k-space for variational network reconstruction.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
masked_kspace: k-space after applying sampling mask.
|
| 411 |
+
mask: The applied sampling mask.
|
| 412 |
+
num_low_frequencies: The number of samples for the densely-sampled
|
| 413 |
+
center.
|
| 414 |
+
target: The target image (if applicable).
|
| 415 |
+
fname: File name.
|
| 416 |
+
slice_num: The slice index.
|
| 417 |
+
max_value: Maximum image value.
|
| 418 |
+
crop_size: The size to crop the final image.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
masked_kspace: torch.Tensor
|
| 422 |
+
mask: torch.Tensor
|
| 423 |
+
num_low_frequencies: Optional[int]
|
| 424 |
+
target: torch.Tensor
|
| 425 |
+
fname: str
|
| 426 |
+
slice_num: int
|
| 427 |
+
max_value: float
|
| 428 |
+
crop_size: Tuple[int, int]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class VarNetDataTransform:
|
| 432 |
+
"""
|
| 433 |
+
Data Transformer for training VarNet models.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
def __init__(
|
| 437 |
+
self, mask_func: Optional[MaskFunc] = None, use_seed: bool = True
|
| 438 |
+
):
|
| 439 |
+
"""
|
| 440 |
+
Args:
|
| 441 |
+
mask_func: Optional; A function that can create a mask of
|
| 442 |
+
appropriate shape. Defaults to None.
|
| 443 |
+
use_seed: If True, this class computes a pseudo random number
|
| 444 |
+
generator seed from the filename. This ensures that the same
|
| 445 |
+
mask is used for all the slices of a given volume every time.
|
| 446 |
+
"""
|
| 447 |
+
self.mask_func = mask_func
|
| 448 |
+
self.use_seed = use_seed
|
| 449 |
+
|
| 450 |
+
def __call__(
|
| 451 |
+
self,
|
| 452 |
+
kspace: np.ndarray,
|
| 453 |
+
mask: np.ndarray,
|
| 454 |
+
target: Optional[np.ndarray],
|
| 455 |
+
attrs: Dict,
|
| 456 |
+
fname: str,
|
| 457 |
+
slice_num: int,
|
| 458 |
+
) -> VarNetSample:
|
| 459 |
+
"""
|
| 460 |
+
Args:
|
| 461 |
+
kspace: Input k-space of shape (num_coils, rows, cols) for
|
| 462 |
+
multi-coil data.
|
| 463 |
+
mask: Mask from the test dataset.
|
| 464 |
+
target: Target image.
|
| 465 |
+
attrs: Acquisition related information stored in the HDF5 object.
|
| 466 |
+
fname: File name.
|
| 467 |
+
slice_num: Serial number of the slice.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
A VarNetSample with the masked k-space, sampling mask, target
|
| 471 |
+
image, the filename, the slice number, the maximum image value
|
| 472 |
+
(from target), the target crop size, and the number of low
|
| 473 |
+
frequency lines sampled.
|
| 474 |
+
"""
|
| 475 |
+
if target is not None:
|
| 476 |
+
target_torch = to_tensor(target)
|
| 477 |
+
max_value = attrs["max"]
|
| 478 |
+
else:
|
| 479 |
+
target_torch = torch.tensor(0)
|
| 480 |
+
max_value = 0.0
|
| 481 |
+
|
| 482 |
+
kspace_torch = to_tensor(kspace)
|
| 483 |
+
seed = None if not self.use_seed else tuple(map(ord, fname))
|
| 484 |
+
acq_start = attrs["padding_left"]
|
| 485 |
+
acq_end = attrs["padding_right"]
|
| 486 |
+
|
| 487 |
+
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 488 |
+
|
| 489 |
+
if self.mask_func is not None:
|
| 490 |
+
masked_kspace, mask_torch, num_low_frequencies = apply_mask(
|
| 491 |
+
kspace_torch,
|
| 492 |
+
self.mask_func,
|
| 493 |
+
seed=seed,
|
| 494 |
+
padding=(acq_start, acq_end),
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
sample = VarNetSample(
|
| 498 |
+
masked_kspace=masked_kspace,
|
| 499 |
+
mask=mask_torch.to(torch.bool),
|
| 500 |
+
num_low_frequencies=num_low_frequencies,
|
| 501 |
+
target=target_torch,
|
| 502 |
+
fname=fname,
|
| 503 |
+
slice_num=slice_num,
|
| 504 |
+
max_value=max_value,
|
| 505 |
+
crop_size=crop_size,
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
masked_kspace = kspace_torch
|
| 509 |
+
shape = np.array(kspace_torch.shape)
|
| 510 |
+
num_cols = shape[-2]
|
| 511 |
+
shape[:-3] = 1
|
| 512 |
+
mask_shape = [1] * len(shape)
|
| 513 |
+
mask_shape[-2] = num_cols
|
| 514 |
+
mask_torch = torch.from_numpy(
|
| 515 |
+
mask.reshape(*mask_shape).astype(np.float32)
|
| 516 |
+
)
|
| 517 |
+
mask_torch = mask_torch.reshape(*mask_shape)
|
| 518 |
+
mask_torch[:, :, :acq_start] = 0
|
| 519 |
+
mask_torch[:, :, acq_end:] = 0
|
| 520 |
+
|
| 521 |
+
sample = VarNetSample(
|
| 522 |
+
masked_kspace=masked_kspace,
|
| 523 |
+
mask=mask_torch.to(torch.bool),
|
| 524 |
+
num_low_frequencies=0,
|
| 525 |
+
target=target_torch,
|
| 526 |
+
fname=fname,
|
| 527 |
+
slice_num=slice_num,
|
| 528 |
+
max_value=max_value,
|
| 529 |
+
crop_size=crop_size,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# whether to crop samples for batch processing
|
| 533 |
+
batch_crop = False
|
| 534 |
+
|
| 535 |
+
def save_img(x, fname):
|
| 536 |
+
slice_kspace2 = x
|
| 537 |
+
slice_image = fastmri.ifft2c(
|
| 538 |
+
slice_kspace2
|
| 539 |
+
) # Apply Inverse Fourier Transform to get the complex image
|
| 540 |
+
slice_image_abs = fastmri.complex_abs(
|
| 541 |
+
slice_image
|
| 542 |
+
) # Compute absolute value to get a real image
|
| 543 |
+
slice_image_rss = fastmri.rss(slice_image_abs, dim=0)
|
| 544 |
+
plt.imsave(f"{fname}.png", torch.abs(slice_image_rss), cmap="gray")
|
| 545 |
+
|
| 546 |
+
def save_raw_img(x, fname):
|
| 547 |
+
# slice_kspace2 = x
|
| 548 |
+
# slice_image = fastmri.ifft2c(
|
| 549 |
+
# slice_kspace2
|
| 550 |
+
# ) # Apply Inverse Fourier Transform to get the complex image
|
| 551 |
+
# slice_image_abs = fastmri.complex_abs(
|
| 552 |
+
# slice_image
|
| 553 |
+
# ) # Compute absolute value to get a real image
|
| 554 |
+
x = fastmri.rss(x, dim=0)[:, :, 0]
|
| 555 |
+
|
| 556 |
+
plt.imsave(f"{fname}.png", torch.abs(x))
|
| 557 |
+
|
| 558 |
+
if batch_crop:
|
| 559 |
+
# crop kspace data to minx, miny size (640, 320 cols)
|
| 560 |
+
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 561 |
+
# print(square_crop)
|
| 562 |
+
cropped_kspace = fastmri.fft2c(
|
| 563 |
+
complex_center_crop(
|
| 564 |
+
fastmri.ifft2c(sample.masked_kspace), square_crop
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
cropped_kspace = complex_center_crop(cropped_kspace, (320, 320))
|
| 568 |
+
# print(cropped_kspace.shape)
|
| 569 |
+
# exit(0)
|
| 570 |
+
|
| 571 |
+
# CHECK: debugging purposes
|
| 572 |
+
# save_img(sample.masked_kspace, "og")
|
| 573 |
+
# save_img(cropped_kspace, "cropped")
|
| 574 |
+
# save_raw_img(sample.masked_kspace, "og_kspace")
|
| 575 |
+
# save_raw_img(cropped_kspace, "cropped_kspace")
|
| 576 |
+
|
| 577 |
+
# exit(0)
|
| 578 |
+
|
| 579 |
+
# crop mask shape
|
| 580 |
+
h_from = (mask_torch.shape[-2] - 320) // 2
|
| 581 |
+
h_to = h_from + 320
|
| 582 |
+
cropped_mask = mask_torch[..., :, h_from:h_to, :]
|
| 583 |
+
|
| 584 |
+
sample = VarNetSample(
|
| 585 |
+
masked_kspace=cropped_kspace,
|
| 586 |
+
mask=cropped_mask.to(torch.bool),
|
| 587 |
+
num_low_frequencies=0,
|
| 588 |
+
target=target_torch,
|
| 589 |
+
fname=fname,
|
| 590 |
+
slice_num=slice_num,
|
| 591 |
+
max_value=max_value,
|
| 592 |
+
crop_size=crop_size,
|
| 593 |
+
)
|
| 594 |
+
return sample
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class EnhancedVarNetDataTransform(VarNetDataTransform):
|
| 598 |
+
"""
|
| 599 |
+
Enhanced Data Transformer for training VarNet models with additional functionality.
|
| 600 |
+
- allows for training on multiple patterns
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
def __init__(
|
| 604 |
+
self, mask_funcs: List[MaskFunc] = None, use_seed: bool = True
|
| 605 |
+
):
|
| 606 |
+
self.mask_funcs = mask_funcs
|
| 607 |
+
self.use_seed = use_seed
|
| 608 |
+
|
| 609 |
+
def __call__(
|
| 610 |
+
self,
|
| 611 |
+
kspace: np.ndarray,
|
| 612 |
+
mask: np.ndarray,
|
| 613 |
+
target: Optional[np.ndarray],
|
| 614 |
+
attrs: Dict,
|
| 615 |
+
fname: str,
|
| 616 |
+
slice_num: int,
|
| 617 |
+
) -> VarNetSample:
|
| 618 |
+
"""
|
| 619 |
+
Args:
|
| 620 |
+
kspace: Input k-space of shape (num_coils, rows, cols) for
|
| 621 |
+
multi-coil data.
|
| 622 |
+
mask: Mask from the test dataset.
|
| 623 |
+
use mask for test data see og VarNetDataTransform __call__
|
| 624 |
+
target: Target image.
|
| 625 |
+
attrs: Acquisition related information stored in the HDF5 object.
|
| 626 |
+
fname: File name.
|
| 627 |
+
slice_num: Serial number of the slice.
|
| 628 |
+
|
| 629 |
+
Returns:
|
| 630 |
+
A VarNetSample with the masked k-space, sampling mask, target
|
| 631 |
+
image, the filename, the slice number, the maximum image value
|
| 632 |
+
(from target), the target crop size, and the number of low
|
| 633 |
+
frequency lines sampled.
|
| 634 |
+
"""
|
| 635 |
+
if target is not None:
|
| 636 |
+
target_torch = to_tensor(target)
|
| 637 |
+
max_value = attrs["max"]
|
| 638 |
+
else:
|
| 639 |
+
target_torch = torch.tensor(0)
|
| 640 |
+
max_value = 0.0
|
| 641 |
+
|
| 642 |
+
kspace_torch = to_tensor(kspace)
|
| 643 |
+
seed = None if not self.use_seed else tuple(map(ord, fname))
|
| 644 |
+
acq_start = attrs["padding_left"]
|
| 645 |
+
acq_end = attrs["padding_right"]
|
| 646 |
+
|
| 647 |
+
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 648 |
+
|
| 649 |
+
# choose one of the masking functions provided randomly
|
| 650 |
+
mask_func = random.choice(self.mask_funcs)
|
| 651 |
+
|
| 652 |
+
masked_kspace, mask_torch, num_low_frequencies = apply_mask(
|
| 653 |
+
kspace_torch,
|
| 654 |
+
mask_func,
|
| 655 |
+
seed=seed,
|
| 656 |
+
padding=(acq_start, acq_end),
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# print(masked_kspace.shape)
|
| 660 |
+
# print(mask_torch.shape)
|
| 661 |
+
|
| 662 |
+
# torch.save(masked_kspace, f"masked_kspace_{slice_num}.pkl")
|
| 663 |
+
# torch.save(mask_torch, f"mask_torch_{slice_num}.pkl")
|
| 664 |
+
|
| 665 |
+
sample = VarNetSample(
|
| 666 |
+
masked_kspace=masked_kspace,
|
| 667 |
+
mask=mask_torch.to(torch.bool),
|
| 668 |
+
num_low_frequencies=num_low_frequencies,
|
| 669 |
+
target=target_torch,
|
| 670 |
+
fname=fname,
|
| 671 |
+
slice_num=slice_num,
|
| 672 |
+
max_value=max_value,
|
| 673 |
+
crop_size=crop_size,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# whether to crop samples for batch processing
|
| 677 |
+
batch_crop = False
|
| 678 |
+
|
| 679 |
+
if batch_crop:
|
| 680 |
+
# crop kspace data to minx, miny size (640, 320 cols)
|
| 681 |
+
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 682 |
+
# print(square_crop)
|
| 683 |
+
cropped_kspace = fastmri.fft2c(
|
| 684 |
+
complex_center_crop(
|
| 685 |
+
fastmri.ifft2c(sample.masked_kspace), square_crop
|
| 686 |
+
)
|
| 687 |
+
)
|
| 688 |
+
# cropped_kspace = complex_center_crop(cropped_kspace, (640, 320))
|
| 689 |
+
|
| 690 |
+
# exit(0)
|
| 691 |
+
|
| 692 |
+
# crop mask shape
|
| 693 |
+
h_from = (mask_torch.shape[-2] - 320) // 2
|
| 694 |
+
h_to = h_from + 320
|
| 695 |
+
cropped_mask = mask_torch[..., :, h_from:h_to, :]
|
| 696 |
+
|
| 697 |
+
sample = VarNetSample(
|
| 698 |
+
masked_kspace=cropped_kspace,
|
| 699 |
+
mask=cropped_mask.to(torch.bool),
|
| 700 |
+
num_low_frequencies=0,
|
| 701 |
+
target=target_torch,
|
| 702 |
+
fname=fname,
|
| 703 |
+
slice_num=slice_num,
|
| 704 |
+
max_value=max_value,
|
| 705 |
+
crop_size=crop_size,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
return sample
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
class MiniCoilSample(NamedTuple):
|
| 712 |
+
"""
|
| 713 |
+
A sample of masked coil-compressed k-space for reconstruction.
|
| 714 |
+
|
| 715 |
+
Args:
|
| 716 |
+
kspace: the original k-space before masking.
|
| 717 |
+
masked_kspace: k-space after applying sampling mask.
|
| 718 |
+
mask: The applied sampling mask.
|
| 719 |
+
num_low_frequencies: The number of samples for the densely-sampled
|
| 720 |
+
center.
|
| 721 |
+
target: The target image (if applicable).
|
| 722 |
+
fname: File name.
|
| 723 |
+
slice_num: The slice index.
|
| 724 |
+
max_value: Maximum image value.
|
| 725 |
+
crop_size: The size to crop the final image.
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
kspace: torch.Tensor
|
| 729 |
+
masked_kspace: torch.Tensor
|
| 730 |
+
mask: torch.Tensor
|
| 731 |
+
target: torch.Tensor
|
| 732 |
+
fname: str
|
| 733 |
+
slice_num: int
|
| 734 |
+
max_value: float
|
| 735 |
+
crop_size: Tuple[int, int]
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class MiniCoilTransform:
|
| 739 |
+
"""
|
| 740 |
+
Multi-coil compressed transform, for faster prototyping.
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
def __init__(
|
| 744 |
+
self,
|
| 745 |
+
mask_func: Optional[MaskFunc] = None,
|
| 746 |
+
use_seed: Optional[bool] = True,
|
| 747 |
+
crop_size: Optional[tuple] = None,
|
| 748 |
+
num_compressed_coils: Optional[int] = None,
|
| 749 |
+
):
|
| 750 |
+
"""
|
| 751 |
+
Args:
|
| 752 |
+
mask_func: Optional; A function that can create a mask of
|
| 753 |
+
appropriate shape. Defaults to None.
|
| 754 |
+
use_seed: If True, this class computes a pseudo random number
|
| 755 |
+
generator seed from the filename. This ensures that the same
|
| 756 |
+
mask is used for all the slices of a given volume every time.
|
| 757 |
+
crop_size: Image dimensions for mini MR images.
|
| 758 |
+
num_compressed_coils: Number of coils to output from coil
|
| 759 |
+
compression.
|
| 760 |
+
"""
|
| 761 |
+
self.mask_func = mask_func
|
| 762 |
+
self.use_seed = use_seed
|
| 763 |
+
self.crop_size = crop_size
|
| 764 |
+
self.num_compressed_coils = num_compressed_coils
|
| 765 |
+
|
| 766 |
+
def __call__(self, kspace, mask, target, attrs, fname, slice_num):
|
| 767 |
+
"""
|
| 768 |
+
Args:
|
| 769 |
+
kspace: Input k-space of shape (num_coils, rows, cols) for
|
| 770 |
+
multi-coil data.
|
| 771 |
+
mask: Mask from the test dataset. Not used if mask_func is defined.
|
| 772 |
+
target: Target image.
|
| 773 |
+
attrs: Acquisition related information stored in the HDF5 object.
|
| 774 |
+
fname: File name.
|
| 775 |
+
slice_num: Serial number of the slice.
|
| 776 |
+
|
| 777 |
+
Returns:
|
| 778 |
+
tuple containing:
|
| 779 |
+
kspace: original kspace (used for active acquisition only).
|
| 780 |
+
masked_kspace: k-space after applying sampling mask. If there
|
| 781 |
+
is no mask or mask_func, returns same as kspace.
|
| 782 |
+
mask: The applied sampling mask
|
| 783 |
+
target: The target image (if applicable). The target is built
|
| 784 |
+
from the RSS opp of all coils pre-compression.
|
| 785 |
+
fname: File name.
|
| 786 |
+
slice_num: The slice index.
|
| 787 |
+
max_value: Maximum image value.
|
| 788 |
+
crop_size: The size to crop the final image.
|
| 789 |
+
"""
|
| 790 |
+
if target is not None:
|
| 791 |
+
target = to_tensor(target)
|
| 792 |
+
max_value = attrs["max"]
|
| 793 |
+
else:
|
| 794 |
+
target = torch.tensor(0)
|
| 795 |
+
max_value = 0.0
|
| 796 |
+
|
| 797 |
+
if self.crop_size is None:
|
| 798 |
+
crop_size = torch.tensor(
|
| 799 |
+
[attrs["recon_size"][0], attrs["recon_size"][1]]
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
if isinstance(self.crop_size, tuple) or isinstance(
|
| 803 |
+
self.crop_size, list
|
| 804 |
+
):
|
| 805 |
+
assert len(self.crop_size) == 2
|
| 806 |
+
if self.crop_size[0] is None or self.crop_size[1] is None:
|
| 807 |
+
crop_size = torch.tensor(
|
| 808 |
+
[attrs["recon_size"][0], attrs["recon_size"][1]]
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
crop_size = torch.tensor(self.crop_size)
|
| 812 |
+
elif isinstance(self.crop_size, int):
|
| 813 |
+
crop_size = torch.tensor((self.crop_size, self.crop_size))
|
| 814 |
+
else:
|
| 815 |
+
raise ValueError(
|
| 816 |
+
"`crop_size` should be None, tuple, list, or int, not:"
|
| 817 |
+
f" {type(self.crop_size)}"
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
if self.num_compressed_coils is None:
|
| 821 |
+
num_compressed_coils = kspace.shape[0]
|
| 822 |
+
else:
|
| 823 |
+
num_compressed_coils = self.num_compressed_coils
|
| 824 |
+
|
| 825 |
+
seed = None if not self.use_seed else tuple(map(ord, fname))
|
| 826 |
+
acq_start = 0
|
| 827 |
+
acq_end = crop_size[1]
|
| 828 |
+
|
| 829 |
+
# new cropping section
|
| 830 |
+
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1])
|
| 831 |
+
kspace = fastmri.fft2c(
|
| 832 |
+
complex_center_crop(fastmri.ifft2c(to_tensor(kspace)), square_crop)
|
| 833 |
+
).numpy()
|
| 834 |
+
kspace = complex_center_crop(kspace, crop_size)
|
| 835 |
+
|
| 836 |
+
# we calculate the target before coil compression. This causes the mini
|
| 837 |
+
# simulation to be one where we have a 15-coil, low-resolution image
|
| 838 |
+
# and our reconstructor has an SVD coil approximation. This is a little
|
| 839 |
+
# bit more realistic than doing the target after SVD compression
|
| 840 |
+
target = fastmri.rss_complex(fastmri.ifft2c(to_tensor(kspace)))
|
| 841 |
+
max_value = target.max()
|
| 842 |
+
|
| 843 |
+
# apply coil compression
|
| 844 |
+
new_shape = (num_compressed_coils,) + kspace.shape[1:]
|
| 845 |
+
kspace = np.reshape(kspace, (kspace.shape[0], -1))
|
| 846 |
+
left_vec, _, _ = np.linalg.svd(
|
| 847 |
+
kspace, compute_uv=True, full_matrices=False
|
| 848 |
+
)
|
| 849 |
+
kspace = np.reshape(
|
| 850 |
+
np.array(np.matrix(left_vec[:, :num_compressed_coils]).H @ kspace),
|
| 851 |
+
new_shape,
|
| 852 |
+
)
|
| 853 |
+
kspace = to_tensor(kspace)
|
| 854 |
+
|
| 855 |
+
# Mask kspace
|
| 856 |
+
if self.mask_func:
|
| 857 |
+
masked_kspace, mask, _ = apply_mask(
|
| 858 |
+
kspace, self.mask_func, seed, (acq_start, acq_end)
|
| 859 |
+
)
|
| 860 |
+
mask = mask.byte()
|
| 861 |
+
elif mask is not None:
|
| 862 |
+
masked_kspace = kspace
|
| 863 |
+
shape = np.array(kspace.shape)
|
| 864 |
+
num_cols = shape[-2]
|
| 865 |
+
shape[:-3] = 1
|
| 866 |
+
mask_shape = [1] * len(shape)
|
| 867 |
+
mask_shape[-2] = num_cols
|
| 868 |
+
mask = torch.from_numpy(
|
| 869 |
+
mask.reshape(*mask_shape).astype(np.float32)
|
| 870 |
+
)
|
| 871 |
+
mask = mask.reshape(*mask_shape)
|
| 872 |
+
mask = mask.byte()
|
| 873 |
+
else:
|
| 874 |
+
masked_kspace = kspace
|
| 875 |
+
shape = np.array(kspace.shape)
|
| 876 |
+
num_cols = shape[-2]
|
| 877 |
+
|
| 878 |
+
return MiniCoilSample(
|
| 879 |
+
kspace,
|
| 880 |
+
masked_kspace,
|
| 881 |
+
mask,
|
| 882 |
+
target,
|
| 883 |
+
fname,
|
| 884 |
+
slice_num,
|
| 885 |
+
max_value,
|
| 886 |
+
crop_size,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
"""
|
| 891 |
+
sens maps & feature transformations
|
| 892 |
+
- expand
|
| 893 |
+
- reduce
|
| 894 |
+
- batch -> chan
|
| 895 |
+
- chan -> batch
|
| 896 |
+
"""
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 900 |
+
"""
|
| 901 |
+
Calculates F (x sens_maps)
|
| 902 |
+
|
| 903 |
+
Parameters
|
| 904 |
+
----------
|
| 905 |
+
x : ndarray
|
| 906 |
+
Single-channel image of shape (..., H, W, 2)
|
| 907 |
+
sens_maps : ndarray
|
| 908 |
+
Sensitivity maps (image space)
|
| 909 |
+
|
| 910 |
+
Returns
|
| 911 |
+
-------
|
| 912 |
+
ndarray
|
| 913 |
+
Result of the operation F (x sens_maps)
|
| 914 |
+
"""
|
| 915 |
+
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 919 |
+
"""
|
| 920 |
+
Calculates F^{-1}(k) * conj(sens_maps)
|
| 921 |
+
where conj(sens_maps) is the element-wise applied complex conjugate
|
| 922 |
+
|
| 923 |
+
Parameters
|
| 924 |
+
----------
|
| 925 |
+
k : ndarray
|
| 926 |
+
Multi-channel k-space of shape (B, C, H, W, 2)
|
| 927 |
+
sens_maps : ndarray
|
| 928 |
+
Sensitivity maps (image space)
|
| 929 |
+
|
| 930 |
+
Returns
|
| 931 |
+
-------
|
| 932 |
+
ndarray
|
| 933 |
+
Result of the operation F^{-1}(k) * conj(sens_maps)
|
| 934 |
+
"""
|
| 935 |
+
return fastmri.complex_mul(
|
| 936 |
+
fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)
|
| 937 |
+
).sum(dim=1, keepdim=True)
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 941 |
+
"""Reshapes batched multi-channel samples into multiple single channel samples.
|
| 942 |
+
|
| 943 |
+
Parameters
|
| 944 |
+
----------
|
| 945 |
+
x : torch.Tensor
|
| 946 |
+
x has shape (b, c, h, w, 2)
|
| 947 |
+
|
| 948 |
+
Returns
|
| 949 |
+
-------
|
| 950 |
+
Tuple[torch.Tensor, int]
|
| 951 |
+
tensor of shape (b * c, 1, h, w, 2), b
|
| 952 |
+
"""
|
| 953 |
+
b, c, h, w, comp = x.shape
|
| 954 |
+
return x.view(b * c, 1, h, w, comp), b
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 958 |
+
"""Reshapes batched independent samples into original multi-channel samples.
|
| 959 |
+
|
| 960 |
+
Parameters
|
| 961 |
+
----------
|
| 962 |
+
x : torch.Tensor
|
| 963 |
+
tensor of shape (b * c, 1, h, w, 2)
|
| 964 |
+
batch_size : int
|
| 965 |
+
batch size
|
| 966 |
+
|
| 967 |
+
Returns
|
| 968 |
+
-------
|
| 969 |
+
torch.Tensor
|
| 970 |
+
original multi-channel tensor of shape (b, c, h, w, 2)
|
| 971 |
+
"""
|
| 972 |
+
bc, _, h, w, comp = x.shape
|
| 973 |
+
c = bc // batch_size
|
| 974 |
+
return x.view(batch_size, c, h, w, comp)
|
models/lightning/mri_module.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified for use in <TODO: paper name>
|
| 3 |
+
- minified and removed extraneous abstractions
|
| 4 |
+
- updated to latest version of lightning
|
| 5 |
+
|
| 6 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 7 |
+
|
| 8 |
+
This source code is licensed under the MIT license found in the
|
| 9 |
+
LICENSE file in the root directory of this source tree.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
import pathlib
|
| 15 |
+
import os
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import wandb
|
| 21 |
+
import lightning as L
|
| 22 |
+
import torch
|
| 23 |
+
from torchmetrics.metric import Metric
|
| 24 |
+
import matplotlib
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
from PIL import Image
|
| 27 |
+
|
| 28 |
+
matplotlib.use("Agg")
|
| 29 |
+
|
| 30 |
+
from fastmri import evaluate
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DistributedMetricSum(Metric):
|
| 34 |
+
def __init__(self, dist_sync_on_step=True):
|
| 35 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
| 36 |
+
|
| 37 |
+
self.add_state(
|
| 38 |
+
"quantity", default=torch.tensor(0.0), dist_reduce_fx="sum"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def update(self, batch: torch.Tensor): # type: ignore
|
| 42 |
+
self.quantity += batch
|
| 43 |
+
|
| 44 |
+
def compute(self):
|
| 45 |
+
return self.quantity
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MriModule(L.LightningModule):
|
| 49 |
+
"""
|
| 50 |
+
Abstract super class for deep learning reconstruction models.
|
| 51 |
+
|
| 52 |
+
This is a subclass of the LightningModule class from lightning,
|
| 53 |
+
with some additional functionality specific to fastMRI:
|
| 54 |
+
- Evaluating reconstructions
|
| 55 |
+
- Visualization
|
| 56 |
+
|
| 57 |
+
To implement a new reconstruction model, inherit from this class and
|
| 58 |
+
implement the following methods:
|
| 59 |
+
- training_step: Define what happens in one step of training
|
| 60 |
+
- validation_step: Define what happens in one step of validation
|
| 61 |
+
- test_step: Define what happens in one step of testing
|
| 62 |
+
- configure_optimizers: Create and return the optimizers
|
| 63 |
+
|
| 64 |
+
Other methods from LightningModule can be overridden as needed.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, num_log_images: int = 16):
|
| 68 |
+
"""
|
| 69 |
+
Initialize the MRI module.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
num_log_images : int, optional
|
| 74 |
+
Number of images to log. Defaults to 16.
|
| 75 |
+
"""
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.num_log_images = num_log_images
|
| 79 |
+
self.val_log_indices = [1, 2, 3, 4, 5]
|
| 80 |
+
self.val_batch_results = []
|
| 81 |
+
|
| 82 |
+
self.NMSE = DistributedMetricSum()
|
| 83 |
+
self.SSIM = DistributedMetricSum()
|
| 84 |
+
self.PSNR = DistributedMetricSum()
|
| 85 |
+
self.ValLoss = DistributedMetricSum()
|
| 86 |
+
self.TotExamples = DistributedMetricSum()
|
| 87 |
+
self.TotSliceExamples = DistributedMetricSum()
|
| 88 |
+
|
| 89 |
+
def log_image(self, name, image):
|
| 90 |
+
if self.logger != None:
|
| 91 |
+
self.logger.log_image(
|
| 92 |
+
key=f"{name}", images=[image], caption=[{self.global_step}]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def on_validation_batch_end(
|
| 96 |
+
self, outputs, batch, batch_idx, dataloader_idx=0
|
| 97 |
+
):
|
| 98 |
+
# breakpoint()
|
| 99 |
+
val_logs = outputs
|
| 100 |
+
|
| 101 |
+
mse_vals = defaultdict(dict)
|
| 102 |
+
target_norms = defaultdict(dict)
|
| 103 |
+
ssim_vals = defaultdict(dict)
|
| 104 |
+
max_vals = dict()
|
| 105 |
+
|
| 106 |
+
for i, fname in enumerate(val_logs["fname"]):
|
| 107 |
+
if i == 0 and batch_idx in self.val_log_indices:
|
| 108 |
+
key = f"val_images_idx_{batch_idx}"
|
| 109 |
+
target = val_logs["target"][i].unsqueeze(0)
|
| 110 |
+
output = val_logs["output"][i].unsqueeze(0)
|
| 111 |
+
error = torch.abs(target - output)
|
| 112 |
+
output = output / output.max()
|
| 113 |
+
target = target / target.max()
|
| 114 |
+
error = error / error.max()
|
| 115 |
+
self.log_image(f"{key}/target", target)
|
| 116 |
+
self.log_image(f"{key}/reconstruction", output)
|
| 117 |
+
self.log_image(f"{key}/error", error)
|
| 118 |
+
slice_num = int(val_logs["slice_num"][i].cpu())
|
| 119 |
+
|
| 120 |
+
maxval = val_logs["max_value"][i].cpu().numpy()
|
| 121 |
+
output = val_logs["output"][i].cpu().numpy()
|
| 122 |
+
target = val_logs["target"][i].cpu().numpy()
|
| 123 |
+
mse_vals[fname][slice_num] = torch.tensor(
|
| 124 |
+
evaluate.mse(target, output)
|
| 125 |
+
).view(1)
|
| 126 |
+
target_norms[fname][slice_num] = torch.tensor(
|
| 127 |
+
evaluate.mse(target, np.zeros_like(target))
|
| 128 |
+
).view(1)
|
| 129 |
+
ssim_vals[fname][slice_num] = torch.tensor(
|
| 130 |
+
evaluate.ssim(
|
| 131 |
+
target[None, ...], output[None, ...], maxval=maxval
|
| 132 |
+
)
|
| 133 |
+
).view(1)
|
| 134 |
+
max_vals[fname] = maxval
|
| 135 |
+
|
| 136 |
+
self.val_batch_results.append(
|
| 137 |
+
{
|
| 138 |
+
"slug": val_logs["slug"],
|
| 139 |
+
"val_loss": val_logs["val_loss"],
|
| 140 |
+
"mse_vals": dict(mse_vals),
|
| 141 |
+
"target_norms": dict(target_norms),
|
| 142 |
+
"ssim_vals": dict(ssim_vals),
|
| 143 |
+
"max_vals": max_vals,
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def on_validation_epoch_end(self):
|
| 148 |
+
val_logs = self.val_batch_results
|
| 149 |
+
|
| 150 |
+
dataset_metrics = defaultdict(
|
| 151 |
+
lambda: {
|
| 152 |
+
"losses": [],
|
| 153 |
+
"mse_vals": defaultdict(dict),
|
| 154 |
+
"target_norms": defaultdict(dict),
|
| 155 |
+
"ssim_vals": defaultdict(dict),
|
| 156 |
+
"max_vals": dict(),
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# use dict updates to handle duplicate slices
|
| 161 |
+
for val_log in val_logs:
|
| 162 |
+
slug = val_log["slug"]
|
| 163 |
+
dataset_metrics[slug]["losses"].append(val_log["val_loss"].view(-1))
|
| 164 |
+
|
| 165 |
+
for k in val_log["mse_vals"].keys():
|
| 166 |
+
dataset_metrics[slug]["mse_vals"][k].update(
|
| 167 |
+
val_log["mse_vals"][k]
|
| 168 |
+
)
|
| 169 |
+
for k in val_log["target_norms"].keys():
|
| 170 |
+
dataset_metrics[slug]["target_norms"][k].update(
|
| 171 |
+
val_log["target_norms"][k]
|
| 172 |
+
)
|
| 173 |
+
for k in val_log["ssim_vals"].keys():
|
| 174 |
+
dataset_metrics[slug]["ssim_vals"][k].update(
|
| 175 |
+
val_log["ssim_vals"][k]
|
| 176 |
+
)
|
| 177 |
+
for k in val_log["max_vals"]:
|
| 178 |
+
dataset_metrics[slug]["max_vals"][k] = val_log["max_vals"][k]
|
| 179 |
+
|
| 180 |
+
metrics_to_plot = {"psnr": [], "ssim": [], "nmse": []}
|
| 181 |
+
slugs = []
|
| 182 |
+
|
| 183 |
+
for slug, metrics_data in dataset_metrics.items():
|
| 184 |
+
mse_vals, target_norms, ssim_vals, max_vals, losses = (
|
| 185 |
+
metrics_data["mse_vals"],
|
| 186 |
+
metrics_data["target_norms"],
|
| 187 |
+
metrics_data["ssim_vals"],
|
| 188 |
+
metrics_data["max_vals"],
|
| 189 |
+
metrics_data["losses"],
|
| 190 |
+
)
|
| 191 |
+
# check to make sure we have all files in all metrics
|
| 192 |
+
assert (
|
| 193 |
+
mse_vals.keys()
|
| 194 |
+
== target_norms.keys()
|
| 195 |
+
== ssim_vals.keys()
|
| 196 |
+
== max_vals.keys()
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# apply means across image volumes
|
| 200 |
+
metrics = {"nmse": 0, "ssim": 0, "psnr": 0}
|
| 201 |
+
metric_values = {
|
| 202 |
+
"nmse": [],
|
| 203 |
+
"ssim": [],
|
| 204 |
+
"psnr": [],
|
| 205 |
+
} # to store individual values for std
|
| 206 |
+
local_examples = 0
|
| 207 |
+
|
| 208 |
+
for fname in mse_vals.keys():
|
| 209 |
+
local_examples = local_examples + 1
|
| 210 |
+
mse_val = torch.mean(
|
| 211 |
+
torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])
|
| 212 |
+
)
|
| 213 |
+
target_norm = torch.mean(
|
| 214 |
+
torch.cat(
|
| 215 |
+
[v.view(-1) for _, v in target_norms[fname].items()]
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
nmse = mse_val / target_norm
|
| 219 |
+
psnr = 20 * torch.log10(
|
| 220 |
+
torch.tensor(
|
| 221 |
+
max_vals[fname],
|
| 222 |
+
dtype=mse_val.dtype,
|
| 223 |
+
device=mse_val.device,
|
| 224 |
+
)
|
| 225 |
+
) - 10 * torch.log10(mse_val)
|
| 226 |
+
ssim = torch.mean(
|
| 227 |
+
torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()])
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Accumulate metric values
|
| 231 |
+
metrics["nmse"] += nmse
|
| 232 |
+
metrics["psnr"] += psnr
|
| 233 |
+
metrics["ssim"] += ssim
|
| 234 |
+
|
| 235 |
+
# Store individual metric values for std calculation
|
| 236 |
+
metric_values["nmse"].append(nmse)
|
| 237 |
+
metric_values["psnr"].append(psnr)
|
| 238 |
+
metric_values["ssim"].append(ssim)
|
| 239 |
+
|
| 240 |
+
# reduce across ddp via sum
|
| 241 |
+
metrics["nmse"] = self.NMSE(metrics["nmse"])
|
| 242 |
+
metrics["ssim"] = self.SSIM(metrics["ssim"])
|
| 243 |
+
metrics["psnr"] = self.PSNR(metrics["psnr"])
|
| 244 |
+
|
| 245 |
+
tot_examples = self.TotExamples(torch.tensor(local_examples))
|
| 246 |
+
val_loss = self.ValLoss(torch.sum(torch.cat(losses))) # type: ignore
|
| 247 |
+
tot_slice_examples = self.TotSliceExamples(
|
| 248 |
+
torch.tensor(len(losses), dtype=torch.float)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
metrics_to_plot["nmse"].append(
|
| 252 |
+
(
|
| 253 |
+
(metrics["nmse"] / tot_examples).item(),
|
| 254 |
+
torch.std(torch.stack(metric_values["nmse"])).item(),
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
metrics_to_plot["psnr"].append(
|
| 258 |
+
(
|
| 259 |
+
(metrics["psnr"] / tot_examples).item(),
|
| 260 |
+
torch.std(torch.stack(metric_values["psnr"])).item(),
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
metrics_to_plot["ssim"].append(
|
| 264 |
+
(
|
| 265 |
+
(metrics["ssim"] / tot_examples).item(),
|
| 266 |
+
torch.std(torch.stack(metric_values["ssim"])).item(),
|
| 267 |
+
)
|
| 268 |
+
)
|
| 269 |
+
slugs.append(slug)
|
| 270 |
+
|
| 271 |
+
# Log the mean values
|
| 272 |
+
self.log(
|
| 273 |
+
f"{slug}--validation_loss",
|
| 274 |
+
val_loss / tot_slice_examples,
|
| 275 |
+
prog_bar=True,
|
| 276 |
+
)
|
| 277 |
+
for metric, value in metrics.items():
|
| 278 |
+
self.log(f"{slug}--val_metrics_{metric}", value / tot_examples)
|
| 279 |
+
|
| 280 |
+
# Calculate and log the standard deviation for each metric
|
| 281 |
+
for metric, values in metric_values.items():
|
| 282 |
+
std_value = torch.std(torch.stack(values))
|
| 283 |
+
self.log(f"{slug}--val_metrics_{metric}_std", std_value)
|
| 284 |
+
|
| 285 |
+
# generate graph
|
| 286 |
+
# breakpoint()
|
| 287 |
+
for metric_name, values in metrics_to_plot.items():
|
| 288 |
+
scores = [val[0] for val in values]
|
| 289 |
+
std_devs = [val[1] for val in values]
|
| 290 |
+
|
| 291 |
+
plt.figure(figsize=(10, 6))
|
| 292 |
+
plt.bar(slugs, scores, yerr=std_devs, capsize=5)
|
| 293 |
+
plt.xlabel("Dataset Slug")
|
| 294 |
+
plt.ylabel(f"{metric_name.upper()} Score")
|
| 295 |
+
plt.title(
|
| 296 |
+
f"{metric_name.upper()} per Dataset with Standard Deviation"
|
| 297 |
+
)
|
| 298 |
+
plt.xticks(rotation=45)
|
| 299 |
+
plt.tight_layout()
|
| 300 |
+
|
| 301 |
+
# Save the plot
|
| 302 |
+
buf = BytesIO()
|
| 303 |
+
plt.savefig(buf, format="png")
|
| 304 |
+
buf.seek(0)
|
| 305 |
+
image = Image.open(buf)
|
| 306 |
+
image_array = np.array(image)
|
| 307 |
+
self.log_image(f"summary_plot_{metric_name}", image_array)
|
| 308 |
+
buf.close()
|
| 309 |
+
plt.close()
|
| 310 |
+
|
| 311 |
+
def OLD_on_validation_epoch_end(self):
|
| 312 |
+
val_logs = self.val_batch_results
|
| 313 |
+
|
| 314 |
+
# aggregate losses
|
| 315 |
+
losses = []
|
| 316 |
+
mse_vals = defaultdict(dict)
|
| 317 |
+
target_norms = defaultdict(dict)
|
| 318 |
+
ssim_vals = defaultdict(dict)
|
| 319 |
+
max_vals = dict()
|
| 320 |
+
|
| 321 |
+
# use dict updates to handle duplicate slices
|
| 322 |
+
for val_log in val_logs:
|
| 323 |
+
losses.append(val_log["val_loss"].view(-1))
|
| 324 |
+
|
| 325 |
+
for k in val_log["mse_vals"].keys():
|
| 326 |
+
mse_vals[k].update(val_log["mse_vals"][k])
|
| 327 |
+
for k in val_log["target_norms"].keys():
|
| 328 |
+
target_norms[k].update(val_log["target_norms"][k])
|
| 329 |
+
for k in val_log["ssim_vals"].keys():
|
| 330 |
+
ssim_vals[k].update(val_log["ssim_vals"][k])
|
| 331 |
+
for k in val_log["max_vals"]:
|
| 332 |
+
max_vals[k] = val_log["max_vals"][k]
|
| 333 |
+
|
| 334 |
+
# check to make sure we have all files in all metrics
|
| 335 |
+
assert (
|
| 336 |
+
mse_vals.keys()
|
| 337 |
+
== target_norms.keys()
|
| 338 |
+
== ssim_vals.keys()
|
| 339 |
+
== max_vals.keys()
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# apply means across image volumes
|
| 343 |
+
metrics = {"nmse": 0, "ssim": 0, "psnr": 0}
|
| 344 |
+
local_examples = 0
|
| 345 |
+
for fname in mse_vals.keys():
|
| 346 |
+
local_examples = local_examples + 1
|
| 347 |
+
mse_val = torch.mean(
|
| 348 |
+
torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])
|
| 349 |
+
)
|
| 350 |
+
target_norm = torch.mean(
|
| 351 |
+
torch.cat([v.view(-1) for _, v in target_norms[fname].items()])
|
| 352 |
+
)
|
| 353 |
+
metrics["nmse"] = metrics["nmse"] + mse_val / target_norm
|
| 354 |
+
metrics["psnr"] = (
|
| 355 |
+
metrics["psnr"]
|
| 356 |
+
+ 20
|
| 357 |
+
* torch.log10(
|
| 358 |
+
torch.tensor(
|
| 359 |
+
max_vals[fname],
|
| 360 |
+
dtype=mse_val.dtype,
|
| 361 |
+
device=mse_val.device,
|
| 362 |
+
)
|
| 363 |
+
)
|
| 364 |
+
- 10 * torch.log10(mse_val)
|
| 365 |
+
)
|
| 366 |
+
metrics["ssim"] = metrics["ssim"] + torch.mean(
|
| 367 |
+
torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()])
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# reduce across ddp via sum
|
| 371 |
+
metrics["nmse"] = self.NMSE(metrics["nmse"])
|
| 372 |
+
metrics["ssim"] = self.SSIM(metrics["ssim"])
|
| 373 |
+
metrics["psnr"] = self.PSNR(metrics["psnr"])
|
| 374 |
+
|
| 375 |
+
tot_examples = self.TotExamples(torch.tensor(local_examples))
|
| 376 |
+
val_loss = self.ValLoss(torch.sum(torch.cat(losses)))
|
| 377 |
+
tot_slice_examples = self.TotSliceExamples(
|
| 378 |
+
torch.tensor(len(losses), dtype=torch.float)
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.log(
|
| 382 |
+
"validation_loss", val_loss / tot_slice_examples, prog_bar=True
|
| 383 |
+
)
|
| 384 |
+
for metric, value in metrics.items():
|
| 385 |
+
self.log(f"val_metrics_{metric}", value / tot_examples)
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
def add_model_specific_args(parent_parser): # pragma: no-cover
|
| 389 |
+
"""
|
| 390 |
+
Define parameters that only apply to this model
|
| 391 |
+
"""
|
| 392 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 393 |
+
|
| 394 |
+
# logging params
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--num_log_images",
|
| 397 |
+
default=16,
|
| 398 |
+
type=int,
|
| 399 |
+
help="Number of images to log to Tensorboard",
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return parser
|
models/lightning/no_shared_module.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import fastmri
|
| 7 |
+
from fastmri import transforms
|
| 8 |
+
from models.no_shared import NOShared
|
| 9 |
+
|
| 10 |
+
from models.lightning.mri_module import MriModule
|
| 11 |
+
from type_utils import tuple_type
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NOSharedModule(MriModule):
|
| 15 |
+
"""
|
| 16 |
+
NO-Shared training module.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
num_cascades: int = 12,
|
| 22 |
+
pools: int = 4,
|
| 23 |
+
chans: int = 18,
|
| 24 |
+
sens_pools: int = 4,
|
| 25 |
+
sens_chans: int = 8,
|
| 26 |
+
gno_pools: int = 4,
|
| 27 |
+
gno_chans: int = 16,
|
| 28 |
+
gno_radius_cutoff: float = 0.02,
|
| 29 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 30 |
+
radius_cutoff: float = 0.02,
|
| 31 |
+
kernel_shape: Tuple[int, int] = (6, 7),
|
| 32 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 33 |
+
use_dc_term: bool = True,
|
| 34 |
+
lr: float = 0.0003,
|
| 35 |
+
lr_step_size: int = 40,
|
| 36 |
+
lr_gamma: float = 0.1,
|
| 37 |
+
weight_decay: float = 0.0,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
num_cascades : int
|
| 44 |
+
Number of cascades (i.e., layers) for the variational network.
|
| 45 |
+
pools : int
|
| 46 |
+
Number of downsampling and upsampling layers for the cascade U-Net.
|
| 47 |
+
chans : int
|
| 48 |
+
Number of channels for the cascade U-Net.
|
| 49 |
+
sens_pools : int
|
| 50 |
+
Number of downsampling and upsampling layers for the sensitivity map U-Net.
|
| 51 |
+
sens_chans : int
|
| 52 |
+
Number of channels for the sensitivity map U-Net.
|
| 53 |
+
lr : float
|
| 54 |
+
Learning rate.
|
| 55 |
+
lr_step_size : int
|
| 56 |
+
Learning rate step size.
|
| 57 |
+
lr_gamma : float
|
| 58 |
+
Learning rate gamma decay.
|
| 59 |
+
weight_decay : float
|
| 60 |
+
Parameter for penalizing weights norm.
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
self.save_hyperparameters()
|
| 64 |
+
|
| 65 |
+
self.num_cascades = num_cascades
|
| 66 |
+
self.pools = pools
|
| 67 |
+
self.chans = chans
|
| 68 |
+
self.sens_pools = sens_pools
|
| 69 |
+
self.sens_chans = sens_chans
|
| 70 |
+
self.gno_pools = gno_pools
|
| 71 |
+
self.gno_chans = gno_chans
|
| 72 |
+
self.gno_radius_cutoff = gno_radius_cutoff
|
| 73 |
+
self.gno_kernel_shape = gno_kernel_shape
|
| 74 |
+
self.radius_cutoff = radius_cutoff
|
| 75 |
+
self.kernel_shape = kernel_shape
|
| 76 |
+
self.in_shape = in_shape
|
| 77 |
+
self.use_dc_term = use_dc_term
|
| 78 |
+
self.lr = lr
|
| 79 |
+
self.lr_step_size = lr_step_size
|
| 80 |
+
self.lr_gamma = lr_gamma
|
| 81 |
+
self.weight_decay = weight_decay
|
| 82 |
+
|
| 83 |
+
self.model = NOShared(
|
| 84 |
+
num_cascades=self.num_cascades,
|
| 85 |
+
sens_chans=self.sens_chans,
|
| 86 |
+
sens_pools=self.sens_pools,
|
| 87 |
+
chans=self.chans,
|
| 88 |
+
pools=self.pools,
|
| 89 |
+
gno_chans=self.gno_chans,
|
| 90 |
+
gno_pools=self.gno_pools,
|
| 91 |
+
gno_radius_cutoff=self.gno_radius_cutoff,
|
| 92 |
+
gno_kernel_shape=self.gno_kernel_shape,
|
| 93 |
+
radius_cutoff=radius_cutoff,
|
| 94 |
+
kernel_shape=kernel_shape,
|
| 95 |
+
in_shape=in_shape,
|
| 96 |
+
use_dc_term=use_dc_term,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.criterion = fastmri.SSIMLoss()
|
| 100 |
+
self.num_params = sum(p.numel() for p in self.parameters())
|
| 101 |
+
|
| 102 |
+
def forward(self, masked_kspace, mask, num_low_frequencies):
|
| 103 |
+
return self.model(masked_kspace, mask, num_low_frequencies)
|
| 104 |
+
|
| 105 |
+
def training_step(self, batch, batch_idx):
|
| 106 |
+
output = self.forward(
|
| 107 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
target, output = transforms.center_crop_to_smallest(
|
| 111 |
+
batch.target, output
|
| 112 |
+
)
|
| 113 |
+
loss = self.criterion(
|
| 114 |
+
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True)
|
| 118 |
+
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
|
| 119 |
+
|
| 120 |
+
return loss
|
| 121 |
+
|
| 122 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 123 |
+
dataloaders = self.trainer.val_dataloaders
|
| 124 |
+
slug = list(dataloaders.keys())[dataloader_idx]
|
| 125 |
+
|
| 126 |
+
output = self.forward(
|
| 127 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
target, output = transforms.center_crop_to_smallest(
|
| 131 |
+
batch.target, output
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
loss = self.criterion(
|
| 135 |
+
output.unsqueeze(1),
|
| 136 |
+
target.unsqueeze(1),
|
| 137 |
+
data_range=batch.max_value,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"slug": slug,
|
| 142 |
+
"fname": batch.fname,
|
| 143 |
+
"slice_num": batch.slice_num,
|
| 144 |
+
"max_value": batch.max_value,
|
| 145 |
+
"output": output,
|
| 146 |
+
"target": target,
|
| 147 |
+
"val_loss": loss,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
def configure_optimizers(self):
|
| 151 |
+
optim = torch.optim.Adam(
|
| 152 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 153 |
+
)
|
| 154 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 155 |
+
optim, self.lr_step_size, self.lr_gamma
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return [optim], [scheduler]
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def add_model_specific_args(parent_parser):
|
| 162 |
+
"""
|
| 163 |
+
Define parameters that only apply to this model
|
| 164 |
+
"""
|
| 165 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 166 |
+
parser = MriModule.add_model_specific_args(parser)
|
| 167 |
+
|
| 168 |
+
# network params
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--num_cascades",
|
| 171 |
+
default=12,
|
| 172 |
+
type=int,
|
| 173 |
+
help="Number of VarNet cascades",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--pools",
|
| 177 |
+
default=4,
|
| 178 |
+
type=int,
|
| 179 |
+
help="Number of U-Net pooling layers in VarNet blocks",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--chans",
|
| 183 |
+
default=18,
|
| 184 |
+
type=int,
|
| 185 |
+
help="Number of channels for U-Net in VarNet blocks",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--sens_pools",
|
| 189 |
+
default=4,
|
| 190 |
+
type=int,
|
| 191 |
+
help=(
|
| 192 |
+
"Number of pooling layers for sense map estimation U-Net in"
|
| 193 |
+
" VarNet"
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--sens_chans",
|
| 198 |
+
default=8,
|
| 199 |
+
type=float,
|
| 200 |
+
help="Number of channels for sense map estimation U-Net in VarNet",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--gno_pools",
|
| 204 |
+
default=4,
|
| 205 |
+
type=int,
|
| 206 |
+
help=("Number of pooling layers for GNO"),
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--gno_chans",
|
| 210 |
+
default=16,
|
| 211 |
+
type=int,
|
| 212 |
+
help="Number of channels for GNO",
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--gno_radius_cutoff",
|
| 216 |
+
default=0.02,
|
| 217 |
+
type=float,
|
| 218 |
+
help="GNO module radius_cutoff",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--gno_kernel_shape",
|
| 222 |
+
default=(6, 7),
|
| 223 |
+
type=tuple_type,
|
| 224 |
+
help="GNO module kernel_shape. Ex: (6, 7)",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--radius_cutoff",
|
| 228 |
+
default=0.02,
|
| 229 |
+
type=float,
|
| 230 |
+
help="DISCO module radius_cutoff",
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--kernel_shape",
|
| 234 |
+
default=(6, 7),
|
| 235 |
+
type=tuple_type,
|
| 236 |
+
help="DISCO module kernel_shape. Ex: (6, 7)",
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--in_shape",
|
| 240 |
+
default=(320, 320),
|
| 241 |
+
type=tuple_type,
|
| 242 |
+
help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--use_dc_term",
|
| 246 |
+
default=True,
|
| 247 |
+
type=bool,
|
| 248 |
+
help="Whether to use the DC term in the unrolled iterative update step",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# training params (opt)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--lr", default=0.0003, type=float, help="Adam learning rate"
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--lr_step_size",
|
| 257 |
+
default=40,
|
| 258 |
+
type=int,
|
| 259 |
+
help="Epoch at which to decrease step size",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--lr_gamma",
|
| 263 |
+
default=0.1,
|
| 264 |
+
type=float,
|
| 265 |
+
help="Extent to which step size should be decreased",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--weight_decay",
|
| 269 |
+
default=0.0,
|
| 270 |
+
type=float,
|
| 271 |
+
help="Strength of weight decay regularization",
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return parser
|
models/lightning/no_varnet_module.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import fastmri
|
| 7 |
+
from fastmri import transforms
|
| 8 |
+
from models.no_varnet import NOVarnet
|
| 9 |
+
|
| 10 |
+
from models.lightning.mri_module import MriModule
|
| 11 |
+
# from type_utils import tuple_type
|
| 12 |
+
|
| 13 |
+
def tuple_type(strings):
|
| 14 |
+
strings = strings.replace("(", "").replace(")", "").replace(" ", "")
|
| 15 |
+
mapped_int = map(int, strings.split(","))
|
| 16 |
+
return tuple(mapped_int)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class NOVarnetModule(MriModule):
|
| 20 |
+
"""
|
| 21 |
+
NO-Varnet training module.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
num_cascades: int = 12,
|
| 27 |
+
pools: int = 4,
|
| 28 |
+
chans: int = 18,
|
| 29 |
+
sens_pools: int = 4,
|
| 30 |
+
sens_chans: int = 8,
|
| 31 |
+
gno_pools: int = 4,
|
| 32 |
+
gno_chans: int = 16,
|
| 33 |
+
gno_radius_cutoff: float = 0.02,
|
| 34 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 35 |
+
radius_cutoff: float = 0.02,
|
| 36 |
+
kernel_shape: Tuple[int, int] = (6, 7),
|
| 37 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 38 |
+
use_dc_term: bool = True,
|
| 39 |
+
lr: float = 0.0003,
|
| 40 |
+
lr_step_size: int = 40,
|
| 41 |
+
lr_gamma: float = 0.1,
|
| 42 |
+
weight_decay: float = 0.0,
|
| 43 |
+
reduction_method: str = "rss",
|
| 44 |
+
skip_method: str = "add",
|
| 45 |
+
**kwargs,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
num_cascades : int
|
| 51 |
+
Number of cascades (i.e., layers) for the variational network.
|
| 52 |
+
pools : int
|
| 53 |
+
Number of downsampling and upsampling layers for the cascade U-Net.
|
| 54 |
+
chans : int
|
| 55 |
+
Number of channels for the cascade U-Net.
|
| 56 |
+
sens_pools : int
|
| 57 |
+
Number of downsampling and upsampling layers for the sensitivity map U-Net.
|
| 58 |
+
sens_chans : int
|
| 59 |
+
Number of channels for the sensitivity map U-Net.
|
| 60 |
+
lr : float
|
| 61 |
+
Learning rate.
|
| 62 |
+
lr_step_size : int
|
| 63 |
+
Learning rate step size.
|
| 64 |
+
lr_gamma : float
|
| 65 |
+
Learning rate gamma decay.
|
| 66 |
+
weight_decay : float
|
| 67 |
+
Parameter for penalizing weights norm.
|
| 68 |
+
"""
|
| 69 |
+
super().__init__(**kwargs)
|
| 70 |
+
self.save_hyperparameters()
|
| 71 |
+
|
| 72 |
+
self.num_cascades = num_cascades
|
| 73 |
+
self.pools = pools
|
| 74 |
+
self.chans = chans
|
| 75 |
+
self.sens_pools = sens_pools
|
| 76 |
+
self.sens_chans = sens_chans
|
| 77 |
+
self.gno_pools = gno_pools
|
| 78 |
+
self.gno_chans = gno_chans
|
| 79 |
+
self.gno_radius_cutoff = gno_radius_cutoff
|
| 80 |
+
self.gno_kernel_shape = gno_kernel_shape
|
| 81 |
+
self.radius_cutoff = radius_cutoff
|
| 82 |
+
self.kernel_shape = kernel_shape
|
| 83 |
+
self.in_shape = in_shape
|
| 84 |
+
self.use_dc_term = use_dc_term
|
| 85 |
+
self.lr = lr
|
| 86 |
+
self.lr_step_size = lr_step_size
|
| 87 |
+
self.lr_gamma = lr_gamma
|
| 88 |
+
self.weight_decay = weight_decay
|
| 89 |
+
self.reduction_method = reduction_method
|
| 90 |
+
self.skip_method = skip_method
|
| 91 |
+
|
| 92 |
+
self.model = NOVarnet(
|
| 93 |
+
num_cascades=self.num_cascades,
|
| 94 |
+
sens_chans=self.sens_chans,
|
| 95 |
+
sens_pools=self.sens_pools,
|
| 96 |
+
chans=self.chans,
|
| 97 |
+
pools=self.pools,
|
| 98 |
+
gno_chans=self.gno_chans,
|
| 99 |
+
gno_pools=self.gno_pools,
|
| 100 |
+
gno_radius_cutoff=self.gno_radius_cutoff,
|
| 101 |
+
gno_kernel_shape=self.gno_kernel_shape,
|
| 102 |
+
radius_cutoff=radius_cutoff,
|
| 103 |
+
kernel_shape=kernel_shape,
|
| 104 |
+
in_shape=in_shape,
|
| 105 |
+
use_dc_term=use_dc_term,
|
| 106 |
+
reduction_method=reduction_method,
|
| 107 |
+
skip_method=skip_method,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.criterion = fastmri.SSIMLoss()
|
| 111 |
+
self.num_params = sum(p.numel() for p in self.parameters())
|
| 112 |
+
|
| 113 |
+
def forward(self, masked_kspace, mask, num_low_frequencies):
|
| 114 |
+
return self.model(masked_kspace, mask, num_low_frequencies)
|
| 115 |
+
|
| 116 |
+
def training_step(self, batch, batch_idx):
|
| 117 |
+
output = self.forward(
|
| 118 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 122 |
+
loss = self.criterion(
|
| 123 |
+
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True)
|
| 127 |
+
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
|
| 128 |
+
|
| 129 |
+
return loss
|
| 130 |
+
|
| 131 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 132 |
+
dataloaders = self.trainer.val_dataloaders
|
| 133 |
+
slug = list(dataloaders.keys())[dataloader_idx]
|
| 134 |
+
|
| 135 |
+
output = self.forward(
|
| 136 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 140 |
+
|
| 141 |
+
loss = self.criterion(
|
| 142 |
+
output.unsqueeze(1),
|
| 143 |
+
target.unsqueeze(1),
|
| 144 |
+
data_range=batch.max_value,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"slug": slug,
|
| 149 |
+
"fname": batch.fname,
|
| 150 |
+
"slice_num": batch.slice_num,
|
| 151 |
+
"max_value": batch.max_value,
|
| 152 |
+
"output": output,
|
| 153 |
+
"target": target,
|
| 154 |
+
"val_loss": loss,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def configure_optimizers(self):
|
| 158 |
+
optim = torch.optim.Adam(
|
| 159 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 160 |
+
)
|
| 161 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 162 |
+
optim, self.lr_step_size, self.lr_gamma
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return [optim], [scheduler]
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def add_model_specific_args(parent_parser):
|
| 169 |
+
"""
|
| 170 |
+
Define parameters that only apply to this model
|
| 171 |
+
"""
|
| 172 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 173 |
+
parser = MriModule.add_model_specific_args(parser)
|
| 174 |
+
|
| 175 |
+
# network params
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--num_cascades",
|
| 178 |
+
default=12,
|
| 179 |
+
type=int,
|
| 180 |
+
help="Number of VarNet cascades",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--pools",
|
| 184 |
+
default=4,
|
| 185 |
+
type=int,
|
| 186 |
+
help="Number of U-Net pooling layers in VarNet blocks",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--chans",
|
| 190 |
+
default=18,
|
| 191 |
+
type=int,
|
| 192 |
+
help="Number of channels for U-Net in VarNet blocks",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--sens_pools",
|
| 196 |
+
default=4,
|
| 197 |
+
type=int,
|
| 198 |
+
help=(
|
| 199 |
+
"Number of pooling layers for sense map estimation U-Net in" " VarNet"
|
| 200 |
+
),
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--sens_chans",
|
| 204 |
+
default=8,
|
| 205 |
+
type=float,
|
| 206 |
+
help="Number of channels for sense map estimation U-Net in VarNet",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--gno_pools",
|
| 210 |
+
default=4,
|
| 211 |
+
type=int,
|
| 212 |
+
help=("Number of pooling layers for GNO"),
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--gno_chans",
|
| 216 |
+
default=16,
|
| 217 |
+
type=int,
|
| 218 |
+
help="Number of channels for GNO",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--gno_radius_cutoff",
|
| 222 |
+
default=0.02,
|
| 223 |
+
type=float,
|
| 224 |
+
required=True,
|
| 225 |
+
help="GNO module radius_cutoff",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--gno_kernel_shape",
|
| 229 |
+
default=(6, 7),
|
| 230 |
+
type=tuple_type,
|
| 231 |
+
required=True,
|
| 232 |
+
help="GNO module kernel_shape. Ex: (6, 7)",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--radius_cutoff",
|
| 236 |
+
default=0.01,
|
| 237 |
+
type=float,
|
| 238 |
+
required=True,
|
| 239 |
+
help="DISCO module radius_cutoff",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--kernel_shape",
|
| 243 |
+
default=(6, 7),
|
| 244 |
+
type=tuple_type,
|
| 245 |
+
required=True,
|
| 246 |
+
help="DISCO module kernel_shape. Ex: (6, 7)",
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--in_shape",
|
| 250 |
+
default=(640, 320),
|
| 251 |
+
type=tuple_type,
|
| 252 |
+
required=True,
|
| 253 |
+
help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--use_dc_term",
|
| 257 |
+
default=True,
|
| 258 |
+
type=bool,
|
| 259 |
+
help="Whether to use the DC term in the unrolled iterative update step",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# training params (opt)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--lr", default=0.0003, type=float, help="Adam learning rate"
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--lr_step_size",
|
| 268 |
+
default=40,
|
| 269 |
+
type=int,
|
| 270 |
+
help="Epoch at which to decrease step size",
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--lr_gamma",
|
| 274 |
+
default=0.1,
|
| 275 |
+
type=float,
|
| 276 |
+
help="Extent to which step size should be decreased",
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--weight_decay",
|
| 280 |
+
default=0.0,
|
| 281 |
+
type=float,
|
| 282 |
+
help="Strength of weight decay regularization",
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--reduction_method",
|
| 286 |
+
default="rss",
|
| 287 |
+
type=str,
|
| 288 |
+
choices=["rss", "batch"],
|
| 289 |
+
help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--skip_method",
|
| 293 |
+
default="add_inv",
|
| 294 |
+
type=str,
|
| 295 |
+
choices=["add_inv", "add", "concat", "replace"],
|
| 296 |
+
help="Method for skip connection around inpainting module.",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return parser
|
models/lightning/no_varnet_nokno_module.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import fastmri
|
| 7 |
+
from fastmri import transforms
|
| 8 |
+
|
| 9 |
+
from models.lightning.mri_module import MriModule
|
| 10 |
+
from models.no_varnet_nokno import NOVarnet_no_KNO
|
| 11 |
+
from type_utils import tuple_type
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NOVarnet_no_KNOModule(MriModule):
|
| 15 |
+
"""
|
| 16 |
+
NO-Varnet training module.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
num_cascades: int = 12,
|
| 22 |
+
pools: int = 4,
|
| 23 |
+
chans: int = 18,
|
| 24 |
+
sens_pools: int = 4,
|
| 25 |
+
sens_chans: int = 8,
|
| 26 |
+
gno_pools: int = 4,
|
| 27 |
+
gno_chans: int = 16,
|
| 28 |
+
gno_radius_cutoff: float = 0.02,
|
| 29 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 30 |
+
radius_cutoff: float = 0.02,
|
| 31 |
+
kernel_shape: Tuple[int, int] = (6, 7),
|
| 32 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 33 |
+
use_dc_term: bool = True,
|
| 34 |
+
lr: float = 0.0003,
|
| 35 |
+
lr_step_size: int = 40,
|
| 36 |
+
lr_gamma: float = 0.1,
|
| 37 |
+
weight_decay: float = 0.0,
|
| 38 |
+
reduction_method: str = "rss",
|
| 39 |
+
skip_method: str = "add",
|
| 40 |
+
**kwargs,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
num_cascades : int
|
| 46 |
+
Number of cascades (i.e., layers) for the variational network.
|
| 47 |
+
pools : int
|
| 48 |
+
Number of downsampling and upsampling layers for the cascade U-Net.
|
| 49 |
+
chans : int
|
| 50 |
+
Number of channels for the cascade U-Net.
|
| 51 |
+
sens_pools : int
|
| 52 |
+
Number of downsampling and upsampling layers for the sensitivity map U-Net.
|
| 53 |
+
sens_chans : int
|
| 54 |
+
Number of channels for the sensitivity map U-Net.
|
| 55 |
+
lr : float
|
| 56 |
+
Learning rate.
|
| 57 |
+
lr_step_size : int
|
| 58 |
+
Learning rate step size.
|
| 59 |
+
lr_gamma : float
|
| 60 |
+
Learning rate gamma decay.
|
| 61 |
+
weight_decay : float
|
| 62 |
+
Parameter for penalizing weights norm.
|
| 63 |
+
"""
|
| 64 |
+
super().__init__(**kwargs)
|
| 65 |
+
self.save_hyperparameters()
|
| 66 |
+
|
| 67 |
+
self.num_cascades = num_cascades
|
| 68 |
+
self.pools = pools
|
| 69 |
+
self.chans = chans
|
| 70 |
+
self.sens_pools = sens_pools
|
| 71 |
+
self.sens_chans = sens_chans
|
| 72 |
+
self.gno_pools = gno_pools
|
| 73 |
+
self.gno_chans = gno_chans
|
| 74 |
+
self.gno_radius_cutoff = gno_radius_cutoff
|
| 75 |
+
self.gno_kernel_shape = gno_kernel_shape
|
| 76 |
+
self.radius_cutoff = radius_cutoff
|
| 77 |
+
self.kernel_shape = kernel_shape
|
| 78 |
+
self.in_shape = in_shape
|
| 79 |
+
self.use_dc_term = use_dc_term
|
| 80 |
+
self.lr = lr
|
| 81 |
+
self.lr_step_size = lr_step_size
|
| 82 |
+
self.lr_gamma = lr_gamma
|
| 83 |
+
self.weight_decay = weight_decay
|
| 84 |
+
self.reduction_method = reduction_method
|
| 85 |
+
self.skip_method = skip_method
|
| 86 |
+
|
| 87 |
+
self.model = NOVarnet_no_KNO(
|
| 88 |
+
num_cascades=self.num_cascades,
|
| 89 |
+
sens_chans=self.sens_chans,
|
| 90 |
+
sens_pools=self.sens_pools,
|
| 91 |
+
chans=self.chans,
|
| 92 |
+
pools=self.pools,
|
| 93 |
+
gno_chans=self.gno_chans,
|
| 94 |
+
gno_pools=self.gno_pools,
|
| 95 |
+
gno_radius_cutoff=self.gno_radius_cutoff,
|
| 96 |
+
gno_kernel_shape=self.gno_kernel_shape,
|
| 97 |
+
radius_cutoff=radius_cutoff,
|
| 98 |
+
kernel_shape=kernel_shape,
|
| 99 |
+
in_shape=in_shape,
|
| 100 |
+
use_dc_term=use_dc_term,
|
| 101 |
+
reduction_method=reduction_method,
|
| 102 |
+
skip_method=skip_method,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.criterion = fastmri.SSIMLoss()
|
| 106 |
+
self.num_params = sum(p.numel() for p in self.parameters())
|
| 107 |
+
|
| 108 |
+
def forward(self, masked_kspace, mask, num_low_frequencies):
|
| 109 |
+
return self.model(masked_kspace, mask, num_low_frequencies)
|
| 110 |
+
|
| 111 |
+
def training_step(self, batch, batch_idx):
|
| 112 |
+
output = self.forward(
|
| 113 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 117 |
+
loss = self.criterion(
|
| 118 |
+
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True)
|
| 122 |
+
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
|
| 123 |
+
|
| 124 |
+
return loss
|
| 125 |
+
|
| 126 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 127 |
+
dataloaders = self.trainer.val_dataloaders
|
| 128 |
+
slug = list(dataloaders.keys())[dataloader_idx]
|
| 129 |
+
|
| 130 |
+
output = self.forward(
|
| 131 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 135 |
+
|
| 136 |
+
loss = self.criterion(
|
| 137 |
+
output.unsqueeze(1),
|
| 138 |
+
target.unsqueeze(1),
|
| 139 |
+
data_range=batch.max_value,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"slug": slug,
|
| 144 |
+
"fname": batch.fname,
|
| 145 |
+
"slice_num": batch.slice_num,
|
| 146 |
+
"max_value": batch.max_value,
|
| 147 |
+
"output": output,
|
| 148 |
+
"target": target,
|
| 149 |
+
"val_loss": loss,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def configure_optimizers(self):
|
| 153 |
+
optim = torch.optim.Adam(
|
| 154 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 155 |
+
)
|
| 156 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 157 |
+
optim, self.lr_step_size, self.lr_gamma
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return [optim], [scheduler]
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def add_model_specific_args(parent_parser):
|
| 164 |
+
"""
|
| 165 |
+
Define parameters that only apply to this model
|
| 166 |
+
"""
|
| 167 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 168 |
+
parser = MriModule.add_model_specific_args(parser)
|
| 169 |
+
|
| 170 |
+
# network params
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--num_cascades",
|
| 173 |
+
default=12,
|
| 174 |
+
type=int,
|
| 175 |
+
help="Number of VarNet cascades",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--pools",
|
| 179 |
+
default=4,
|
| 180 |
+
type=int,
|
| 181 |
+
help="Number of U-Net pooling layers in VarNet blocks",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--chans",
|
| 185 |
+
default=18,
|
| 186 |
+
type=int,
|
| 187 |
+
help="Number of channels for U-Net in VarNet blocks",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--sens_pools",
|
| 191 |
+
default=4,
|
| 192 |
+
type=int,
|
| 193 |
+
help=(
|
| 194 |
+
"Number of pooling layers for sense map estimation U-Net in" " VarNet"
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--sens_chans",
|
| 199 |
+
default=8,
|
| 200 |
+
type=float,
|
| 201 |
+
help="Number of channels for sense map estimation U-Net in VarNet",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--gno_pools",
|
| 205 |
+
default=4,
|
| 206 |
+
type=int,
|
| 207 |
+
help=("Number of pooling layers for GNO"),
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--gno_chans",
|
| 211 |
+
default=16,
|
| 212 |
+
type=int,
|
| 213 |
+
help="Number of channels for GNO",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--gno_radius_cutoff",
|
| 217 |
+
default=0.02,
|
| 218 |
+
type=float,
|
| 219 |
+
required=True,
|
| 220 |
+
help="GNO module radius_cutoff",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--gno_kernel_shape",
|
| 224 |
+
default=(6, 7),
|
| 225 |
+
type=tuple_type,
|
| 226 |
+
required=True,
|
| 227 |
+
help="GNO module kernel_shape. Ex: (6, 7)",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--radius_cutoff",
|
| 231 |
+
default=0.01,
|
| 232 |
+
type=float,
|
| 233 |
+
required=True,
|
| 234 |
+
help="DISCO module radius_cutoff",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--kernel_shape",
|
| 238 |
+
default=(6, 7),
|
| 239 |
+
type=tuple_type,
|
| 240 |
+
required=True,
|
| 241 |
+
help="DISCO module kernel_shape. Ex: (6, 7)",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--in_shape",
|
| 245 |
+
default=(640, 320),
|
| 246 |
+
type=tuple_type,
|
| 247 |
+
required=True,
|
| 248 |
+
help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--use_dc_term",
|
| 252 |
+
default=True,
|
| 253 |
+
type=bool,
|
| 254 |
+
help="Whether to use the DC term in the unrolled iterative update step",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# training params (opt)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--lr", default=0.0003, type=float, help="Adam learning rate"
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--lr_step_size",
|
| 263 |
+
default=40,
|
| 264 |
+
type=int,
|
| 265 |
+
help="Epoch at which to decrease step size",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--lr_gamma",
|
| 269 |
+
default=0.1,
|
| 270 |
+
type=float,
|
| 271 |
+
help="Extent to which step size should be decreased",
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--weight_decay",
|
| 275 |
+
default=0.0,
|
| 276 |
+
type=float,
|
| 277 |
+
help="Strength of weight decay regularization",
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--reduction_method",
|
| 281 |
+
default="rss",
|
| 282 |
+
type=str,
|
| 283 |
+
choices=["rss", "batch"],
|
| 284 |
+
help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--skip_method",
|
| 288 |
+
default="add_inv",
|
| 289 |
+
type=str,
|
| 290 |
+
choices=["add_inv", "add", "concat", "replace"],
|
| 291 |
+
help="Method for skip connection around inpainting module.",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return parser
|
models/lightning/varnet_module.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from argparse import ArgumentParser
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import fastmri
|
| 13 |
+
from fastmri import transforms
|
| 14 |
+
from ..varnet import VarNet
|
| 15 |
+
import wandb
|
| 16 |
+
|
| 17 |
+
from .mri_module import MriModule
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VarNetModule(MriModule):
|
| 21 |
+
"""
|
| 22 |
+
VarNet training module.
|
| 23 |
+
|
| 24 |
+
This can be used to train variational networks from the paper:
|
| 25 |
+
|
| 26 |
+
A. Sriram et al. End-to-end variational networks for accelerated MRI
|
| 27 |
+
reconstruction. In International Conference on Medical Image Computing and
|
| 28 |
+
Computer-Assisted Intervention, 2020.
|
| 29 |
+
|
| 30 |
+
which was inspired by the earlier paper:
|
| 31 |
+
|
| 32 |
+
K. Hammernik et al. Learning a variational network for reconstruction of
|
| 33 |
+
accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
num_cascades: int = 12,
|
| 39 |
+
pools: int = 4,
|
| 40 |
+
chans: int = 18,
|
| 41 |
+
sens_pools: int = 4,
|
| 42 |
+
sens_chans: int = 8,
|
| 43 |
+
lr: float = 0.0003,
|
| 44 |
+
lr_step_size: int = 40,
|
| 45 |
+
lr_gamma: float = 0.1,
|
| 46 |
+
weight_decay: float = 0.0,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
num_cascades : int
|
| 53 |
+
Number of cascades (i.e., layers) for the variational network.
|
| 54 |
+
pools : int
|
| 55 |
+
Number of downsampling and upsampling layers for the cascade U-Net.
|
| 56 |
+
chans : int
|
| 57 |
+
Number of channels for the cascade U-Net.
|
| 58 |
+
sens_pools : int
|
| 59 |
+
Number of downsampling and upsampling layers for the sensitivity map U-Net.
|
| 60 |
+
sens_chans : int
|
| 61 |
+
Number of channels for the sensitivity map U-Net.
|
| 62 |
+
lr : float
|
| 63 |
+
Learning rate.
|
| 64 |
+
lr_step_size : int
|
| 65 |
+
Learning rate step size.
|
| 66 |
+
lr_gamma : float
|
| 67 |
+
Learning rate gamma decay.
|
| 68 |
+
weight_decay : float
|
| 69 |
+
Parameter for penalizing weights norm.
|
| 70 |
+
num_sense_lines : int, optional
|
| 71 |
+
Number of low-frequency lines to use for sensitivity map computation.
|
| 72 |
+
Must be even or `None`. Default `None` will automatically compute the number
|
| 73 |
+
from masks. Default behavior may cause some slices to use more low-frequency
|
| 74 |
+
lines than others, when used in conjunction with e.g. the EquispacedMaskFunc
|
| 75 |
+
defaults. To prevent this, either set `num_sense_lines`, or set
|
| 76 |
+
`skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc.
|
| 77 |
+
Note that setting this value may lead to undesired behavior when training on
|
| 78 |
+
multiple accelerations simultaneously.
|
| 79 |
+
"""
|
| 80 |
+
super().__init__(**kwargs)
|
| 81 |
+
self.save_hyperparameters()
|
| 82 |
+
|
| 83 |
+
self.num_cascades = num_cascades
|
| 84 |
+
self.pools = pools
|
| 85 |
+
self.chans = chans
|
| 86 |
+
self.sens_pools = sens_pools
|
| 87 |
+
self.sens_chans = sens_chans
|
| 88 |
+
self.lr = lr
|
| 89 |
+
self.lr_step_size = lr_step_size
|
| 90 |
+
self.lr_gamma = lr_gamma
|
| 91 |
+
self.weight_decay = weight_decay
|
| 92 |
+
|
| 93 |
+
self.varnet = VarNet(
|
| 94 |
+
num_cascades=self.num_cascades,
|
| 95 |
+
sens_chans=self.sens_chans,
|
| 96 |
+
sens_pools=self.sens_pools,
|
| 97 |
+
chans=self.chans,
|
| 98 |
+
pools=self.pools,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.criterion = fastmri.SSIMLoss()
|
| 102 |
+
self.num_params = sum(p.numel() for p in self.parameters())
|
| 103 |
+
|
| 104 |
+
def forward(self, masked_kspace, mask, num_low_frequencies):
|
| 105 |
+
return self.varnet(masked_kspace, mask, num_low_frequencies)
|
| 106 |
+
|
| 107 |
+
def training_step(self, batch, batch_idx):
|
| 108 |
+
output = self.forward(
|
| 109 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 113 |
+
loss = self.criterion(
|
| 114 |
+
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True)
|
| 118 |
+
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
|
| 119 |
+
|
| 120 |
+
return loss
|
| 121 |
+
|
| 122 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 123 |
+
dataloaders = self.trainer.val_dataloaders
|
| 124 |
+
slug = list(dataloaders.keys())[dataloader_idx]
|
| 125 |
+
|
| 126 |
+
# breakpoint()
|
| 127 |
+
output = self.forward(
|
| 128 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 132 |
+
|
| 133 |
+
loss = self.criterion(
|
| 134 |
+
output.unsqueeze(1),
|
| 135 |
+
target.unsqueeze(1),
|
| 136 |
+
data_range=batch.max_value,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"slug": slug,
|
| 141 |
+
"fname": batch.fname,
|
| 142 |
+
"slice_num": batch.slice_num,
|
| 143 |
+
"max_value": batch.max_value,
|
| 144 |
+
"output": output,
|
| 145 |
+
"target": target,
|
| 146 |
+
"val_loss": loss,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
def configure_optimizers(self):
|
| 150 |
+
optim = torch.optim.Adam(
|
| 151 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 152 |
+
)
|
| 153 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 154 |
+
optim, self.lr_step_size, self.lr_gamma
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return [optim], [scheduler]
|
| 158 |
+
|
| 159 |
+
@staticmethod
|
| 160 |
+
def add_model_specific_args(parent_parser): # pragma: no-cover
|
| 161 |
+
"""
|
| 162 |
+
Define parameters that only apply to this model
|
| 163 |
+
"""
|
| 164 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 165 |
+
parser = MriModule.add_model_specific_args(parser)
|
| 166 |
+
|
| 167 |
+
# network params
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--num_cascades",
|
| 170 |
+
default=12,
|
| 171 |
+
type=int,
|
| 172 |
+
help="Number of VarNet cascades",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--pools",
|
| 176 |
+
default=4,
|
| 177 |
+
type=int,
|
| 178 |
+
help="Number of U-Net pooling layers in VarNet blocks",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--chans",
|
| 182 |
+
default=18,
|
| 183 |
+
type=int,
|
| 184 |
+
help="Number of channels for U-Net in VarNet blocks",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--sens_pools",
|
| 188 |
+
default=4,
|
| 189 |
+
type=int,
|
| 190 |
+
help=(
|
| 191 |
+
"Number of pooling layers for sense map estimation U-Net in" " VarNet"
|
| 192 |
+
),
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--sens_chans",
|
| 196 |
+
default=8,
|
| 197 |
+
type=float,
|
| 198 |
+
help="Number of channels for sense map estimation U-Net in VarNet",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# training params (opt)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--lr", default=0.0003, type=float, help="Adam learning rate"
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--lr_step_size",
|
| 207 |
+
default=40,
|
| 208 |
+
type=int,
|
| 209 |
+
help="Epoch at which to decrease step size",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--lr_gamma",
|
| 213 |
+
default=0.1,
|
| 214 |
+
type=float,
|
| 215 |
+
help="Extent to which step size should be decreased",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--weight_decay",
|
| 219 |
+
default=0.0,
|
| 220 |
+
type=float,
|
| 221 |
+
help="Strength of weight decay regularization",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return parser
|
models/no_shared.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Literal, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import fastmri
|
| 9 |
+
from fastmri.transforms import (
|
| 10 |
+
batched_mask_center,
|
| 11 |
+
batch_chans_to_chan_dim,
|
| 12 |
+
chans_to_batch_dim,
|
| 13 |
+
sens_reduce,
|
| 14 |
+
sens_expand,
|
| 15 |
+
)
|
| 16 |
+
from models.udno import UDNO
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class NormUDNO(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Normalized UDNO model.
|
| 22 |
+
|
| 23 |
+
Inputs are normalized before the UDNO for numerically stable training.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
chans: int,
|
| 29 |
+
num_pool_layers: int,
|
| 30 |
+
radius_cutoff: float,
|
| 31 |
+
in_shape: Tuple[int, int],
|
| 32 |
+
kernel_shape: Tuple[int, int],
|
| 33 |
+
in_chans: int = 2,
|
| 34 |
+
out_chans: int = 2,
|
| 35 |
+
drop_prob: float = 0.0,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize the VarNet model.
|
| 39 |
+
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
chans : int
|
| 43 |
+
Number of output channels of the first convolution layer.
|
| 44 |
+
num_pools : int
|
| 45 |
+
Number of down-sampling and up-sampling layers.
|
| 46 |
+
in_chans : int, optional
|
| 47 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 48 |
+
out_chans : int, optional
|
| 49 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 50 |
+
drop_prob : float, optional
|
| 51 |
+
Dropout probability. Default is 0.0.
|
| 52 |
+
"""
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.udno = UDNO(
|
| 56 |
+
in_chans=in_chans,
|
| 57 |
+
out_chans=out_chans,
|
| 58 |
+
radius_cutoff=radius_cutoff,
|
| 59 |
+
chans=chans,
|
| 60 |
+
num_pool_layers=num_pool_layers,
|
| 61 |
+
drop_prob=drop_prob,
|
| 62 |
+
in_shape=in_shape,
|
| 63 |
+
kernel_shape=kernel_shape,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
b, c, h, w, two = x.shape
|
| 68 |
+
assert two == 2
|
| 69 |
+
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
|
| 70 |
+
|
| 71 |
+
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
b, c2, h, w = x.shape
|
| 73 |
+
assert c2 % 2 == 0
|
| 74 |
+
c = c2 // 2
|
| 75 |
+
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
|
| 76 |
+
|
| 77 |
+
def norm(
|
| 78 |
+
self, x: torch.Tensor
|
| 79 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 80 |
+
# group norm
|
| 81 |
+
b, c, h, w = x.shape
|
| 82 |
+
x = x.view(b, 2, c // 2 * h * w)
|
| 83 |
+
|
| 84 |
+
mean = x.mean(dim=2).view(b, 2, 1, 1)
|
| 85 |
+
std = x.std(dim=2).view(b, 2, 1, 1)
|
| 86 |
+
|
| 87 |
+
x = x.view(b, c, h, w)
|
| 88 |
+
|
| 89 |
+
return (x - mean) / std, mean, std
|
| 90 |
+
|
| 91 |
+
def norm_new(
|
| 92 |
+
self, x: torch.Tensor
|
| 93 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 94 |
+
# group norm
|
| 95 |
+
b, c, h, w = x.shape
|
| 96 |
+
num_groups = 2
|
| 97 |
+
assert (
|
| 98 |
+
c % num_groups == 0
|
| 99 |
+
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
|
| 100 |
+
|
| 101 |
+
x = x.view(b, num_groups, c // num_groups * h * w)
|
| 102 |
+
|
| 103 |
+
mean = x.mean(dim=2).view(b, num_groups, 1, 1)
|
| 104 |
+
std = x.std(dim=2).view(b, num_groups, 1, 1)
|
| 105 |
+
print(x.shape, mean.shape, std.shape)
|
| 106 |
+
|
| 107 |
+
x = x.view(b, c, h, w)
|
| 108 |
+
mean = (
|
| 109 |
+
mean.view(b, num_groups, 1, 1)
|
| 110 |
+
.repeat(1, c // num_groups, h, w)
|
| 111 |
+
.view(b, c, h, w)
|
| 112 |
+
)
|
| 113 |
+
std = (
|
| 114 |
+
std.view(b, num_groups, 1, 1)
|
| 115 |
+
.repeat(1, c // num_groups, h, w)
|
| 116 |
+
.view(b, c, h, w)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return (x - mean) / std, mean, std
|
| 120 |
+
|
| 121 |
+
def unnorm(
|
| 122 |
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
return x * std + mean
|
| 125 |
+
|
| 126 |
+
def pad(
|
| 127 |
+
self, x: torch.Tensor
|
| 128 |
+
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
|
| 129 |
+
_, _, h, w = x.shape
|
| 130 |
+
w_mult = ((w - 1) | 15) + 1
|
| 131 |
+
h_mult = ((h - 1) | 15) + 1
|
| 132 |
+
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
|
| 133 |
+
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
|
| 134 |
+
# TODO: fix this type when PyTorch fixes theirs
|
| 135 |
+
# the documentation lies - this actually takes a list
|
| 136 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
|
| 137 |
+
# https://github.com/pytorch/pytorch/pull/16949
|
| 138 |
+
x = F.pad(x, w_pad + h_pad)
|
| 139 |
+
|
| 140 |
+
return x, (h_pad, w_pad, h_mult, w_mult)
|
| 141 |
+
|
| 142 |
+
def unpad(
|
| 143 |
+
self,
|
| 144 |
+
x: torch.Tensor,
|
| 145 |
+
h_pad: List[int],
|
| 146 |
+
w_pad: List[int],
|
| 147 |
+
h_mult: int,
|
| 148 |
+
w_mult: int,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
return x[
|
| 151 |
+
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
if not x.shape[-1] == 2:
|
| 156 |
+
raise ValueError("Last dimension must be 2 for complex.")
|
| 157 |
+
|
| 158 |
+
chans = x.shape[1]
|
| 159 |
+
if chans == 2:
|
| 160 |
+
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
|
| 161 |
+
x = self.complex_to_chan_dim(x)
|
| 162 |
+
x = self.udno(x)
|
| 163 |
+
return self.chan_complex_to_last_dim(x)
|
| 164 |
+
|
| 165 |
+
# get shapes for unet and normalize
|
| 166 |
+
x = self.complex_to_chan_dim(x)
|
| 167 |
+
x, mean, std = self.norm(x)
|
| 168 |
+
x, pad_sizes = self.pad(x)
|
| 169 |
+
|
| 170 |
+
x = self.udno(x)
|
| 171 |
+
|
| 172 |
+
# get shapes back and unnormalize
|
| 173 |
+
x = self.unpad(x, *pad_sizes)
|
| 174 |
+
x = self.unnorm(x, mean, std)
|
| 175 |
+
x = self.chan_complex_to_last_dim(x)
|
| 176 |
+
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class SensitivityModel(nn.Module):
|
| 181 |
+
"""
|
| 182 |
+
Learn sensitivity maps
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
chans: int,
|
| 188 |
+
num_pools: int,
|
| 189 |
+
radius_cutoff: float,
|
| 190 |
+
in_shape: Tuple[int, int],
|
| 191 |
+
kernel_shape: Tuple[int, int],
|
| 192 |
+
in_chans: int = 2,
|
| 193 |
+
out_chans: int = 2,
|
| 194 |
+
drop_prob: float = 0.0,
|
| 195 |
+
mask_center: bool = True,
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Parameters
|
| 199 |
+
----------
|
| 200 |
+
chans : int
|
| 201 |
+
Number of output channels of the first convolution layer.
|
| 202 |
+
num_pools : int
|
| 203 |
+
Number of down-sampling and up-sampling layers.
|
| 204 |
+
in_chans : int, optional
|
| 205 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 206 |
+
out_chans : int, optional
|
| 207 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 208 |
+
drop_prob : float, optional
|
| 209 |
+
Dropout probability. Default is 0.0.
|
| 210 |
+
mask_center : bool, optional
|
| 211 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 212 |
+
Default is True.
|
| 213 |
+
"""
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.mask_center = mask_center
|
| 216 |
+
self.norm_udno = NormUDNO(
|
| 217 |
+
chans,
|
| 218 |
+
num_pools,
|
| 219 |
+
radius_cutoff,
|
| 220 |
+
in_shape,
|
| 221 |
+
kernel_shape,
|
| 222 |
+
in_chans=in_chans,
|
| 223 |
+
out_chans=out_chans,
|
| 224 |
+
drop_prob=drop_prob,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
|
| 229 |
+
|
| 230 |
+
def get_pad_and_num_low_freqs(
|
| 231 |
+
self, mask: torch.Tensor, num_low_frequencies=None
|
| 232 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 233 |
+
if num_low_frequencies is None or any(
|
| 234 |
+
torch.any(t == 0) for t in num_low_frequencies
|
| 235 |
+
):
|
| 236 |
+
# get low frequency line locations and mask them out
|
| 237 |
+
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
|
| 238 |
+
cent = squeezed_mask.shape[1] // 2
|
| 239 |
+
# running argmin returns the first non-zero
|
| 240 |
+
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
|
| 241 |
+
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
|
| 242 |
+
num_low_frequencies_tensor = torch.max(
|
| 243 |
+
2 * torch.min(left, right), torch.ones_like(left)
|
| 244 |
+
) # force a symmetric center unless 1
|
| 245 |
+
else:
|
| 246 |
+
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
|
| 247 |
+
mask.shape[0], dtype=mask.dtype, device=mask.device
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
|
| 251 |
+
|
| 252 |
+
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
|
| 253 |
+
|
| 254 |
+
def forward(
|
| 255 |
+
self,
|
| 256 |
+
masked_kspace: torch.Tensor,
|
| 257 |
+
mask: torch.Tensor,
|
| 258 |
+
num_low_frequencies: Optional[int] = None,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
if self.mask_center:
|
| 261 |
+
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
|
| 262 |
+
mask, num_low_frequencies
|
| 263 |
+
)
|
| 264 |
+
masked_kspace = batched_mask_center(
|
| 265 |
+
masked_kspace, pad, pad + num_low_freqs
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# convert to image space
|
| 269 |
+
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
|
| 270 |
+
|
| 271 |
+
# estimate sensitivities
|
| 272 |
+
return self.divide_root_sum_of_squares(
|
| 273 |
+
batch_chans_to_chan_dim(self.norm_udno(images), batches)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class VarNetBlock(nn.Module):
|
| 278 |
+
"""
|
| 279 |
+
Model block for iterative refinement of k-space data.
|
| 280 |
+
|
| 281 |
+
This model applies a combination of soft data consistency with the input
|
| 282 |
+
model as a regularizer. A series of these blocks can be stacked to form
|
| 283 |
+
the full variational network.
|
| 284 |
+
|
| 285 |
+
aka Refinement Module in Fig 1
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self, kno: nn.Module, ino: nn.Module):
|
| 289 |
+
"""
|
| 290 |
+
Args:
|
| 291 |
+
model: Module for "regularization" component of variational
|
| 292 |
+
network.
|
| 293 |
+
"""
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.kno = kno
|
| 296 |
+
self.ino = ino
|
| 297 |
+
self.dc_weight = nn.Parameter(torch.ones(1))
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
current_kspace: torch.Tensor,
|
| 302 |
+
ref_kspace: torch.Tensor,
|
| 303 |
+
mask: torch.Tensor,
|
| 304 |
+
sens_maps: torch.Tensor,
|
| 305 |
+
use_dc_term: bool = True,
|
| 306 |
+
) -> torch.Tensor:
|
| 307 |
+
"""
|
| 308 |
+
Args:
|
| 309 |
+
current_kspace: The current k-space data (frequency domain data)
|
| 310 |
+
being processed by the network. (torch.Tensor)
|
| 311 |
+
ref_kspace: Original subsampled k-space data (from which we are
|
| 312 |
+
reconstrucintg the image (reference k-space). (torch.Tensor)
|
| 313 |
+
mask: A binary mask indicating the locations in k-space where
|
| 314 |
+
data consistency should be enforced. (torch.Tensor)
|
| 315 |
+
sens_maps: Sensitivity maps for the different coils in parallel
|
| 316 |
+
imaging. (torch.Tensor)
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
# model-term see orange box of Fig 1 in E2E-VarNet paper!
|
| 320 |
+
# multi channel k-space -> single channel image-space
|
| 321 |
+
b, c, h, w, _ = current_kspace.shape
|
| 322 |
+
|
| 323 |
+
# ======= kNO in measurement (k) space ========
|
| 324 |
+
current_kspace, b = chans_to_batch_dim(current_kspace) # reduce
|
| 325 |
+
current_kspace = self.kno(current_kspace) # inpaint
|
| 326 |
+
current_kspace = batch_chans_to_chan_dim(current_kspace, b) # expand
|
| 327 |
+
|
| 328 |
+
# ======= iNO in image (i) space ========
|
| 329 |
+
reduced_image = sens_reduce(current_kspace, sens_maps)
|
| 330 |
+
# single channel image-space
|
| 331 |
+
refined_image = self.ino(reduced_image)
|
| 332 |
+
# single channel image-space -> multi channel k-space
|
| 333 |
+
model_term = sens_expand(refined_image, sens_maps)
|
| 334 |
+
|
| 335 |
+
# only use first 15 channels (masked_kspace) in the update
|
| 336 |
+
# current_kspace = current_kspace[:, :15, :, :, :]
|
| 337 |
+
|
| 338 |
+
if not use_dc_term:
|
| 339 |
+
return current_kspace - model_term
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
Soft data consistency term:
|
| 343 |
+
- Calculates the difference between current k-space and reference k-space where the mask is true.
|
| 344 |
+
- Multiplies this difference by the data consistency weight.
|
| 345 |
+
"""
|
| 346 |
+
# dc_term: see green box of Fig 1 in E2E-VarNet paper!
|
| 347 |
+
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
|
| 348 |
+
soft_dc = (
|
| 349 |
+
torch.where(mask, current_kspace - ref_kspace, zero)
|
| 350 |
+
* self.dc_weight
|
| 351 |
+
)
|
| 352 |
+
return current_kspace - soft_dc - model_term
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class NOShared(nn.Module):
|
| 356 |
+
"""
|
| 357 |
+
Neural Operator model with shared cascade parameters for MRI reconstruction.
|
| 358 |
+
|
| 359 |
+
Uses a variational architecture (iterative updates) with a learned sensitivity
|
| 360 |
+
model. All operations are resolution invariant employing neural operator
|
| 361 |
+
modules.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
def __init__(
|
| 365 |
+
self,
|
| 366 |
+
num_cascades: int = 12,
|
| 367 |
+
sens_chans: int = 8,
|
| 368 |
+
sens_pools: int = 4,
|
| 369 |
+
chans: int = 18,
|
| 370 |
+
pools: int = 4,
|
| 371 |
+
gno_chans: int = 16,
|
| 372 |
+
gno_pools: int = 4,
|
| 373 |
+
gno_radius_cutoff: float = 0.02,
|
| 374 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 375 |
+
radius_cutoff: float = 0.01,
|
| 376 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 377 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 378 |
+
mask_center: bool = True,
|
| 379 |
+
use_dc_term: bool = True,
|
| 380 |
+
):
|
| 381 |
+
"""
|
| 382 |
+
Parameters
|
| 383 |
+
----------
|
| 384 |
+
num_cascades : int
|
| 385 |
+
Number of cascades (i.e., layers) for variational network.
|
| 386 |
+
sens_chans : int
|
| 387 |
+
Number of channels for sensitivity map U-Net.
|
| 388 |
+
sens_pools : int
|
| 389 |
+
Number of downsampling and upsampling layers for sensitivity map U-Net.
|
| 390 |
+
chans : int
|
| 391 |
+
Number of channels for cascade U-Net.
|
| 392 |
+
pools : int
|
| 393 |
+
Number of downsampling and upsampling layers for cascade U-Net.
|
| 394 |
+
mask_center : bool
|
| 395 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 396 |
+
use_dc_term : bool
|
| 397 |
+
Whether to use the data consistency term.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
super().__init__()
|
| 401 |
+
|
| 402 |
+
self.num_cascades = num_cascades
|
| 403 |
+
|
| 404 |
+
self.sens_net = SensitivityModel(
|
| 405 |
+
sens_chans,
|
| 406 |
+
sens_pools,
|
| 407 |
+
radius_cutoff,
|
| 408 |
+
in_shape,
|
| 409 |
+
kernel_shape,
|
| 410 |
+
mask_center=False,
|
| 411 |
+
)
|
| 412 |
+
self.kno = NormUDNO(
|
| 413 |
+
gno_chans,
|
| 414 |
+
gno_pools,
|
| 415 |
+
in_shape=in_shape,
|
| 416 |
+
radius_cutoff=gno_radius_cutoff,
|
| 417 |
+
kernel_shape=gno_kernel_shape,
|
| 418 |
+
in_chans=2,
|
| 419 |
+
out_chans=2,
|
| 420 |
+
)
|
| 421 |
+
self.ino = NormUDNO(
|
| 422 |
+
chans,
|
| 423 |
+
pools,
|
| 424 |
+
radius_cutoff,
|
| 425 |
+
in_shape,
|
| 426 |
+
kernel_shape,
|
| 427 |
+
in_chans=2,
|
| 428 |
+
out_chans=2,
|
| 429 |
+
)
|
| 430 |
+
self.cascade = VarNetBlock(self.kno, self.ino)
|
| 431 |
+
self.use_dc_term = use_dc_term
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
masked_kspace: torch.Tensor,
|
| 436 |
+
mask: torch.Tensor,
|
| 437 |
+
num_low_frequencies: Optional[int] = None,
|
| 438 |
+
) -> torch.Tensor:
|
| 439 |
+
|
| 440 |
+
# (B, C, X, Y, 2)
|
| 441 |
+
kspace_pred = masked_kspace
|
| 442 |
+
# iterative update
|
| 443 |
+
for _ in range(self.num_cascades):
|
| 444 |
+
# sens model
|
| 445 |
+
sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies)
|
| 446 |
+
|
| 447 |
+
# kno + ino (cascade)
|
| 448 |
+
kspace_pred = self.cascade(
|
| 449 |
+
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
spatial_pred = fastmri.ifft2c(kspace_pred)
|
| 453 |
+
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
|
| 454 |
+
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
|
| 455 |
+
|
| 456 |
+
return combined_spatial
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
if __name__ == "__main__":
|
| 460 |
+
model = NOShared(
|
| 461 |
+
num_cascades=4,
|
| 462 |
+
radius_cutoff=0.02,
|
| 463 |
+
kernel_shape=(6, 7),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
x = torch.rand((2, 15, 320, 320, 2))
|
| 467 |
+
o = model(x, x.bool(), None)
|
models/no_varnet.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Literal, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import fastmri
|
| 9 |
+
from fastmri import transforms
|
| 10 |
+
from models.udno import UDNO
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Calculates F (x sens_maps)
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
x : ndarray
|
| 20 |
+
Single-channel image of shape (..., H, W, 2)
|
| 21 |
+
sens_maps : ndarray
|
| 22 |
+
Sensitivity maps (image space)
|
| 23 |
+
|
| 24 |
+
Returns
|
| 25 |
+
-------
|
| 26 |
+
ndarray
|
| 27 |
+
Result of the operation F (x sens_maps)
|
| 28 |
+
"""
|
| 29 |
+
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Calculates F^{-1}(k) * conj(sens_maps)
|
| 35 |
+
where conj(sens_maps) is the element-wise applied complex conjugate
|
| 36 |
+
|
| 37 |
+
Parameters
|
| 38 |
+
----------
|
| 39 |
+
k : ndarray
|
| 40 |
+
Multi-channel k-space of shape (B, C, H, W, 2)
|
| 41 |
+
sens_maps : ndarray
|
| 42 |
+
Sensitivity maps (image space)
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
ndarray
|
| 47 |
+
Result of the operation F^{-1}(k) * conj(sens_maps)
|
| 48 |
+
"""
|
| 49 |
+
return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum(
|
| 50 |
+
dim=1, keepdim=True
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 55 |
+
"""Reshapes batched multi-channel samples into multiple single channel samples.
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
----------
|
| 59 |
+
x : torch.Tensor
|
| 60 |
+
x has shape (b, c, h, w, 2)
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
Tuple[torch.Tensor, int]
|
| 65 |
+
tensor of shape (b * c, 1, h, w, 2), b
|
| 66 |
+
"""
|
| 67 |
+
b, c, h, w, comp = x.shape
|
| 68 |
+
return x.view(b * c, 1, h, w, comp), b
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 72 |
+
"""Reshapes batched independent samples into original multi-channel samples.
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
----------
|
| 76 |
+
x : torch.Tensor
|
| 77 |
+
tensor of shape (b * c, 1, h, w, 2)
|
| 78 |
+
batch_size : int
|
| 79 |
+
batch size
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
torch.Tensor
|
| 84 |
+
original multi-channel tensor of shape (b, c, h, w, 2)
|
| 85 |
+
"""
|
| 86 |
+
bc, _, h, w, comp = x.shape
|
| 87 |
+
c = bc // batch_size
|
| 88 |
+
return x.view(batch_size, c, h, w, comp)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class NormUDNO(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
Normalized UDNO model.
|
| 94 |
+
|
| 95 |
+
Inputs are normalized before the UDNO for numerically stable training.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
chans: int,
|
| 101 |
+
num_pool_layers: int,
|
| 102 |
+
radius_cutoff: float,
|
| 103 |
+
in_shape: Tuple[int, int],
|
| 104 |
+
kernel_shape: Tuple[int, int],
|
| 105 |
+
in_chans: int = 2,
|
| 106 |
+
out_chans: int = 2,
|
| 107 |
+
drop_prob: float = 0.0,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Initialize the VarNet model.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
chans : int
|
| 115 |
+
Number of output channels of the first convolution layer.
|
| 116 |
+
num_pools : int
|
| 117 |
+
Number of down-sampling and up-sampling layers.
|
| 118 |
+
in_chans : int, optional
|
| 119 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 120 |
+
out_chans : int, optional
|
| 121 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 122 |
+
drop_prob : float, optional
|
| 123 |
+
Dropout probability. Default is 0.0.
|
| 124 |
+
"""
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
self.udno = UDNO(
|
| 128 |
+
in_chans=in_chans,
|
| 129 |
+
out_chans=out_chans,
|
| 130 |
+
radius_cutoff=radius_cutoff,
|
| 131 |
+
chans=chans,
|
| 132 |
+
num_pool_layers=num_pool_layers,
|
| 133 |
+
drop_prob=drop_prob,
|
| 134 |
+
in_shape=in_shape,
|
| 135 |
+
kernel_shape=kernel_shape,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
b, c, h, w, two = x.shape
|
| 140 |
+
assert two == 2
|
| 141 |
+
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
|
| 142 |
+
|
| 143 |
+
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
b, c2, h, w = x.shape
|
| 145 |
+
assert c2 % 2 == 0
|
| 146 |
+
c = c2 // 2
|
| 147 |
+
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
|
| 148 |
+
|
| 149 |
+
def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 150 |
+
# group norm
|
| 151 |
+
b, c, h, w = x.shape
|
| 152 |
+
x = x.view(b, 2, c // 2 * h * w)
|
| 153 |
+
|
| 154 |
+
mean = x.mean(dim=2).view(b, 2, 1, 1)
|
| 155 |
+
std = x.std(dim=2).view(b, 2, 1, 1)
|
| 156 |
+
|
| 157 |
+
x = x.view(b, c, h, w)
|
| 158 |
+
|
| 159 |
+
return (x - mean) / std, mean, std
|
| 160 |
+
|
| 161 |
+
def norm_new(
|
| 162 |
+
self, x: torch.Tensor
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 164 |
+
# FIXME: not working, wip
|
| 165 |
+
# group norm
|
| 166 |
+
b, c, h, w = x.shape
|
| 167 |
+
num_groups = 2
|
| 168 |
+
assert (
|
| 169 |
+
c % num_groups == 0
|
| 170 |
+
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
|
| 171 |
+
|
| 172 |
+
x = x.view(b, num_groups, c // num_groups * h * w)
|
| 173 |
+
|
| 174 |
+
mean = x.mean(dim=2).view(b, num_groups, 1, 1)
|
| 175 |
+
std = x.std(dim=2).view(b, num_groups, 1, 1)
|
| 176 |
+
print(x.shape, mean.shape, std.shape)
|
| 177 |
+
|
| 178 |
+
x = x.view(b, c, h, w)
|
| 179 |
+
mean = (
|
| 180 |
+
mean.view(b, num_groups, 1, 1)
|
| 181 |
+
.repeat(1, c // num_groups, h, w)
|
| 182 |
+
.view(b, c, h, w)
|
| 183 |
+
)
|
| 184 |
+
std = (
|
| 185 |
+
std.view(b, num_groups, 1, 1)
|
| 186 |
+
.repeat(1, c // num_groups, h, w)
|
| 187 |
+
.view(b, c, h, w)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return (x - mean) / std, mean, std
|
| 191 |
+
|
| 192 |
+
def unnorm(
|
| 193 |
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
return x * std + mean
|
| 196 |
+
|
| 197 |
+
def pad(
|
| 198 |
+
self, x: torch.Tensor
|
| 199 |
+
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
|
| 200 |
+
_, _, h, w = x.shape
|
| 201 |
+
w_mult = ((w - 1) | 15) + 1
|
| 202 |
+
h_mult = ((h - 1) | 15) + 1
|
| 203 |
+
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
|
| 204 |
+
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
|
| 205 |
+
# TODO: fix this type when PyTorch fixes theirs
|
| 206 |
+
# the documentation lies - this actually takes a list
|
| 207 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
|
| 208 |
+
# https://github.com/pytorch/pytorch/pull/16949
|
| 209 |
+
x = F.pad(x, w_pad + h_pad)
|
| 210 |
+
|
| 211 |
+
return x, (h_pad, w_pad, h_mult, w_mult)
|
| 212 |
+
|
| 213 |
+
def unpad(
|
| 214 |
+
self,
|
| 215 |
+
x: torch.Tensor,
|
| 216 |
+
h_pad: List[int],
|
| 217 |
+
w_pad: List[int],
|
| 218 |
+
h_mult: int,
|
| 219 |
+
w_mult: int,
|
| 220 |
+
) -> torch.Tensor:
|
| 221 |
+
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 224 |
+
if not x.shape[-1] == 2:
|
| 225 |
+
raise ValueError("Last dimension must be 2 for complex.")
|
| 226 |
+
|
| 227 |
+
chans = x.shape[1]
|
| 228 |
+
if chans == 2:
|
| 229 |
+
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
|
| 230 |
+
x = self.complex_to_chan_dim(x)
|
| 231 |
+
x = self.udno(x)
|
| 232 |
+
return self.chan_complex_to_last_dim(x)
|
| 233 |
+
|
| 234 |
+
# get shapes for unet and normalize
|
| 235 |
+
x = self.complex_to_chan_dim(x)
|
| 236 |
+
x, mean, std = self.norm(x)
|
| 237 |
+
x, pad_sizes = self.pad(x)
|
| 238 |
+
|
| 239 |
+
x = self.udno(x)
|
| 240 |
+
|
| 241 |
+
# get shapes back and unnormalize
|
| 242 |
+
x = self.unpad(x, *pad_sizes)
|
| 243 |
+
x = self.unnorm(x, mean, std)
|
| 244 |
+
x = self.chan_complex_to_last_dim(x)
|
| 245 |
+
|
| 246 |
+
return x
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class SensitivityModel(nn.Module):
|
| 250 |
+
"""
|
| 251 |
+
Learn sensitivity maps
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
chans: int,
|
| 257 |
+
num_pools: int,
|
| 258 |
+
radius_cutoff: float,
|
| 259 |
+
in_shape: Tuple[int, int],
|
| 260 |
+
kernel_shape: Tuple[int, int],
|
| 261 |
+
in_chans: int = 2,
|
| 262 |
+
out_chans: int = 2,
|
| 263 |
+
drop_prob: float = 0.0,
|
| 264 |
+
mask_center: bool = True,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Parameters
|
| 268 |
+
----------
|
| 269 |
+
chans : int
|
| 270 |
+
Number of output channels of the first convolution layer.
|
| 271 |
+
num_pools : int
|
| 272 |
+
Number of down-sampling and up-sampling layers.
|
| 273 |
+
in_chans : int, optional
|
| 274 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 275 |
+
out_chans : int, optional
|
| 276 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 277 |
+
drop_prob : float, optional
|
| 278 |
+
Dropout probability. Default is 0.0.
|
| 279 |
+
mask_center : bool, optional
|
| 280 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 281 |
+
Default is True.
|
| 282 |
+
"""
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.mask_center = mask_center
|
| 285 |
+
self.norm_udno = NormUDNO(
|
| 286 |
+
chans,
|
| 287 |
+
num_pools,
|
| 288 |
+
radius_cutoff,
|
| 289 |
+
in_shape,
|
| 290 |
+
kernel_shape,
|
| 291 |
+
in_chans=in_chans,
|
| 292 |
+
out_chans=out_chans,
|
| 293 |
+
drop_prob=drop_prob,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
|
| 297 |
+
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
|
| 298 |
+
|
| 299 |
+
def get_pad_and_num_low_freqs(
|
| 300 |
+
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
|
| 301 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 302 |
+
if num_low_frequencies is None or any(
|
| 303 |
+
torch.any(t == 0) for t in num_low_frequencies
|
| 304 |
+
):
|
| 305 |
+
# get low frequency line locations and mask them out
|
| 306 |
+
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
|
| 307 |
+
cent = squeezed_mask.shape[1] // 2
|
| 308 |
+
# running argmin returns the first non-zero
|
| 309 |
+
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
|
| 310 |
+
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
|
| 311 |
+
num_low_frequencies_tensor = torch.max(
|
| 312 |
+
2 * torch.min(left, right), torch.ones_like(left)
|
| 313 |
+
) # force a symmetric center unless 1
|
| 314 |
+
else:
|
| 315 |
+
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
|
| 316 |
+
mask.shape[0], dtype=mask.dtype, device=mask.device
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
|
| 320 |
+
|
| 321 |
+
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
|
| 322 |
+
|
| 323 |
+
def forward(
|
| 324 |
+
self,
|
| 325 |
+
masked_kspace: torch.Tensor,
|
| 326 |
+
mask: torch.Tensor,
|
| 327 |
+
num_low_frequencies: Optional[int] = None,
|
| 328 |
+
) -> torch.Tensor:
|
| 329 |
+
if self.mask_center:
|
| 330 |
+
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
|
| 331 |
+
mask, num_low_frequencies
|
| 332 |
+
)
|
| 333 |
+
masked_kspace = transforms.batched_mask_center(
|
| 334 |
+
masked_kspace, pad, pad + num_low_freqs
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# convert to image space
|
| 338 |
+
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
|
| 339 |
+
|
| 340 |
+
# estimate sensitivities
|
| 341 |
+
return self.divide_root_sum_of_squares(
|
| 342 |
+
batch_chans_to_chan_dim(self.norm_udno(images), batches)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class VarNetBlock(nn.Module):
|
| 347 |
+
"""
|
| 348 |
+
Model block for iterative refinement of k-space data.
|
| 349 |
+
|
| 350 |
+
This model applies a combination of soft data consistency with the input
|
| 351 |
+
model as a regularizer. A series of these blocks can be stacked to form
|
| 352 |
+
the full variational network.
|
| 353 |
+
|
| 354 |
+
aka Refinement Module in Fig 1
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, model: nn.Module):
|
| 358 |
+
"""
|
| 359 |
+
Args:
|
| 360 |
+
model: Module for "regularization" component of variational
|
| 361 |
+
network.
|
| 362 |
+
"""
|
| 363 |
+
super().__init__()
|
| 364 |
+
|
| 365 |
+
self.model = model
|
| 366 |
+
self.dc_weight = nn.Parameter(torch.ones(1))
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
current_kspace: torch.Tensor,
|
| 371 |
+
ref_kspace: torch.Tensor,
|
| 372 |
+
mask: torch.Tensor,
|
| 373 |
+
sens_maps: torch.Tensor,
|
| 374 |
+
use_dc_term: bool = True,
|
| 375 |
+
) -> torch.Tensor:
|
| 376 |
+
"""
|
| 377 |
+
Args:
|
| 378 |
+
current_kspace: The current k-space data (frequency domain data)
|
| 379 |
+
being processed by the network. (torch.Tensor)
|
| 380 |
+
ref_kspace: Original subsampled k-space data (from which we are
|
| 381 |
+
reconstrucintg the image (reference k-space). (torch.Tensor)
|
| 382 |
+
mask: A binary mask indicating the locations in k-space where
|
| 383 |
+
data consistency should be enforced. (torch.Tensor)
|
| 384 |
+
sens_maps: Sensitivity maps for the different coils in parallel
|
| 385 |
+
imaging. (torch.Tensor)
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
# model-term see orange box of Fig 1 in E2E-VarNet paper!
|
| 389 |
+
# multi channel k-space -> single channel image-space
|
| 390 |
+
b, c, h, w, _ = current_kspace.shape
|
| 391 |
+
|
| 392 |
+
if c == 30:
|
| 393 |
+
# get kspace and inpainted kspace
|
| 394 |
+
kspace = current_kspace[:, :15, :, :, :]
|
| 395 |
+
in_kspace = current_kspace[:, 15:, :, :, :]
|
| 396 |
+
# convert to image space
|
| 397 |
+
image = sens_reduce(kspace, sens_maps)
|
| 398 |
+
in_image = sens_reduce(in_kspace, sens_maps)
|
| 399 |
+
# concatenate both onto each other
|
| 400 |
+
reduced_image = torch.cat([image, in_image], dim=1)
|
| 401 |
+
else:
|
| 402 |
+
reduced_image = sens_reduce(current_kspace, sens_maps)
|
| 403 |
+
|
| 404 |
+
# single channel image-space
|
| 405 |
+
refined_image = self.model(reduced_image)
|
| 406 |
+
|
| 407 |
+
# single channel image-space -> multi channel k-space
|
| 408 |
+
model_term = sens_expand(refined_image, sens_maps)
|
| 409 |
+
|
| 410 |
+
# only use first 15 channels (masked_kspace) in the update
|
| 411 |
+
# current_kspace = current_kspace[:, :15, :, :, :]
|
| 412 |
+
|
| 413 |
+
if not use_dc_term:
|
| 414 |
+
return current_kspace - model_term
|
| 415 |
+
|
| 416 |
+
"""
|
| 417 |
+
Soft data consistency term:
|
| 418 |
+
- Calculates the difference between current k-space and reference k-space where the mask is true.
|
| 419 |
+
- Multiplies this difference by the data consistency weight.
|
| 420 |
+
"""
|
| 421 |
+
# dc_term: see green box of Fig 1 in E2E-VarNet paper!
|
| 422 |
+
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
|
| 423 |
+
soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
|
| 424 |
+
return current_kspace - soft_dc - model_term
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class NOVarnet(nn.Module):
|
| 428 |
+
"""
|
| 429 |
+
Neural Operator model for MRI reconstruction.
|
| 430 |
+
|
| 431 |
+
Uses a variational architecture (iterative updates) with a learned sensitivity
|
| 432 |
+
model. All operations are resolution invariant employing neural operator
|
| 433 |
+
modules (GNO, UDNO).
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
num_cascades: int = 12,
|
| 439 |
+
sens_chans: int = 8,
|
| 440 |
+
sens_pools: int = 4,
|
| 441 |
+
chans: int = 18,
|
| 442 |
+
pools: int = 4,
|
| 443 |
+
gno_chans: int = 16,
|
| 444 |
+
gno_pools: int = 4,
|
| 445 |
+
gno_radius_cutoff: float = 0.02,
|
| 446 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 447 |
+
radius_cutoff: float = 0.01,
|
| 448 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 449 |
+
in_shape: Tuple[int, int] = (640, 320),
|
| 450 |
+
mask_center: bool = True,
|
| 451 |
+
use_dc_term: bool = True,
|
| 452 |
+
reduction_method: Literal["batch", "rss"] = "rss",
|
| 453 |
+
skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
|
| 454 |
+
):
|
| 455 |
+
"""
|
| 456 |
+
Parameters
|
| 457 |
+
----------
|
| 458 |
+
num_cascades : int
|
| 459 |
+
Number of cascades (i.e., layers) for variational network.
|
| 460 |
+
sens_chans : int
|
| 461 |
+
Number of channels for sensitivity map U-Net.
|
| 462 |
+
sens_pools : int
|
| 463 |
+
Number of downsampling and upsampling layers for sensitivity map U-Net.
|
| 464 |
+
chans : int
|
| 465 |
+
Number of channels for cascade U-Net.
|
| 466 |
+
pools : int
|
| 467 |
+
Number of downsampling and upsampling layers for cascade U-Net.
|
| 468 |
+
mask_center : bool
|
| 469 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 470 |
+
use_dc_term : bool
|
| 471 |
+
Whether to use the data consistency term.
|
| 472 |
+
reduction_method : "batch" or "rss"
|
| 473 |
+
Method for reducing sensitivity maps to single channel.
|
| 474 |
+
"batch" reduces to single channel by stacking channels.
|
| 475 |
+
"rss" reduces to single channel by root sum of squares.
|
| 476 |
+
skip_method : "replace" or "add" or "add_inv" or "concat"
|
| 477 |
+
"replace" replaces the input with the output of the GNO
|
| 478 |
+
"add" adds the output of the GNO to the input
|
| 479 |
+
"add_inv" adds the output of the GNO to the input (only where samples are missing)
|
| 480 |
+
"concat" concatenates the output of the GNO to the input
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
super().__init__()
|
| 484 |
+
|
| 485 |
+
self.sens_net = SensitivityModel(
|
| 486 |
+
sens_chans,
|
| 487 |
+
sens_pools,
|
| 488 |
+
radius_cutoff,
|
| 489 |
+
in_shape,
|
| 490 |
+
kernel_shape,
|
| 491 |
+
mask_center=mask_center,
|
| 492 |
+
)
|
| 493 |
+
self.gno = NormUDNO(
|
| 494 |
+
gno_chans,
|
| 495 |
+
gno_pools,
|
| 496 |
+
in_shape=in_shape,
|
| 497 |
+
radius_cutoff=radius_cutoff,
|
| 498 |
+
kernel_shape=kernel_shape,
|
| 499 |
+
# radius_cutoff=gno_radius_cutoff,
|
| 500 |
+
# kernel_shape=gno_kernel_shape,
|
| 501 |
+
in_chans=2,
|
| 502 |
+
out_chans=2,
|
| 503 |
+
)
|
| 504 |
+
self.cascades = nn.ModuleList(
|
| 505 |
+
[
|
| 506 |
+
VarNetBlock(
|
| 507 |
+
NormUDNO(
|
| 508 |
+
chans,
|
| 509 |
+
pools,
|
| 510 |
+
radius_cutoff,
|
| 511 |
+
in_shape,
|
| 512 |
+
kernel_shape,
|
| 513 |
+
in_chans=(
|
| 514 |
+
4 if skip_method == "concat" and cascade_idx == 0 else 2
|
| 515 |
+
),
|
| 516 |
+
out_chans=2,
|
| 517 |
+
)
|
| 518 |
+
)
|
| 519 |
+
for cascade_idx in range(num_cascades)
|
| 520 |
+
]
|
| 521 |
+
)
|
| 522 |
+
self.use_dc_term = use_dc_term
|
| 523 |
+
self.reduction_method = reduction_method
|
| 524 |
+
self.skip_method = skip_method
|
| 525 |
+
|
| 526 |
+
def forward(
|
| 527 |
+
self,
|
| 528 |
+
masked_kspace: torch.Tensor,
|
| 529 |
+
mask: torch.Tensor,
|
| 530 |
+
num_low_frequencies: Optional[int] = None,
|
| 531 |
+
) -> torch.Tensor:
|
| 532 |
+
|
| 533 |
+
# (B, C, X, Y, 2)
|
| 534 |
+
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
|
| 535 |
+
|
| 536 |
+
# reduce before inpainting
|
| 537 |
+
if self.reduction_method == "rss":
|
| 538 |
+
# (B, 1, H, W, 2) single channel image space
|
| 539 |
+
x_reduced = sens_reduce(masked_kspace, sens_maps)
|
| 540 |
+
# (B, 1, H, W, 2)
|
| 541 |
+
k_reduced = fastmri.fft2c(x_reduced)
|
| 542 |
+
elif self.reduction_method == "batch":
|
| 543 |
+
k_reduced, b = chans_to_batch_dim(masked_kspace)
|
| 544 |
+
|
| 545 |
+
# inpainting
|
| 546 |
+
if self.skip_method == "replace":
|
| 547 |
+
kspace_pred = self.gno(k_reduced)
|
| 548 |
+
elif self.skip_method == "add_inv":
|
| 549 |
+
# FIXME: this is not correct (mask has shape B, 1, H, W, 2 and self.gno(k_reduced) has shape B*C, 1, H, W, 2)
|
| 550 |
+
kspace_pred = k_reduced.clone() + (~mask * self.gno(k_reduced))
|
| 551 |
+
elif self.skip_method == "add":
|
| 552 |
+
kspace_pred = k_reduced.clone() + self.gno(k_reduced)
|
| 553 |
+
elif self.skip_method == "concat":
|
| 554 |
+
kspace_pred = torch.cat([k_reduced.clone(), self.gno(k_reduced)], dim=1)
|
| 555 |
+
else:
|
| 556 |
+
raise NotImplementedError("skip_method not implemented")
|
| 557 |
+
|
| 558 |
+
# expand after inpainting
|
| 559 |
+
if self.reduction_method == "rss":
|
| 560 |
+
if self.skip_method == "concat":
|
| 561 |
+
# kspace_pred is (B, 2, H, W, 2)
|
| 562 |
+
kspace = kspace_pred[:, :1, :, :, :]
|
| 563 |
+
in_kspace = kspace_pred[:, 1:, :, :, :]
|
| 564 |
+
# B, 2C, H, W, 2
|
| 565 |
+
kspace_pred = torch.cat(
|
| 566 |
+
[sens_expand(kspace, sens_maps), sens_expand(in_kspace, sens_maps)],
|
| 567 |
+
dim=1,
|
| 568 |
+
)
|
| 569 |
+
else:
|
| 570 |
+
# (B, 1, H, W, 2) -> (B, C, H, W, 2) multi-channel k space
|
| 571 |
+
kspace_pred = sens_expand(kspace_pred, sens_maps)
|
| 572 |
+
elif self.reduction_method == "batch":
|
| 573 |
+
# (B, C, H, W, 2) multi-channel k space
|
| 574 |
+
if self.skip_method == "concat":
|
| 575 |
+
kspace = kspace_pred[:, :1, :, :, :]
|
| 576 |
+
in_kspace = kspace_pred[:, 1:, :, :, :]
|
| 577 |
+
# B, 2C, H, W, 2
|
| 578 |
+
kspace_pred = torch.cat(
|
| 579 |
+
[
|
| 580 |
+
batch_chans_to_chan_dim(kspace, b),
|
| 581 |
+
batch_chans_to_chan_dim(in_kspace, b),
|
| 582 |
+
],
|
| 583 |
+
dim=1,
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
kspace_pred = batch_chans_to_chan_dim(kspace_pred, b)
|
| 587 |
+
|
| 588 |
+
# iterative update
|
| 589 |
+
for cascade in self.cascades:
|
| 590 |
+
kspace_pred = cascade(
|
| 591 |
+
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
spatial_pred = fastmri.ifft2c(kspace_pred)
|
| 595 |
+
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
|
| 596 |
+
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
|
| 597 |
+
|
| 598 |
+
return combined_spatial
|
models/no_varnet_nokno.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NO Varnet WITHOUT KNO for ablation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Iterable, List, Literal, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
import fastmri
|
| 13 |
+
from fastmri import transforms
|
| 14 |
+
from fastmri.datasets import SliceDatasetLMDB, SliceSample
|
| 15 |
+
from models.udno import UDNO
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Calculates F (x sens_maps)
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
x : ndarray
|
| 25 |
+
Single-channel image of shape (..., H, W, 2)
|
| 26 |
+
sens_maps : ndarray
|
| 27 |
+
Sensitivity maps (image space)
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
ndarray
|
| 32 |
+
Result of the operation F (x sens_maps)
|
| 33 |
+
"""
|
| 34 |
+
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Calculates F^{-1}(k) * conj(sens_maps)
|
| 40 |
+
where conj(sens_maps) is the element-wise applied complex conjugate
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
k : ndarray
|
| 45 |
+
Multi-channel k-space of shape (B, C, H, W, 2)
|
| 46 |
+
sens_maps : ndarray
|
| 47 |
+
Sensitivity maps (image space)
|
| 48 |
+
|
| 49 |
+
Returns
|
| 50 |
+
-------
|
| 51 |
+
ndarray
|
| 52 |
+
Result of the operation F^{-1}(k) * conj(sens_maps)
|
| 53 |
+
"""
|
| 54 |
+
return fastmri.complex_mul(
|
| 55 |
+
fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)
|
| 56 |
+
).sum(dim=1, keepdim=True)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 60 |
+
"""Reshapes batched multi-channel samples into multiple single channel samples.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
x : torch.Tensor
|
| 65 |
+
x has shape (b, c, h, w, 2)
|
| 66 |
+
|
| 67 |
+
Returns
|
| 68 |
+
-------
|
| 69 |
+
Tuple[torch.Tensor, int]
|
| 70 |
+
tensor of shape (b * c, 1, h, w, 2), b
|
| 71 |
+
"""
|
| 72 |
+
b, c, h, w, comp = x.shape
|
| 73 |
+
return x.view(b * c, 1, h, w, comp), b
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 77 |
+
"""Reshapes batched independent samples into original multi-channel samples.
|
| 78 |
+
|
| 79 |
+
Parameters
|
| 80 |
+
----------
|
| 81 |
+
x : torch.Tensor
|
| 82 |
+
tensor of shape (b * c, 1, h, w, 2)
|
| 83 |
+
batch_size : int
|
| 84 |
+
batch size
|
| 85 |
+
|
| 86 |
+
Returns
|
| 87 |
+
-------
|
| 88 |
+
torch.Tensor
|
| 89 |
+
original multi-channel tensor of shape (b, c, h, w, 2)
|
| 90 |
+
"""
|
| 91 |
+
bc, _, h, w, comp = x.shape
|
| 92 |
+
c = bc // batch_size
|
| 93 |
+
return x.view(batch_size, c, h, w, comp)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class NormUDNO(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
Normalized UDNO model.
|
| 99 |
+
|
| 100 |
+
Inputs are normalized before the UDNO for numerically stable training.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
chans: int,
|
| 106 |
+
num_pool_layers: int,
|
| 107 |
+
radius_cutoff: float,
|
| 108 |
+
in_shape: Tuple[int, int],
|
| 109 |
+
kernel_shape: Tuple[int, int],
|
| 110 |
+
in_chans: int = 2,
|
| 111 |
+
out_chans: int = 2,
|
| 112 |
+
drop_prob: float = 0.0,
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Initialize the VarNet model.
|
| 116 |
+
|
| 117 |
+
Parameters
|
| 118 |
+
----------
|
| 119 |
+
chans : int
|
| 120 |
+
Number of output channels of the first convolution layer.
|
| 121 |
+
num_pools : int
|
| 122 |
+
Number of down-sampling and up-sampling layers.
|
| 123 |
+
in_chans : int, optional
|
| 124 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 125 |
+
out_chans : int, optional
|
| 126 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 127 |
+
drop_prob : float, optional
|
| 128 |
+
Dropout probability. Default is 0.0.
|
| 129 |
+
"""
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
self.udno = UDNO(
|
| 133 |
+
in_chans=in_chans,
|
| 134 |
+
out_chans=out_chans,
|
| 135 |
+
radius_cutoff=radius_cutoff,
|
| 136 |
+
chans=chans,
|
| 137 |
+
num_pool_layers=num_pool_layers,
|
| 138 |
+
drop_prob=drop_prob,
|
| 139 |
+
in_shape=in_shape,
|
| 140 |
+
kernel_shape=kernel_shape,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
b, c, h, w, two = x.shape
|
| 145 |
+
assert two == 2
|
| 146 |
+
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
|
| 147 |
+
|
| 148 |
+
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
b, c2, h, w = x.shape
|
| 150 |
+
assert c2 % 2 == 0
|
| 151 |
+
c = c2 // 2
|
| 152 |
+
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
|
| 153 |
+
|
| 154 |
+
def norm(
|
| 155 |
+
self, x: torch.Tensor
|
| 156 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 157 |
+
# group norm
|
| 158 |
+
b, c, h, w = x.shape
|
| 159 |
+
x = x.view(b, 2, c // 2 * h * w)
|
| 160 |
+
|
| 161 |
+
mean = x.mean(dim=2).view(b, 2, 1, 1)
|
| 162 |
+
std = x.std(dim=2).view(b, 2, 1, 1)
|
| 163 |
+
|
| 164 |
+
x = x.view(b, c, h, w)
|
| 165 |
+
|
| 166 |
+
return (x - mean) / std, mean, std
|
| 167 |
+
|
| 168 |
+
def norm_new(
|
| 169 |
+
self, x: torch.Tensor
|
| 170 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 171 |
+
# FIXME: not working, wip
|
| 172 |
+
# group norm
|
| 173 |
+
b, c, h, w = x.shape
|
| 174 |
+
num_groups = 2
|
| 175 |
+
assert (
|
| 176 |
+
c % num_groups == 0
|
| 177 |
+
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
|
| 178 |
+
|
| 179 |
+
x = x.view(b, num_groups, c // num_groups * h * w)
|
| 180 |
+
|
| 181 |
+
mean = x.mean(dim=2).view(b, num_groups, 1, 1)
|
| 182 |
+
std = x.std(dim=2).view(b, num_groups, 1, 1)
|
| 183 |
+
print(x.shape, mean.shape, std.shape)
|
| 184 |
+
|
| 185 |
+
x = x.view(b, c, h, w)
|
| 186 |
+
mean = (
|
| 187 |
+
mean.view(b, num_groups, 1, 1)
|
| 188 |
+
.repeat(1, c // num_groups, h, w)
|
| 189 |
+
.view(b, c, h, w)
|
| 190 |
+
)
|
| 191 |
+
std = (
|
| 192 |
+
std.view(b, num_groups, 1, 1)
|
| 193 |
+
.repeat(1, c // num_groups, h, w)
|
| 194 |
+
.view(b, c, h, w)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
return (x - mean) / std, mean, std
|
| 198 |
+
|
| 199 |
+
def unnorm(
|
| 200 |
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
| 201 |
+
) -> torch.Tensor:
|
| 202 |
+
return x * std + mean
|
| 203 |
+
|
| 204 |
+
def pad(
|
| 205 |
+
self, x: torch.Tensor
|
| 206 |
+
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
|
| 207 |
+
_, _, h, w = x.shape
|
| 208 |
+
w_mult = ((w - 1) | 15) + 1
|
| 209 |
+
h_mult = ((h - 1) | 15) + 1
|
| 210 |
+
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
|
| 211 |
+
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
|
| 212 |
+
# TODO: fix this type when PyTorch fixes theirs
|
| 213 |
+
# the documentation lies - this actually takes a list
|
| 214 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
|
| 215 |
+
# https://github.com/pytorch/pytorch/pull/16949
|
| 216 |
+
x = F.pad(x, w_pad + h_pad)
|
| 217 |
+
|
| 218 |
+
return x, (h_pad, w_pad, h_mult, w_mult)
|
| 219 |
+
|
| 220 |
+
def unpad(
|
| 221 |
+
self,
|
| 222 |
+
x: torch.Tensor,
|
| 223 |
+
h_pad: List[int],
|
| 224 |
+
w_pad: List[int],
|
| 225 |
+
h_mult: int,
|
| 226 |
+
w_mult: int,
|
| 227 |
+
) -> torch.Tensor:
|
| 228 |
+
return x[
|
| 229 |
+
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 233 |
+
if not x.shape[-1] == 2:
|
| 234 |
+
raise ValueError("Last dimension must be 2 for complex.")
|
| 235 |
+
|
| 236 |
+
chans = x.shape[1]
|
| 237 |
+
if chans == 2:
|
| 238 |
+
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
|
| 239 |
+
x = self.complex_to_chan_dim(x)
|
| 240 |
+
x = self.udno(x)
|
| 241 |
+
return self.chan_complex_to_last_dim(x)
|
| 242 |
+
|
| 243 |
+
# get shapes for unet and normalize
|
| 244 |
+
x = self.complex_to_chan_dim(x)
|
| 245 |
+
x, mean, std = self.norm(x)
|
| 246 |
+
x, pad_sizes = self.pad(x)
|
| 247 |
+
|
| 248 |
+
x = self.udno(x)
|
| 249 |
+
|
| 250 |
+
# get shapes back and unnormalize
|
| 251 |
+
x = self.unpad(x, *pad_sizes)
|
| 252 |
+
x = self.unnorm(x, mean, std)
|
| 253 |
+
x = self.chan_complex_to_last_dim(x)
|
| 254 |
+
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class SensitivityModel(nn.Module):
|
| 259 |
+
"""
|
| 260 |
+
Learn sensitivity maps
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
chans: int,
|
| 266 |
+
num_pools: int,
|
| 267 |
+
radius_cutoff: float,
|
| 268 |
+
in_shape: Tuple[int, int],
|
| 269 |
+
kernel_shape: Tuple[int, int],
|
| 270 |
+
in_chans: int = 2,
|
| 271 |
+
out_chans: int = 2,
|
| 272 |
+
drop_prob: float = 0.0,
|
| 273 |
+
mask_center: bool = True,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Parameters
|
| 277 |
+
----------
|
| 278 |
+
chans : int
|
| 279 |
+
Number of output channels of the first convolution layer.
|
| 280 |
+
num_pools : int
|
| 281 |
+
Number of down-sampling and up-sampling layers.
|
| 282 |
+
in_chans : int, optional
|
| 283 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 284 |
+
out_chans : int, optional
|
| 285 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 286 |
+
drop_prob : float, optional
|
| 287 |
+
Dropout probability. Default is 0.0.
|
| 288 |
+
mask_center : bool, optional
|
| 289 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 290 |
+
Default is True.
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.mask_center = mask_center
|
| 294 |
+
self.norm_udno = NormUDNO(
|
| 295 |
+
chans,
|
| 296 |
+
num_pools,
|
| 297 |
+
radius_cutoff,
|
| 298 |
+
in_shape,
|
| 299 |
+
kernel_shape,
|
| 300 |
+
in_chans=in_chans,
|
| 301 |
+
out_chans=out_chans,
|
| 302 |
+
drop_prob=drop_prob,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
|
| 306 |
+
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
|
| 307 |
+
|
| 308 |
+
def get_pad_and_num_low_freqs(
|
| 309 |
+
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
|
| 310 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 311 |
+
if num_low_frequencies is None or (isinstance(num_low_frequencies, Iterable) and any(
|
| 312 |
+
torch.any(t == 0) for t in num_low_frequencies
|
| 313 |
+
)):
|
| 314 |
+
# get low frequency line locations and mask them out
|
| 315 |
+
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
|
| 316 |
+
cent = squeezed_mask.shape[1] // 2
|
| 317 |
+
# running argmin returns the first non-zero
|
| 318 |
+
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
|
| 319 |
+
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
|
| 320 |
+
num_low_frequencies_tensor = torch.max(
|
| 321 |
+
2 * torch.min(left, right), torch.ones_like(left)
|
| 322 |
+
) # force a symmetric center unless 1
|
| 323 |
+
else:
|
| 324 |
+
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
|
| 325 |
+
mask.shape[0], dtype=mask.dtype, device=mask.device
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
|
| 329 |
+
|
| 330 |
+
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
masked_kspace: torch.Tensor,
|
| 335 |
+
mask: torch.Tensor,
|
| 336 |
+
num_low_frequencies: Optional[int] = None,
|
| 337 |
+
) -> torch.Tensor:
|
| 338 |
+
if self.mask_center:
|
| 339 |
+
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
|
| 340 |
+
mask, num_low_frequencies
|
| 341 |
+
)
|
| 342 |
+
masked_kspace = transforms.batched_mask_center(
|
| 343 |
+
masked_kspace, pad, pad + num_low_freqs
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# convert to image space
|
| 347 |
+
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
|
| 348 |
+
|
| 349 |
+
# estimate sensitivities
|
| 350 |
+
return self.divide_root_sum_of_squares(
|
| 351 |
+
batch_chans_to_chan_dim(self.norm_udno(images), batches)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class VarNetBlock(nn.Module):
|
| 356 |
+
"""
|
| 357 |
+
Model block for iterative refinement of k-space data.
|
| 358 |
+
|
| 359 |
+
This model applies a combination of soft data consistency with the input
|
| 360 |
+
model as a regularizer. A series of these blocks can be stacked to form
|
| 361 |
+
the full variational network.
|
| 362 |
+
|
| 363 |
+
aka Refinement Module in Fig 1
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, model: nn.Module):
|
| 367 |
+
"""
|
| 368 |
+
Args:
|
| 369 |
+
model: Module for "regularization" component of variational
|
| 370 |
+
network.
|
| 371 |
+
"""
|
| 372 |
+
super().__init__()
|
| 373 |
+
|
| 374 |
+
self.model = model
|
| 375 |
+
self.dc_weight = nn.Parameter(torch.ones(1))
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
current_kspace: torch.Tensor,
|
| 380 |
+
ref_kspace: torch.Tensor,
|
| 381 |
+
mask: torch.Tensor,
|
| 382 |
+
sens_maps: torch.Tensor,
|
| 383 |
+
use_dc_term: bool = True,
|
| 384 |
+
) -> torch.Tensor:
|
| 385 |
+
"""
|
| 386 |
+
Args:
|
| 387 |
+
current_kspace: The current k-space data (frequency domain data)
|
| 388 |
+
being processed by the network. (torch.Tensor)
|
| 389 |
+
ref_kspace: Original subsampled k-space data (from which we are
|
| 390 |
+
reconstrucintg the image (reference k-space). (torch.Tensor)
|
| 391 |
+
mask: A binary mask indicating the locations in k-space where
|
| 392 |
+
data consistency should be enforced. (torch.Tensor)
|
| 393 |
+
sens_maps: Sensitivity maps for the different coils in parallel
|
| 394 |
+
imaging. (torch.Tensor)
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
# model-term see orange box of Fig 1 in E2E-VarNet paper!
|
| 398 |
+
# multi channel k-space -> single channel image-space
|
| 399 |
+
b, c, h, w, _ = current_kspace.shape
|
| 400 |
+
|
| 401 |
+
if c == 30:
|
| 402 |
+
# get kspace and inpainted kspace
|
| 403 |
+
kspace = current_kspace[:, :15, :, :, :]
|
| 404 |
+
in_kspace = current_kspace[:, 15:, :, :, :]
|
| 405 |
+
# convert to image space
|
| 406 |
+
image = sens_reduce(kspace, sens_maps)
|
| 407 |
+
in_image = sens_reduce(in_kspace, sens_maps)
|
| 408 |
+
# concatenate both onto each other
|
| 409 |
+
reduced_image = torch.cat([image, in_image], dim=1)
|
| 410 |
+
else:
|
| 411 |
+
reduced_image = sens_reduce(current_kspace, sens_maps)
|
| 412 |
+
|
| 413 |
+
# single channel image-space
|
| 414 |
+
refined_image = self.model(reduced_image)
|
| 415 |
+
|
| 416 |
+
# single channel image-space -> multi channel k-space
|
| 417 |
+
model_term = sens_expand(refined_image, sens_maps)
|
| 418 |
+
|
| 419 |
+
# only use first 15 channels (masked_kspace) in the update
|
| 420 |
+
# current_kspace = current_kspace[:, :15, :, :, :]
|
| 421 |
+
|
| 422 |
+
if not use_dc_term:
|
| 423 |
+
return current_kspace - model_term
|
| 424 |
+
|
| 425 |
+
"""
|
| 426 |
+
Soft data consistency term:
|
| 427 |
+
- Calculates the difference between current k-space and reference k-space where the mask is true.
|
| 428 |
+
- Multiplies this difference by the data consistency weight.
|
| 429 |
+
"""
|
| 430 |
+
# dc_term: see green box of Fig 1 in E2E-VarNet paper!
|
| 431 |
+
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
|
| 432 |
+
soft_dc = (
|
| 433 |
+
torch.where(mask, current_kspace - ref_kspace, zero)
|
| 434 |
+
* self.dc_weight
|
| 435 |
+
)
|
| 436 |
+
return current_kspace - soft_dc - model_term
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class NOVarnet_no_KNO(nn.Module):
|
| 440 |
+
"""
|
| 441 |
+
Neural Operator model for MRI reconstruction.
|
| 442 |
+
|
| 443 |
+
Uses a variational architecture (iterative updates) with a learned sensitivity
|
| 444 |
+
model. All operations are resolution invariant employing neural operator
|
| 445 |
+
modules (GNO, UDNO).
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
def __init__(
|
| 449 |
+
self,
|
| 450 |
+
num_cascades: int = 12,
|
| 451 |
+
sens_chans: int = 8,
|
| 452 |
+
sens_pools: int = 4,
|
| 453 |
+
chans: int = 18,
|
| 454 |
+
pools: int = 4,
|
| 455 |
+
gno_chans: int = 16,
|
| 456 |
+
gno_pools: int = 4,
|
| 457 |
+
gno_radius_cutoff: float = 0.02,
|
| 458 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 459 |
+
radius_cutoff: float = 0.01,
|
| 460 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 461 |
+
in_shape: Tuple[int, int] = (640, 320),
|
| 462 |
+
mask_center: bool = True,
|
| 463 |
+
use_dc_term: bool = True,
|
| 464 |
+
reduction_method: Literal["batch", "rss"] = "rss",
|
| 465 |
+
skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
|
| 466 |
+
):
|
| 467 |
+
"""
|
| 468 |
+
Parameters
|
| 469 |
+
----------
|
| 470 |
+
num_cascades : int
|
| 471 |
+
Number of cascades (i.e., layers) for variational network.
|
| 472 |
+
sens_chans : int
|
| 473 |
+
Number of channels for sensitivity map U-Net.
|
| 474 |
+
sens_pools : int
|
| 475 |
+
Number of downsampling and upsampling layers for sensitivity map U-Net.
|
| 476 |
+
chans : int
|
| 477 |
+
Number of channels for cascade U-Net.
|
| 478 |
+
pools : int
|
| 479 |
+
Number of downsampling and upsampling layers for cascade U-Net.
|
| 480 |
+
mask_center : bool
|
| 481 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 482 |
+
use_dc_term : bool
|
| 483 |
+
Whether to use the data consistency term.
|
| 484 |
+
reduction_method : "batch" or "rss"
|
| 485 |
+
Method for reducing sensitivity maps to single channel.
|
| 486 |
+
"batch" reduces to single channel by stacking channels.
|
| 487 |
+
"rss" reduces to single channel by root sum of squares.
|
| 488 |
+
skip_method : "replace" or "add" or "add_inv" or "concat"
|
| 489 |
+
"replace" replaces the input with the output of the GNO
|
| 490 |
+
"add" adds the output of the GNO to the input
|
| 491 |
+
"add_inv" adds the output of the GNO to the input (only where samples are missing)
|
| 492 |
+
"concat" concatenates the output of the GNO to the input
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
super().__init__()
|
| 496 |
+
|
| 497 |
+
self.sens_net = SensitivityModel(
|
| 498 |
+
sens_chans,
|
| 499 |
+
sens_pools,
|
| 500 |
+
radius_cutoff,
|
| 501 |
+
in_shape,
|
| 502 |
+
kernel_shape,
|
| 503 |
+
mask_center=mask_center,
|
| 504 |
+
)
|
| 505 |
+
# self.gno = NormUDNO(
|
| 506 |
+
# gno_chans,
|
| 507 |
+
# gno_pools,
|
| 508 |
+
# in_shape=in_shape,
|
| 509 |
+
# radius_cutoff=radius_cutoff,
|
| 510 |
+
# kernel_shape=kernel_shape,
|
| 511 |
+
# # radius_cutoff=gno_radius_cutoff,
|
| 512 |
+
# # kernel_shape=gno_kernel_shape,
|
| 513 |
+
# in_chans=2,
|
| 514 |
+
# out_chans=2,
|
| 515 |
+
# )
|
| 516 |
+
self.cascades = nn.ModuleList(
|
| 517 |
+
[
|
| 518 |
+
VarNetBlock(
|
| 519 |
+
NormUDNO(
|
| 520 |
+
chans,
|
| 521 |
+
pools,
|
| 522 |
+
radius_cutoff,
|
| 523 |
+
in_shape,
|
| 524 |
+
kernel_shape,
|
| 525 |
+
in_chans=(
|
| 526 |
+
4
|
| 527 |
+
if skip_method == "concat" and cascade_idx == 0
|
| 528 |
+
else 2
|
| 529 |
+
),
|
| 530 |
+
out_chans=2,
|
| 531 |
+
)
|
| 532 |
+
)
|
| 533 |
+
for cascade_idx in range(num_cascades)
|
| 534 |
+
]
|
| 535 |
+
)
|
| 536 |
+
self.use_dc_term = use_dc_term
|
| 537 |
+
self.reduction_method = reduction_method # not used anywhere anymore
|
| 538 |
+
self.skip_method = skip_method # not used anywhere anymore
|
| 539 |
+
|
| 540 |
+
def forward(
|
| 541 |
+
self,
|
| 542 |
+
masked_kspace: torch.Tensor,
|
| 543 |
+
mask: torch.Tensor,
|
| 544 |
+
num_low_frequencies: Optional[int] = None,
|
| 545 |
+
) -> torch.Tensor:
|
| 546 |
+
|
| 547 |
+
# (B, C, X, Y, 2)
|
| 548 |
+
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
|
| 549 |
+
|
| 550 |
+
kspace_pred = masked_kspace.clone()
|
| 551 |
+
|
| 552 |
+
# iterative update
|
| 553 |
+
for cascade in self.cascades:
|
| 554 |
+
kspace_pred = cascade(
|
| 555 |
+
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
spatial_pred = fastmri.ifft2c(kspace_pred)
|
| 559 |
+
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
|
| 560 |
+
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
|
| 561 |
+
|
| 562 |
+
return combined_spatial
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if __name__ == "__main__":
|
| 566 |
+
ds = SliceDatasetLMDB(
|
| 567 |
+
"knee",
|
| 568 |
+
partition="train",
|
| 569 |
+
mask_fns=None, # type: ignore
|
| 570 |
+
complex=False,
|
| 571 |
+
sample_rate=0.5,
|
| 572 |
+
crop_shape=(320, 320),
|
| 573 |
+
coils=15,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
sample: SliceSample = ds[0]
|
| 577 |
+
kspace = sample.masked_kspace
|
| 578 |
+
target = sample.target
|
| 579 |
+
|
| 580 |
+
model = NOVarnet_no_KNO(1)
|
| 581 |
+
res = model.forward(sample.masked_kspace.unsqueeze(0), sample.mask.unsqueeze(0), torch.tensor(sample.num_low_frequencies).unsqueeze(0))
|
models/temp/no_repeatk.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import pdb
|
| 3 |
+
from typing import List, Optional, Tuple, Literal
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
import fastmri
|
| 10 |
+
from fastmri import transforms
|
| 11 |
+
|
| 12 |
+
from models.udno import UDNO
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Calculates F (x sens_maps)
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
x : ndarray
|
| 22 |
+
Single-channel image of shape (..., H, W, 2)
|
| 23 |
+
sens_maps : ndarray
|
| 24 |
+
Sensitivity maps (image space)
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
ndarray
|
| 29 |
+
Result of the operation F (x sens_maps)
|
| 30 |
+
"""
|
| 31 |
+
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
Calculates F^{-1}(k) * conj(sens_maps)
|
| 37 |
+
where conj(sens_maps) is the element-wise applied complex conjugate
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
k : ndarray
|
| 42 |
+
Multi-channel k-space of shape (B, C, H, W, 2)
|
| 43 |
+
sens_maps : ndarray
|
| 44 |
+
Sensitivity maps (image space)
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
ndarray
|
| 49 |
+
Result of the operation F^{-1}(k) * conj(sens_maps)
|
| 50 |
+
"""
|
| 51 |
+
return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum(
|
| 52 |
+
dim=1, keepdim=True
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 57 |
+
"""Reshapes batched multi-channel samples into multiple single channel samples.
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
x : torch.Tensor
|
| 62 |
+
x has shape (b, c, h, w, 2)
|
| 63 |
+
|
| 64 |
+
Returns
|
| 65 |
+
-------
|
| 66 |
+
Tuple[torch.Tensor, int]
|
| 67 |
+
tensor of shape (b * c, 1, h, w, 2), b
|
| 68 |
+
"""
|
| 69 |
+
b, c, h, w, comp = x.shape
|
| 70 |
+
return x.view(b * c, 1, h, w, comp), b
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 74 |
+
"""Reshapes batched independent samples into original multi-channel samples.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
x : torch.Tensor
|
| 79 |
+
tensor of shape (b * c, 1, h, w, 2)
|
| 80 |
+
batch_size : int
|
| 81 |
+
batch size
|
| 82 |
+
|
| 83 |
+
Returns
|
| 84 |
+
-------
|
| 85 |
+
torch.Tensor
|
| 86 |
+
original multi-channel tensor of shape (b, c, h, w, 2)
|
| 87 |
+
"""
|
| 88 |
+
bc, _, h, w, comp = x.shape
|
| 89 |
+
c = bc // batch_size
|
| 90 |
+
return x.view(batch_size, c, h, w, comp)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class NormUDNO(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
Normalized UDNO model.
|
| 96 |
+
|
| 97 |
+
Inputs are normalized before the UDNO for numerically stable training.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
chans: int,
|
| 103 |
+
num_pool_layers: int,
|
| 104 |
+
radius_cutoff: float,
|
| 105 |
+
in_shape: Tuple[int, int],
|
| 106 |
+
kernel_shape: Tuple[int, int],
|
| 107 |
+
in_chans: int = 2,
|
| 108 |
+
out_chans: int = 2,
|
| 109 |
+
drop_prob: float = 0.0,
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Initialize the VarNet model.
|
| 113 |
+
|
| 114 |
+
Parameters
|
| 115 |
+
----------
|
| 116 |
+
chans : int
|
| 117 |
+
Number of output channels of the first convolution layer.
|
| 118 |
+
num_pools : int
|
| 119 |
+
Number of down-sampling and up-sampling layers.
|
| 120 |
+
in_chans : int, optional
|
| 121 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 122 |
+
out_chans : int, optional
|
| 123 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 124 |
+
drop_prob : float, optional
|
| 125 |
+
Dropout probability. Default is 0.0.
|
| 126 |
+
"""
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
self.udno = UDNO(
|
| 130 |
+
in_chans=in_chans,
|
| 131 |
+
out_chans=out_chans,
|
| 132 |
+
radius_cutoff=radius_cutoff,
|
| 133 |
+
chans=chans,
|
| 134 |
+
num_pool_layers=num_pool_layers,
|
| 135 |
+
drop_prob=drop_prob,
|
| 136 |
+
in_shape=in_shape,
|
| 137 |
+
kernel_shape=kernel_shape,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
b, c, h, w, two = x.shape
|
| 142 |
+
assert two == 2
|
| 143 |
+
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
|
| 144 |
+
|
| 145 |
+
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 146 |
+
b, c2, h, w = x.shape
|
| 147 |
+
assert c2 % 2 == 0
|
| 148 |
+
c = c2 // 2
|
| 149 |
+
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
|
| 150 |
+
|
| 151 |
+
def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 152 |
+
# group norm
|
| 153 |
+
b, c, h, w = x.shape
|
| 154 |
+
x = x.view(b, 2, c // 2 * h * w)
|
| 155 |
+
|
| 156 |
+
mean = x.mean(dim=2).view(b, 2, 1, 1)
|
| 157 |
+
std = x.std(dim=2).view(b, 2, 1, 1)
|
| 158 |
+
|
| 159 |
+
x = x.view(b, c, h, w)
|
| 160 |
+
|
| 161 |
+
return (x - mean) / std, mean, std
|
| 162 |
+
|
| 163 |
+
def norm_new(
|
| 164 |
+
self, x: torch.Tensor
|
| 165 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 166 |
+
# FIXME: not working, wip
|
| 167 |
+
# group norm
|
| 168 |
+
b, c, h, w = x.shape
|
| 169 |
+
num_groups = 2
|
| 170 |
+
assert (
|
| 171 |
+
c % num_groups == 0
|
| 172 |
+
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
|
| 173 |
+
|
| 174 |
+
x = x.view(b, num_groups, c // num_groups * h * w)
|
| 175 |
+
|
| 176 |
+
mean = x.mean(dim=2).view(b, num_groups, 1, 1)
|
| 177 |
+
std = x.std(dim=2).view(b, num_groups, 1, 1)
|
| 178 |
+
print(x.shape, mean.shape, std.shape)
|
| 179 |
+
|
| 180 |
+
x = x.view(b, c, h, w)
|
| 181 |
+
mean = (
|
| 182 |
+
mean.view(b, num_groups, 1, 1)
|
| 183 |
+
.repeat(1, c // num_groups, h, w)
|
| 184 |
+
.view(b, c, h, w)
|
| 185 |
+
)
|
| 186 |
+
std = (
|
| 187 |
+
std.view(b, num_groups, 1, 1)
|
| 188 |
+
.repeat(1, c // num_groups, h, w)
|
| 189 |
+
.view(b, c, h, w)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return (x - mean) / std, mean, std
|
| 193 |
+
|
| 194 |
+
def unnorm(
|
| 195 |
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
return x * std + mean
|
| 198 |
+
|
| 199 |
+
def pad(
|
| 200 |
+
self, x: torch.Tensor
|
| 201 |
+
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
|
| 202 |
+
_, _, h, w = x.shape
|
| 203 |
+
w_mult = ((w - 1) | 15) + 1
|
| 204 |
+
h_mult = ((h - 1) | 15) + 1
|
| 205 |
+
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
|
| 206 |
+
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
|
| 207 |
+
# TODO: fix this type when PyTorch fixes theirs
|
| 208 |
+
# the documentation lies - this actually takes a list
|
| 209 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
|
| 210 |
+
# https://github.com/pytorch/pytorch/pull/16949
|
| 211 |
+
x = F.pad(x, w_pad + h_pad)
|
| 212 |
+
|
| 213 |
+
return x, (h_pad, w_pad, h_mult, w_mult)
|
| 214 |
+
|
| 215 |
+
def unpad(
|
| 216 |
+
self,
|
| 217 |
+
x: torch.Tensor,
|
| 218 |
+
h_pad: List[int],
|
| 219 |
+
w_pad: List[int],
|
| 220 |
+
h_mult: int,
|
| 221 |
+
w_mult: int,
|
| 222 |
+
) -> torch.Tensor:
|
| 223 |
+
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
if not x.shape[-1] == 2:
|
| 227 |
+
raise ValueError("Last dimension must be 2 for complex.")
|
| 228 |
+
|
| 229 |
+
chans = x.shape[1]
|
| 230 |
+
if chans == 2:
|
| 231 |
+
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
|
| 232 |
+
x = self.complex_to_chan_dim(x)
|
| 233 |
+
x = self.udno(x)
|
| 234 |
+
return self.chan_complex_to_last_dim(x)
|
| 235 |
+
|
| 236 |
+
# get shapes for unet and normalize
|
| 237 |
+
x = self.complex_to_chan_dim(x)
|
| 238 |
+
x, mean, std = self.norm(x)
|
| 239 |
+
x, pad_sizes = self.pad(x)
|
| 240 |
+
|
| 241 |
+
x = self.udno(x)
|
| 242 |
+
|
| 243 |
+
# get shapes back and unnormalize
|
| 244 |
+
x = self.unpad(x, *pad_sizes)
|
| 245 |
+
x = self.unnorm(x, mean, std)
|
| 246 |
+
x = self.chan_complex_to_last_dim(x)
|
| 247 |
+
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class SensitivityModel(nn.Module):
|
| 252 |
+
"""
|
| 253 |
+
Learn sensitivity maps
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
chans: int,
|
| 259 |
+
num_pools: int,
|
| 260 |
+
radius_cutoff: float,
|
| 261 |
+
in_shape: Tuple[int, int],
|
| 262 |
+
kernel_shape: Tuple[int, int],
|
| 263 |
+
in_chans: int = 2,
|
| 264 |
+
out_chans: int = 2,
|
| 265 |
+
drop_prob: float = 0.0,
|
| 266 |
+
mask_center: bool = True,
|
| 267 |
+
):
|
| 268 |
+
"""
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
chans : int
|
| 272 |
+
Number of output channels of the first convolution layer.
|
| 273 |
+
num_pools : int
|
| 274 |
+
Number of down-sampling and up-sampling layers.
|
| 275 |
+
in_chans : int, optional
|
| 276 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 277 |
+
out_chans : int, optional
|
| 278 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 279 |
+
drop_prob : float, optional
|
| 280 |
+
Dropout probability. Default is 0.0.
|
| 281 |
+
mask_center : bool, optional
|
| 282 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 283 |
+
Default is True.
|
| 284 |
+
"""
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.mask_center = mask_center
|
| 287 |
+
self.norm_udno = NormUDNO(
|
| 288 |
+
chans,
|
| 289 |
+
num_pools,
|
| 290 |
+
radius_cutoff,
|
| 291 |
+
in_shape,
|
| 292 |
+
kernel_shape,
|
| 293 |
+
in_chans=in_chans,
|
| 294 |
+
out_chans=out_chans,
|
| 295 |
+
drop_prob=drop_prob,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
|
| 299 |
+
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
|
| 300 |
+
|
| 301 |
+
def get_pad_and_num_low_freqs(
|
| 302 |
+
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
|
| 303 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 304 |
+
if num_low_frequencies is None or any(
|
| 305 |
+
torch.any(t == 0) for t in num_low_frequencies
|
| 306 |
+
):
|
| 307 |
+
# get low frequency line locations and mask them out
|
| 308 |
+
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
|
| 309 |
+
cent = squeezed_mask.shape[1] // 2
|
| 310 |
+
# running argmin returns the first non-zero
|
| 311 |
+
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
|
| 312 |
+
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
|
| 313 |
+
num_low_frequencies_tensor = torch.max(
|
| 314 |
+
2 * torch.min(left, right), torch.ones_like(left)
|
| 315 |
+
) # force a symmetric center unless 1
|
| 316 |
+
else:
|
| 317 |
+
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
|
| 318 |
+
mask.shape[0], dtype=mask.dtype, device=mask.device
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
|
| 322 |
+
|
| 323 |
+
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
|
| 324 |
+
|
| 325 |
+
def forward(
|
| 326 |
+
self,
|
| 327 |
+
masked_kspace: torch.Tensor,
|
| 328 |
+
mask: torch.Tensor,
|
| 329 |
+
num_low_frequencies: Optional[int] = None,
|
| 330 |
+
) -> torch.Tensor:
|
| 331 |
+
if self.mask_center:
|
| 332 |
+
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
|
| 333 |
+
mask, num_low_frequencies
|
| 334 |
+
)
|
| 335 |
+
masked_kspace = transforms.batched_mask_center(
|
| 336 |
+
masked_kspace, pad, pad + num_low_freqs
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# convert to image space
|
| 340 |
+
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
|
| 341 |
+
|
| 342 |
+
# estimate sensitivities
|
| 343 |
+
return self.divide_root_sum_of_squares(
|
| 344 |
+
batch_chans_to_chan_dim(self.norm_udno(images), batches)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class VarNetBlock(nn.Module):
|
| 349 |
+
"""
|
| 350 |
+
Model block for iterative refinement of k-space data.
|
| 351 |
+
|
| 352 |
+
This model applies a combination of soft data consistency with the input
|
| 353 |
+
model as a regularizer. A series of these blocks can be stacked to form
|
| 354 |
+
the full variational network.
|
| 355 |
+
|
| 356 |
+
aka Refinement Module in Fig 1
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
def __init__(self, model: nn.Module):
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
model: Module for "regularization" component of variational
|
| 363 |
+
network.
|
| 364 |
+
"""
|
| 365 |
+
super().__init__()
|
| 366 |
+
|
| 367 |
+
self.model = model
|
| 368 |
+
self.dc_weight = nn.Parameter(torch.ones(1))
|
| 369 |
+
|
| 370 |
+
def forward(
|
| 371 |
+
self,
|
| 372 |
+
current_kspace: torch.Tensor,
|
| 373 |
+
ref_kspace: torch.Tensor,
|
| 374 |
+
mask: torch.Tensor,
|
| 375 |
+
sens_maps: torch.Tensor,
|
| 376 |
+
use_dc_term: bool = True,
|
| 377 |
+
) -> torch.Tensor:
|
| 378 |
+
"""
|
| 379 |
+
Args:
|
| 380 |
+
current_kspace: The current k-space data (frequency domain data)
|
| 381 |
+
being processed by the network. (torch.Tensor)
|
| 382 |
+
ref_kspace: Original subsampled k-space data (from which we are
|
| 383 |
+
reconstrucintg the image (reference k-space). (torch.Tensor)
|
| 384 |
+
mask: A binary mask indicating the locations in k-space where
|
| 385 |
+
data consistency should be enforced. (torch.Tensor)
|
| 386 |
+
sens_maps: Sensitivity maps for the different coils in parallel
|
| 387 |
+
imaging. (torch.Tensor)
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
# model-term see orange box of Fig 1 in E2E-VarNet paper!
|
| 391 |
+
# multi channel k-space -> single channel image-space
|
| 392 |
+
b, c, h, w, _ = current_kspace.shape
|
| 393 |
+
|
| 394 |
+
if c == 30:
|
| 395 |
+
# get kspace and inpainted kspace
|
| 396 |
+
kspace = current_kspace[:, :15, :, :, :]
|
| 397 |
+
in_kspace = current_kspace[:, 15:, :, :, :]
|
| 398 |
+
# convert to image space
|
| 399 |
+
image = sens_reduce(kspace, sens_maps)
|
| 400 |
+
in_image = sens_reduce(in_kspace, sens_maps)
|
| 401 |
+
# concatenate both onto each other
|
| 402 |
+
reduced_image = torch.cat([image, in_image], dim=1)
|
| 403 |
+
else:
|
| 404 |
+
reduced_image = sens_reduce(current_kspace, sens_maps)
|
| 405 |
+
|
| 406 |
+
# single channel image-space
|
| 407 |
+
refined_image = self.model(reduced_image)
|
| 408 |
+
|
| 409 |
+
# single channel image-space -> multi channel k-space
|
| 410 |
+
model_term = sens_expand(refined_image, sens_maps)
|
| 411 |
+
|
| 412 |
+
# only use first 15 channels (masked_kspace) in the update
|
| 413 |
+
current_kspace = current_kspace[:, :15, :, :, :]
|
| 414 |
+
|
| 415 |
+
if not use_dc_term:
|
| 416 |
+
return current_kspace - model_term
|
| 417 |
+
|
| 418 |
+
"""
|
| 419 |
+
Soft data consistency term:
|
| 420 |
+
- Calculates the difference between current k-space and reference k-space where the mask is true.
|
| 421 |
+
- Multiplies this difference by the data consistency weight.
|
| 422 |
+
"""
|
| 423 |
+
# dc_term: see green box of Fig 1 in E2E-VarNet paper!
|
| 424 |
+
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
|
| 425 |
+
soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
|
| 426 |
+
return current_kspace - soft_dc - model_term
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class NOVarnet(nn.Module):
|
| 430 |
+
"""
|
| 431 |
+
Neural Operator model for MRI reconstruction.
|
| 432 |
+
|
| 433 |
+
Uses a variational architecture (iterative updates) with a learned sensitivity
|
| 434 |
+
model. All operations are resolution invariant employing neural operator
|
| 435 |
+
modules (GNO, UDNO).
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
num_cascades: int = 12,
|
| 441 |
+
sens_chans: int = 8,
|
| 442 |
+
sens_pools: int = 4,
|
| 443 |
+
chans: int = 18,
|
| 444 |
+
pools: int = 4,
|
| 445 |
+
gno_chans: int = 16,
|
| 446 |
+
gno_pools: int = 4,
|
| 447 |
+
gno_radius_cutoff: float = 0.02,
|
| 448 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 449 |
+
radius_cutoff: float = 0.01,
|
| 450 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 451 |
+
in_shape: Tuple[int, int] = (640, 320),
|
| 452 |
+
mask_center: bool = True,
|
| 453 |
+
use_dc_term: bool = True,
|
| 454 |
+
reduction_method: Literal["batch", "rss"] = "rss",
|
| 455 |
+
skip_method: Literal["replace", "add", "add_inv", "concat"] = "add",
|
| 456 |
+
):
|
| 457 |
+
"""
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
num_cascades : int
|
| 461 |
+
Number of cascades (i.e., layers) for variational network.
|
| 462 |
+
sens_chans : int
|
| 463 |
+
Number of channels for sensitivity map U-Net.
|
| 464 |
+
sens_pools : int
|
| 465 |
+
Number of downsampling and upsampling layers for sensitivity map U-Net.
|
| 466 |
+
chans : int
|
| 467 |
+
Number of channels for cascade U-Net.
|
| 468 |
+
pools : int
|
| 469 |
+
Number of downsampling and upsampling layers for cascade U-Net.
|
| 470 |
+
mask_center : bool
|
| 471 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 472 |
+
use_dc_term : bool
|
| 473 |
+
Whether to use the data consistency term.
|
| 474 |
+
reduction_method : "batch" or "rss"
|
| 475 |
+
Method for reducing sensitivity maps to single channel.
|
| 476 |
+
"batch" reduces to single channel by stacking channels.
|
| 477 |
+
"rss" reduces to single channel by root sum of squares.
|
| 478 |
+
skip_method : "replace" or "add" or "add_inv" or "concat"
|
| 479 |
+
"replace" replaces the input with the output of the GNO
|
| 480 |
+
"add" adds the output of the GNO to the input
|
| 481 |
+
"add_inv" adds the output of the GNO to the input (only where samples are missing)
|
| 482 |
+
"concat" concatenates the output of the GNO to the input
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
super().__init__()
|
| 486 |
+
|
| 487 |
+
self.sens_net = SensitivityModel(
|
| 488 |
+
sens_chans,
|
| 489 |
+
sens_pools,
|
| 490 |
+
radius_cutoff,
|
| 491 |
+
in_shape,
|
| 492 |
+
kernel_shape,
|
| 493 |
+
mask_center=mask_center,
|
| 494 |
+
)
|
| 495 |
+
self.gno = NormUDNO(
|
| 496 |
+
gno_chans,
|
| 497 |
+
gno_pools,
|
| 498 |
+
in_shape=in_shape,
|
| 499 |
+
radius_cutoff=radius_cutoff,
|
| 500 |
+
kernel_shape=kernel_shape,
|
| 501 |
+
in_chans=2,
|
| 502 |
+
out_chans=2,
|
| 503 |
+
)
|
| 504 |
+
self.cascades = nn.ModuleList(
|
| 505 |
+
[
|
| 506 |
+
VarNetBlock(
|
| 507 |
+
NormUDNO(
|
| 508 |
+
chans,
|
| 509 |
+
pools,
|
| 510 |
+
radius_cutoff,
|
| 511 |
+
in_shape,
|
| 512 |
+
kernel_shape,
|
| 513 |
+
in_chans=(
|
| 514 |
+
4 if skip_method == "concat" and cascade_idx == 0 else 2
|
| 515 |
+
),
|
| 516 |
+
out_chans=2,
|
| 517 |
+
)
|
| 518 |
+
)
|
| 519 |
+
for cascade_idx in range(num_cascades)
|
| 520 |
+
]
|
| 521 |
+
)
|
| 522 |
+
self.use_dc_term = use_dc_term
|
| 523 |
+
self.reduction_method = reduction_method
|
| 524 |
+
self.skip_method = skip_method
|
| 525 |
+
|
| 526 |
+
print("===================================")
|
| 527 |
+
print("===================================")
|
| 528 |
+
print("initialized no repeat k ")
|
| 529 |
+
print("===================================")
|
| 530 |
+
print("===================================")
|
| 531 |
+
|
| 532 |
+
def forward(
|
| 533 |
+
self,
|
| 534 |
+
masked_kspace: torch.Tensor,
|
| 535 |
+
mask: torch.Tensor,
|
| 536 |
+
num_low_frequencies: Optional[int] = None,
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
|
| 539 |
+
# (B, C, X, Y, 2)
|
| 540 |
+
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
|
| 541 |
+
|
| 542 |
+
kspace_pred = masked_kspace
|
| 543 |
+
# iterative update
|
| 544 |
+
for cascade in self.cascades:
|
| 545 |
+
# breakpoint()
|
| 546 |
+
sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies)
|
| 547 |
+
# reduce before inpainting
|
| 548 |
+
kspace_pred, b = chans_to_batch_dim(kspace_pred)
|
| 549 |
+
# inpainting
|
| 550 |
+
kspace_pred = self.gno(kspace_pred)
|
| 551 |
+
kspace_pred = batch_chans_to_chan_dim(kspace_pred, b)
|
| 552 |
+
|
| 553 |
+
# image
|
| 554 |
+
kspace_pred = cascade(
|
| 555 |
+
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
spatial_pred = fastmri.ifft2c(kspace_pred)
|
| 559 |
+
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
|
| 560 |
+
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
|
| 561 |
+
|
| 562 |
+
return combined_spatial
|
models/temp/no_repeatk_module.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import fastmri
|
| 7 |
+
from fastmri import transforms
|
| 8 |
+
from models.temp.no_repeatk import NOVarnet
|
| 9 |
+
|
| 10 |
+
from models.lightning.mri_module import MriModule
|
| 11 |
+
from type_utils import tuple_type
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NORepeatKModule(MriModule):
|
| 15 |
+
"""
|
| 16 |
+
NO-Varnet repeat-k (temp) training module.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
num_cascades: int = 12,
|
| 22 |
+
pools: int = 4,
|
| 23 |
+
chans: int = 18,
|
| 24 |
+
sens_pools: int = 4,
|
| 25 |
+
sens_chans: int = 8,
|
| 26 |
+
gno_pools: int = 4,
|
| 27 |
+
gno_chans: int = 16,
|
| 28 |
+
gno_radius_cutoff: float = 0.02,
|
| 29 |
+
gno_kernel_shape: Tuple[int, int] = (6, 7),
|
| 30 |
+
radius_cutoff: float = 0.02,
|
| 31 |
+
kernel_shape: Tuple[int, int] = (6, 7),
|
| 32 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 33 |
+
use_dc_term: bool = True,
|
| 34 |
+
lr: float = 0.0003,
|
| 35 |
+
lr_step_size: int = 40,
|
| 36 |
+
lr_gamma: float = 0.1,
|
| 37 |
+
weight_decay: float = 0.0,
|
| 38 |
+
reduction_method: str = "rss",
|
| 39 |
+
skip_method: str = "add",
|
| 40 |
+
**kwargs,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
num_cascades : int
|
| 46 |
+
Number of cascades (i.e., layers) for the variational network.
|
| 47 |
+
pools : int
|
| 48 |
+
Number of downsampling and upsampling layers for the cascade U-Net.
|
| 49 |
+
chans : int
|
| 50 |
+
Number of channels for the cascade U-Net.
|
| 51 |
+
sens_pools : int
|
| 52 |
+
Number of downsampling and upsampling layers for the sensitivity map U-Net.
|
| 53 |
+
sens_chans : int
|
| 54 |
+
Number of channels for the sensitivity map U-Net.
|
| 55 |
+
lr : float
|
| 56 |
+
Learning rate.
|
| 57 |
+
lr_step_size : int
|
| 58 |
+
Learning rate step size.
|
| 59 |
+
lr_gamma : float
|
| 60 |
+
Learning rate gamma decay.
|
| 61 |
+
weight_decay : float
|
| 62 |
+
Parameter for penalizing weights norm.
|
| 63 |
+
num_sense_lines : int, optional
|
| 64 |
+
Number of low-frequency lines to use for sensitivity map computation.
|
| 65 |
+
Must be even or `None`. Default `None` will automatically compute the number
|
| 66 |
+
from masks. Default behavior may cause some slices to use more low-frequency
|
| 67 |
+
lines than others, when used in conjunction with e.g. the EquispacedMaskFunc
|
| 68 |
+
defaults. To prevent this, either set `num_sense_lines`, or set
|
| 69 |
+
`skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc.
|
| 70 |
+
Note that setting this value may lead to undesired behavior when training on
|
| 71 |
+
multiple accelerations simultaneously.
|
| 72 |
+
"""
|
| 73 |
+
super().__init__(**kwargs)
|
| 74 |
+
self.save_hyperparameters()
|
| 75 |
+
|
| 76 |
+
self.num_cascades = num_cascades
|
| 77 |
+
self.pools = pools
|
| 78 |
+
self.chans = chans
|
| 79 |
+
self.sens_pools = sens_pools
|
| 80 |
+
self.sens_chans = sens_chans
|
| 81 |
+
self.gno_pools = gno_pools
|
| 82 |
+
self.gno_chans = gno_chans
|
| 83 |
+
self.gno_radius_cutoff = gno_radius_cutoff
|
| 84 |
+
self.gno_kernel_shape = gno_kernel_shape
|
| 85 |
+
self.radius_cutoff = radius_cutoff
|
| 86 |
+
self.kernel_shape = kernel_shape
|
| 87 |
+
self.in_shape = in_shape
|
| 88 |
+
self.use_dc_term = use_dc_term
|
| 89 |
+
self.lr = lr
|
| 90 |
+
self.lr_step_size = lr_step_size
|
| 91 |
+
self.lr_gamma = lr_gamma
|
| 92 |
+
self.weight_decay = weight_decay
|
| 93 |
+
self.reduction_method = reduction_method
|
| 94 |
+
self.skip_method = skip_method
|
| 95 |
+
|
| 96 |
+
self.model = NOVarnet(
|
| 97 |
+
num_cascades=self.num_cascades,
|
| 98 |
+
sens_chans=self.sens_chans,
|
| 99 |
+
sens_pools=self.sens_pools,
|
| 100 |
+
chans=self.chans,
|
| 101 |
+
pools=self.pools,
|
| 102 |
+
gno_chans=self.gno_chans,
|
| 103 |
+
gno_pools=self.gno_pools,
|
| 104 |
+
gno_radius_cutoff=self.gno_radius_cutoff,
|
| 105 |
+
gno_kernel_shape=self.gno_kernel_shape,
|
| 106 |
+
radius_cutoff=radius_cutoff,
|
| 107 |
+
kernel_shape=kernel_shape,
|
| 108 |
+
in_shape=in_shape,
|
| 109 |
+
use_dc_term=use_dc_term,
|
| 110 |
+
reduction_method=reduction_method,
|
| 111 |
+
skip_method=skip_method,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.criterion = fastmri.SSIMLoss()
|
| 115 |
+
self.num_params = sum(p.numel() for p in self.parameters())
|
| 116 |
+
|
| 117 |
+
def forward(self, masked_kspace, mask, num_low_frequencies):
|
| 118 |
+
return self.model(masked_kspace, mask, num_low_frequencies)
|
| 119 |
+
|
| 120 |
+
def training_step(self, batch, batch_idx):
|
| 121 |
+
output = self.forward(
|
| 122 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 126 |
+
loss = self.criterion(
|
| 127 |
+
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True)
|
| 131 |
+
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
|
| 132 |
+
|
| 133 |
+
return loss
|
| 134 |
+
|
| 135 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 136 |
+
dataloaders = self.trainer.val_dataloaders
|
| 137 |
+
slug = list(dataloaders.keys())[dataloader_idx]
|
| 138 |
+
|
| 139 |
+
output = self.forward(
|
| 140 |
+
batch.masked_kspace, batch.mask, batch.num_low_frequencies
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
target, output = transforms.center_crop_to_smallest(batch.target, output)
|
| 144 |
+
|
| 145 |
+
loss = self.criterion(
|
| 146 |
+
output.unsqueeze(1),
|
| 147 |
+
target.unsqueeze(1),
|
| 148 |
+
data_range=batch.max_value,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"slug": slug,
|
| 153 |
+
"fname": batch.fname,
|
| 154 |
+
"slice_num": batch.slice_num,
|
| 155 |
+
"max_value": batch.max_value,
|
| 156 |
+
"output": output,
|
| 157 |
+
"target": target,
|
| 158 |
+
"val_loss": loss,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def configure_optimizers(self):
|
| 162 |
+
optim = torch.optim.Adam(
|
| 163 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 164 |
+
)
|
| 165 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 166 |
+
optim, self.lr_step_size, self.lr_gamma
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return [optim], [scheduler]
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def add_model_specific_args(parent_parser):
|
| 173 |
+
"""
|
| 174 |
+
Define parameters that only apply to this model
|
| 175 |
+
"""
|
| 176 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 177 |
+
parser = MriModule.add_model_specific_args(parser)
|
| 178 |
+
|
| 179 |
+
# network params
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--num_cascades",
|
| 182 |
+
default=12,
|
| 183 |
+
type=int,
|
| 184 |
+
help="Number of VarNet cascades",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--pools",
|
| 188 |
+
default=4,
|
| 189 |
+
type=int,
|
| 190 |
+
help="Number of U-Net pooling layers in VarNet blocks",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--chans",
|
| 194 |
+
default=18,
|
| 195 |
+
type=int,
|
| 196 |
+
help="Number of channels for U-Net in VarNet blocks",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--sens_pools",
|
| 200 |
+
default=4,
|
| 201 |
+
type=int,
|
| 202 |
+
help=(
|
| 203 |
+
"Number of pooling layers for sense map estimation U-Net in" " VarNet"
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--sens_chans",
|
| 208 |
+
default=8,
|
| 209 |
+
type=float,
|
| 210 |
+
help="Number of channels for sense map estimation U-Net in VarNet",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--gno_pools",
|
| 214 |
+
default=4,
|
| 215 |
+
type=int,
|
| 216 |
+
help=("Number of pooling layers for GNO"),
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--gno_chans",
|
| 220 |
+
default=16,
|
| 221 |
+
type=int,
|
| 222 |
+
help="Number of channels for GNO",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--gno_radius_cutoff",
|
| 226 |
+
default=0.02,
|
| 227 |
+
type=float,
|
| 228 |
+
required=True,
|
| 229 |
+
help="GNO module radius_cutoff",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--gno_kernel_shape",
|
| 233 |
+
default=(6, 7),
|
| 234 |
+
type=tuple_type,
|
| 235 |
+
required=True,
|
| 236 |
+
help="GNO module kernel_shape. Ex: (6, 7)",
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--radius_cutoff",
|
| 240 |
+
default=0.01,
|
| 241 |
+
type=float,
|
| 242 |
+
required=True,
|
| 243 |
+
help="DISCO module radius_cutoff",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--kernel_shape",
|
| 247 |
+
default=(6, 7),
|
| 248 |
+
type=tuple_type,
|
| 249 |
+
required=True,
|
| 250 |
+
help="DISCO module kernel_shape. Ex: (6, 7)",
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--in_shape",
|
| 254 |
+
default=(640, 320),
|
| 255 |
+
type=tuple_type,
|
| 256 |
+
required=True,
|
| 257 |
+
help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--use_dc_term",
|
| 261 |
+
default=True,
|
| 262 |
+
type=bool,
|
| 263 |
+
help="Whether to use the DC term in the unrolled iterative update step",
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# training params (opt)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--lr", default=0.0003, type=float, help="Adam learning rate"
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--lr_step_size",
|
| 272 |
+
default=40,
|
| 273 |
+
type=int,
|
| 274 |
+
help="Epoch at which to decrease step size",
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--lr_gamma",
|
| 278 |
+
default=0.1,
|
| 279 |
+
type=float,
|
| 280 |
+
help="Extent to which step size should be decreased",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--weight_decay",
|
| 284 |
+
default=0.0,
|
| 285 |
+
type=float,
|
| 286 |
+
help="Strength of weight decay regularization",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--reduction_method",
|
| 290 |
+
default="rss",
|
| 291 |
+
type=str,
|
| 292 |
+
choices=["rss", "batch"],
|
| 293 |
+
help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of GNO for more information.",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--skip_method",
|
| 297 |
+
default="add_inv",
|
| 298 |
+
type=str,
|
| 299 |
+
choices=["add_inv", "add", "concat", "replace"],
|
| 300 |
+
help="Method for skip connection around inpainting module.",
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return parser
|
models/udno.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
U-shaped DISCO Neural Operator
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from torch_harmonics_local.convolution import (
|
| 12 |
+
EquidistantDiscreteContinuousConv2d as DISCO2d,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UDNO(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
U-shaped DISCO Neural Operator in PyTorch
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
in_chans: int,
|
| 24 |
+
out_chans: int,
|
| 25 |
+
radius_cutoff: float,
|
| 26 |
+
chans: int = 32,
|
| 27 |
+
num_pool_layers: int = 4,
|
| 28 |
+
drop_prob: float = 0.0,
|
| 29 |
+
in_shape: Tuple[int, int] = (320, 320),
|
| 30 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Parameters
|
| 34 |
+
----------
|
| 35 |
+
in_chans : int
|
| 36 |
+
Number of channels in the input to the U-Net model.
|
| 37 |
+
out_chans : int
|
| 38 |
+
Number of channels in the output to the U-Net model.
|
| 39 |
+
radius_cutoff : float
|
| 40 |
+
Control the effective radius of the DISCO kernel. Values are
|
| 41 |
+
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
|
| 42 |
+
of the normalized input space, to ensure that kernels are resolution
|
| 43 |
+
invaraint.
|
| 44 |
+
chans : int, optional
|
| 45 |
+
Number of output channels of the first DISCO layer. Default is 32.
|
| 46 |
+
num_pool_layers : int, optional
|
| 47 |
+
Number of down-sampling and up-sampling layers. Default is 4.
|
| 48 |
+
drop_prob : float, optional
|
| 49 |
+
Dropout probability. Default is 0.0.
|
| 50 |
+
in_shape : Tuple[int, int]
|
| 51 |
+
Shape of the input to the UDNO. This is required to dynamically
|
| 52 |
+
compile DISCO kernels for resolution invariance.
|
| 53 |
+
kernel_shape : Tuple[int, int], optional
|
| 54 |
+
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
|
| 55 |
+
rings and 4 anisotropic basis functions. Under the hood, each DISCO
|
| 56 |
+
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
|
| 57 |
+
3x3 convolution kernel.
|
| 58 |
+
|
| 59 |
+
Note: This is NOT kernel_size, as under the DISCO framework,
|
| 60 |
+
kernels are dynamically compiled to support resolution invariance.
|
| 61 |
+
"""
|
| 62 |
+
super().__init__()
|
| 63 |
+
assert len(in_shape) == 2, "Input shape must be 2D"
|
| 64 |
+
|
| 65 |
+
self.in_chans = in_chans
|
| 66 |
+
self.out_chans = out_chans
|
| 67 |
+
self.chans = chans
|
| 68 |
+
self.num_pool_layers = num_pool_layers
|
| 69 |
+
self.drop_prob = drop_prob
|
| 70 |
+
self.in_shape = in_shape
|
| 71 |
+
self.kernel_shape = kernel_shape
|
| 72 |
+
|
| 73 |
+
self.down_sample_layers = nn.ModuleList(
|
| 74 |
+
[
|
| 75 |
+
DISCOBlock(
|
| 76 |
+
in_chans,
|
| 77 |
+
chans,
|
| 78 |
+
radius_cutoff,
|
| 79 |
+
drop_prob,
|
| 80 |
+
in_shape,
|
| 81 |
+
kernel_shape,
|
| 82 |
+
)
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
ch = chans
|
| 86 |
+
shape = (in_shape[0] // 2, in_shape[1] // 2)
|
| 87 |
+
radius_cutoff = radius_cutoff * 2
|
| 88 |
+
for _ in range(num_pool_layers - 1):
|
| 89 |
+
self.down_sample_layers.append(
|
| 90 |
+
DISCOBlock(
|
| 91 |
+
ch,
|
| 92 |
+
ch * 2,
|
| 93 |
+
radius_cutoff,
|
| 94 |
+
drop_prob,
|
| 95 |
+
in_shape=shape,
|
| 96 |
+
kernel_shape=kernel_shape,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
ch *= 2
|
| 100 |
+
shape = (shape[0] // 2, shape[1] // 2)
|
| 101 |
+
radius_cutoff *= 2
|
| 102 |
+
|
| 103 |
+
# test commit
|
| 104 |
+
|
| 105 |
+
self.bottleneck = DISCOBlock(
|
| 106 |
+
ch,
|
| 107 |
+
ch * 2,
|
| 108 |
+
radius_cutoff,
|
| 109 |
+
drop_prob,
|
| 110 |
+
in_shape=shape,
|
| 111 |
+
kernel_shape=kernel_shape,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.up = nn.ModuleList()
|
| 115 |
+
self.up_transpose = nn.ModuleList()
|
| 116 |
+
for _ in range(num_pool_layers - 1):
|
| 117 |
+
self.up_transpose.append(
|
| 118 |
+
TransposeDISCOBlock(
|
| 119 |
+
ch * 2,
|
| 120 |
+
ch,
|
| 121 |
+
radius_cutoff,
|
| 122 |
+
in_shape=shape,
|
| 123 |
+
kernel_shape=kernel_shape,
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
shape = (shape[0] * 2, shape[1] * 2)
|
| 127 |
+
radius_cutoff /= 2
|
| 128 |
+
self.up.append(
|
| 129 |
+
DISCOBlock(
|
| 130 |
+
ch * 2,
|
| 131 |
+
ch,
|
| 132 |
+
radius_cutoff,
|
| 133 |
+
drop_prob,
|
| 134 |
+
in_shape=shape,
|
| 135 |
+
kernel_shape=kernel_shape,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
ch //= 2
|
| 139 |
+
|
| 140 |
+
self.up_transpose.append(
|
| 141 |
+
TransposeDISCOBlock(
|
| 142 |
+
ch * 2,
|
| 143 |
+
ch,
|
| 144 |
+
radius_cutoff,
|
| 145 |
+
in_shape=shape,
|
| 146 |
+
kernel_shape=kernel_shape,
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
shape = (shape[0] * 2, shape[1] * 2)
|
| 150 |
+
radius_cutoff /= 2
|
| 151 |
+
self.up.append(
|
| 152 |
+
nn.Sequential(
|
| 153 |
+
DISCOBlock(
|
| 154 |
+
ch * 2,
|
| 155 |
+
ch,
|
| 156 |
+
radius_cutoff,
|
| 157 |
+
drop_prob,
|
| 158 |
+
in_shape=shape,
|
| 159 |
+
kernel_shape=kernel_shape,
|
| 160 |
+
),
|
| 161 |
+
nn.Conv2d(
|
| 162 |
+
ch, self.out_chans, kernel_size=1, stride=1
|
| 163 |
+
), # 1x1 conv is always res-invariant (pixel wise channel transformation)
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
image : torch.Tensor
|
| 172 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 173 |
+
|
| 174 |
+
Returns
|
| 175 |
+
-------
|
| 176 |
+
torch.Tensor
|
| 177 |
+
Output tensor of shape `(N, out_chans, H, W)`.
|
| 178 |
+
"""
|
| 179 |
+
stack = []
|
| 180 |
+
output = image
|
| 181 |
+
|
| 182 |
+
# apply down-sampling layers
|
| 183 |
+
for layer in self.down_sample_layers:
|
| 184 |
+
output = layer(output)
|
| 185 |
+
stack.append(output)
|
| 186 |
+
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
|
| 187 |
+
|
| 188 |
+
output = self.bottleneck(output)
|
| 189 |
+
|
| 190 |
+
# apply up-sampling layers
|
| 191 |
+
for transpose, disco in zip(self.up_transpose, self.up):
|
| 192 |
+
downsample_layer = stack.pop()
|
| 193 |
+
output = transpose(output)
|
| 194 |
+
|
| 195 |
+
# reflect pad on the right/botton if needed to handle odd input dimensions
|
| 196 |
+
padding = [0, 0, 0, 0]
|
| 197 |
+
if output.shape[-1] != downsample_layer.shape[-1]:
|
| 198 |
+
padding[1] = 1 # padding right
|
| 199 |
+
if output.shape[-2] != downsample_layer.shape[-2]:
|
| 200 |
+
padding[3] = 1 # padding bottom
|
| 201 |
+
if torch.sum(torch.tensor(padding)) != 0:
|
| 202 |
+
output = F.pad(output, padding, "reflect")
|
| 203 |
+
|
| 204 |
+
output = torch.cat([output, downsample_layer], dim=1)
|
| 205 |
+
output = disco(output)
|
| 206 |
+
|
| 207 |
+
return output
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DISCOBlock(nn.Module):
|
| 211 |
+
"""
|
| 212 |
+
A DISCO Block that consists of two DISCO layers each followed by
|
| 213 |
+
instance normalization, LeakyReLU activation and dropout.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
in_chans: int,
|
| 219 |
+
out_chans: int,
|
| 220 |
+
radius_cutoff: float,
|
| 221 |
+
drop_prob: float,
|
| 222 |
+
in_shape: Tuple[int, int],
|
| 223 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Parameters
|
| 227 |
+
----------
|
| 228 |
+
in_chans : int
|
| 229 |
+
Number of channels in the input.
|
| 230 |
+
out_chans : int
|
| 231 |
+
Number of channels in the output.
|
| 232 |
+
radius_cutoff : float
|
| 233 |
+
Control the effective radius of the DISCO kernel. Values are
|
| 234 |
+
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
|
| 235 |
+
of the normalized input space, to ensure that kernels are resolution
|
| 236 |
+
invaraint.
|
| 237 |
+
in_shape : Tuple[int]
|
| 238 |
+
Unbatched spatial 2D shape of the input to this block.
|
| 239 |
+
Rrequired to dynamically compile DISCO kernels for resolution invariance.
|
| 240 |
+
kernel_shape : Tuple[int, int], optional
|
| 241 |
+
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
|
| 242 |
+
rings and 4 anisotropic basis functions. Under the hood, each DISCO
|
| 243 |
+
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
|
| 244 |
+
3x3 convolution kernel.
|
| 245 |
+
|
| 246 |
+
Note: This is NOT kernel_size, as under the DISCO framework,
|
| 247 |
+
kernels are dynamically compiled to support resolution invariance.
|
| 248 |
+
drop_prob : float
|
| 249 |
+
Dropout probability.
|
| 250 |
+
"""
|
| 251 |
+
super().__init__()
|
| 252 |
+
|
| 253 |
+
self.in_chans = in_chans
|
| 254 |
+
self.out_chans = out_chans
|
| 255 |
+
self.drop_prob = drop_prob
|
| 256 |
+
|
| 257 |
+
self.layers = nn.Sequential(
|
| 258 |
+
DISCO2d(
|
| 259 |
+
in_chans,
|
| 260 |
+
out_chans,
|
| 261 |
+
kernel_shape=kernel_shape,
|
| 262 |
+
in_shape=in_shape,
|
| 263 |
+
bias=False,
|
| 264 |
+
radius_cutoff=radius_cutoff,
|
| 265 |
+
padding_mode="constant",
|
| 266 |
+
),
|
| 267 |
+
nn.InstanceNorm2d(out_chans),
|
| 268 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 269 |
+
nn.Dropout2d(drop_prob),
|
| 270 |
+
DISCO2d(
|
| 271 |
+
out_chans,
|
| 272 |
+
out_chans,
|
| 273 |
+
kernel_shape=kernel_shape,
|
| 274 |
+
in_shape=in_shape,
|
| 275 |
+
bias=False,
|
| 276 |
+
radius_cutoff=radius_cutoff,
|
| 277 |
+
padding_mode="constant",
|
| 278 |
+
),
|
| 279 |
+
nn.InstanceNorm2d(out_chans),
|
| 280 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 281 |
+
nn.Dropout2d(drop_prob),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 285 |
+
"""
|
| 286 |
+
Parameters
|
| 287 |
+
----------
|
| 288 |
+
image : ndarray
|
| 289 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 290 |
+
|
| 291 |
+
Returns
|
| 292 |
+
-------
|
| 293 |
+
ndarray
|
| 294 |
+
Output tensor of shape `(N, out_chans, H, W)`.
|
| 295 |
+
"""
|
| 296 |
+
return self.layers(image)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class TransposeDISCOBlock(nn.Module):
|
| 300 |
+
"""
|
| 301 |
+
A transpose DISCO Block that consists of an up-sampling layer followed by a
|
| 302 |
+
DISCO layer, instance normalization, and LeakyReLU activation.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
in_chans: int,
|
| 308 |
+
out_chans: int,
|
| 309 |
+
radius_cutoff: float,
|
| 310 |
+
in_shape: Tuple[int, int],
|
| 311 |
+
kernel_shape: Tuple[int, int] = (3, 4),
|
| 312 |
+
):
|
| 313 |
+
"""
|
| 314 |
+
Parameters
|
| 315 |
+
----------
|
| 316 |
+
in_chans : int
|
| 317 |
+
Number of channels in the input.
|
| 318 |
+
out_chans : int
|
| 319 |
+
Number of channels in the output.
|
| 320 |
+
radius_cutoff : float
|
| 321 |
+
Control the effective radius of the DISCO kernel. Values are
|
| 322 |
+
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
|
| 323 |
+
of the normalized input space, to ensure that kernels are resolution
|
| 324 |
+
invaraint.
|
| 325 |
+
in_shape : Tuple[int]
|
| 326 |
+
Unbatched spatial 2D shape of the input to this block.
|
| 327 |
+
Rrequired to dynamically compile DISCO kernels for resolution invariance.
|
| 328 |
+
kernel_shape : Tuple[int, int], optional
|
| 329 |
+
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
|
| 330 |
+
rings and 4 anisotropic basis functions. Under the hood, each DISCO
|
| 331 |
+
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
|
| 332 |
+
3x3 convolution kernel.
|
| 333 |
+
|
| 334 |
+
Note: This is NOT kernel_size, as under the DISCO framework,
|
| 335 |
+
kernels are dynamically compiled to support resolution invariance
|
| 336 |
+
"""
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
self.in_chans = in_chans
|
| 340 |
+
self.out_chans = out_chans
|
| 341 |
+
|
| 342 |
+
self.layers = nn.Sequential(
|
| 343 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 344 |
+
DISCO2d(
|
| 345 |
+
in_chans,
|
| 346 |
+
out_chans,
|
| 347 |
+
kernel_shape=kernel_shape,
|
| 348 |
+
in_shape=(2 * in_shape[0], 2 * in_shape[1]),
|
| 349 |
+
bias=False,
|
| 350 |
+
radius_cutoff=(radius_cutoff / 2),
|
| 351 |
+
padding_mode="constant",
|
| 352 |
+
),
|
| 353 |
+
nn.InstanceNorm2d(out_chans),
|
| 354 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 358 |
+
"""
|
| 359 |
+
Parameters
|
| 360 |
+
----------
|
| 361 |
+
image : torch.Tensor
|
| 362 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 363 |
+
|
| 364 |
+
Returns
|
| 365 |
+
-------
|
| 366 |
+
torch.Tensor
|
| 367 |
+
Output tensor of shape `(N, out_chans, H*2, W*2)`.
|
| 368 |
+
"""
|
| 369 |
+
return self.layers(image)
|
models/unet.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Unet(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
PyTorch implementation of a U-Net model.
|
| 18 |
+
|
| 19 |
+
O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
|
| 20 |
+
for biomedical image segmentation. In International Conference on Medical
|
| 21 |
+
image computing and computer-assisted intervention, pages 234–241.
|
| 22 |
+
Springer, 2015.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
in_chans: int,
|
| 28 |
+
out_chans: int,
|
| 29 |
+
chans: int = 32,
|
| 30 |
+
num_pool_layers: int = 4,
|
| 31 |
+
drop_prob: float = 0.0,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
in_chans : int
|
| 37 |
+
Number of channels in the input to the U-Net model.
|
| 38 |
+
out_chans : int
|
| 39 |
+
Number of channels in the output to the U-Net model.
|
| 40 |
+
chans : int, optional
|
| 41 |
+
Number of output channels of the first convolution layer. Default is 32.
|
| 42 |
+
num_pool_layers : int, optional
|
| 43 |
+
Number of down-sampling and up-sampling layers. Default is 4.
|
| 44 |
+
drop_prob : float, optional
|
| 45 |
+
Dropout probability. Default is 0.0.
|
| 46 |
+
"""
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.in_chans = in_chans
|
| 50 |
+
self.out_chans = out_chans
|
| 51 |
+
self.chans = chans
|
| 52 |
+
self.num_pool_layers = num_pool_layers
|
| 53 |
+
self.drop_prob = drop_prob
|
| 54 |
+
|
| 55 |
+
self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
|
| 56 |
+
ch = chans
|
| 57 |
+
for _ in range(num_pool_layers - 1):
|
| 58 |
+
self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
|
| 59 |
+
ch *= 2
|
| 60 |
+
self.conv = ConvBlock(ch, ch * 2, drop_prob)
|
| 61 |
+
|
| 62 |
+
self.up_conv = nn.ModuleList()
|
| 63 |
+
self.up_transpose_conv = nn.ModuleList()
|
| 64 |
+
for _ in range(num_pool_layers - 1):
|
| 65 |
+
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
|
| 66 |
+
self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
|
| 67 |
+
ch //= 2
|
| 68 |
+
|
| 69 |
+
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
|
| 70 |
+
self.up_conv.append(
|
| 71 |
+
nn.Sequential(
|
| 72 |
+
ConvBlock(ch * 2, ch, drop_prob),
|
| 73 |
+
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Parameters
|
| 80 |
+
----------
|
| 81 |
+
image : torch.Tensor
|
| 82 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
torch.Tensor
|
| 87 |
+
Output tensor of shape `(N, out_chans, H, W)`.
|
| 88 |
+
"""
|
| 89 |
+
stack = []
|
| 90 |
+
output = image
|
| 91 |
+
|
| 92 |
+
# apply down-sampling layers
|
| 93 |
+
for layer in self.down_sample_layers:
|
| 94 |
+
output = layer(output)
|
| 95 |
+
stack.append(output)
|
| 96 |
+
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
|
| 97 |
+
|
| 98 |
+
output = self.conv(output)
|
| 99 |
+
|
| 100 |
+
# apply up-sampling layers
|
| 101 |
+
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
|
| 102 |
+
downsample_layer = stack.pop()
|
| 103 |
+
output = transpose_conv(output)
|
| 104 |
+
|
| 105 |
+
# reflect pad on the right/botton if needed to handle odd input dimensions
|
| 106 |
+
padding = [0, 0, 0, 0]
|
| 107 |
+
if output.shape[-1] != downsample_layer.shape[-1]:
|
| 108 |
+
padding[1] = 1 # padding right
|
| 109 |
+
if output.shape[-2] != downsample_layer.shape[-2]:
|
| 110 |
+
padding[3] = 1 # padding bottom
|
| 111 |
+
if torch.sum(torch.tensor(padding)) != 0:
|
| 112 |
+
output = F.pad(output, padding, "reflect")
|
| 113 |
+
|
| 114 |
+
output = torch.cat([output, downsample_layer], dim=1)
|
| 115 |
+
output = conv(output)
|
| 116 |
+
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ConvBlock(nn.Module):
|
| 121 |
+
"""
|
| 122 |
+
A Convolutional Block that consists of two convolution layers each followed by
|
| 123 |
+
instance normalization, LeakyReLU activation and dropout.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, in_chans: int, out_chans: int, drop_prob: float):
|
| 127 |
+
"""
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
in_chans : int
|
| 131 |
+
Number of channels in the input.
|
| 132 |
+
out_chans : int
|
| 133 |
+
Number of channels in the output.
|
| 134 |
+
drop_prob : float
|
| 135 |
+
Dropout probability.
|
| 136 |
+
"""
|
| 137 |
+
super().__init__()
|
| 138 |
+
|
| 139 |
+
self.in_chans = in_chans
|
| 140 |
+
self.out_chans = out_chans
|
| 141 |
+
self.drop_prob = drop_prob
|
| 142 |
+
|
| 143 |
+
self.layers = nn.Sequential(
|
| 144 |
+
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
|
| 145 |
+
nn.InstanceNorm2d(out_chans),
|
| 146 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 147 |
+
nn.Dropout2d(drop_prob),
|
| 148 |
+
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
|
| 149 |
+
nn.InstanceNorm2d(out_chans),
|
| 150 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 151 |
+
nn.Dropout2d(drop_prob),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Parameters
|
| 157 |
+
----------
|
| 158 |
+
image : ndarray
|
| 159 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 160 |
+
|
| 161 |
+
Returns
|
| 162 |
+
-------
|
| 163 |
+
ndarray
|
| 164 |
+
Output tensor of shape `(N, out_chans, H, W)`.
|
| 165 |
+
"""
|
| 166 |
+
return self.layers(image)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TransposeConvBlock(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
A Transpose Convolutional Block that consists of one convolution transpose
|
| 172 |
+
layers followed by instance normalization and LeakyReLU activation.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, in_chans: int, out_chans: int):
|
| 176 |
+
"""
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
in_chans : int
|
| 180 |
+
Number of channels in the input.
|
| 181 |
+
out_chans : int
|
| 182 |
+
Number of channels in the output.
|
| 183 |
+
"""
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
self.in_chans = in_chans
|
| 187 |
+
self.out_chans = out_chans
|
| 188 |
+
|
| 189 |
+
self.layers = nn.Sequential(
|
| 190 |
+
nn.ConvTranspose2d(
|
| 191 |
+
in_chans, out_chans, kernel_size=2, stride=2, bias=False
|
| 192 |
+
),
|
| 193 |
+
nn.InstanceNorm2d(out_chans),
|
| 194 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
"""
|
| 199 |
+
Parameters
|
| 200 |
+
----------
|
| 201 |
+
image : torch.Tensor
|
| 202 |
+
Input 4D tensor of shape `(N, in_chans, H, W)`.
|
| 203 |
+
|
| 204 |
+
Returns
|
| 205 |
+
-------
|
| 206 |
+
torch.Tensor
|
| 207 |
+
Output tensor of shape `(N, out_chans, H*2, W*2)`.
|
| 208 |
+
"""
|
| 209 |
+
return self.layers(image)
|
models/varnet.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
This source code is licensed under the MIT license found in the
|
| 5 |
+
LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
from typing import List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
import fastmri
|
| 17 |
+
from fastmri import transforms
|
| 18 |
+
from models.unet import Unet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class NormUnet(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Normalized U-Net model.
|
| 24 |
+
|
| 25 |
+
This is the same as a regular U-Net, but with normalization applied to the
|
| 26 |
+
input before the U-Net. This keeps the values more numerically stable
|
| 27 |
+
during training.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
chans: int,
|
| 33 |
+
num_pools: int,
|
| 34 |
+
in_chans: int = 2,
|
| 35 |
+
out_chans: int = 2,
|
| 36 |
+
drop_prob: float = 0.0,
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
Initialize the VarNet model.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
chans : int
|
| 45 |
+
Number of output channels of the first convolution layer.
|
| 46 |
+
num_pools : int
|
| 47 |
+
Number of down-sampling and up-sampling layers.
|
| 48 |
+
in_chans : int, optional
|
| 49 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 50 |
+
out_chans : int, optional
|
| 51 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 52 |
+
drop_prob : float, optional
|
| 53 |
+
Dropout probability. Default is 0.0.
|
| 54 |
+
"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.unet = Unet(
|
| 58 |
+
in_chans=in_chans,
|
| 59 |
+
out_chans=out_chans,
|
| 60 |
+
chans=chans,
|
| 61 |
+
num_pool_layers=num_pools,
|
| 62 |
+
drop_prob=drop_prob,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
b, c, h, w, two = x.shape
|
| 67 |
+
assert two == 2
|
| 68 |
+
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
|
| 69 |
+
|
| 70 |
+
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
b, c2, h, w = x.shape
|
| 72 |
+
assert c2 % 2 == 0
|
| 73 |
+
c = c2 // 2
|
| 74 |
+
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
|
| 75 |
+
|
| 76 |
+
def norm(
|
| 77 |
+
self, x: torch.Tensor
|
| 78 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 79 |
+
# group norm
|
| 80 |
+
b, c, h, w = x.shape
|
| 81 |
+
x = x.view(b, 2, c // 2 * h * w)
|
| 82 |
+
|
| 83 |
+
mean = x.mean(dim=2).view(b, 2, 1, 1)
|
| 84 |
+
std = x.std(dim=2).view(b, 2, 1, 1)
|
| 85 |
+
|
| 86 |
+
x = x.view(b, c, h, w)
|
| 87 |
+
|
| 88 |
+
return (x - mean) / std, mean, std
|
| 89 |
+
|
| 90 |
+
def unnorm(
|
| 91 |
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
return x * std + mean
|
| 94 |
+
|
| 95 |
+
def pad(
|
| 96 |
+
self, x: torch.Tensor
|
| 97 |
+
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
|
| 98 |
+
_, _, h, w = x.shape
|
| 99 |
+
w_mult = ((w - 1) | 15) + 1
|
| 100 |
+
h_mult = ((h - 1) | 15) + 1
|
| 101 |
+
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
|
| 102 |
+
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
|
| 103 |
+
# TODO: fix this type when PyTorch fixes theirs
|
| 104 |
+
# the documentation lies - this actually takes a list
|
| 105 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
|
| 106 |
+
# https://github.com/pytorch/pytorch/pull/16949
|
| 107 |
+
x = F.pad(x, w_pad + h_pad)
|
| 108 |
+
|
| 109 |
+
return x, (h_pad, w_pad, h_mult, w_mult)
|
| 110 |
+
|
| 111 |
+
def unpad(
|
| 112 |
+
self,
|
| 113 |
+
x: torch.Tensor,
|
| 114 |
+
h_pad: List[int],
|
| 115 |
+
w_pad: List[int],
|
| 116 |
+
h_mult: int,
|
| 117 |
+
w_mult: int,
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
return x[
|
| 120 |
+
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
if not x.shape[-1] == 2:
|
| 125 |
+
raise ValueError("Last dimension must be 2 for complex.")
|
| 126 |
+
|
| 127 |
+
# get shapes for unet and normalize
|
| 128 |
+
x = self.complex_to_chan_dim(x)
|
| 129 |
+
x, mean, std = self.norm(x)
|
| 130 |
+
x, pad_sizes = self.pad(x)
|
| 131 |
+
|
| 132 |
+
x = self.unet(x)
|
| 133 |
+
|
| 134 |
+
# get shapes back and unnormalize
|
| 135 |
+
x = self.unpad(x, *pad_sizes)
|
| 136 |
+
x = self.unnorm(x, mean, std)
|
| 137 |
+
x = self.chan_complex_to_last_dim(x)
|
| 138 |
+
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SensitivityModel(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Model for learning sensitivity estimation from k-space data.
|
| 145 |
+
|
| 146 |
+
This model applies an IFFT to multichannel k-space data and then a U-Net
|
| 147 |
+
to the coil images to estimate coil sensitivities. It can be used with the
|
| 148 |
+
end-to-end variational network.
|
| 149 |
+
|
| 150 |
+
Input: multi-coil k-space data
|
| 151 |
+
Output: multi-coil spatial domain sensitivity maps
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
chans: int,
|
| 157 |
+
num_pools: int,
|
| 158 |
+
in_chans: int = 2,
|
| 159 |
+
out_chans: int = 2,
|
| 160 |
+
drop_prob: float = 0.0,
|
| 161 |
+
mask_center: bool = True,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
chans : int
|
| 167 |
+
Number of output channels of the first convolution layer.
|
| 168 |
+
num_pools : int
|
| 169 |
+
Number of down-sampling and up-sampling layers.
|
| 170 |
+
in_chans : int, optional
|
| 171 |
+
Number of channels in the input to the U-Net model. Default is 2.
|
| 172 |
+
out_chans : int, optional
|
| 173 |
+
Number of channels in the output to the U-Net model. Default is 2.
|
| 174 |
+
drop_prob : float, optional
|
| 175 |
+
Dropout probability. Default is 0.0.
|
| 176 |
+
mask_center : bool, optional
|
| 177 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 178 |
+
Default is True.
|
| 179 |
+
"""
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.mask_center = mask_center
|
| 182 |
+
self.norm_unet = NormUnet(
|
| 183 |
+
chans,
|
| 184 |
+
num_pools,
|
| 185 |
+
in_chans=in_chans,
|
| 186 |
+
out_chans=out_chans,
|
| 187 |
+
drop_prob=drop_prob,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 191 |
+
b, c, h, w, comp = x.shape
|
| 192 |
+
|
| 193 |
+
return x.view(b * c, 1, h, w, comp), b
|
| 194 |
+
|
| 195 |
+
def batch_chans_to_chan_dim(
|
| 196 |
+
self,
|
| 197 |
+
x: torch.Tensor,
|
| 198 |
+
batch_size: int,
|
| 199 |
+
) -> torch.Tensor:
|
| 200 |
+
bc, _, h, w, comp = x.shape
|
| 201 |
+
c = bc // batch_size
|
| 202 |
+
|
| 203 |
+
return x.view(batch_size, c, h, w, comp)
|
| 204 |
+
|
| 205 |
+
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
|
| 206 |
+
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
|
| 207 |
+
|
| 208 |
+
def get_pad_and_num_low_freqs(
|
| 209 |
+
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
|
| 210 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 211 |
+
if num_low_frequencies is None or any(
|
| 212 |
+
torch.any(t == 0) for t in num_low_frequencies
|
| 213 |
+
):
|
| 214 |
+
# get low frequency line locations and mask them out
|
| 215 |
+
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
|
| 216 |
+
cent = squeezed_mask.shape[1] // 2
|
| 217 |
+
# running argmin returns the first non-zero
|
| 218 |
+
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
|
| 219 |
+
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
|
| 220 |
+
num_low_frequencies_tensor = torch.max(
|
| 221 |
+
2 * torch.min(left, right), torch.ones_like(left)
|
| 222 |
+
) # force a symmetric center unless 1
|
| 223 |
+
else:
|
| 224 |
+
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
|
| 225 |
+
mask.shape[0], dtype=mask.dtype, device=mask.device
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
|
| 229 |
+
|
| 230 |
+
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
masked_kspace: torch.Tensor,
|
| 235 |
+
mask: torch.Tensor,
|
| 236 |
+
num_low_frequencies: Optional[int] = None,
|
| 237 |
+
) -> torch.Tensor:
|
| 238 |
+
if self.mask_center:
|
| 239 |
+
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
|
| 240 |
+
mask, num_low_frequencies
|
| 241 |
+
)
|
| 242 |
+
masked_kspace = transforms.batched_mask_center(
|
| 243 |
+
masked_kspace, pad, pad + num_low_freqs
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# convert to image space
|
| 247 |
+
images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
|
| 248 |
+
|
| 249 |
+
# estimate sensitivities
|
| 250 |
+
return self.divide_root_sum_of_squares(
|
| 251 |
+
self.batch_chans_to_chan_dim(self.norm_unet(images), batches)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class VarNet(nn.Module):
|
| 256 |
+
"""
|
| 257 |
+
A full variational network model.
|
| 258 |
+
|
| 259 |
+
This model applies a combination of soft data consistency with a U-Net
|
| 260 |
+
regularizer. To use non-U-Net regularizers, use VarNetBlock.
|
| 261 |
+
|
| 262 |
+
Input: multi-channel k-space data
|
| 263 |
+
Output: single-channel RSS reconstructed image
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
num_cascades: int = 12,
|
| 269 |
+
sens_chans: int = 8,
|
| 270 |
+
sens_pools: int = 4,
|
| 271 |
+
chans: int = 18,
|
| 272 |
+
pools: int = 4,
|
| 273 |
+
mask_center: bool = True,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Parameters
|
| 277 |
+
----------
|
| 278 |
+
num_cascades : int
|
| 279 |
+
Number of cascades (i.e., layers) for variational network.
|
| 280 |
+
sens_chans : int
|
| 281 |
+
Number of channels for sensitivity map U-Net.
|
| 282 |
+
sens_pools : int
|
| 283 |
+
Number of downsampling and upsampling layers for sensitivity map U-Net.
|
| 284 |
+
chans : int
|
| 285 |
+
Number of channels for cascade U-Net.
|
| 286 |
+
pools : int
|
| 287 |
+
Number of downsampling and upsampling layers for cascade U-Net.
|
| 288 |
+
mask_center : bool
|
| 289 |
+
Whether to mask center of k-space for sensitivity map calculation.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.sens_net = SensitivityModel(
|
| 295 |
+
chans=sens_chans,
|
| 296 |
+
num_pools=sens_pools,
|
| 297 |
+
mask_center=mask_center,
|
| 298 |
+
)
|
| 299 |
+
self.cascades = nn.ModuleList(
|
| 300 |
+
[VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)]
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def forward(
|
| 304 |
+
self,
|
| 305 |
+
masked_kspace: torch.Tensor,
|
| 306 |
+
mask: torch.Tensor,
|
| 307 |
+
num_low_frequencies: Optional[int] = None,
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
|
| 310 |
+
kspace_pred = masked_kspace.clone()
|
| 311 |
+
for cascade in self.cascades:
|
| 312 |
+
kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
|
| 313 |
+
|
| 314 |
+
spatial_pred = fastmri.ifft2c(kspace_pred)
|
| 315 |
+
|
| 316 |
+
# ---------> FIXME: CHANGE FOR MVUE MODE
|
| 317 |
+
if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]:
|
| 318 |
+
combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1)
|
| 319 |
+
else:
|
| 320 |
+
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
|
| 321 |
+
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
|
| 322 |
+
return combined_spatial
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class VarNetBlock(nn.Module):
|
| 326 |
+
"""
|
| 327 |
+
Model block for end-to-end variational network (refinemnt module)
|
| 328 |
+
|
| 329 |
+
This model applies a combination of soft data consistency with the input
|
| 330 |
+
model as a regularizer. A series of these blocks can be stacked to form
|
| 331 |
+
the full variational network.
|
| 332 |
+
|
| 333 |
+
Input: multi-channel k-space data
|
| 334 |
+
Output: multi-channel k-space data
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(self, model: nn.Module):
|
| 338 |
+
"""
|
| 339 |
+
Parameters
|
| 340 |
+
----------
|
| 341 |
+
model : nn.Module
|
| 342 |
+
Module for "regularization" component of variational network.
|
| 343 |
+
"""
|
| 344 |
+
super().__init__()
|
| 345 |
+
|
| 346 |
+
self.model = model
|
| 347 |
+
self.dc_weight = nn.Parameter(torch.ones(1))
|
| 348 |
+
|
| 349 |
+
def sens_expand(
|
| 350 |
+
self, x: torch.Tensor, sens_maps: torch.Tensor
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
"""
|
| 353 |
+
Calculates F (x sens_maps)
|
| 354 |
+
"""
|
| 355 |
+
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
|
| 356 |
+
|
| 357 |
+
def sens_reduce(
|
| 358 |
+
self, x: torch.Tensor, sens_maps: torch.Tensor
|
| 359 |
+
) -> torch.Tensor:
|
| 360 |
+
"""
|
| 361 |
+
Calculates F^{-1}(x) \overline{sens_maps}
|
| 362 |
+
where \overline{sens_maps} is the element-wise applied complex conjugate
|
| 363 |
+
"""
|
| 364 |
+
return fastmri.complex_mul(
|
| 365 |
+
fastmri.ifft2c(x), fastmri.complex_conj(sens_maps)
|
| 366 |
+
).sum(dim=1, keepdim=True)
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
current_kspace: torch.Tensor,
|
| 371 |
+
ref_kspace: torch.Tensor,
|
| 372 |
+
mask: torch.Tensor,
|
| 373 |
+
sens_maps: torch.Tensor,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
"""
|
| 376 |
+
Parameters
|
| 377 |
+
----------
|
| 378 |
+
current_kspace : torch.Tensor
|
| 379 |
+
The current k-space data (frequency domain data) being processed by the network.
|
| 380 |
+
ref_kspace : torch.Tensor
|
| 381 |
+
The reference k-space data (measured data) used for data consistency.
|
| 382 |
+
mask : torch.Tensor
|
| 383 |
+
A binary mask indicating the locations in k-space where data consistency should be enforced.
|
| 384 |
+
sens_maps : torch.Tensor
|
| 385 |
+
Sensitivity maps for the different coils in parallel imaging.
|
| 386 |
+
|
| 387 |
+
Returns
|
| 388 |
+
-------
|
| 389 |
+
torch.Tensor
|
| 390 |
+
The output k-space data after applying the variational network block.
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
"""
|
| 394 |
+
Model term:
|
| 395 |
+
- Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation).
|
| 396 |
+
- Applies the neural network model to the reduced data.
|
| 397 |
+
- Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform).
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
model_term = self.sens_expand(
|
| 401 |
+
self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
"""
|
| 405 |
+
Soft data consistency term:
|
| 406 |
+
- Calculates the difference between current k-space and reference k-space where the mask is true.
|
| 407 |
+
- Multiplies this difference by the data consistency weight.
|
| 408 |
+
"""
|
| 409 |
+
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
|
| 410 |
+
soft_dc = (
|
| 411 |
+
torch.where(mask, current_kspace - ref_kspace, zero)
|
| 412 |
+
* self.dc_weight
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# with data consistency term (removed for single cascade experiments)
|
| 416 |
+
return current_kspace - soft_dc - model_term
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=64.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "no-med"
|
| 7 |
+
version = "0.0.0"
|
| 8 |
+
description = "Neural operators for medical imaging."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
# readme-content-type = "text/markdown"
|
| 11 |
+
authors = [{ name = "Armeet Singh Jatyani", email = "[email protected]" }]
|
| 12 |
+
license = { text = "MIT" }
|
| 13 |
+
keywords = ["medical imaging", "neural operators", "AI"]
|
| 14 |
+
classifiers = [
|
| 15 |
+
"Programming Language :: Python :: 3",
|
| 16 |
+
"License :: OSI Approved :: MIT License",
|
| 17 |
+
"Operating System :: OS Independent",
|
| 18 |
+
]
|
| 19 |
+
requires-python = ">=3.6"
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
dev = ["pytest>=6.0", "black", "flake8", "pdoc3"]
|
| 23 |
+
|
| 24 |
+
[tool.setuptools]
|
| 25 |
+
include-package-data = true
|
| 26 |
+
packages = { find = {} } # auto find packages
|
| 27 |
+
|
| 28 |
+
[tool.black]
|
| 29 |
+
line-length = 80 # Default is 88, but you can set it to 100 or 120 if needed
|
| 30 |
+
target-version = ['py38', 'py39', 'py310'] # Set target Python versions
|
| 31 |
+
include = '\.pyi?$' # Format only .py and .pyi files
|
| 32 |
+
skip-string-normalization = true
|
| 33 |
+
exclude = '''
|
| 34 |
+
/(
|
| 35 |
+
\.eggs # Exclude files generated by packaging
|
| 36 |
+
| \.git # Exclude version control files
|
| 37 |
+
| \.mypy_cache # Exclude mypy caches
|
| 38 |
+
| \.tox # Exclude tox environments
|
| 39 |
+
| \.venv # Exclude virtual environments
|
| 40 |
+
)/
|
| 41 |
+
'''
|
| 42 |
+
fast = true
|
| 43 |
+
|
| 44 |
+
[tool.isort]
|
| 45 |
+
profile = "black"
|
| 46 |
+
line_length = 80
|
pytest.ini
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
markers =
|
| 3 |
+
train: marks training tests (long runtime) (deselect with '-m "not train"')
|
| 4 |
+
addopts = -m "not train"
|
setup_config.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
config_filename = "fastmri.yaml"
|
| 4 |
+
default_config_content = """brain_path: brain path here
|
| 5 |
+
knee_path: knee path here
|
| 6 |
+
log_path: log path here
|
| 7 |
+
checkpoint_path: checkpoint path here
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def check_and_create_config():
|
| 12 |
+
if not os.path.exists(config_filename):
|
| 13 |
+
print(f"{config_filename} not found. Creating with default template...")
|
| 14 |
+
with open(config_filename, "w") as config_file:
|
| 15 |
+
config_file.write(default_config_content)
|
| 16 |
+
print(f"Default configuration file created at {config_filename}.")
|
| 17 |
+
else:
|
| 18 |
+
print(f"{config_filename} already exists. No changes made.")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
check_and_create_config()
|
torch_harmonics_local/__init__.py
ADDED
|
File without changes
|
torch_harmonics_local/_disco_convolution.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
|
| 32 |
+
import math
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
# triton will only be avaiable on cuda installations of pytorch
|
| 37 |
+
import triton
|
| 38 |
+
import triton.language as tl
|
| 39 |
+
|
| 40 |
+
BLOCK_SIZE_BATCH = 4
|
| 41 |
+
BLOCK_SIZE_NZ = 8
|
| 42 |
+
BLOCK_SIZE_POUT = 8
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@triton.jit
|
| 46 |
+
def _disco_s2_contraction_kernel(
|
| 47 |
+
inz_ptr,
|
| 48 |
+
vnz_ptr,
|
| 49 |
+
nnz,
|
| 50 |
+
inz_stride_ii,
|
| 51 |
+
inz_stride_nz,
|
| 52 |
+
vnz_stride,
|
| 53 |
+
x_ptr,
|
| 54 |
+
batch_size,
|
| 55 |
+
nlat_in,
|
| 56 |
+
nlon_in,
|
| 57 |
+
x_stride_b,
|
| 58 |
+
x_stride_t,
|
| 59 |
+
x_stride_p,
|
| 60 |
+
y_ptr,
|
| 61 |
+
kernel_size,
|
| 62 |
+
nlat_out,
|
| 63 |
+
nlon_out,
|
| 64 |
+
y_stride_b,
|
| 65 |
+
y_stride_f,
|
| 66 |
+
y_stride_t,
|
| 67 |
+
y_stride_p,
|
| 68 |
+
pscale,
|
| 69 |
+
backward: tl.constexpr,
|
| 70 |
+
BLOCK_SIZE_BATCH: tl.constexpr,
|
| 71 |
+
BLOCK_SIZE_NZ: tl.constexpr,
|
| 72 |
+
BLOCK_SIZE_POUT: tl.constexpr,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Kernel for the sparse-dense contraction for the S2 DISCO convolution.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
pid_batch = tl.program_id(0)
|
| 79 |
+
pid_pout = tl.program_id(2)
|
| 80 |
+
|
| 81 |
+
# pid_nz should always be 0 as we do not account for larger grids in this dimension
|
| 82 |
+
pid_nz = tl.program_id(1) # should be always 0
|
| 83 |
+
tl.device_assert(pid_nz == 0)
|
| 84 |
+
|
| 85 |
+
# create the pointer block for pout
|
| 86 |
+
pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT)
|
| 87 |
+
b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
|
| 88 |
+
|
| 89 |
+
# create pointer blocks for the psi datastructure
|
| 90 |
+
iinz = tl.arange(0, BLOCK_SIZE_NZ)
|
| 91 |
+
|
| 92 |
+
# get the initial pointers
|
| 93 |
+
fout_ptrs = inz_ptr + iinz * inz_stride_nz
|
| 94 |
+
tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii
|
| 95 |
+
tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii
|
| 96 |
+
vals_ptrs = vnz_ptr + iinz * vnz_stride
|
| 97 |
+
|
| 98 |
+
# iterate in a blocked fashion over the non-zero entries
|
| 99 |
+
for offs_nz in range(0, nnz, BLOCK_SIZE_NZ):
|
| 100 |
+
# load input output latitude coordinate pairs
|
| 101 |
+
fout = tl.load(
|
| 102 |
+
fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
|
| 103 |
+
)
|
| 104 |
+
tout = tl.load(
|
| 105 |
+
tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
|
| 106 |
+
)
|
| 107 |
+
tpnz = tl.load(
|
| 108 |
+
tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# load corresponding values
|
| 112 |
+
vals = tl.load(
|
| 113 |
+
vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# compute the shifted longitude coordinates p+p' to read in a coalesced fashion
|
| 117 |
+
tnz = tpnz // nlon_in
|
| 118 |
+
pnz = tpnz % nlon_in
|
| 119 |
+
|
| 120 |
+
# make sure the value is not out of bounds
|
| 121 |
+
tl.device_assert(fout < kernel_size)
|
| 122 |
+
tl.device_assert(tout < nlat_out)
|
| 123 |
+
tl.device_assert(tnz < nlat_in)
|
| 124 |
+
tl.device_assert(pnz < nlon_in)
|
| 125 |
+
|
| 126 |
+
# load corresponding portion of the input array
|
| 127 |
+
x_ptrs = (
|
| 128 |
+
x_ptr
|
| 129 |
+
+ tnz[None, :, None] * x_stride_t
|
| 130 |
+
+ ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in)
|
| 131 |
+
* x_stride_p
|
| 132 |
+
+ b[:, None, None] * x_stride_b
|
| 133 |
+
)
|
| 134 |
+
y_ptrs = (
|
| 135 |
+
y_ptr
|
| 136 |
+
+ fout[None, :, None] * y_stride_f
|
| 137 |
+
+ tout[None, :, None] * y_stride_t
|
| 138 |
+
+ (pout[None, None, :] % nlon_out) * y_stride_p
|
| 139 |
+
+ b[:, None, None] * y_stride_b
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# precompute the mask
|
| 143 |
+
mask = (
|
| 144 |
+
(b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz)
|
| 145 |
+
) and (pout[None, None, :] < nlon_out)
|
| 146 |
+
|
| 147 |
+
# do the actual computation. Backward is essentially just the same operation with swapped tensors.
|
| 148 |
+
if not backward:
|
| 149 |
+
x = tl.load(x_ptrs, mask=mask, other=0.0)
|
| 150 |
+
y = vals[None, :, None] * x
|
| 151 |
+
|
| 152 |
+
# store it to the output array
|
| 153 |
+
tl.atomic_add(y_ptrs, y, mask=mask)
|
| 154 |
+
else:
|
| 155 |
+
y = tl.load(y_ptrs, mask=mask, other=0.0)
|
| 156 |
+
x = vals[None, :, None] * y
|
| 157 |
+
|
| 158 |
+
# store it to the output array
|
| 159 |
+
tl.atomic_add(x_ptrs, x, mask=mask)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 163 |
+
"""
|
| 164 |
+
Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
x: torch.Tensor
|
| 169 |
+
Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in).
|
| 170 |
+
psi : torch.Tensor
|
| 171 |
+
Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
|
| 172 |
+
nlon_out: int
|
| 173 |
+
Number of longitude points the output should have.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
# check the shapes of all input tensors
|
| 177 |
+
assert len(psi.shape) == 3
|
| 178 |
+
assert len(x.shape) == 4
|
| 179 |
+
assert psi.is_sparse, "Psi must be a sparse COO tensor"
|
| 180 |
+
|
| 181 |
+
# TODO: check that Psi is also coalesced
|
| 182 |
+
|
| 183 |
+
# get the dimensions of the problem
|
| 184 |
+
kernel_size, nlat_out, n_in = psi.shape
|
| 185 |
+
nnz = psi.indices().shape[-1]
|
| 186 |
+
batch_size, n_chans, nlat_in, nlon_in = x.shape
|
| 187 |
+
assert nlat_in * nlon_in == n_in
|
| 188 |
+
|
| 189 |
+
# TODO: check that Psi index vector is of type long
|
| 190 |
+
|
| 191 |
+
# make sure that the grid-points of the output grid fall onto the grid points of the input grid
|
| 192 |
+
assert nlon_in % nlon_out == 0
|
| 193 |
+
pscale = nlon_in // nlon_out
|
| 194 |
+
|
| 195 |
+
# to simplify things, we merge batch and channel dimensions
|
| 196 |
+
x = x.reshape(batch_size * n_chans, nlat_in, nlon_in)
|
| 197 |
+
|
| 198 |
+
# prepare the output tensor
|
| 199 |
+
y = torch.zeros(
|
| 200 |
+
batch_size * n_chans,
|
| 201 |
+
kernel_size,
|
| 202 |
+
nlat_out,
|
| 203 |
+
nlon_out,
|
| 204 |
+
device=x.device,
|
| 205 |
+
dtype=x.dtype,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# determine the grid for the computation
|
| 209 |
+
grid = (
|
| 210 |
+
triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
|
| 211 |
+
1,
|
| 212 |
+
triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# launch the kernel
|
| 216 |
+
_disco_s2_contraction_kernel[grid](
|
| 217 |
+
psi.indices(),
|
| 218 |
+
psi.values(),
|
| 219 |
+
nnz,
|
| 220 |
+
psi.indices().stride(-2),
|
| 221 |
+
psi.indices().stride(-1),
|
| 222 |
+
psi.values().stride(-1),
|
| 223 |
+
x,
|
| 224 |
+
batch_size * n_chans,
|
| 225 |
+
nlat_in,
|
| 226 |
+
nlon_in,
|
| 227 |
+
x.stride(0),
|
| 228 |
+
x.stride(-2),
|
| 229 |
+
x.stride(-1),
|
| 230 |
+
y,
|
| 231 |
+
kernel_size,
|
| 232 |
+
nlat_out,
|
| 233 |
+
nlon_out,
|
| 234 |
+
y.stride(0),
|
| 235 |
+
y.stride(1),
|
| 236 |
+
y.stride(-2),
|
| 237 |
+
y.stride(-1),
|
| 238 |
+
pscale,
|
| 239 |
+
False,
|
| 240 |
+
BLOCK_SIZE_BATCH,
|
| 241 |
+
BLOCK_SIZE_NZ,
|
| 242 |
+
BLOCK_SIZE_POUT,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# reshape y back to expose the correct dimensions
|
| 246 |
+
y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)
|
| 247 |
+
|
| 248 |
+
return y
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int):
|
| 252 |
+
"""
|
| 253 |
+
Backward pass for the triton implementation of the efficient DISCO convolution on the sphere.
|
| 254 |
+
|
| 255 |
+
Parameters
|
| 256 |
+
----------
|
| 257 |
+
grad_y: torch.Tensor
|
| 258 |
+
Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out.
|
| 259 |
+
psi : torch.Tensor
|
| 260 |
+
Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
|
| 261 |
+
nlon_in: int
|
| 262 |
+
Number of longitude points the input used. Is required to infer the correct dimensions
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
# check the shapes of all input tensors
|
| 266 |
+
assert len(psi.shape) == 3
|
| 267 |
+
assert len(grad_y.shape) == 5
|
| 268 |
+
assert psi.is_sparse, "psi must be a sparse COO tensor"
|
| 269 |
+
|
| 270 |
+
# TODO: check that Psi is also coalesced
|
| 271 |
+
|
| 272 |
+
# get the dimensions of the problem
|
| 273 |
+
kernel_size, nlat_out, n_in = psi.shape
|
| 274 |
+
nnz = psi.indices().shape[-1]
|
| 275 |
+
assert grad_y.shape[-2] == nlat_out
|
| 276 |
+
assert grad_y.shape[-3] == kernel_size
|
| 277 |
+
assert n_in % nlon_in == 0
|
| 278 |
+
nlat_in = n_in // nlon_in
|
| 279 |
+
batch_size, n_chans, _, _, nlon_out = grad_y.shape
|
| 280 |
+
|
| 281 |
+
# make sure that the grid-points of the output grid fall onto the grid points of the input grid
|
| 282 |
+
assert nlon_in % nlon_out == 0
|
| 283 |
+
pscale = nlon_in // nlon_out
|
| 284 |
+
|
| 285 |
+
# to simplify things, we merge batch and channel dimensions
|
| 286 |
+
grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out)
|
| 287 |
+
|
| 288 |
+
# prepare the output tensor
|
| 289 |
+
grad_x = torch.zeros(
|
| 290 |
+
batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# determine the grid for the computation
|
| 294 |
+
grid = (
|
| 295 |
+
triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
|
| 296 |
+
1,
|
| 297 |
+
triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# launch the kernel
|
| 301 |
+
_disco_s2_contraction_kernel[grid](
|
| 302 |
+
psi.indices(),
|
| 303 |
+
psi.values(),
|
| 304 |
+
nnz,
|
| 305 |
+
psi.indices().stride(-2),
|
| 306 |
+
psi.indices().stride(-1),
|
| 307 |
+
psi.values().stride(-1),
|
| 308 |
+
grad_x,
|
| 309 |
+
batch_size * n_chans,
|
| 310 |
+
nlat_in,
|
| 311 |
+
nlon_in,
|
| 312 |
+
grad_x.stride(0),
|
| 313 |
+
grad_x.stride(-2),
|
| 314 |
+
grad_x.stride(-1),
|
| 315 |
+
grad_y,
|
| 316 |
+
kernel_size,
|
| 317 |
+
nlat_out,
|
| 318 |
+
nlon_out,
|
| 319 |
+
grad_y.stride(0),
|
| 320 |
+
grad_y.stride(1),
|
| 321 |
+
grad_y.stride(-2),
|
| 322 |
+
grad_y.stride(-1),
|
| 323 |
+
pscale,
|
| 324 |
+
True,
|
| 325 |
+
BLOCK_SIZE_BATCH,
|
| 326 |
+
BLOCK_SIZE_NZ,
|
| 327 |
+
BLOCK_SIZE_POUT,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# reshape y back to expose the correct dimensions
|
| 331 |
+
grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in)
|
| 332 |
+
|
| 333 |
+
return grad_x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class _DiscoS2ContractionTriton(torch.autograd.Function):
|
| 337 |
+
"""
|
| 338 |
+
Helper function to make the triton implementation work with PyTorch autograd functionality
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
@staticmethod
|
| 342 |
+
def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 343 |
+
ctx.save_for_backward(psi)
|
| 344 |
+
ctx.nlon_in = x.shape[-1]
|
| 345 |
+
|
| 346 |
+
return _disco_s2_contraction_fwd(x, psi, nlon_out)
|
| 347 |
+
|
| 348 |
+
@staticmethod
|
| 349 |
+
def backward(ctx, grad_output):
|
| 350 |
+
(psi,) = ctx.saved_tensors
|
| 351 |
+
grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in)
|
| 352 |
+
grad_x = grad_psi = None
|
| 353 |
+
|
| 354 |
+
return grad_input, None, None
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class _DiscoS2TransposeContractionTriton(torch.autograd.Function):
|
| 358 |
+
"""
|
| 359 |
+
Helper function to make the triton implementation work with PyTorch autograd functionality
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
@staticmethod
|
| 363 |
+
def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 364 |
+
ctx.save_for_backward(psi)
|
| 365 |
+
ctx.nlon_in = x.shape[-1]
|
| 366 |
+
|
| 367 |
+
return _disco_s2_contraction_bwd(x, psi, nlon_out)
|
| 368 |
+
|
| 369 |
+
@staticmethod
|
| 370 |
+
def backward(ctx, grad_output):
|
| 371 |
+
(psi,) = ctx.saved_tensors
|
| 372 |
+
grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in)
|
| 373 |
+
grad_x = grad_psi = None
|
| 374 |
+
|
| 375 |
+
return grad_input, None, None
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 379 |
+
return _DiscoS2ContractionTriton.apply(x, psi, nlon_out)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _disco_s2_transpose_contraction_triton(
|
| 383 |
+
x: torch.Tensor, psi: torch.Tensor, nlon_out: int
|
| 384 |
+
):
|
| 385 |
+
return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 389 |
+
"""
|
| 390 |
+
Reference implementation of the custom contraction as described in [1]. This requires repeated
|
| 391 |
+
shifting of the input tensor, which can potentially be costly. For an efficient implementation
|
| 392 |
+
on GPU, make sure to use the custom kernel written in Triton.
|
| 393 |
+
"""
|
| 394 |
+
assert len(psi.shape) == 3
|
| 395 |
+
assert len(x.shape) == 4
|
| 396 |
+
psi = psi.to(x.device)
|
| 397 |
+
|
| 398 |
+
batch_size, n_chans, nlat_in, nlon_in = x.shape
|
| 399 |
+
kernel_size, nlat_out, _ = psi.shape
|
| 400 |
+
|
| 401 |
+
assert psi.shape[-1] == nlat_in * nlon_in
|
| 402 |
+
assert nlon_in % nlon_out == 0
|
| 403 |
+
assert nlon_in >= nlat_out
|
| 404 |
+
pscale = nlon_in // nlon_out
|
| 405 |
+
|
| 406 |
+
# add a dummy dimension for nkernel and move the batch and channel dims to the end
|
| 407 |
+
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
|
| 408 |
+
x = x.expand(kernel_size, -1, -1, -1)
|
| 409 |
+
|
| 410 |
+
y = torch.zeros(
|
| 411 |
+
nlon_out,
|
| 412 |
+
kernel_size,
|
| 413 |
+
nlat_out,
|
| 414 |
+
batch_size * n_chans,
|
| 415 |
+
device=x.device,
|
| 416 |
+
dtype=x.dtype,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
for pout in range(nlon_out):
|
| 420 |
+
# sparse contraction with psi
|
| 421 |
+
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
|
| 422 |
+
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
|
| 423 |
+
x = torch.roll(x, -pscale, dims=2)
|
| 424 |
+
|
| 425 |
+
# reshape y back to expose the correct dimensions
|
| 426 |
+
y = y.permute(3, 1, 2, 0).reshape(
|
| 427 |
+
batch_size, n_chans, kernel_size, nlat_out, nlon_out
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return y
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _disco_s2_transpose_contraction_torch(
|
| 434 |
+
x: torch.Tensor, psi: torch.Tensor, nlon_out: int
|
| 435 |
+
):
|
| 436 |
+
"""
|
| 437 |
+
Reference implementation of the custom contraction as described in [1]. This requires repeated
|
| 438 |
+
shifting of the input tensor, which can potentially be costly. For an efficient implementation
|
| 439 |
+
on GPU, make sure to use the custom kernel written in Triton.
|
| 440 |
+
"""
|
| 441 |
+
assert len(psi.shape) == 3
|
| 442 |
+
assert len(x.shape) == 5
|
| 443 |
+
psi = psi.to(x.device)
|
| 444 |
+
|
| 445 |
+
batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
|
| 446 |
+
kernel_size, _, n_out = psi.shape
|
| 447 |
+
|
| 448 |
+
assert psi.shape[-2] == nlat_in
|
| 449 |
+
assert n_out % nlon_out == 0
|
| 450 |
+
nlat_out = n_out // nlon_out
|
| 451 |
+
assert nlon_out >= nlat_in
|
| 452 |
+
pscale = nlon_out // nlon_in
|
| 453 |
+
|
| 454 |
+
# we do a semi-transposition to faciliate the computation
|
| 455 |
+
inz = psi.indices()
|
| 456 |
+
tout = inz[2] // nlon_out
|
| 457 |
+
pout = inz[2] % nlon_out
|
| 458 |
+
# flip the axis of longitudes
|
| 459 |
+
pout = nlon_out - 1 - pout
|
| 460 |
+
tin = inz[1]
|
| 461 |
+
inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0)
|
| 462 |
+
psi_mod = torch.sparse_coo_tensor(
|
| 463 |
+
inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
|
| 467 |
+
x_ext = torch.zeros(
|
| 468 |
+
kernel_size,
|
| 469 |
+
nlat_in,
|
| 470 |
+
nlon_out,
|
| 471 |
+
batch_size * n_chans,
|
| 472 |
+
device=x.device,
|
| 473 |
+
dtype=x.dtype,
|
| 474 |
+
)
|
| 475 |
+
x_ext[:, :, ::pscale, :] = x.reshape(
|
| 476 |
+
batch_size * n_chans, kernel_size, nlat_in, nlon_in
|
| 477 |
+
).permute(1, 2, 3, 0)
|
| 478 |
+
# we need to go backwards through the vector, so we flip the axis
|
| 479 |
+
x_ext = x_ext.contiguous()
|
| 480 |
+
|
| 481 |
+
y = torch.zeros(
|
| 482 |
+
kernel_size,
|
| 483 |
+
nlon_out,
|
| 484 |
+
nlat_out,
|
| 485 |
+
batch_size * n_chans,
|
| 486 |
+
device=x.device,
|
| 487 |
+
dtype=x.dtype,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
for pout in range(nlon_out):
|
| 491 |
+
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
|
| 492 |
+
# TODO: double-check why this has to happen first
|
| 493 |
+
x_ext = torch.roll(x_ext, -1, dims=2)
|
| 494 |
+
# sparse contraction with the modified psi
|
| 495 |
+
y[:, pout, :, :] = torch.bmm(
|
| 496 |
+
psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# sum over the kernel dimension and reshape to the correct output size
|
| 500 |
+
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
|
| 501 |
+
|
| 502 |
+
return y
|
torch_harmonics_local/convolution.py
ADDED
|
@@ -0,0 +1,1014 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
|
| 32 |
+
import abc
|
| 33 |
+
import math
|
| 34 |
+
from functools import partial
|
| 35 |
+
from typing import List, Optional, Tuple, Union
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
|
| 41 |
+
from .quadrature import _precompute_grid, _precompute_latitudes
|
| 42 |
+
|
| 43 |
+
if torch.cuda.is_available():
|
| 44 |
+
from ._disco_convolution import (
|
| 45 |
+
_disco_s2_contraction_triton,
|
| 46 |
+
_disco_s2_transpose_contraction_triton,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _compute_support_vals_isotropic(
|
| 51 |
+
r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# compute the support
|
| 58 |
+
dr = (r_cutoff - 0.0) / nr
|
| 59 |
+
ikernel = torch.arange(nr).reshape(-1, 1, 1)
|
| 60 |
+
ir = ikernel * dr
|
| 61 |
+
|
| 62 |
+
if norm == "none":
|
| 63 |
+
norm_factor = 1.0
|
| 64 |
+
elif norm == "2d":
|
| 65 |
+
norm_factor = (
|
| 66 |
+
math.pi * (r_cutoff * nr / (nr + 1)) ** 2
|
| 67 |
+
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
|
| 68 |
+
)
|
| 69 |
+
elif norm == "s2":
|
| 70 |
+
norm_factor = (
|
| 71 |
+
2
|
| 72 |
+
* math.pi
|
| 73 |
+
* (
|
| 74 |
+
1
|
| 75 |
+
- math.cos(r_cutoff - dr)
|
| 76 |
+
+ math.cos(r_cutoff - dr)
|
| 77 |
+
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f"Unknown normalization mode {norm}.")
|
| 82 |
+
|
| 83 |
+
# find the indices where the rotated position falls into the support of the kernel
|
| 84 |
+
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
|
| 85 |
+
vals = (
|
| 86 |
+
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
|
| 87 |
+
) / norm_factor
|
| 88 |
+
return iidx, vals
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _compute_support_vals_anisotropic(
|
| 92 |
+
r: torch.Tensor,
|
| 93 |
+
phi: torch.Tensor,
|
| 94 |
+
nr: int,
|
| 95 |
+
nphi: int,
|
| 96 |
+
r_cutoff: float,
|
| 97 |
+
norm: str = "s2",
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
# compute the support
|
| 104 |
+
dr = (r_cutoff - 0.0) / nr
|
| 105 |
+
dphi = 2.0 * math.pi / nphi
|
| 106 |
+
kernel_size = (nr - 1) * nphi + 1
|
| 107 |
+
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
|
| 108 |
+
ir = ((ikernel - 1) // nphi + 1) * dr
|
| 109 |
+
iphi = ((ikernel - 1) % nphi) * dphi
|
| 110 |
+
|
| 111 |
+
if norm == "none":
|
| 112 |
+
norm_factor = 1.0
|
| 113 |
+
elif norm == "2d":
|
| 114 |
+
norm_factor = (
|
| 115 |
+
math.pi * (r_cutoff * nr / (nr + 1)) ** 2
|
| 116 |
+
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
|
| 117 |
+
)
|
| 118 |
+
elif norm == "s2":
|
| 119 |
+
norm_factor = (
|
| 120 |
+
2
|
| 121 |
+
* math.pi
|
| 122 |
+
* (
|
| 123 |
+
1
|
| 124 |
+
- math.cos(r_cutoff - dr)
|
| 125 |
+
+ math.cos(r_cutoff - dr)
|
| 126 |
+
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unknown normalization mode {norm}.")
|
| 131 |
+
|
| 132 |
+
# find the indices where the rotated position falls into the support of the kernel
|
| 133 |
+
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
|
| 134 |
+
cond_phi = (
|
| 135 |
+
(ikernel == 0)
|
| 136 |
+
| ((phi - iphi).abs() <= dphi)
|
| 137 |
+
| ((2 * math.pi - (phi - iphi).abs()) <= dphi)
|
| 138 |
+
)
|
| 139 |
+
iidx = torch.argwhere(cond_r & cond_phi)
|
| 140 |
+
vals = (
|
| 141 |
+
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
|
| 142 |
+
) / norm_factor
|
| 143 |
+
vals *= torch.where(
|
| 144 |
+
iidx[:, 0] > 0,
|
| 145 |
+
(
|
| 146 |
+
1
|
| 147 |
+
- torch.minimum(
|
| 148 |
+
(phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(),
|
| 149 |
+
(
|
| 150 |
+
2 * math.pi
|
| 151 |
+
- (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
/ dphi
|
| 155 |
+
),
|
| 156 |
+
1.0,
|
| 157 |
+
)
|
| 158 |
+
return iidx, vals
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _precompute_convolution_tensor_s2(
|
| 162 |
+
in_shape,
|
| 163 |
+
out_shape,
|
| 164 |
+
kernel_shape,
|
| 165 |
+
grid_in="equiangular",
|
| 166 |
+
grid_out="equiangular",
|
| 167 |
+
theta_cutoff=0.01 * math.pi,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
|
| 171 |
+
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
|
| 172 |
+
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
|
| 173 |
+
|
| 174 |
+
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
|
| 175 |
+
$$
|
| 176 |
+
Y(\alpha) Z(\beta) Y(\gamma) n =
|
| 177 |
+
{\begin{bmatrix}
|
| 178 |
+
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
|
| 179 |
+
\sin(\beta)\sin(\gamma) \\
|
| 180 |
+
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
|
| 181 |
+
\end{bmatrix}}
|
| 182 |
+
$$
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
assert len(in_shape) == 2
|
| 186 |
+
assert len(out_shape) == 2
|
| 187 |
+
|
| 188 |
+
if len(kernel_shape) == 1:
|
| 189 |
+
kernel_handle = partial(
|
| 190 |
+
_compute_support_vals_isotropic,
|
| 191 |
+
nr=kernel_shape[0],
|
| 192 |
+
r_cutoff=theta_cutoff,
|
| 193 |
+
norm="s2",
|
| 194 |
+
)
|
| 195 |
+
elif len(kernel_shape) == 2:
|
| 196 |
+
kernel_handle = partial(
|
| 197 |
+
_compute_support_vals_anisotropic,
|
| 198 |
+
nr=kernel_shape[0],
|
| 199 |
+
nphi=kernel_shape[1],
|
| 200 |
+
r_cutoff=theta_cutoff,
|
| 201 |
+
norm="s2",
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError("kernel_shape should be either one- or two-dimensional.")
|
| 205 |
+
|
| 206 |
+
nlat_in, nlon_in = in_shape
|
| 207 |
+
nlat_out, nlon_out = out_shape
|
| 208 |
+
|
| 209 |
+
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
|
| 210 |
+
lats_in = torch.from_numpy(lats_in).float()
|
| 211 |
+
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
|
| 212 |
+
lats_out = torch.from_numpy(lats_out).float()
|
| 213 |
+
|
| 214 |
+
# array for accumulating non-zero indices
|
| 215 |
+
out_idx = torch.empty([3, 0], dtype=torch.long)
|
| 216 |
+
out_vals = torch.empty([0], dtype=torch.long)
|
| 217 |
+
|
| 218 |
+
# compute the phi differences
|
| 219 |
+
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
|
| 220 |
+
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
|
| 221 |
+
|
| 222 |
+
for t in range(nlat_out):
|
| 223 |
+
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
|
| 224 |
+
alpha = -lats_out[t]
|
| 225 |
+
beta = lons_in
|
| 226 |
+
gamma = lats_in.reshape(-1, 1)
|
| 227 |
+
|
| 228 |
+
# compute cartesian coordinates of the rotated position
|
| 229 |
+
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
|
| 230 |
+
# and therefore applied with a negative sign
|
| 231 |
+
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(
|
| 232 |
+
alpha
|
| 233 |
+
) * torch.cos(gamma)
|
| 234 |
+
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(
|
| 235 |
+
gamma
|
| 236 |
+
) * torch.sin(alpha)
|
| 237 |
+
y = torch.sin(beta) * torch.sin(gamma)
|
| 238 |
+
|
| 239 |
+
# normalization is emportant to avoid NaNs when arccos and atan are applied
|
| 240 |
+
# this can otherwise lead to spurious artifacts in the solution
|
| 241 |
+
norm = torch.sqrt(x * x + y * y + z * z)
|
| 242 |
+
x = x / norm
|
| 243 |
+
y = y / norm
|
| 244 |
+
z = z / norm
|
| 245 |
+
|
| 246 |
+
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
|
| 247 |
+
theta = torch.arccos(z)
|
| 248 |
+
phi = torch.arctan2(y, x) + torch.pi
|
| 249 |
+
|
| 250 |
+
# find the indices where the rotated position falls into the support of the kernel
|
| 251 |
+
iidx, vals = kernel_handle(theta, phi)
|
| 252 |
+
|
| 253 |
+
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
|
| 254 |
+
idx = torch.stack(
|
| 255 |
+
[
|
| 256 |
+
iidx[:, 0],
|
| 257 |
+
t * torch.ones_like(iidx[:, 0]),
|
| 258 |
+
iidx[:, 1] * nlon_in + iidx[:, 2],
|
| 259 |
+
],
|
| 260 |
+
dim=0,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# append indices and values to the COO datastructure
|
| 264 |
+
out_idx = torch.cat([out_idx, idx], dim=-1)
|
| 265 |
+
out_vals = torch.cat([out_vals, vals], dim=-1)
|
| 266 |
+
|
| 267 |
+
return out_idx, out_vals
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _precompute_convolution_tensor_2d(
|
| 271 |
+
grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False
|
| 272 |
+
):
|
| 273 |
+
"""
|
| 274 |
+
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
|
| 275 |
+
only that it assumes a non-periodic subset of the euclidean plane
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
# check that input arrays are valid point clouds in 2D
|
| 279 |
+
assert len(grid_in) == 2
|
| 280 |
+
assert len(grid_out) == 2
|
| 281 |
+
assert grid_in.shape[0] == 2
|
| 282 |
+
assert grid_out.shape[0] == 2
|
| 283 |
+
|
| 284 |
+
n_in = grid_in.shape[-1]
|
| 285 |
+
n_out = grid_out.shape[-1]
|
| 286 |
+
|
| 287 |
+
if len(kernel_shape) == 1:
|
| 288 |
+
kernel_handle = partial(
|
| 289 |
+
_compute_support_vals_isotropic,
|
| 290 |
+
nr=kernel_shape[0],
|
| 291 |
+
r_cutoff=radius_cutoff,
|
| 292 |
+
norm="2d",
|
| 293 |
+
)
|
| 294 |
+
elif len(kernel_shape) == 2:
|
| 295 |
+
kernel_handle = partial(
|
| 296 |
+
_compute_support_vals_anisotropic,
|
| 297 |
+
nr=kernel_shape[0],
|
| 298 |
+
nphi=kernel_shape[1],
|
| 299 |
+
r_cutoff=radius_cutoff,
|
| 300 |
+
norm="2d",
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError("kernel_shape should be either one- or two-dimensional.")
|
| 304 |
+
|
| 305 |
+
grid_in = grid_in.reshape(2, 1, n_in)
|
| 306 |
+
grid_out = grid_out.reshape(2, n_out, 1)
|
| 307 |
+
|
| 308 |
+
diffs = grid_in - grid_out
|
| 309 |
+
if periodic:
|
| 310 |
+
periodic_diffs = torch.where(diffs > 0.0, diffs - 1, diffs + 1)
|
| 311 |
+
diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)
|
| 312 |
+
|
| 313 |
+
r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
|
| 314 |
+
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi
|
| 315 |
+
|
| 316 |
+
idx, vals = kernel_handle(r, phi)
|
| 317 |
+
idx = idx.permute(1, 0)
|
| 318 |
+
|
| 319 |
+
return idx, vals
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class DiscreteContinuousConv(nn.Module, abc.ABC):
|
| 323 |
+
"""
|
| 324 |
+
Abstract base class for DISCO convolutions
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(
|
| 328 |
+
self,
|
| 329 |
+
in_channels: int,
|
| 330 |
+
out_channels: int,
|
| 331 |
+
kernel_shape: Union[int, List[int]],
|
| 332 |
+
groups: Optional[int] = 1,
|
| 333 |
+
bias: Optional[bool] = True,
|
| 334 |
+
):
|
| 335 |
+
super().__init__()
|
| 336 |
+
|
| 337 |
+
if isinstance(kernel_shape, int):
|
| 338 |
+
self.kernel_shape = [kernel_shape]
|
| 339 |
+
else:
|
| 340 |
+
self.kernel_shape = kernel_shape
|
| 341 |
+
|
| 342 |
+
if len(self.kernel_shape) == 1:
|
| 343 |
+
self.kernel_size = self.kernel_shape[0]
|
| 344 |
+
elif len(self.kernel_shape) == 2:
|
| 345 |
+
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError("kernel_shape should be either one- or two-dimensional.")
|
| 348 |
+
|
| 349 |
+
# groups
|
| 350 |
+
self.groups = groups
|
| 351 |
+
|
| 352 |
+
# weight tensor
|
| 353 |
+
if in_channels % self.groups != 0:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"Error, the number of input channels has to be an integer multiple of the group size"
|
| 356 |
+
)
|
| 357 |
+
if out_channels % self.groups != 0:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"Error, the number of output channels has to be an integer multiple of the group size"
|
| 360 |
+
)
|
| 361 |
+
self.groupsize = in_channels // self.groups
|
| 362 |
+
scale = math.sqrt(1.0 / self.groupsize)
|
| 363 |
+
self.weight = nn.Parameter(
|
| 364 |
+
scale * torch.randn(out_channels, self.groupsize, self.kernel_size)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if bias:
|
| 368 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 369 |
+
else:
|
| 370 |
+
self.bias = None
|
| 371 |
+
|
| 372 |
+
@abc.abstractmethod
|
| 373 |
+
def forward(self, x: torch.Tensor):
|
| 374 |
+
raise NotImplementedError
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
|
| 378 |
+
"""
|
| 379 |
+
Reference implementation of the custom contraction as described in [1]. This requires repeated
|
| 380 |
+
shifting of the input tensor, which can potentially be costly. For an efficient implementation
|
| 381 |
+
on GPU, make sure to use the custom kernel written in Triton.
|
| 382 |
+
"""
|
| 383 |
+
assert len(psi.shape) == 3
|
| 384 |
+
assert len(x.shape) == 4
|
| 385 |
+
psi = psi.to(x.device)
|
| 386 |
+
|
| 387 |
+
batch_size, n_chans, nlat_in, nlon_in = x.shape
|
| 388 |
+
kernel_size, nlat_out, _ = psi.shape
|
| 389 |
+
|
| 390 |
+
assert psi.shape[-1] == nlat_in * nlon_in
|
| 391 |
+
assert nlon_in % nlon_out == 0
|
| 392 |
+
assert nlon_in >= nlat_out
|
| 393 |
+
pscale = nlon_in // nlon_out
|
| 394 |
+
|
| 395 |
+
# add a dummy dimension for nkernel and move the batch and channel dims to the end
|
| 396 |
+
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
|
| 397 |
+
x = x.expand(kernel_size, -1, -1, -1)
|
| 398 |
+
|
| 399 |
+
y = torch.zeros(
|
| 400 |
+
nlon_out,
|
| 401 |
+
kernel_size,
|
| 402 |
+
nlat_out,
|
| 403 |
+
batch_size * n_chans,
|
| 404 |
+
device=x.device,
|
| 405 |
+
dtype=x.dtype,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
for pout in range(nlon_out):
|
| 409 |
+
# sparse contraction with psi
|
| 410 |
+
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
|
| 411 |
+
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
|
| 412 |
+
x = torch.roll(x, -pscale, dims=2)
|
| 413 |
+
|
| 414 |
+
# reshape y back to expose the correct dimensions
|
| 415 |
+
y = y.permute(3, 1, 2, 0).reshape(
|
| 416 |
+
batch_size, n_chans, kernel_size, nlat_out, nlon_out
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
return y
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _disco_s2_transpose_contraction_torch(
|
| 423 |
+
x: torch.Tensor, psi: torch.Tensor, nlon_out: int
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Reference implementation of the custom contraction as described in [1]. This requires repeated
|
| 427 |
+
shifting of the input tensor, which can potentially be costly. For an efficient implementation
|
| 428 |
+
on GPU, make sure to use the custom kernel written in Triton.
|
| 429 |
+
"""
|
| 430 |
+
assert len(psi.shape) == 3
|
| 431 |
+
assert len(x.shape) == 5
|
| 432 |
+
psi = psi.to(x.device)
|
| 433 |
+
|
| 434 |
+
batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
|
| 435 |
+
kernel_size, _, n_out = psi.shape
|
| 436 |
+
|
| 437 |
+
assert psi.shape[-2] == nlat_in
|
| 438 |
+
assert n_out % nlon_out == 0
|
| 439 |
+
nlat_out = n_out // nlon_out
|
| 440 |
+
assert nlon_out >= nlat_in
|
| 441 |
+
pscale = nlon_out // nlon_in
|
| 442 |
+
|
| 443 |
+
# we do a semi-transposition to faciliate the computation
|
| 444 |
+
inz = psi.indices()
|
| 445 |
+
tout = inz[2] // nlon_out
|
| 446 |
+
pout = inz[2] % nlon_out
|
| 447 |
+
# flip the axis of longitudes
|
| 448 |
+
pout = nlon_out - 1 - pout
|
| 449 |
+
tin = inz[1]
|
| 450 |
+
inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0)
|
| 451 |
+
psi_mod = torch.sparse_coo_tensor(
|
| 452 |
+
inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
|
| 456 |
+
x_ext = torch.zeros(
|
| 457 |
+
kernel_size,
|
| 458 |
+
nlat_in,
|
| 459 |
+
nlon_out,
|
| 460 |
+
batch_size * n_chans,
|
| 461 |
+
device=x.device,
|
| 462 |
+
dtype=x.dtype,
|
| 463 |
+
)
|
| 464 |
+
x_ext[:, :, ::pscale, :] = x.reshape(
|
| 465 |
+
batch_size * n_chans, kernel_size, nlat_in, nlon_in
|
| 466 |
+
).permute(1, 2, 3, 0)
|
| 467 |
+
# we need to go backwards through the vector, so we flip the axis
|
| 468 |
+
x_ext = x_ext.contiguous()
|
| 469 |
+
|
| 470 |
+
y = torch.zeros(
|
| 471 |
+
kernel_size,
|
| 472 |
+
nlon_out,
|
| 473 |
+
nlat_out,
|
| 474 |
+
batch_size * n_chans,
|
| 475 |
+
device=x.device,
|
| 476 |
+
dtype=x.dtype,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
for pout in range(nlon_out):
|
| 480 |
+
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
|
| 481 |
+
# TODO: double-check why this has to happen first
|
| 482 |
+
x_ext = torch.roll(x_ext, -1, dims=2)
|
| 483 |
+
# sparse contraction with the modified psi
|
| 484 |
+
y[:, pout, :, :] = torch.bmm(
|
| 485 |
+
psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# sum over the kernel dimension and reshape to the correct output size
|
| 489 |
+
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
|
| 490 |
+
|
| 491 |
+
return y
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class DiscreteContinuousConvS2(DiscreteContinuousConv):
|
| 495 |
+
"""
|
| 496 |
+
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
|
| 497 |
+
|
| 498 |
+
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
def __init__(
|
| 502 |
+
self,
|
| 503 |
+
in_channels: int,
|
| 504 |
+
out_channels: int,
|
| 505 |
+
in_shape: Tuple[int],
|
| 506 |
+
out_shape: Tuple[int],
|
| 507 |
+
kernel_shape: Union[int, List[int]],
|
| 508 |
+
groups: Optional[int] = 1,
|
| 509 |
+
grid_in: Optional[str] = "equiangular",
|
| 510 |
+
grid_out: Optional[str] = "equiangular",
|
| 511 |
+
bias: Optional[bool] = True,
|
| 512 |
+
theta_cutoff: Optional[float] = None,
|
| 513 |
+
):
|
| 514 |
+
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
|
| 515 |
+
|
| 516 |
+
self.nlat_in, self.nlon_in = in_shape
|
| 517 |
+
self.nlat_out, self.nlon_out = out_shape
|
| 518 |
+
|
| 519 |
+
# compute theta cutoff based on the bandlimit of the input field
|
| 520 |
+
if theta_cutoff is None:
|
| 521 |
+
theta_cutoff = (
|
| 522 |
+
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if theta_cutoff <= 0.0:
|
| 526 |
+
raise ValueError("Error, theta_cutoff has to be positive.")
|
| 527 |
+
|
| 528 |
+
# integration weights
|
| 529 |
+
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
|
| 530 |
+
quad_weights = (
|
| 531 |
+
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
|
| 532 |
+
)
|
| 533 |
+
self.register_buffer("quad_weights", quad_weights, persistent=False)
|
| 534 |
+
|
| 535 |
+
idx, vals = _precompute_convolution_tensor_s2(
|
| 536 |
+
in_shape,
|
| 537 |
+
out_shape,
|
| 538 |
+
self.kernel_shape,
|
| 539 |
+
grid_in=grid_in,
|
| 540 |
+
grid_out=grid_out,
|
| 541 |
+
theta_cutoff=theta_cutoff,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
self.register_buffer("psi_idx", idx, persistent=False)
|
| 545 |
+
self.register_buffer("psi_vals", vals, persistent=False)
|
| 546 |
+
|
| 547 |
+
def get_psi(self):
|
| 548 |
+
psi = torch.sparse_coo_tensor(
|
| 549 |
+
self.psi_idx,
|
| 550 |
+
self.psi_vals,
|
| 551 |
+
size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in),
|
| 552 |
+
).coalesce()
|
| 553 |
+
return psi
|
| 554 |
+
|
| 555 |
+
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
|
| 556 |
+
# pre-multiply x with the quadrature weights
|
| 557 |
+
x = self.quad_weights * x
|
| 558 |
+
|
| 559 |
+
psi = self.get_psi()
|
| 560 |
+
|
| 561 |
+
if x.is_cuda and use_triton_kernel:
|
| 562 |
+
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
|
| 563 |
+
else:
|
| 564 |
+
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
|
| 565 |
+
|
| 566 |
+
# extract shape
|
| 567 |
+
B, C, K, H, W = x.shape
|
| 568 |
+
x = x.reshape(B, self.groups, self.groupsize, K, H, W)
|
| 569 |
+
|
| 570 |
+
# do weight multiplication
|
| 571 |
+
out = torch.einsum(
|
| 572 |
+
"bgckxy,gock->bgoxy",
|
| 573 |
+
x,
|
| 574 |
+
self.weight.reshape(
|
| 575 |
+
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
|
| 576 |
+
),
|
| 577 |
+
)
|
| 578 |
+
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
|
| 579 |
+
|
| 580 |
+
if self.bias is not None:
|
| 581 |
+
out = out + self.bias.reshape(1, -1, 1, 1)
|
| 582 |
+
|
| 583 |
+
return out
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
|
| 587 |
+
"""
|
| 588 |
+
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
|
| 589 |
+
|
| 590 |
+
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
in_channels: int,
|
| 596 |
+
out_channels: int,
|
| 597 |
+
in_shape: Tuple[int],
|
| 598 |
+
out_shape: Tuple[int],
|
| 599 |
+
kernel_shape: Union[int, List[int]],
|
| 600 |
+
groups: Optional[int] = 1,
|
| 601 |
+
grid_in: Optional[str] = "equiangular",
|
| 602 |
+
grid_out: Optional[str] = "equiangular",
|
| 603 |
+
bias: Optional[bool] = True,
|
| 604 |
+
theta_cutoff: Optional[float] = None,
|
| 605 |
+
):
|
| 606 |
+
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
|
| 607 |
+
|
| 608 |
+
self.nlat_in, self.nlon_in = in_shape
|
| 609 |
+
self.nlat_out, self.nlon_out = out_shape
|
| 610 |
+
|
| 611 |
+
# bandlimit
|
| 612 |
+
if theta_cutoff is None:
|
| 613 |
+
theta_cutoff = (
|
| 614 |
+
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
if theta_cutoff <= 0.0:
|
| 618 |
+
raise ValueError("Error, theta_cutoff has to be positive.")
|
| 619 |
+
|
| 620 |
+
# integration weights
|
| 621 |
+
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
|
| 622 |
+
quad_weights = (
|
| 623 |
+
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
|
| 624 |
+
)
|
| 625 |
+
self.register_buffer("quad_weights", quad_weights, persistent=False)
|
| 626 |
+
|
| 627 |
+
# switch in_shape and out_shape since we want transpose conv
|
| 628 |
+
idx, vals = _precompute_convolution_tensor_s2(
|
| 629 |
+
out_shape,
|
| 630 |
+
in_shape,
|
| 631 |
+
self.kernel_shape,
|
| 632 |
+
grid_in=grid_out,
|
| 633 |
+
grid_out=grid_in,
|
| 634 |
+
theta_cutoff=theta_cutoff,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
self.register_buffer("psi_idx", idx, persistent=False)
|
| 638 |
+
self.register_buffer("psi_vals", vals, persistent=False)
|
| 639 |
+
|
| 640 |
+
def get_psi(self):
|
| 641 |
+
psi = torch.sparse_coo_tensor(
|
| 642 |
+
self.psi_idx,
|
| 643 |
+
self.psi_vals,
|
| 644 |
+
size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out),
|
| 645 |
+
).coalesce()
|
| 646 |
+
return psi
|
| 647 |
+
|
| 648 |
+
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
|
| 649 |
+
# extract shape
|
| 650 |
+
B, C, H, W = x.shape
|
| 651 |
+
x = x.reshape(B, self.groups, self.groupsize, H, W)
|
| 652 |
+
|
| 653 |
+
# do weight multiplication
|
| 654 |
+
x = torch.einsum(
|
| 655 |
+
"bgcxy,gock->bgokxy",
|
| 656 |
+
x,
|
| 657 |
+
self.weight.reshape(
|
| 658 |
+
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
|
| 659 |
+
),
|
| 660 |
+
)
|
| 661 |
+
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
|
| 662 |
+
|
| 663 |
+
# pre-multiply x with the quadrature weights
|
| 664 |
+
x = self.quad_weights * x
|
| 665 |
+
|
| 666 |
+
psi = self.get_psi()
|
| 667 |
+
|
| 668 |
+
if x.is_cuda and use_triton_kernel:
|
| 669 |
+
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
|
| 670 |
+
else:
|
| 671 |
+
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
|
| 672 |
+
|
| 673 |
+
if self.bias is not None:
|
| 674 |
+
out = out + self.bias.reshape(1, -1, 1, 1)
|
| 675 |
+
|
| 676 |
+
return out
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class DiscreteContinuousConv2d(DiscreteContinuousConv):
|
| 680 |
+
"""
|
| 681 |
+
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
|
| 682 |
+
|
| 683 |
+
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
|
| 684 |
+
"""
|
| 685 |
+
|
| 686 |
+
def __init__(
|
| 687 |
+
self,
|
| 688 |
+
in_channels: int,
|
| 689 |
+
out_channels: int,
|
| 690 |
+
grid_in: torch.Tensor,
|
| 691 |
+
grid_out: torch.Tensor,
|
| 692 |
+
kernel_shape: Union[int, List[int]],
|
| 693 |
+
n_in: Optional[Tuple[int]] = None,
|
| 694 |
+
n_out: Optional[Tuple[int]] = None,
|
| 695 |
+
quad_weights: Optional[torch.Tensor] = None,
|
| 696 |
+
periodic: Optional[bool] = False,
|
| 697 |
+
groups: Optional[int] = 1,
|
| 698 |
+
bias: Optional[bool] = True,
|
| 699 |
+
radius_cutoff: Optional[float] = None,
|
| 700 |
+
):
|
| 701 |
+
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
|
| 702 |
+
|
| 703 |
+
# the instantiator supports convenience constructors for the input and output grids
|
| 704 |
+
if isinstance(grid_in, torch.Tensor):
|
| 705 |
+
assert isinstance(quad_weights, torch.Tensor)
|
| 706 |
+
assert not periodic
|
| 707 |
+
elif isinstance(grid_in, str):
|
| 708 |
+
assert n_in is not None
|
| 709 |
+
assert len(n_in) == 2
|
| 710 |
+
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
|
| 711 |
+
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
|
| 712 |
+
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
|
| 713 |
+
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
|
| 714 |
+
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
|
| 715 |
+
quad_weights = (wx * wy).reshape(-1)
|
| 716 |
+
else:
|
| 717 |
+
raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
|
| 718 |
+
|
| 719 |
+
if isinstance(grid_out, torch.Tensor):
|
| 720 |
+
pass
|
| 721 |
+
elif isinstance(grid_out, str):
|
| 722 |
+
assert n_out is not None
|
| 723 |
+
assert len(n_out) == 2
|
| 724 |
+
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
|
| 725 |
+
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
|
| 726 |
+
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
|
| 727 |
+
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
|
| 730 |
+
|
| 731 |
+
# check that input arrays are valid point clouds in 2D
|
| 732 |
+
assert len(grid_in.shape) == 2
|
| 733 |
+
assert len(grid_out.shape) == 2
|
| 734 |
+
assert len(quad_weights.shape) == 1
|
| 735 |
+
assert grid_in.shape[0] == 2
|
| 736 |
+
assert grid_out.shape[0] == 2
|
| 737 |
+
|
| 738 |
+
self.n_in = grid_in.shape[-1]
|
| 739 |
+
self.n_out = grid_out.shape[-1]
|
| 740 |
+
|
| 741 |
+
# compute the cutoff radius based on the bandlimit of the input field
|
| 742 |
+
# TODO: this heuristic is ad-hoc! Verify that we do the right one
|
| 743 |
+
if radius_cutoff is None:
|
| 744 |
+
radius_cutoff = (
|
| 745 |
+
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
if radius_cutoff <= 0.0:
|
| 749 |
+
raise ValueError("Error, radius_cutoff has to be positive.")
|
| 750 |
+
|
| 751 |
+
# integration weights
|
| 752 |
+
self.register_buffer("quad_weights", quad_weights, persistent=False)
|
| 753 |
+
|
| 754 |
+
idx, vals = _precompute_convolution_tensor_2d(
|
| 755 |
+
grid_in,
|
| 756 |
+
grid_out,
|
| 757 |
+
self.kernel_shape,
|
| 758 |
+
radius_cutoff=radius_cutoff,
|
| 759 |
+
periodic=periodic,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# to improve performance, we make psi a matrix by merging the first two dimensions
|
| 763 |
+
# This has to be accounted for in the forward pass
|
| 764 |
+
idx = torch.stack([idx[0] * self.n_out + idx[1], idx[2]], dim=0)
|
| 765 |
+
|
| 766 |
+
self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
|
| 767 |
+
self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
|
| 768 |
+
|
| 769 |
+
def get_psi(self):
|
| 770 |
+
psi = torch.sparse_coo_tensor(
|
| 771 |
+
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
|
| 772 |
+
)
|
| 773 |
+
return psi
|
| 774 |
+
|
| 775 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 776 |
+
# pre-multiply x with the quadrature weights
|
| 777 |
+
x = self.quad_weights * x
|
| 778 |
+
|
| 779 |
+
psi = self.get_psi()
|
| 780 |
+
|
| 781 |
+
# extract shape
|
| 782 |
+
B, C, _ = x.shape
|
| 783 |
+
|
| 784 |
+
# bring into the right shape for the bmm and perform it
|
| 785 |
+
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
|
| 786 |
+
x = torch.mm(psi, x)
|
| 787 |
+
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
|
| 788 |
+
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
|
| 789 |
+
|
| 790 |
+
# do weight multiplication
|
| 791 |
+
out = torch.einsum(
|
| 792 |
+
"bgckx,gock->bgox",
|
| 793 |
+
x,
|
| 794 |
+
self.weight.reshape(
|
| 795 |
+
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
|
| 796 |
+
),
|
| 797 |
+
)
|
| 798 |
+
out = out.reshape(out.shape[0], -1, out.shape[-1])
|
| 799 |
+
|
| 800 |
+
if self.bias is not None:
|
| 801 |
+
out = out + self.bias.reshape(1, -1, 1)
|
| 802 |
+
|
| 803 |
+
return out
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
class DiscreteContinuousConvTranspose2d(DiscreteContinuousConv):
|
| 807 |
+
"""
|
| 808 |
+
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
|
| 809 |
+
|
| 810 |
+
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
|
| 811 |
+
"""
|
| 812 |
+
|
| 813 |
+
def __init__(
|
| 814 |
+
self,
|
| 815 |
+
in_channels: int,
|
| 816 |
+
out_channels: int,
|
| 817 |
+
grid_in: torch.Tensor,
|
| 818 |
+
grid_out: torch.Tensor,
|
| 819 |
+
kernel_shape: Union[int, List[int]],
|
| 820 |
+
n_in: Optional[Tuple[int]] = None,
|
| 821 |
+
n_out: Optional[Tuple[int]] = None,
|
| 822 |
+
quad_weights: Optional[torch.Tensor] = None,
|
| 823 |
+
periodic: Optional[bool] = False,
|
| 824 |
+
groups: Optional[int] = 1,
|
| 825 |
+
bias: Optional[bool] = True,
|
| 826 |
+
radius_cutoff: Optional[float] = None,
|
| 827 |
+
):
|
| 828 |
+
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
|
| 829 |
+
|
| 830 |
+
# the instantiator supports convenience constructors for the input and output grids
|
| 831 |
+
if isinstance(grid_in, torch.Tensor):
|
| 832 |
+
assert isinstance(quad_weights, torch.Tensor)
|
| 833 |
+
assert not periodic
|
| 834 |
+
elif isinstance(grid_in, str):
|
| 835 |
+
assert n_in is not None
|
| 836 |
+
assert len(n_in) == 2
|
| 837 |
+
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
|
| 838 |
+
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
|
| 839 |
+
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
|
| 840 |
+
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
|
| 841 |
+
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
|
| 842 |
+
quad_weights = (wx * wy).reshape(-1)
|
| 843 |
+
else:
|
| 844 |
+
raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
|
| 845 |
+
|
| 846 |
+
if isinstance(grid_out, torch.Tensor):
|
| 847 |
+
pass
|
| 848 |
+
elif isinstance(grid_out, str):
|
| 849 |
+
assert n_out is not None
|
| 850 |
+
assert len(n_out) == 2
|
| 851 |
+
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
|
| 852 |
+
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
|
| 853 |
+
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
|
| 854 |
+
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
|
| 855 |
+
else:
|
| 856 |
+
raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
|
| 857 |
+
|
| 858 |
+
# check that input arrays are valid point clouds in 2D
|
| 859 |
+
assert len(grid_in.shape) == 2
|
| 860 |
+
assert len(grid_out.shape) == 2
|
| 861 |
+
assert len(quad_weights.shape) == 1
|
| 862 |
+
assert grid_in.shape[0] == 2
|
| 863 |
+
assert grid_out.shape[0] == 2
|
| 864 |
+
|
| 865 |
+
self.n_in = grid_in.shape[-1]
|
| 866 |
+
self.n_out = grid_out.shape[-1]
|
| 867 |
+
|
| 868 |
+
# compute the cutoff radius based on the bandlimit of the input field
|
| 869 |
+
# TODO: this heuristic is ad-hoc! Verify that we do the right one
|
| 870 |
+
if radius_cutoff is None:
|
| 871 |
+
radius_cutoff = (
|
| 872 |
+
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
if radius_cutoff <= 0.0:
|
| 876 |
+
raise ValueError("Error, radius_cutoff has to be positive.")
|
| 877 |
+
|
| 878 |
+
# integration weights
|
| 879 |
+
self.register_buffer("quad_weights", quad_weights, persistent=False)
|
| 880 |
+
|
| 881 |
+
# precompute the transposed tensor
|
| 882 |
+
idx, vals = _precompute_convolution_tensor_2d(
|
| 883 |
+
grid_out,
|
| 884 |
+
grid_in,
|
| 885 |
+
self.kernel_shape,
|
| 886 |
+
radius_cutoff=radius_cutoff,
|
| 887 |
+
periodic=periodic,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
# to improve performance, we make psi a matrix by merging the first two dimensions
|
| 891 |
+
# This has to be accounted for in the forward pass
|
| 892 |
+
idx = torch.stack([idx[0] * self.n_out + idx[2], idx[1]], dim=0)
|
| 893 |
+
|
| 894 |
+
self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
|
| 895 |
+
self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
|
| 896 |
+
|
| 897 |
+
def get_psi(self):
|
| 898 |
+
psi = torch.sparse_coo_tensor(
|
| 899 |
+
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
|
| 900 |
+
)
|
| 901 |
+
return psi
|
| 902 |
+
|
| 903 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 904 |
+
# pre-multiply x with the quadrature weights
|
| 905 |
+
x = self.quad_weights * x
|
| 906 |
+
|
| 907 |
+
psi = self.get_psi()
|
| 908 |
+
|
| 909 |
+
# extract shape
|
| 910 |
+
B, C, _ = x.shape
|
| 911 |
+
|
| 912 |
+
# bring into the right shape for the bmm and perform it
|
| 913 |
+
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
|
| 914 |
+
x = torch.mm(psi, x)
|
| 915 |
+
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
|
| 916 |
+
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
|
| 917 |
+
|
| 918 |
+
# do weight multiplication
|
| 919 |
+
out = torch.einsum(
|
| 920 |
+
"bgckx,gock->bgox",
|
| 921 |
+
x,
|
| 922 |
+
self.weight.reshape(
|
| 923 |
+
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
|
| 924 |
+
),
|
| 925 |
+
)
|
| 926 |
+
out = out.reshape(out.shape[0], -1, out.shape[-1])
|
| 927 |
+
|
| 928 |
+
if self.bias is not None:
|
| 929 |
+
out = out + self.bias.reshape(1, -1, 1)
|
| 930 |
+
|
| 931 |
+
return out
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
class EquidistantDiscreteContinuousConv2d(DiscreteContinuousConv):
|
| 935 |
+
"""
|
| 936 |
+
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
|
| 937 |
+
|
| 938 |
+
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
|
| 939 |
+
"""
|
| 940 |
+
|
| 941 |
+
def __init__(
|
| 942 |
+
self,
|
| 943 |
+
in_channels: int,
|
| 944 |
+
out_channels: int,
|
| 945 |
+
kernel_shape: Union[int, List[int]],
|
| 946 |
+
in_shape: Tuple[int],
|
| 947 |
+
groups: Optional[int] = 1,
|
| 948 |
+
bias: Optional[bool] = True,
|
| 949 |
+
radius_cutoff: Optional[float] = None,
|
| 950 |
+
padding_mode: str = "circular",
|
| 951 |
+
use_min_dim: bool = True,
|
| 952 |
+
**kwargs,
|
| 953 |
+
):
|
| 954 |
+
"""
|
| 955 |
+
use_min_dim (bool, optional): Use the minimum dimension of the input
|
| 956 |
+
shape to compute the cutoff radius. Otherwise use the maximum
|
| 957 |
+
dimension. Defaults to True.
|
| 958 |
+
"""
|
| 959 |
+
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
|
| 960 |
+
|
| 961 |
+
self.padding_mode = padding_mode
|
| 962 |
+
|
| 963 |
+
# compute the cutoff radius based on the assumption that the grid is [-1, 1]^2
|
| 964 |
+
# this still assumes a quadratic domain
|
| 965 |
+
f = min if use_min_dim else max
|
| 966 |
+
if radius_cutoff is None:
|
| 967 |
+
radius_cutoff = 2 * (self.kernel_shape[0]) / float(f(*in_shape))
|
| 968 |
+
# 2 * 0.02 * 7 / 2 + 1 = 1.14
|
| 969 |
+
self.psi_local_size = math.floor(2 * radius_cutoff * f(*in_shape) / 2) + 1
|
| 970 |
+
|
| 971 |
+
# psi_local is essentially the support of the hat functions evaluated locally
|
| 972 |
+
x = torch.linspace(-radius_cutoff, radius_cutoff, self.psi_local_size)
|
| 973 |
+
x, y = torch.meshgrid(x, x)
|
| 974 |
+
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
|
| 975 |
+
grid_out = torch.Tensor([[0.0], [0.0]])
|
| 976 |
+
|
| 977 |
+
idx, vals = _precompute_convolution_tensor_2d(
|
| 978 |
+
grid_in,
|
| 979 |
+
grid_out,
|
| 980 |
+
self.kernel_shape,
|
| 981 |
+
radius_cutoff=radius_cutoff,
|
| 982 |
+
periodic=False,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
psi_loc = torch.zeros(
|
| 986 |
+
self.kernel_size, self.psi_local_size * self.psi_local_size
|
| 987 |
+
)
|
| 988 |
+
for ie in range(len(vals)):
|
| 989 |
+
f = idx[0, ie]
|
| 990 |
+
j = idx[2, ie]
|
| 991 |
+
v = vals[ie]
|
| 992 |
+
psi_loc[f, j] = v
|
| 993 |
+
|
| 994 |
+
# compute local version of the filter matrix
|
| 995 |
+
psi_loc = psi_loc.reshape(
|
| 996 |
+
self.kernel_size, self.psi_local_size, self.psi_local_size
|
| 997 |
+
)
|
| 998 |
+
# normalization by the quadrature weights
|
| 999 |
+
psi_loc = 4.0 * psi_loc / float(in_shape[0] * in_shape[1])
|
| 1000 |
+
|
| 1001 |
+
self.register_buffer("psi_loc", psi_loc, persistent=False)
|
| 1002 |
+
|
| 1003 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1004 |
+
|
| 1005 |
+
kernel = torch.einsum("kxy,ogk->ogxy", self.psi_loc, self.weight)
|
| 1006 |
+
|
| 1007 |
+
left_pad = self.psi_local_size // 2
|
| 1008 |
+
right_pad = (self.psi_local_size + 1) // 2 - 1
|
| 1009 |
+
x = F.pad(x, (left_pad, right_pad, left_pad, right_pad), mode=self.padding_mode)
|
| 1010 |
+
out = F.conv2d(
|
| 1011 |
+
x, kernel, self.bias, stride=1, dilation=1, padding=0, groups=self.groups
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
return out
|
torch_harmonics_local/quadrature.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
|
| 36 |
+
|
| 37 |
+
if (grid != "equidistant") and periodic:
|
| 38 |
+
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
|
| 39 |
+
|
| 40 |
+
# compute coordinates
|
| 41 |
+
if grid == "equidistant":
|
| 42 |
+
xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic)
|
| 43 |
+
elif grid == "legendre-gauss":
|
| 44 |
+
xlg, wlg = legendre_gauss_weights(n, a=a, b=b)
|
| 45 |
+
elif grid == "lobatto":
|
| 46 |
+
xlg, wlg = lobatto_weights(n, a=a, b=b)
|
| 47 |
+
elif grid == "equiangular":
|
| 48 |
+
xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unknown grid type {grid}")
|
| 51 |
+
|
| 52 |
+
return xlg, wlg
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _precompute_latitudes(nlat, grid="equiangular"):
|
| 56 |
+
r"""
|
| 57 |
+
Convenience routine to precompute latitudes
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# compute coordinates
|
| 61 |
+
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
|
| 62 |
+
|
| 63 |
+
lats = np.flip(np.arccos(xlg)).copy()
|
| 64 |
+
wlg = np.flip(wlg).copy()
|
| 65 |
+
|
| 66 |
+
return lats, wlg
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
|
| 70 |
+
r"""
|
| 71 |
+
Helper routine which returns equidistant nodes with trapezoidal weights
|
| 72 |
+
on the interval [a, b]
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
xlg = np.linspace(a, b, n)
|
| 76 |
+
wlg = (b - a) / (n - 1) * np.ones(n)
|
| 77 |
+
|
| 78 |
+
if not periodic:
|
| 79 |
+
wlg[0] *= 0.5
|
| 80 |
+
wlg[-1] *= 0.5
|
| 81 |
+
|
| 82 |
+
return xlg, wlg
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def legendre_gauss_weights(n, a=-1.0, b=1.0):
|
| 86 |
+
r"""
|
| 87 |
+
Helper routine which returns the Legendre-Gauss nodes and weights
|
| 88 |
+
on the interval [a, b]
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
xlg, wlg = np.polynomial.legendre.leggauss(n)
|
| 92 |
+
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
|
| 93 |
+
wlg = wlg * (b - a) * 0.5
|
| 94 |
+
|
| 95 |
+
return xlg, wlg
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
|
| 99 |
+
r"""
|
| 100 |
+
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
|
| 101 |
+
on the interval [a, b]
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
wlg = np.zeros((n,))
|
| 105 |
+
tlg = np.zeros((n,))
|
| 106 |
+
tmp = np.zeros((n,))
|
| 107 |
+
|
| 108 |
+
# Vandermonde Matrix
|
| 109 |
+
vdm = np.zeros((n, n))
|
| 110 |
+
|
| 111 |
+
# initialize Chebyshev nodes as first guess
|
| 112 |
+
for i in range(n):
|
| 113 |
+
tlg[i] = -np.cos(np.pi * i / (n - 1))
|
| 114 |
+
|
| 115 |
+
tmp = 2.0
|
| 116 |
+
|
| 117 |
+
for i in range(maxiter):
|
| 118 |
+
tmp = tlg
|
| 119 |
+
|
| 120 |
+
vdm[:, 0] = 1.0
|
| 121 |
+
vdm[:, 1] = tlg
|
| 122 |
+
|
| 123 |
+
for k in range(2, n):
|
| 124 |
+
vdm[:, k] = (
|
| 125 |
+
(2 * k - 1) * tlg * vdm[:, k - 1] - (k - 1) * vdm[:, k - 2]
|
| 126 |
+
) / k
|
| 127 |
+
|
| 128 |
+
tlg = tmp - (tlg * vdm[:, n - 1] - vdm[:, n - 2]) / (n * vdm[:, n - 1])
|
| 129 |
+
|
| 130 |
+
if max(abs(tlg - tmp).flatten()) < tol:
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
wlg = 2.0 / ((n * (n - 1)) * (vdm[:, n - 1] ** 2))
|
| 134 |
+
|
| 135 |
+
# rescale
|
| 136 |
+
tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5
|
| 137 |
+
wlg = wlg * (b - a) * 0.5
|
| 138 |
+
|
| 139 |
+
return tlg, wlg
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
|
| 143 |
+
r"""
|
| 144 |
+
Computation of the Clenshaw-Curtis quadrature nodes and weights.
|
| 145 |
+
This implementation follows
|
| 146 |
+
|
| 147 |
+
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
assert n > 1
|
| 151 |
+
|
| 152 |
+
tcc = np.cos(np.linspace(np.pi, 0, n))
|
| 153 |
+
|
| 154 |
+
if n == 2:
|
| 155 |
+
wcc = np.array([1.0, 1.0])
|
| 156 |
+
else:
|
| 157 |
+
|
| 158 |
+
n1 = n - 1
|
| 159 |
+
N = np.arange(1, n1, 2)
|
| 160 |
+
l = len(N)
|
| 161 |
+
m = n1 - l
|
| 162 |
+
|
| 163 |
+
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
|
| 164 |
+
v = 0 - v[:-1] - v[-1:0:-1]
|
| 165 |
+
|
| 166 |
+
g0 = -np.ones(n1)
|
| 167 |
+
g0[l] = g0[l] + n1
|
| 168 |
+
g0[m] = g0[m] + n1
|
| 169 |
+
g = g0 / (n1**2 - 1 + (n1 % 2))
|
| 170 |
+
wcc = np.fft.ifft(v + g).real
|
| 171 |
+
wcc = np.concatenate((wcc, wcc[:1]))
|
| 172 |
+
|
| 173 |
+
# rescale
|
| 174 |
+
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
|
| 175 |
+
wcc = wcc * (b - a) * 0.5
|
| 176 |
+
|
| 177 |
+
return tcc, wcc
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def fejer2_weights(n, a=-1.0, b=1.0):
|
| 181 |
+
r"""
|
| 182 |
+
Computation of the Fejer quadrature nodes and weights.
|
| 183 |
+
This implementation follows
|
| 184 |
+
|
| 185 |
+
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
assert n > 2
|
| 189 |
+
|
| 190 |
+
tcc = np.cos(np.linspace(np.pi, 0, n))
|
| 191 |
+
|
| 192 |
+
n1 = n - 1
|
| 193 |
+
N = np.arange(1, n1, 2)
|
| 194 |
+
l = len(N)
|
| 195 |
+
m = n1 - l
|
| 196 |
+
|
| 197 |
+
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
|
| 198 |
+
v = 0 - v[:-1] - v[-1:0:-1]
|
| 199 |
+
|
| 200 |
+
wcc = np.fft.ifft(v).real
|
| 201 |
+
wcc = np.concatenate((wcc, wcc[:1]))
|
| 202 |
+
|
| 203 |
+
# rescale
|
| 204 |
+
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
|
| 205 |
+
wcc = wcc * (b - a) * 0.5
|
| 206 |
+
|
| 207 |
+
return tcc, wcc
|
type_utils.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def tuple_type(strings):
|
| 2 |
+
strings = strings.replace("(", "").replace(")", "").replace(" ", "")
|
| 3 |
+
mapped_int = map(int, strings.split(","))
|
| 4 |
+
return tuple(mapped_int)
|