|
--- |
|
library_name: project-lighter |
|
tags: |
|
- lighter |
|
- model_hub_mixin |
|
- pytorch_model_hub_mixin |
|
language: en |
|
license: apache-2.0 |
|
arxiv: 2501.09001 |
|
--- |
|
|
|
# Whole Body Segmentation |
|
|
|
This model is a whole body segmentation model based on the SegResNet architecture. It was fine-tuned on CT-FM |
|
|
|
## Running instructions |
|
|
|
|
|
# Whole Body Segmentation Inference |
|
|
|
This notebook demonstrates how to: |
|
1. Load a pre-trained whole body segmentation model from HuggingFace Hub |
|
2. Set up preprocessing and postprocessing pipelines |
|
3. Perform sliding window inference on CT volumes |
|
4. Save the segmentation results |
|
|
|
The model segments 118 different anatomical structures from CT scans. |
|
|
|
## Setup |
|
Install requirements and import necessary packages |
|
|
|
|
|
```python |
|
# Install lighter_zoo package |
|
%pip install lighter_zoo -U -qq |
|
``` |
|
|
|
Note: you may need to restart the kernel to use updated packages. |
|
|
|
|
|
|
|
```python |
|
# Imports |
|
import torch |
|
from lighter_zoo import SegResNet |
|
from monai.transforms import ( |
|
Compose, LoadImage, EnsureType, Orientation, |
|
ScaleIntensityRange, CropForeground, Invert, |
|
Activations, AsDiscrete, KeepLargestConnectedComponent, |
|
SaveImage |
|
) |
|
from monai.inferers import SlidingWindowInferer |
|
``` |
|
|
|
Note: you may need to restart the kernel to use updated packages. |
|
|
|
|
|
## Load Model |
|
Download and initialize the pre-trained model from HuggingFace Hub |
|
|
|
|
|
```python |
|
# Load pre-trained model |
|
model = SegResNet.from_pretrained( |
|
"project-lighter/whole_body_segmentation", |
|
force_download=True |
|
) |
|
``` |
|
|
|
|
|
config.json: 0%| | 0.00/162 [00:00<?, ?B/s] |
|
|
|
|
|
|
|
model.safetensors: 0%| | 0.00/349M [00:00<?, ?B/s] |
|
|
|
|
|
## Configure Inference |
|
Set up sliding window inference for processing large volumes |
|
|
|
|
|
```python |
|
# Configure sliding window inference |
|
inferer = SlidingWindowInferer( |
|
roi_size=[96, 160, 160], # Size of patches to process |
|
sw_batch_size=2, # Number of windows to process in parallel |
|
overlap=0.625, # Overlap between windows (reduces boundary artifacts) |
|
mode="gaussian" # Gaussian weighting for overlap regions |
|
) |
|
``` |
|
|
|
## Setup Processing Pipelines |
|
Define preprocessing and postprocessing transforms |
|
|
|
|
|
```python |
|
# Preprocessing pipeline |
|
preprocess = Compose([ |
|
LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension |
|
EnsureType(), # Ensure correct data type |
|
Orientation(axcodes="SPL"), # Standardize orientation |
|
# Scale intensity to [0,1] range, clipping outliers |
|
ScaleIntensityRange( |
|
a_min=-1024, # Min HU value |
|
a_max=2048, # Max HU value |
|
b_min=0, # Target min |
|
b_max=1, # Target max |
|
clip=True # Clip values outside range |
|
), |
|
CropForeground() # Remove background to reduce computation |
|
]) |
|
|
|
# Postprocessing pipeline |
|
postprocess = Compose([ |
|
Activations(softmax=True), # Apply softmax to get probabilities |
|
AsDiscrete(argmax=True, dtype=torch.int32), # Convert to class labels |
|
KeepLargestConnectedComponent(), # Remove small disconnected regions |
|
Invert(transform=preprocess), # Restore original space |
|
# Save the result |
|
SaveImage(output_dir="./segmentations") |
|
]) |
|
``` |
|
|
|
/home/suraj/miniconda3/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5. |
|
warn_deprecated(argname, msg, warning_category) |
|
|
|
|
|
## Run Inference |
|
Process an input CT scan and generate segmentation |
|
|
|
|
|
```python |
|
# Input path |
|
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz" |
|
|
|
# Preprocess input |
|
input_tensor = preprocess(input_path) |
|
|
|
# Run inference |
|
with torch.no_grad(): |
|
output = inferer(input_tensor.unsqueeze(dim=0), model)[0] |
|
|
|
# Copy metadata from input |
|
output.applied_operations = input_tensor.applied_operations |
|
output.affine = input_tensor.affine |
|
|
|
# Postprocess and save result |
|
result = postprocess(output[0]) |
|
print("✅ Segmentation completed and saved") |
|
``` |
|
|
|
2025-01-16 18:41:57,674 INFO image_writer.py:197 - writing: /home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/segmentations/0/0_trans.nii.gz |
|
✅ Segmentation completed and saved |
|
|
|
|
|
|
|
|