|
--- |
|
license: mit |
|
--- |
|
|
|
Instructions to load the pre-trained model weights |
|
|
|
```python |
|
import torch |
|
from monai.networks.nets import SegResNetDS |
|
|
|
weights = torch.load("pretrained_segresnet.torch") |
|
|
|
model = SegResNetDS( |
|
blocks_down=(1, 2, 2, 4, 4) |
|
) |
|
|
|
model.load_state_dict(weights, strict=False) # Set strict to False as we load only the encoder |
|
|
|
# Dummy forward pass |
|
tensor = torch.randn((2, 1, 32, 32, 32)) |
|
|
|
# Note that the input data needs to be in "SPL" format (OR z,y,x default numpy/torch format), |
|
# you can use Orientation transform in MONAI set with value "SPL". |
|
# Note: All subsequent transforms must be applied in (z,y,x) format. Eg patch size of [16, 32, 32] corresponds to 16 in z-axis |
|
out = model(tensor) |
|
``` |