Spaces:
Running
Running
init the space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +5 -4
- SegVol_v1.pth +3 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +308 -0
- model/LICENSE +21 -0
- model/README.md +74 -0
- model/__pycache__/inference_cpu.cpython-39.pyc +0 -0
- model/asset/FLARE22_Tr_0002_0000.nii.gz +3 -0
- model/asset/FLARE22_Tr_0005_0000.nii.gz +3 -0
- model/asset/FLARE22_Tr_0034_0000.nii.gz +3 -0
- model/asset/FLARE22_Tr_0045_0000.nii.gz +3 -0
- model/asset/model.png +0 -0
- model/asset/overview back.png +0 -0
- model/asset/overview.png +0 -0
- model/config/clip/config.json +157 -0
- model/config/clip/special_tokens_map.json +1 -0
- model/config/clip/tokenizer.json +0 -0
- model/config/clip/tokenizer_config.json +1 -0
- model/config/clip/vocab.json +0 -0
- model/config/config_demo.json +8 -0
- model/data_process/__pycache__/demo_data_process.cpython-39.pyc +0 -0
- model/data_process/demo_data_process.py +91 -0
- model/inference_cpu.py +173 -0
- model/inference_demo.py +219 -0
- model/network/__pycache__/model.cpython-39.pyc +0 -0
- model/network/model.py +91 -0
- model/script/inference_demo.sh +8 -0
- model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py +172 -0
- model/segment_anything_volumetric/__init__.py +12 -0
- model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/automatic_mask_generator.py +372 -0
- model/segment_anything_volumetric/build_sam.py +111 -0
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py +709 -0
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py +232 -0
- model/segment_anything_volumetric/modeling/__init__.py +11 -0
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
title: SegVol
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
-
sdk_version: 1.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: SegVol
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: streamlit
|
| 7 |
+
sdk_version: 1.28.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
SegVol_v1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
|
| 3 |
+
size 1660308695
|
__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from streamlit_drawable_canvas import st_canvas
|
| 3 |
+
from streamlit_image_coordinates import streamlit_image_coordinates
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from model.data_process.demo_data_process import process_ct_gt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from PIL import Image, ImageDraw
|
| 10 |
+
import monai.transforms as transforms
|
| 11 |
+
from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
|
| 12 |
+
|
| 13 |
+
print('script run')
|
| 14 |
+
|
| 15 |
+
#############################################
|
| 16 |
+
# init session_state
|
| 17 |
+
if 'option' not in st.session_state:
|
| 18 |
+
st.session_state.option = None
|
| 19 |
+
if 'text_prompt' not in st.session_state:
|
| 20 |
+
st.session_state.text_prompt = None
|
| 21 |
+
|
| 22 |
+
if 'reset_demo_case' not in st.session_state:
|
| 23 |
+
st.session_state.reset_demo_case = False
|
| 24 |
+
|
| 25 |
+
if 'preds_3D' not in st.session_state:
|
| 26 |
+
st.session_state.preds_3D = None
|
| 27 |
+
|
| 28 |
+
if 'data_item' not in st.session_state:
|
| 29 |
+
st.session_state.data_item = None
|
| 30 |
+
|
| 31 |
+
if 'points' not in st.session_state:
|
| 32 |
+
st.session_state.points = []
|
| 33 |
+
|
| 34 |
+
if 'use_text_prompt' not in st.session_state:
|
| 35 |
+
st.session_state.use_text_prompt = False
|
| 36 |
+
|
| 37 |
+
if 'use_point_prompt' not in st.session_state:
|
| 38 |
+
st.session_state.use_point_prompt = False
|
| 39 |
+
|
| 40 |
+
if 'use_box_prompt' not in st.session_state:
|
| 41 |
+
st.session_state.use_box_prompt = False
|
| 42 |
+
|
| 43 |
+
if 'rectangle_3Dbox' not in st.session_state:
|
| 44 |
+
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
| 45 |
+
|
| 46 |
+
if 'irregular_box' not in st.session_state:
|
| 47 |
+
st.session_state.irregular_box = False
|
| 48 |
+
|
| 49 |
+
if 'running' not in st.session_state:
|
| 50 |
+
st.session_state.running = False
|
| 51 |
+
|
| 52 |
+
if 'transparency' not in st.session_state:
|
| 53 |
+
st.session_state.transparency = 0.25
|
| 54 |
+
|
| 55 |
+
case_list = [
|
| 56 |
+
'model/asset/FLARE22_Tr_0002_0000.nii.gz',
|
| 57 |
+
'model/asset/FLARE22_Tr_0005_0000.nii.gz',
|
| 58 |
+
'model/asset/FLARE22_Tr_0034_0000.nii.gz',
|
| 59 |
+
'model/asset/FLARE22_Tr_0045_0000.nii.gz'
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
#############################################
|
| 63 |
+
|
| 64 |
+
#############################################
|
| 65 |
+
# reset functions
|
| 66 |
+
def clear_prompts():
|
| 67 |
+
st.session_state.points = []
|
| 68 |
+
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
| 69 |
+
|
| 70 |
+
def reset_demo_case():
|
| 71 |
+
st.session_state.data_item = None
|
| 72 |
+
st.session_state.reset_demo_case = True
|
| 73 |
+
clear_prompts()
|
| 74 |
+
|
| 75 |
+
def clear_file():
|
| 76 |
+
st.session_state.option = None
|
| 77 |
+
process_ct_gt.clear()
|
| 78 |
+
reset_demo_case()
|
| 79 |
+
clear_prompts()
|
| 80 |
+
|
| 81 |
+
#############################################
|
| 82 |
+
|
| 83 |
+
st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
|
| 84 |
+
|
| 85 |
+
github_col, arxive_col = st.columns(2)
|
| 86 |
+
|
| 87 |
+
with github_col:
|
| 88 |
+
st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
|
| 89 |
+
|
| 90 |
+
with arxive_col:
|
| 91 |
+
st.write('Paper:https://arxiv.org/abs/2311.13385')
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# modify demo case here
|
| 95 |
+
demo_type = st.radio(
|
| 96 |
+
"Demo case source",
|
| 97 |
+
["Select", "Upload"],
|
| 98 |
+
on_change=clear_file
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if demo_type=="Select":
|
| 102 |
+
uploaded_file = st.selectbox(
|
| 103 |
+
"Select a demo case",
|
| 104 |
+
case_list,
|
| 105 |
+
index=None,
|
| 106 |
+
placeholder="Select a demo case...",
|
| 107 |
+
on_change=reset_demo_case
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
|
| 111 |
+
|
| 112 |
+
st.session_state.option = uploaded_file
|
| 113 |
+
|
| 114 |
+
if st.session_state.option is not None and \
|
| 115 |
+
st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
|
| 116 |
+
|
| 117 |
+
st.session_state.data_item = process_ct_gt(st.session_state.option)
|
| 118 |
+
st.session_state.reset_demo_case = False
|
| 119 |
+
st.session_state.preds_3D = None
|
| 120 |
+
|
| 121 |
+
prompt_col1, prompt_col2 = st.columns(2)
|
| 122 |
+
|
| 123 |
+
with prompt_col1:
|
| 124 |
+
st.session_state.use_text_prompt = st.toggle('Sematic prompt')
|
| 125 |
+
text_prompt_type = st.radio(
|
| 126 |
+
"Sematic prompt type",
|
| 127 |
+
["Predefined", "Custom"],
|
| 128 |
+
disabled=(not st.session_state.use_text_prompt)
|
| 129 |
+
)
|
| 130 |
+
if text_prompt_type == "Predefined":
|
| 131 |
+
pre_text = st.selectbox(
|
| 132 |
+
"Predefined anatomical category:",
|
| 133 |
+
['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
|
| 134 |
+
index=None,
|
| 135 |
+
disabled=(not st.session_state.use_text_prompt)
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
|
| 139 |
+
disabled=(not st.session_state.use_text_prompt))
|
| 140 |
+
if pre_text is None or len(pre_text) > 0:
|
| 141 |
+
st.session_state.text_prompt = pre_text
|
| 142 |
+
else:
|
| 143 |
+
st.session_state.text_prompt = None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
with prompt_col2:
|
| 147 |
+
spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
|
| 148 |
+
spatial_prompt = st.radio(
|
| 149 |
+
"Spatial prompt type",
|
| 150 |
+
["Point prompt", "Box prompt"],
|
| 151 |
+
on_change=clear_prompts,
|
| 152 |
+
disabled=(not spatial_prompt_on))
|
| 153 |
+
|
| 154 |
+
if spatial_prompt == "Point prompt":
|
| 155 |
+
st.session_state.use_point_prompt = True
|
| 156 |
+
st.session_state.use_box_prompt = False
|
| 157 |
+
elif spatial_prompt == "Box prompt":
|
| 158 |
+
st.session_state.use_box_prompt = True
|
| 159 |
+
st.session_state.use_point_prompt = False
|
| 160 |
+
else:
|
| 161 |
+
st.session_state.use_point_prompt = False
|
| 162 |
+
st.session_state.use_box_prompt = False
|
| 163 |
+
|
| 164 |
+
if not spatial_prompt_on:
|
| 165 |
+
st.session_state.use_point_prompt = False
|
| 166 |
+
st.session_state.use_box_prompt = False
|
| 167 |
+
|
| 168 |
+
if not st.session_state.use_text_prompt:
|
| 169 |
+
st.session_state.text_prompt = None
|
| 170 |
+
|
| 171 |
+
if st.session_state.option is None:
|
| 172 |
+
st.write('please select demo case first')
|
| 173 |
+
else:
|
| 174 |
+
image_3D = st.session_state.data_item['z_image'][0].numpy()
|
| 175 |
+
col_control1, col_control2 = st.columns(2)
|
| 176 |
+
|
| 177 |
+
with col_control1:
|
| 178 |
+
selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
|
| 179 |
+
|
| 180 |
+
with col_control2:
|
| 181 |
+
selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
|
| 182 |
+
if st.session_state.use_box_prompt:
|
| 183 |
+
top, bottom = st.select_slider(
|
| 184 |
+
'Top and bottom of box',
|
| 185 |
+
options=range(0, 325),
|
| 186 |
+
value=(0, 324),
|
| 187 |
+
disabled=st.session_state.running
|
| 188 |
+
)
|
| 189 |
+
st.session_state.rectangle_3Dbox[0] = top
|
| 190 |
+
st.session_state.rectangle_3Dbox[3] = bottom
|
| 191 |
+
col_image1, col_image2 = st.columns(2)
|
| 192 |
+
|
| 193 |
+
if st.session_state.preds_3D is not None:
|
| 194 |
+
st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
|
| 195 |
+
|
| 196 |
+
with col_image1:
|
| 197 |
+
|
| 198 |
+
image_z_array = image_3D[selected_index_z]
|
| 199 |
+
|
| 200 |
+
preds_z_array = None
|
| 201 |
+
if st.session_state.preds_3D is not None:
|
| 202 |
+
preds_z_array = st.session_state.preds_3D[selected_index_z]
|
| 203 |
+
|
| 204 |
+
image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if st.session_state.use_point_prompt:
|
| 208 |
+
value_xy = streamlit_image_coordinates(image_z, width=325)
|
| 209 |
+
|
| 210 |
+
if value_xy is not None:
|
| 211 |
+
point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
|
| 212 |
+
if len(st.session_state.points) >= 3:
|
| 213 |
+
st.warning('Max point num is 3', icon="⚠️")
|
| 214 |
+
elif point_ax_xy not in st.session_state.points:
|
| 215 |
+
st.session_state.points.append(point_ax_xy)
|
| 216 |
+
print('point_ax_xy add rerun')
|
| 217 |
+
st.rerun()
|
| 218 |
+
elif st.session_state.use_box_prompt:
|
| 219 |
+
canvas_result_xy = st_canvas(
|
| 220 |
+
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
| 221 |
+
stroke_width=3,
|
| 222 |
+
stroke_color='#2909F1',
|
| 223 |
+
background_image=image_z,
|
| 224 |
+
update_streamlit=True,
|
| 225 |
+
height=325,
|
| 226 |
+
width=325,
|
| 227 |
+
drawing_mode='transform',
|
| 228 |
+
point_display_radius=0,
|
| 229 |
+
key="canvas_xy",
|
| 230 |
+
initial_drawing=initial_rectangle,
|
| 231 |
+
display_toolbar=True
|
| 232 |
+
)
|
| 233 |
+
try:
|
| 234 |
+
print(canvas_result_xy.json_data['objects'][0]['angle'])
|
| 235 |
+
if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
|
| 236 |
+
st.warning('Rotating is undefined behavior', icon="⚠️")
|
| 237 |
+
st.session_state.irregular_box = True
|
| 238 |
+
else:
|
| 239 |
+
st.session_state.irregular_box = False
|
| 240 |
+
reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
|
| 241 |
+
except:
|
| 242 |
+
print('exception')
|
| 243 |
+
pass
|
| 244 |
+
else:
|
| 245 |
+
st.image(image_z, use_column_width=False)
|
| 246 |
+
|
| 247 |
+
with col_image2:
|
| 248 |
+
image_y_array = image_3D[:, selected_index_y, :]
|
| 249 |
+
|
| 250 |
+
preds_y_array = None
|
| 251 |
+
if st.session_state.preds_3D is not None:
|
| 252 |
+
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
|
| 253 |
+
|
| 254 |
+
image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
|
| 255 |
+
|
| 256 |
+
if st.session_state.use_point_prompt:
|
| 257 |
+
value_yz = streamlit_image_coordinates(image_y, width=325)
|
| 258 |
+
|
| 259 |
+
if value_yz is not None:
|
| 260 |
+
point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
|
| 261 |
+
if len(st.session_state.points) >= 3:
|
| 262 |
+
st.warning('Max point num is 3', icon="⚠️")
|
| 263 |
+
elif point_ax_xz not in st.session_state.points:
|
| 264 |
+
st.session_state.points.append(point_ax_xz)
|
| 265 |
+
print('point_ax_xz add rerun')
|
| 266 |
+
st.rerun()
|
| 267 |
+
elif st.session_state.use_box_prompt:
|
| 268 |
+
if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
|
| 269 |
+
draw = ImageDraw.Draw(image_y)
|
| 270 |
+
#rectangle xz view (upper-left and lower-right)
|
| 271 |
+
rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
|
| 272 |
+
(st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
|
| 273 |
+
# Draw the rectangle on the image
|
| 274 |
+
draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
|
| 275 |
+
st.image(image_y, use_column_width=False)
|
| 276 |
+
else:
|
| 277 |
+
st.image(image_y, use_column_width=False)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
col1, col2, col3 = st.columns(3)
|
| 281 |
+
|
| 282 |
+
with col1:
|
| 283 |
+
if st.button("Clear", use_container_width=True,
|
| 284 |
+
disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
|
| 285 |
+
clear_prompts()
|
| 286 |
+
st.session_state.preds_3D = None
|
| 287 |
+
st.rerun()
|
| 288 |
+
|
| 289 |
+
with col3:
|
| 290 |
+
run_button_name = 'Run'if not st.session_state.running else 'Running'
|
| 291 |
+
if st.button(run_button_name, type="primary", use_container_width=True,
|
| 292 |
+
disabled=(
|
| 293 |
+
st.session_state.data_item is None or
|
| 294 |
+
(st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
|
| 295 |
+
st.session_state.irregular_box or
|
| 296 |
+
st.session_state.running
|
| 297 |
+
)):
|
| 298 |
+
st.session_state.running = True
|
| 299 |
+
st.rerun()
|
| 300 |
+
|
| 301 |
+
# if len(st.session_state.points) > 0:
|
| 302 |
+
# st.write(st.session_state.points)
|
| 303 |
+
|
| 304 |
+
if st.session_state.running:
|
| 305 |
+
st.session_state.running = False
|
| 306 |
+
with st.status("Running...", expanded=False) as status:
|
| 307 |
+
run()
|
| 308 |
+
st.rerun()
|
model/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 BAAI-DCAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
model/README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SegVol: Universal and Interactive Volumetric Medical Image Segmentation
|
| 2 |
+
This repo is the official implementation of [SegVol: Universal and Interactive Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.13385).
|
| 3 |
+
|
| 4 |
+
## News🚀
|
| 5 |
+
(2023.11.24) *You can download weight files of SegVol and ViT(CTs pre-train) [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link).* 🔥
|
| 6 |
+
|
| 7 |
+
(2023.11.23) *The brief introduction and instruction have been uploaded.*
|
| 8 |
+
|
| 9 |
+
(2023.11.23) *The inference demo code has been uploaded.*
|
| 10 |
+
|
| 11 |
+
(2023.11.22) *The first edition of our paper has been uploaded to arXiv.* 📃
|
| 12 |
+
|
| 13 |
+
## Introduction
|
| 14 |
+
<img src="https://github.com/BAAI-DCAI/SegVol/blob/main/asset/overview.png" width="60%" height="60%">
|
| 15 |
+
|
| 16 |
+
The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
|
| 17 |
+
|
| 18 |
+
We will release SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k CTs).
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
### Requirements
|
| 22 |
+
The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or higher virsion) is needed first. Following install key requirements using commands:
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
pip install 'monai[all]==0.9.0'
|
| 26 |
+
pip install einops==0.6.1
|
| 27 |
+
pip install transformers==4.18.0
|
| 28 |
+
pip install matplotlib
|
| 29 |
+
```
|
| 30 |
+
### Config and run demo script
|
| 31 |
+
1. You can download the demo case [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link), or download the whole demo dataset [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
|
| 32 |
+
2. Please set CT path and Ground Truth path of the case in the [config_demo.json](https://github.com/BAAI-DCAI/SegVol/blob/main/config/config_demo.json).
|
| 33 |
+
3. After that, config the [inference_demo.sh](https://github.com/BAAI-DCAI/SegVol/blob/main/script/inference_demo.sh) file for execution:
|
| 34 |
+
|
| 35 |
+
- `$segvol_ckpt`: the path of SegVol's checkpoint (Download from [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link)).
|
| 36 |
+
|
| 37 |
+
- `$work_dir`: any path of folder you want to save the log files and visualizaion results.
|
| 38 |
+
|
| 39 |
+
4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** [here](https://github.com/BAAI-DCAI/SegVol/blob/35f3ff9c943a74f630e6948051a1fe21aaba91bc/inference_demo.py#L208C11-L208C11).
|
| 40 |
+
5. Now, just run `bash script/inference_demo.sh` to infer your demo case.
|
| 41 |
+
|
| 42 |
+
## Citation
|
| 43 |
+
If you find this repository helpful, please consider citing:
|
| 44 |
+
```
|
| 45 |
+
@misc{du2023segvol,
|
| 46 |
+
title={SegVol: Universal and Interactive Volumetric Medical Image Segmentation},
|
| 47 |
+
author={Yuxin Du and Fan Bai and Tiejun Huang and Bo Zhao},
|
| 48 |
+
year={2023},
|
| 49 |
+
eprint={2311.13385},
|
| 50 |
+
archivePrefix={arXiv},
|
| 51 |
+
primaryClass={cs.CV}
|
| 52 |
+
}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Acknowledgement
|
| 56 |
+
Thanks for the following amazing works:
|
| 57 |
+
|
| 58 |
+
[HuggingFace](https://huggingface.co/).
|
| 59 |
+
|
| 60 |
+
[CLIP](https://github.com/openai/CLIP).
|
| 61 |
+
|
| 62 |
+
[MONAI](https://github.com/Project-MONAI/MONAI).
|
| 63 |
+
|
| 64 |
+
[Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.
|
| 65 |
+
|
| 66 |
+
[Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.
|
| 67 |
+
|
| 68 |
+
[Image by pch.vector](https://www.freepik.com/free-vector/different-phone-hand-gestures-set_9649376.htm#query=Vector%20touch%20screen%20hand%20gestures&position=4&from_view=search&track=ais) on Freepik.
|
| 69 |
+
|
| 70 |
+
[Image by starline](https://www.freepik.com/free-vector/set-three-light-bulb-represent-effective-business-idea-concept_37588597.htm#query=idea&position=0&from_view=search&track=sph) on Freepik.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
model/__pycache__/inference_cpu.cpython-39.pyc
ADDED
|
Binary file (4.67 kB). View file
|
|
|
model/asset/FLARE22_Tr_0002_0000.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb16eced003524fa005e28b2822c0b53503f1223d758cdf72528fad359aa10ba
|
| 3 |
+
size 30611274
|
model/asset/FLARE22_Tr_0005_0000.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2be5019bfc7e805d5e24785bcd44ffe7720e13e38b2a3124ad25b454811b221c
|
| 3 |
+
size 26615527
|
model/asset/FLARE22_Tr_0034_0000.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:023c5d06ea2a6c8866c1e214ecee06a4447a8d0c50225142cdfdbbccc2bf8c66
|
| 3 |
+
size 28821917
|
model/asset/FLARE22_Tr_0045_0000.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:336b3719af673fd6fafe89d7d5d95d5f18239a9faccde9753703fc1465f43736
|
| 3 |
+
size 32885093
|
model/asset/model.png
ADDED
|
model/asset/overview back.png
ADDED
|
model/asset/overview.png
ADDED
|
model/config/clip/config.json
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "openai/clip-vit-base-patch32",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPModel"
|
| 5 |
+
],
|
| 6 |
+
"initializer_factor": 1.0,
|
| 7 |
+
"logit_scale_init_value": 2.6592,
|
| 8 |
+
"model_type": "clip",
|
| 9 |
+
"projection_dim": 512,
|
| 10 |
+
"text_config": {
|
| 11 |
+
"_name_or_path": "",
|
| 12 |
+
"add_cross_attention": false,
|
| 13 |
+
"architectures": null,
|
| 14 |
+
"attention_dropout": 0.0,
|
| 15 |
+
"bad_words_ids": null,
|
| 16 |
+
"bos_token_id": 0,
|
| 17 |
+
"chunk_size_feed_forward": 0,
|
| 18 |
+
"cross_attention_hidden_size": null,
|
| 19 |
+
"decoder_start_token_id": null,
|
| 20 |
+
"diversity_penalty": 0.0,
|
| 21 |
+
"do_sample": false,
|
| 22 |
+
"dropout": 0.0,
|
| 23 |
+
"early_stopping": false,
|
| 24 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 25 |
+
"eos_token_id": 2,
|
| 26 |
+
"finetuning_task": null,
|
| 27 |
+
"forced_bos_token_id": null,
|
| 28 |
+
"forced_eos_token_id": null,
|
| 29 |
+
"hidden_act": "quick_gelu",
|
| 30 |
+
"hidden_size": 512,
|
| 31 |
+
"id2label": {
|
| 32 |
+
"0": "LABEL_0",
|
| 33 |
+
"1": "LABEL_1"
|
| 34 |
+
},
|
| 35 |
+
"initializer_factor": 1.0,
|
| 36 |
+
"initializer_range": 0.02,
|
| 37 |
+
"intermediate_size": 2048,
|
| 38 |
+
"is_decoder": false,
|
| 39 |
+
"is_encoder_decoder": false,
|
| 40 |
+
"label2id": {
|
| 41 |
+
"LABEL_0": 0,
|
| 42 |
+
"LABEL_1": 1
|
| 43 |
+
},
|
| 44 |
+
"layer_norm_eps": 1e-05,
|
| 45 |
+
"length_penalty": 1.0,
|
| 46 |
+
"max_length": 20,
|
| 47 |
+
"max_position_embeddings": 77,
|
| 48 |
+
"min_length": 0,
|
| 49 |
+
"model_type": "clip_text_model",
|
| 50 |
+
"no_repeat_ngram_size": 0,
|
| 51 |
+
"num_attention_heads": 8,
|
| 52 |
+
"num_beam_groups": 1,
|
| 53 |
+
"num_beams": 1,
|
| 54 |
+
"num_hidden_layers": 12,
|
| 55 |
+
"num_return_sequences": 1,
|
| 56 |
+
"output_attentions": false,
|
| 57 |
+
"output_hidden_states": false,
|
| 58 |
+
"output_scores": false,
|
| 59 |
+
"pad_token_id": 1,
|
| 60 |
+
"prefix": null,
|
| 61 |
+
"projection_dim": 512,
|
| 62 |
+
"problem_type": null,
|
| 63 |
+
"pruned_heads": {},
|
| 64 |
+
"remove_invalid_values": false,
|
| 65 |
+
"repetition_penalty": 1.0,
|
| 66 |
+
"return_dict": true,
|
| 67 |
+
"return_dict_in_generate": false,
|
| 68 |
+
"sep_token_id": null,
|
| 69 |
+
"task_specific_params": null,
|
| 70 |
+
"temperature": 1.0,
|
| 71 |
+
"tie_encoder_decoder": false,
|
| 72 |
+
"tie_word_embeddings": true,
|
| 73 |
+
"tokenizer_class": null,
|
| 74 |
+
"top_k": 50,
|
| 75 |
+
"top_p": 1.0,
|
| 76 |
+
"torch_dtype": null,
|
| 77 |
+
"torchscript": false,
|
| 78 |
+
"transformers_version": "4.16.0.dev0",
|
| 79 |
+
"use_bfloat16": false,
|
| 80 |
+
"vocab_size": 49408
|
| 81 |
+
},
|
| 82 |
+
"text_config_dict": null,
|
| 83 |
+
"transformers_version": null,
|
| 84 |
+
"vision_config": {
|
| 85 |
+
"_name_or_path": "",
|
| 86 |
+
"add_cross_attention": false,
|
| 87 |
+
"architectures": null,
|
| 88 |
+
"attention_dropout": 0.0,
|
| 89 |
+
"bad_words_ids": null,
|
| 90 |
+
"bos_token_id": null,
|
| 91 |
+
"chunk_size_feed_forward": 0,
|
| 92 |
+
"cross_attention_hidden_size": null,
|
| 93 |
+
"decoder_start_token_id": null,
|
| 94 |
+
"diversity_penalty": 0.0,
|
| 95 |
+
"do_sample": false,
|
| 96 |
+
"dropout": 0.0,
|
| 97 |
+
"early_stopping": false,
|
| 98 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 99 |
+
"eos_token_id": null,
|
| 100 |
+
"finetuning_task": null,
|
| 101 |
+
"forced_bos_token_id": null,
|
| 102 |
+
"forced_eos_token_id": null,
|
| 103 |
+
"hidden_act": "quick_gelu",
|
| 104 |
+
"hidden_size": 768,
|
| 105 |
+
"id2label": {
|
| 106 |
+
"0": "LABEL_0",
|
| 107 |
+
"1": "LABEL_1"
|
| 108 |
+
},
|
| 109 |
+
"image_size": 224,
|
| 110 |
+
"initializer_factor": 1.0,
|
| 111 |
+
"initializer_range": 0.02,
|
| 112 |
+
"intermediate_size": 3072,
|
| 113 |
+
"is_decoder": false,
|
| 114 |
+
"is_encoder_decoder": false,
|
| 115 |
+
"label2id": {
|
| 116 |
+
"LABEL_0": 0,
|
| 117 |
+
"LABEL_1": 1
|
| 118 |
+
},
|
| 119 |
+
"layer_norm_eps": 1e-05,
|
| 120 |
+
"length_penalty": 1.0,
|
| 121 |
+
"max_length": 20,
|
| 122 |
+
"min_length": 0,
|
| 123 |
+
"model_type": "clip_vision_model",
|
| 124 |
+
"no_repeat_ngram_size": 0,
|
| 125 |
+
"num_attention_heads": 12,
|
| 126 |
+
"num_beam_groups": 1,
|
| 127 |
+
"num_beams": 1,
|
| 128 |
+
"num_hidden_layers": 12,
|
| 129 |
+
"num_return_sequences": 1,
|
| 130 |
+
"output_attentions": false,
|
| 131 |
+
"output_hidden_states": false,
|
| 132 |
+
"output_scores": false,
|
| 133 |
+
"pad_token_id": null,
|
| 134 |
+
"patch_size": 32,
|
| 135 |
+
"prefix": null,
|
| 136 |
+
"projection_dim" : 512,
|
| 137 |
+
"problem_type": null,
|
| 138 |
+
"pruned_heads": {},
|
| 139 |
+
"remove_invalid_values": false,
|
| 140 |
+
"repetition_penalty": 1.0,
|
| 141 |
+
"return_dict": true,
|
| 142 |
+
"return_dict_in_generate": false,
|
| 143 |
+
"sep_token_id": null,
|
| 144 |
+
"task_specific_params": null,
|
| 145 |
+
"temperature": 1.0,
|
| 146 |
+
"tie_encoder_decoder": false,
|
| 147 |
+
"tie_word_embeddings": true,
|
| 148 |
+
"tokenizer_class": null,
|
| 149 |
+
"top_k": 50,
|
| 150 |
+
"top_p": 1.0,
|
| 151 |
+
"torch_dtype": null,
|
| 152 |
+
"torchscript": false,
|
| 153 |
+
"transformers_version": "4.16.0.dev0",
|
| 154 |
+
"use_bfloat16": false
|
| 155 |
+
},
|
| 156 |
+
"vision_config_dict": null
|
| 157 |
+
}
|
model/config/clip/special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
model/config/clip/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/config/clip/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
|
model/config/clip/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/config/config_demo.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset_name": "AbdomenCT-1k",
|
| 3 |
+
"categories": ["liver", "kidney", "spleen", "pancreas"],
|
| 4 |
+
"demo_case": {
|
| 5 |
+
"ct_path": "path/to/Case_image",
|
| 6 |
+
"gt_path": "path/to/Case_label"
|
| 7 |
+
}
|
| 8 |
+
}
|
model/data_process/__pycache__/demo_data_process.cpython-39.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
model/data_process/demo_data_process.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import monai.transforms as transforms
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
class MinMaxNormalization(transforms.Transform):
|
| 7 |
+
def __call__(self, data):
|
| 8 |
+
d = dict(data)
|
| 9 |
+
k = "image"
|
| 10 |
+
d[k] = d[k] - d[k].min()
|
| 11 |
+
d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
|
| 12 |
+
return d
|
| 13 |
+
|
| 14 |
+
class DimTranspose(transforms.Transform):
|
| 15 |
+
def __init__(self, keys):
|
| 16 |
+
self.keys = keys
|
| 17 |
+
|
| 18 |
+
def __call__(self, data):
|
| 19 |
+
d = dict(data)
|
| 20 |
+
for key in self.keys:
|
| 21 |
+
d[key] = np.swapaxes(d[key], -1, -3)
|
| 22 |
+
return d
|
| 23 |
+
|
| 24 |
+
class ForegroundNormalization(transforms.Transform):
|
| 25 |
+
def __init__(self, keys):
|
| 26 |
+
self.keys = keys
|
| 27 |
+
|
| 28 |
+
def __call__(self, data):
|
| 29 |
+
d = dict(data)
|
| 30 |
+
|
| 31 |
+
for key in self.keys:
|
| 32 |
+
d[key] = self.normalize(d[key])
|
| 33 |
+
return d
|
| 34 |
+
|
| 35 |
+
def normalize(self, ct_narray):
|
| 36 |
+
ct_voxel_ndarray = ct_narray.copy()
|
| 37 |
+
ct_voxel_ndarray = ct_voxel_ndarray.flatten()
|
| 38 |
+
thred = np.mean(ct_voxel_ndarray)
|
| 39 |
+
voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
|
| 40 |
+
upper_bound = np.percentile(voxel_filtered, 99.95)
|
| 41 |
+
lower_bound = np.percentile(voxel_filtered, 00.05)
|
| 42 |
+
mean = np.mean(voxel_filtered)
|
| 43 |
+
std = np.std(voxel_filtered)
|
| 44 |
+
### transform ###
|
| 45 |
+
ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
|
| 46 |
+
ct_narray = (ct_narray - mean) / max(std, 1e-8)
|
| 47 |
+
return ct_narray
|
| 48 |
+
|
| 49 |
+
@st.cache_data
|
| 50 |
+
def process_ct_gt(case_path, spatial_size=(32,256,256)):
|
| 51 |
+
if case_path is None:
|
| 52 |
+
return None
|
| 53 |
+
print('Data preprocessing...')
|
| 54 |
+
# transform
|
| 55 |
+
img_loader = transforms.LoadImage(dtype=np.float32)
|
| 56 |
+
transform = transforms.Compose(
|
| 57 |
+
[
|
| 58 |
+
transforms.Orientationd(keys=["image"], axcodes="RAS"),
|
| 59 |
+
ForegroundNormalization(keys=["image"]),
|
| 60 |
+
DimTranspose(keys=["image"]),
|
| 61 |
+
MinMaxNormalization(),
|
| 62 |
+
transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
|
| 63 |
+
transforms.CropForegroundd(keys=["image"], source_key="image"),
|
| 64 |
+
transforms.ToTensord(keys=["image"]),
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
|
| 68 |
+
z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
|
| 69 |
+
###
|
| 70 |
+
item = {}
|
| 71 |
+
# generate ct_voxel_ndarray
|
| 72 |
+
if type(case_path) is str:
|
| 73 |
+
ct_voxel_ndarray, _ = img_loader(case_path)
|
| 74 |
+
else:
|
| 75 |
+
bytes_data = case_path.read()
|
| 76 |
+
with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
|
| 77 |
+
tmp.write(bytes_data)
|
| 78 |
+
tmp.seek(0)
|
| 79 |
+
ct_voxel_ndarray, _ = img_loader(tmp.name)
|
| 80 |
+
ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
|
| 81 |
+
ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
|
| 82 |
+
item['image'] = ct_voxel_ndarray
|
| 83 |
+
|
| 84 |
+
# transform
|
| 85 |
+
item = transform(item)
|
| 86 |
+
item_zoom_out = zoom_out_transform(item)
|
| 87 |
+
item['zoom_out_image'] = item_zoom_out['image']
|
| 88 |
+
|
| 89 |
+
item_z = z_transform(item)
|
| 90 |
+
item['z_image'] = item_z['image']
|
| 91 |
+
return item
|
model/inference_cpu.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import json
|
| 6 |
+
import monai.transforms as transforms
|
| 7 |
+
|
| 8 |
+
from model.segment_anything_volumetric import sam_model_registry
|
| 9 |
+
from model.network.model import SegVol
|
| 10 |
+
from model.data_process.demo_data_process import process_ct_gt
|
| 11 |
+
from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
| 12 |
+
from model.utils.visualize import draw_result
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
def set_parse():
|
| 16 |
+
# %% set up parser
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument("--test_mode", default=True, type=bool)
|
| 19 |
+
parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth')
|
| 20 |
+
parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
|
| 21 |
+
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
| 22 |
+
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
| 23 |
+
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
| 24 |
+
### demo
|
| 25 |
+
parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
return args
|
| 28 |
+
|
| 29 |
+
def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
|
| 30 |
+
image_single_resize = image_resize
|
| 31 |
+
image_single = image[0,0]
|
| 32 |
+
ori_shape = image_single.shape
|
| 33 |
+
resize_shape = image_single_resize.shape[2:]
|
| 34 |
+
|
| 35 |
+
# generate prompts
|
| 36 |
+
text_single = None if text_prompt is None else [text_prompt]
|
| 37 |
+
points_single = None
|
| 38 |
+
box_single = None
|
| 39 |
+
|
| 40 |
+
if args.use_point_prompt:
|
| 41 |
+
point, point_label = point_prompt
|
| 42 |
+
points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
|
| 43 |
+
binary_points_resize = build_binary_points(point, point_label, resize_shape)
|
| 44 |
+
if args.use_box_prompt:
|
| 45 |
+
box_single = box_prompt.unsqueeze(0).float()
|
| 46 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
|
| 47 |
+
|
| 48 |
+
####################
|
| 49 |
+
# zoom-out inference:
|
| 50 |
+
print('--- zoom out inference ---')
|
| 51 |
+
print(text_single)
|
| 52 |
+
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
logits_global_single = segvol_model(image_single_resize,
|
| 55 |
+
text=text_single,
|
| 56 |
+
boxes=box_single,
|
| 57 |
+
points=points_single)
|
| 58 |
+
|
| 59 |
+
# resize back global logits
|
| 60 |
+
logits_global_single = F.interpolate(
|
| 61 |
+
logits_global_single.cpu(),
|
| 62 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 63 |
+
|
| 64 |
+
# build prompt reflection for zoom-in
|
| 65 |
+
if args.use_point_prompt:
|
| 66 |
+
binary_points = F.interpolate(
|
| 67 |
+
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 68 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 69 |
+
if args.use_box_prompt:
|
| 70 |
+
binary_cube = F.interpolate(
|
| 71 |
+
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 72 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 73 |
+
# draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
|
| 74 |
+
if not args.use_zoom_in:
|
| 75 |
+
return logits_global_single
|
| 76 |
+
|
| 77 |
+
####################
|
| 78 |
+
# zoom-in inference:
|
| 79 |
+
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
| 80 |
+
if min_d is None:
|
| 81 |
+
print('Fail to detect foreground!')
|
| 82 |
+
return logits_global_single
|
| 83 |
+
|
| 84 |
+
# Crop roi
|
| 85 |
+
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 86 |
+
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
| 87 |
+
|
| 88 |
+
assert not (args.use_box_prompt and args.use_point_prompt)
|
| 89 |
+
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 90 |
+
prompt_reflection = None
|
| 91 |
+
if args.use_box_prompt:
|
| 92 |
+
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 93 |
+
prompt_reflection = (
|
| 94 |
+
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
| 95 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
| 96 |
+
)
|
| 97 |
+
if args.use_point_prompt:
|
| 98 |
+
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 99 |
+
prompt_reflection = (
|
| 100 |
+
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
| 101 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
## inference
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
logits_single_cropped = sliding_window_inference(
|
| 107 |
+
image_single_cropped, prompt_reflection,
|
| 108 |
+
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
| 109 |
+
text=text_single,
|
| 110 |
+
use_box=args.use_box_prompt,
|
| 111 |
+
use_point=args.use_point_prompt,
|
| 112 |
+
logits_global_single=logits_global_single,
|
| 113 |
+
)
|
| 114 |
+
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
| 115 |
+
if logits_single_cropped.shape != logits_global_single.shape:
|
| 116 |
+
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
| 117 |
+
|
| 118 |
+
return logits_global_single
|
| 119 |
+
|
| 120 |
+
@st.cache_resource
|
| 121 |
+
def build_model():
|
| 122 |
+
# build model
|
| 123 |
+
st.write('building model')
|
| 124 |
+
clip_ckpt = 'model/config/clip'
|
| 125 |
+
resume = 'SegVol_v1.pth'
|
| 126 |
+
sam_model = sam_model_registry['vit']()
|
| 127 |
+
segvol_model = SegVol(
|
| 128 |
+
image_encoder=sam_model.image_encoder,
|
| 129 |
+
mask_decoder=sam_model.mask_decoder,
|
| 130 |
+
prompt_encoder=sam_model.prompt_encoder,
|
| 131 |
+
clip_ckpt=clip_ckpt,
|
| 132 |
+
roi_size=(32,256,256),
|
| 133 |
+
patch_size=(4,16,16),
|
| 134 |
+
test_mode=True,
|
| 135 |
+
)
|
| 136 |
+
segvol_model = torch.nn.DataParallel(segvol_model)
|
| 137 |
+
segvol_model.eval()
|
| 138 |
+
# load param
|
| 139 |
+
if os.path.isfile(resume):
|
| 140 |
+
## Map model to be loaded to specified single GPU
|
| 141 |
+
loc = 'cpu'
|
| 142 |
+
checkpoint = torch.load(resume, map_location=loc)
|
| 143 |
+
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
| 144 |
+
print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
|
| 145 |
+
print('model build done!')
|
| 146 |
+
return segvol_model
|
| 147 |
+
|
| 148 |
+
@st.cache_data
|
| 149 |
+
def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
|
| 150 |
+
# seg config
|
| 151 |
+
args = set_parse()
|
| 152 |
+
args.use_zoom_in = True
|
| 153 |
+
args.use_text_prompt = text_prompt is not None
|
| 154 |
+
args.use_box_prompt = _box_prompt is not None
|
| 155 |
+
args.use_point_prompt = _point_prompt is not None
|
| 156 |
+
|
| 157 |
+
segvol_model = build_model()
|
| 158 |
+
|
| 159 |
+
# run inference
|
| 160 |
+
logits = zoom_in_zoom_out(
|
| 161 |
+
args, segvol_model,
|
| 162 |
+
_image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
|
| 163 |
+
text_prompt, _point_prompt, _box_prompt)
|
| 164 |
+
print(logits.shape)
|
| 165 |
+
resize_transform = transforms.Compose([
|
| 166 |
+
transforms.AddChannel(),
|
| 167 |
+
transforms.Resize((325,325,325), mode='trilinear')
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
logits = resize_transform(logits)[0]
|
| 171 |
+
print(logits.shape)
|
| 172 |
+
return (torch.sigmoid(logits) > 0.5).int().numpy()
|
| 173 |
+
|
model/inference_demo.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import json
|
| 6 |
+
from segment_anything_volumetric import sam_model_registry
|
| 7 |
+
from network.model import SegVol
|
| 8 |
+
from data_process.demo_data_process import process_ct_gt
|
| 9 |
+
import monai.transforms as transforms
|
| 10 |
+
from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
| 11 |
+
from utils.visualize import draw_result
|
| 12 |
+
|
| 13 |
+
def set_parse():
|
| 14 |
+
# %% set up parser
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument("--test_mode", default=True, type=bool)
|
| 17 |
+
parser.add_argument("--resume", type = str, default = '')
|
| 18 |
+
parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
|
| 19 |
+
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
| 20 |
+
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
| 21 |
+
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
| 22 |
+
### demo
|
| 23 |
+
parser.add_argument('--demo_config', type=str, required=True)
|
| 24 |
+
parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
return args
|
| 27 |
+
|
| 28 |
+
def dice_score(preds, labels): # on GPU
|
| 29 |
+
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
| 30 |
+
predict = preds.view(1, -1)
|
| 31 |
+
target = labels.view(1, -1)
|
| 32 |
+
if target.shape[1] < 1e8:
|
| 33 |
+
predict = predict.cuda()
|
| 34 |
+
target = target.cuda()
|
| 35 |
+
predict = torch.sigmoid(predict)
|
| 36 |
+
predict = torch.where(predict > 0.5, 1., 0.)
|
| 37 |
+
|
| 38 |
+
tp = torch.sum(torch.mul(predict, target))
|
| 39 |
+
den = torch.sum(predict) + torch.sum(target) + 1
|
| 40 |
+
dice = 2 * tp / den
|
| 41 |
+
|
| 42 |
+
if target.shape[1] < 1e8:
|
| 43 |
+
predict = predict.cpu()
|
| 44 |
+
target = target.cpu()
|
| 45 |
+
return dice
|
| 46 |
+
|
| 47 |
+
def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
|
| 48 |
+
logits_labels_record = {}
|
| 49 |
+
image_single_resize = image_resize
|
| 50 |
+
image_single = image[0,0]
|
| 51 |
+
ori_shape = image_single.shape
|
| 52 |
+
for item_idx in range(len(categories)):
|
| 53 |
+
# get label to generate prompts
|
| 54 |
+
label_single = gt3D[0][item_idx]
|
| 55 |
+
label_single_resize = gt3D_resize[0][item_idx]
|
| 56 |
+
# skip meaningless categories
|
| 57 |
+
if torch.sum(label_single) == 0:
|
| 58 |
+
print('No object, skip')
|
| 59 |
+
continue
|
| 60 |
+
# generate prompts
|
| 61 |
+
text_single = categories[item_idx] if args.use_text_prompt else None
|
| 62 |
+
if categories is not None: print(f'inference |{categories[item_idx]}| target...')
|
| 63 |
+
points_single = None
|
| 64 |
+
box_single = None
|
| 65 |
+
if args.use_point_prompt:
|
| 66 |
+
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
| 67 |
+
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
| 68 |
+
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
|
| 69 |
+
if args.use_box_prompt:
|
| 70 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
| 71 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
|
| 72 |
+
|
| 73 |
+
####################
|
| 74 |
+
# zoom-out inference:
|
| 75 |
+
print('--- zoom out inference ---')
|
| 76 |
+
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
logits_global_single = segvol_model(image_single_resize.cuda(),
|
| 79 |
+
text=text_single,
|
| 80 |
+
boxes=box_single,
|
| 81 |
+
points=points_single)
|
| 82 |
+
|
| 83 |
+
# resize back global logits
|
| 84 |
+
logits_global_single = F.interpolate(
|
| 85 |
+
logits_global_single.cpu(),
|
| 86 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 87 |
+
|
| 88 |
+
# build prompt reflection for zoom-in
|
| 89 |
+
if args.use_point_prompt:
|
| 90 |
+
binary_points = F.interpolate(
|
| 91 |
+
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 92 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 93 |
+
if args.use_box_prompt:
|
| 94 |
+
binary_cube = F.interpolate(
|
| 95 |
+
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 96 |
+
size=ori_shape, mode='nearest')[0][0]
|
| 97 |
+
zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
| 98 |
+
logits_labels_record[categories[item_idx]] = (
|
| 99 |
+
zoom_out_dice,
|
| 100 |
+
image_single,
|
| 101 |
+
points_single,
|
| 102 |
+
box_single,
|
| 103 |
+
logits_global_single,
|
| 104 |
+
label_single)
|
| 105 |
+
print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
|
| 106 |
+
if not args.use_zoom_in:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
####################
|
| 110 |
+
# zoom-in inference:
|
| 111 |
+
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
| 112 |
+
if min_d is None:
|
| 113 |
+
print('Fail to detect foreground!')
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
# Crop roi
|
| 117 |
+
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 118 |
+
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
| 119 |
+
|
| 120 |
+
assert not (args.use_box_prompt and args.use_point_prompt)
|
| 121 |
+
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 122 |
+
prompt_reflection = None
|
| 123 |
+
if args.use_box_prompt:
|
| 124 |
+
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 125 |
+
prompt_reflection = (
|
| 126 |
+
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
| 127 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
| 128 |
+
)
|
| 129 |
+
if args.use_point_prompt:
|
| 130 |
+
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 131 |
+
prompt_reflection = (
|
| 132 |
+
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
| 133 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
## inference
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
logits_single_cropped = sliding_window_inference(
|
| 139 |
+
image_single_cropped.cuda(), prompt_reflection,
|
| 140 |
+
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
| 141 |
+
text=text_single,
|
| 142 |
+
use_box=args.use_box_prompt,
|
| 143 |
+
use_point=args.use_point_prompt,
|
| 144 |
+
)
|
| 145 |
+
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
| 146 |
+
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
| 147 |
+
zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
| 148 |
+
logits_labels_record[categories[item_idx]] = (
|
| 149 |
+
zoom_in_dice,
|
| 150 |
+
image_single,
|
| 151 |
+
points_single,
|
| 152 |
+
box_single,
|
| 153 |
+
logits_global_single,
|
| 154 |
+
label_single)
|
| 155 |
+
print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
|
| 156 |
+
return logits_labels_record
|
| 157 |
+
|
| 158 |
+
def inference_single_ct(args, segvol_model, data_item, categories):
|
| 159 |
+
segvol_model.eval()
|
| 160 |
+
image, gt3D = data_item["image"].float(), data_item["label"]
|
| 161 |
+
image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
|
| 162 |
+
|
| 163 |
+
logits_labels_record = zoom_in_zoom_out(
|
| 164 |
+
args, segvol_model,
|
| 165 |
+
image.unsqueeze(0), image_zoom_out.unsqueeze(0),
|
| 166 |
+
gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
|
| 167 |
+
categories=categories)
|
| 168 |
+
|
| 169 |
+
# visualize
|
| 170 |
+
if args.visualize:
|
| 171 |
+
for target, values in logits_labels_record.items():
|
| 172 |
+
dice_score, image, point_prompt, box_prompt, logits, labels = values
|
| 173 |
+
print(f'{target} result with Dice score {dice_score:.4f} visualizing')
|
| 174 |
+
draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
|
| 175 |
+
|
| 176 |
+
def main(args):
|
| 177 |
+
gpu = 0
|
| 178 |
+
torch.cuda.set_device(gpu)
|
| 179 |
+
# build model
|
| 180 |
+
sam_model = sam_model_registry['vit'](args=args)
|
| 181 |
+
segvol_model = SegVol(
|
| 182 |
+
image_encoder=sam_model.image_encoder,
|
| 183 |
+
mask_decoder=sam_model.mask_decoder,
|
| 184 |
+
prompt_encoder=sam_model.prompt_encoder,
|
| 185 |
+
clip_ckpt=args.clip_ckpt,
|
| 186 |
+
roi_size=args.spatial_size,
|
| 187 |
+
patch_size=args.patch_size,
|
| 188 |
+
test_mode=args.test_mode,
|
| 189 |
+
).cuda()
|
| 190 |
+
segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
|
| 191 |
+
|
| 192 |
+
# load param
|
| 193 |
+
if os.path.isfile(args.resume):
|
| 194 |
+
## Map model to be loaded to specified single GPU
|
| 195 |
+
loc = 'cuda:{}'.format(gpu)
|
| 196 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
| 197 |
+
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
| 198 |
+
print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
| 199 |
+
|
| 200 |
+
# load demo config
|
| 201 |
+
with open(args.demo_config, 'r') as file:
|
| 202 |
+
config_dict = json.load(file)
|
| 203 |
+
ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
|
| 204 |
+
|
| 205 |
+
# preprocess for data
|
| 206 |
+
data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
|
| 207 |
+
|
| 208 |
+
# seg config for prompt & zoom-in-zoom-out
|
| 209 |
+
args.use_zoom_in = True
|
| 210 |
+
args.use_text_prompt = True
|
| 211 |
+
args.use_box_prompt = True
|
| 212 |
+
args.use_point_prompt = False
|
| 213 |
+
args.visualize = False
|
| 214 |
+
|
| 215 |
+
inference_single_ct(args, segvol_model, data_item, categories)
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
args = set_parse()
|
| 219 |
+
main(args)
|
model/network/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
model/network/model.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
|
| 6 |
+
|
| 7 |
+
#%% set up model
|
| 8 |
+
class SegVol(nn.Module):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
image_encoder,
|
| 11 |
+
mask_decoder,
|
| 12 |
+
prompt_encoder,
|
| 13 |
+
clip_ckpt,
|
| 14 |
+
roi_size,
|
| 15 |
+
patch_size,
|
| 16 |
+
test_mode=False,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.image_encoder = image_encoder
|
| 20 |
+
self.mask_decoder = mask_decoder
|
| 21 |
+
self.prompt_encoder = prompt_encoder
|
| 22 |
+
self.text_encoder = TextEncoder(clip_ckpt)
|
| 23 |
+
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 24 |
+
self.test_mode = test_mode
|
| 25 |
+
|
| 26 |
+
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
| 27 |
+
bs = image.shape[0]
|
| 28 |
+
img_shape = (image.shape[2], image.shape[3], image.shape[4])
|
| 29 |
+
image_embedding, _ = self.image_encoder(image)
|
| 30 |
+
image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
|
| 31 |
+
int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
|
| 32 |
+
# test mode
|
| 33 |
+
if self.test_mode:
|
| 34 |
+
return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
|
| 35 |
+
# train mode
|
| 36 |
+
# future release
|
| 37 |
+
|
| 38 |
+
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
if boxes is not None:
|
| 41 |
+
if len(boxes.shape) == 2:
|
| 42 |
+
boxes = boxes[:, None, :] # (B, 1, 6)
|
| 43 |
+
if text is not None:
|
| 44 |
+
text_embedding = self.text_encoder(text) # (B, 768)
|
| 45 |
+
else:
|
| 46 |
+
text_embedding = None
|
| 47 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 48 |
+
points=points,
|
| 49 |
+
boxes=boxes,
|
| 50 |
+
masks=None,
|
| 51 |
+
text_embedding=text_embedding,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
dense_pe = self.prompt_encoder.get_dense_pe()
|
| 55 |
+
low_res_masks, _ = self.mask_decoder(
|
| 56 |
+
image_embeddings=image_embedding,
|
| 57 |
+
text_embedding = text_embedding,
|
| 58 |
+
image_pe=dense_pe,
|
| 59 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 60 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 61 |
+
multimask_output=False,
|
| 62 |
+
)
|
| 63 |
+
logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
|
| 64 |
+
return logits
|
| 65 |
+
|
| 66 |
+
class TextEncoder(nn.Module):
|
| 67 |
+
def __init__(self, clip_ckpt):
|
| 68 |
+
super().__init__()
|
| 69 |
+
config = CLIPTextConfig()
|
| 70 |
+
self.clip_text_model = CLIPTextModel(config)
|
| 71 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
|
| 72 |
+
self.dim_align = nn.Linear(512, 768)
|
| 73 |
+
# freeze text encoder
|
| 74 |
+
for param in self.clip_text_model.parameters():
|
| 75 |
+
param.requires_grad = False
|
| 76 |
+
|
| 77 |
+
def organ2tokens(self, organ_names):
|
| 78 |
+
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
| 79 |
+
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
| 80 |
+
return tokens
|
| 81 |
+
|
| 82 |
+
def forward(self, text):
|
| 83 |
+
if text is None:
|
| 84 |
+
return None
|
| 85 |
+
if type(text) is str:
|
| 86 |
+
text = [text]
|
| 87 |
+
tokens = self.organ2tokens(text)
|
| 88 |
+
clip_outputs = self.clip_text_model(**tokens)
|
| 89 |
+
text_embedding = clip_outputs.pooler_output
|
| 90 |
+
text_embedding = self.dim_align(text_embedding)
|
| 91 |
+
return text_embedding
|
model/script/inference_demo.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export segvol_ckpt="path/to/SegVol_v1.pth"
|
| 2 |
+
export work_dir="path/to/work_dir"
|
| 3 |
+
export demo_config_path="./config/config_demo.json"
|
| 4 |
+
|
| 5 |
+
CUDA_VISIBLE_DEVICES=0 python inference_demo.py \
|
| 6 |
+
--resume $segvol_ckpt \
|
| 7 |
+
-work_dir $work_dir \
|
| 8 |
+
--demo_config $demo_config_path
|
model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
from functools import partial
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .modeling import (
|
| 12 |
+
ImageEncoderViT,
|
| 13 |
+
MaskDecoder,
|
| 14 |
+
PromptEncoder,
|
| 15 |
+
Sam,
|
| 16 |
+
TwoWayTransformer,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from .modeling.image_encoder_swin import SwinTransformer
|
| 20 |
+
|
| 21 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
| 22 |
+
|
| 23 |
+
def build_sam_vit_h(checkpoint=None, image_size=1024):
|
| 24 |
+
return _build_sam(
|
| 25 |
+
encoder_embed_dim=1280,
|
| 26 |
+
encoder_depth=32,
|
| 27 |
+
encoder_num_heads=16,
|
| 28 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
| 29 |
+
checkpoint=checkpoint,
|
| 30 |
+
image_size=image_size,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
build_sam = build_sam_vit_h
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_sam_vit_l(checkpoint=None, image_size=1024):
|
| 38 |
+
return _build_sam(
|
| 39 |
+
encoder_embed_dim=1024,
|
| 40 |
+
encoder_depth=24,
|
| 41 |
+
encoder_num_heads=16,
|
| 42 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
| 43 |
+
checkpoint=checkpoint,
|
| 44 |
+
image_size=image_size,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def build_sam_vit_b(checkpoint=None, image_size=1024):
|
| 49 |
+
return _build_sam(
|
| 50 |
+
encoder_embed_dim=768,
|
| 51 |
+
encoder_depth=12,
|
| 52 |
+
encoder_num_heads=12,
|
| 53 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 54 |
+
checkpoint=checkpoint,
|
| 55 |
+
image_size=image_size,
|
| 56 |
+
)
|
| 57 |
+
"""
|
| 58 |
+
Examples::
|
| 59 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
| 60 |
+
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
| 61 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
| 62 |
+
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
| 63 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
| 64 |
+
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def build_sam_vit_swin(checkpoint=None, image_size=96):
|
| 68 |
+
print('==> build_sam_vit_swin')
|
| 69 |
+
return _build_sam(
|
| 70 |
+
encoder_embed_dim=48,
|
| 71 |
+
encoder_depth=12,
|
| 72 |
+
encoder_num_heads=12,
|
| 73 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 74 |
+
checkpoint=checkpoint,
|
| 75 |
+
image_size=image_size,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
sam_model_registry = {
|
| 79 |
+
"default": build_sam_vit_h,
|
| 80 |
+
"vit_h": build_sam_vit_h,
|
| 81 |
+
"vit_l": build_sam_vit_l,
|
| 82 |
+
"vit_b": build_sam_vit_b,
|
| 83 |
+
"swin_vit": build_sam_vit_swin,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _build_sam(
|
| 88 |
+
encoder_embed_dim,
|
| 89 |
+
encoder_depth,
|
| 90 |
+
encoder_num_heads,
|
| 91 |
+
encoder_global_attn_indexes,
|
| 92 |
+
checkpoint=None,
|
| 93 |
+
image_size=None,
|
| 94 |
+
spatial_dims=3,
|
| 95 |
+
):
|
| 96 |
+
prompt_embed_dim = 768
|
| 97 |
+
patch_size = ensure_tuple_rep(2, spatial_dims)
|
| 98 |
+
window_size = ensure_tuple_rep(7, spatial_dims)
|
| 99 |
+
image_embedding_size = [size // 32 for size in image_size]
|
| 100 |
+
sam = Sam(
|
| 101 |
+
image_encoder=SwinTransformer(
|
| 102 |
+
in_chans=1,
|
| 103 |
+
embed_dim=encoder_embed_dim,
|
| 104 |
+
window_size=window_size,
|
| 105 |
+
patch_size=patch_size,
|
| 106 |
+
depths=(2, 2, 6, 2), #(2, 2, 6, 2),
|
| 107 |
+
num_heads=(3, 6, 12, 24),
|
| 108 |
+
mlp_ratio=4.0,
|
| 109 |
+
qkv_bias=True,
|
| 110 |
+
spatial_dims=spatial_dims,
|
| 111 |
+
),
|
| 112 |
+
prompt_encoder=PromptEncoder(
|
| 113 |
+
embed_dim=prompt_embed_dim,
|
| 114 |
+
image_embedding_size=image_embedding_size,
|
| 115 |
+
input_image_size=image_size,
|
| 116 |
+
mask_in_chans=16,
|
| 117 |
+
),
|
| 118 |
+
mask_decoder=MaskDecoder(
|
| 119 |
+
num_multimask_outputs=3,
|
| 120 |
+
transformer=TwoWayTransformer(
|
| 121 |
+
depth=2,
|
| 122 |
+
embedding_dim=prompt_embed_dim,
|
| 123 |
+
mlp_dim=2048,
|
| 124 |
+
num_heads=8,
|
| 125 |
+
),
|
| 126 |
+
transformer_dim=prompt_embed_dim,
|
| 127 |
+
iou_head_depth=3,
|
| 128 |
+
iou_head_hidden_dim=256,
|
| 129 |
+
),
|
| 130 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
| 131 |
+
pixel_std=[58.395, 57.12, 57.375],
|
| 132 |
+
)
|
| 133 |
+
sam.eval()
|
| 134 |
+
if checkpoint is not None:
|
| 135 |
+
checkpoint = Path(checkpoint)
|
| 136 |
+
if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
|
| 137 |
+
cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
|
| 138 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 139 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
print("Downloading SAM ViT-B checkpoint...")
|
| 141 |
+
urllib.request.urlretrieve(
|
| 142 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
| 143 |
+
checkpoint,
|
| 144 |
+
)
|
| 145 |
+
print(checkpoint.name, " is downloaded!")
|
| 146 |
+
elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
|
| 147 |
+
cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
|
| 148 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 149 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 150 |
+
print("Downloading SAM ViT-H checkpoint...")
|
| 151 |
+
urllib.request.urlretrieve(
|
| 152 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 153 |
+
checkpoint,
|
| 154 |
+
)
|
| 155 |
+
print(checkpoint.name, " is downloaded!")
|
| 156 |
+
elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
|
| 157 |
+
cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
|
| 158 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 159 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 160 |
+
print("Downloading SAM ViT-L checkpoint...")
|
| 161 |
+
urllib.request.urlretrieve(
|
| 162 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
| 163 |
+
checkpoint,
|
| 164 |
+
)
|
| 165 |
+
print(checkpoint.name, " is downloaded!")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if checkpoint is not None:
|
| 169 |
+
with open(checkpoint, "rb") as f:
|
| 170 |
+
state_dict = torch.load(f)
|
| 171 |
+
sam.load_state_dict(state_dict)
|
| 172 |
+
return sam
|
model/segment_anything_volumetric/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .build_sam import (
|
| 8 |
+
build_sam_vit_3d,
|
| 9 |
+
sam_model_registry,
|
| 10 |
+
)
|
| 11 |
+
from .predictor import SamPredictor
|
| 12 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (407 Bytes). View file
|
|
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (377 Bytes). View file
|
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc
ADDED
|
Binary file (9.96 kB). View file
|
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc
ADDED
|
Binary file (9.98 kB). View file
|
|
|
model/segment_anything_volumetric/automatic_mask_generator.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from .modeling import Sam
|
| 14 |
+
from .predictor import SamPredictor
|
| 15 |
+
from .utils.amg import (
|
| 16 |
+
MaskData,
|
| 17 |
+
area_from_rle,
|
| 18 |
+
batch_iterator,
|
| 19 |
+
batched_mask_to_box,
|
| 20 |
+
box_xyxy_to_xywh,
|
| 21 |
+
build_all_layer_point_grids,
|
| 22 |
+
calculate_stability_score,
|
| 23 |
+
coco_encode_rle,
|
| 24 |
+
generate_crop_boxes,
|
| 25 |
+
is_box_near_crop_edge,
|
| 26 |
+
mask_to_rle_pytorch,
|
| 27 |
+
remove_small_regions,
|
| 28 |
+
rle_to_mask,
|
| 29 |
+
uncrop_boxes_xyxy,
|
| 30 |
+
uncrop_masks,
|
| 31 |
+
uncrop_points,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SamAutomaticMaskGenerator:
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
model: Sam,
|
| 39 |
+
points_per_side: Optional[int] = 32,
|
| 40 |
+
points_per_batch: int = 64,
|
| 41 |
+
pred_iou_thresh: float = 0.88,
|
| 42 |
+
stability_score_thresh: float = 0.95,
|
| 43 |
+
stability_score_offset: float = 1.0,
|
| 44 |
+
box_nms_thresh: float = 0.7,
|
| 45 |
+
crop_n_layers: int = 0,
|
| 46 |
+
crop_nms_thresh: float = 0.7,
|
| 47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 48 |
+
crop_n_points_downscale_factor: int = 1,
|
| 49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
| 50 |
+
min_mask_region_area: int = 0,
|
| 51 |
+
output_mode: str = "binary_mask",
|
| 52 |
+
) -> None:
|
| 53 |
+
"""
|
| 54 |
+
Using a SAM model, generates masks for the entire image.
|
| 55 |
+
Generates a grid of point prompts over the image, then filters
|
| 56 |
+
low quality and duplicate masks. The default settings are chosen
|
| 57 |
+
for SAM with a ViT-H backbone.
|
| 58 |
+
|
| 59 |
+
Arguments:
|
| 60 |
+
model (Sam): The SAM model to use for mask prediction.
|
| 61 |
+
points_per_side (int or None): The number of points to be sampled
|
| 62 |
+
along one side of the image. The total number of points is
|
| 63 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 64 |
+
point sampling.
|
| 65 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
| 66 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
| 67 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 68 |
+
model's predicted mask quality.
|
| 69 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 70 |
+
the stability of the mask under changes to the cutoff used to binarize
|
| 71 |
+
the model's mask predictions.
|
| 72 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
| 73 |
+
calculated the stability score.
|
| 74 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 75 |
+
suppression to filter duplicate masks.
|
| 76 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 77 |
+
crops of the image. Sets the number of layers to run, where each
|
| 78 |
+
layer has 2**i_layer number of image crops.
|
| 79 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 80 |
+
suppression to filter duplicate masks between different crops.
|
| 81 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 82 |
+
In the first crop layer, crops will overlap by this fraction of
|
| 83 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 84 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 85 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 86 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 87 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 88 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 89 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 90 |
+
to remove disconnected regions and holes in masks with area smaller
|
| 91 |
+
than min_mask_region_area. Requires opencv.
|
| 92 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 93 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 94 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
| 95 |
+
memory.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
assert (points_per_side is None) != (
|
| 99 |
+
point_grids is None
|
| 100 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 101 |
+
if points_per_side is not None:
|
| 102 |
+
self.point_grids = build_all_layer_point_grids(
|
| 103 |
+
points_per_side,
|
| 104 |
+
crop_n_layers,
|
| 105 |
+
crop_n_points_downscale_factor,
|
| 106 |
+
)
|
| 107 |
+
elif point_grids is not None:
|
| 108 |
+
self.point_grids = point_grids
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 111 |
+
|
| 112 |
+
assert output_mode in [
|
| 113 |
+
"binary_mask",
|
| 114 |
+
"uncompressed_rle",
|
| 115 |
+
"coco_rle",
|
| 116 |
+
], f"Unknown output_mode {output_mode}."
|
| 117 |
+
if output_mode == "coco_rle":
|
| 118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 119 |
+
|
| 120 |
+
if min_mask_region_area > 0:
|
| 121 |
+
import cv2 # type: ignore # noqa: F401
|
| 122 |
+
|
| 123 |
+
self.predictor = SamPredictor(model)
|
| 124 |
+
self.points_per_batch = points_per_batch
|
| 125 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 126 |
+
self.stability_score_thresh = stability_score_thresh
|
| 127 |
+
self.stability_score_offset = stability_score_offset
|
| 128 |
+
self.box_nms_thresh = box_nms_thresh
|
| 129 |
+
self.crop_n_layers = crop_n_layers
|
| 130 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 131 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 132 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 133 |
+
self.min_mask_region_area = min_mask_region_area
|
| 134 |
+
self.output_mode = output_mode
|
| 135 |
+
|
| 136 |
+
@torch.no_grad()
|
| 137 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
| 138 |
+
"""
|
| 139 |
+
Generates masks for the given image.
|
| 140 |
+
|
| 141 |
+
Arguments:
|
| 142 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
| 146 |
+
a dict containing the following keys:
|
| 147 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 148 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 149 |
+
is a dictionary containing the RLE.
|
| 150 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
| 151 |
+
area (int): The area in pixels of the mask.
|
| 152 |
+
predicted_iou (float): The model's own prediction of the mask's
|
| 153 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
| 154 |
+
point_coords (list(list(float))): The point coordinates input
|
| 155 |
+
to the model to generate this mask.
|
| 156 |
+
stability_score (float): A measure of the mask's quality. This
|
| 157 |
+
is filtered on using the stability_score_thresh parameter.
|
| 158 |
+
crop_box (list(float)): The crop of the image used to generate
|
| 159 |
+
the mask, given in XYWH format.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
# Generate masks
|
| 163 |
+
mask_data = self._generate_masks(image)
|
| 164 |
+
|
| 165 |
+
# Filter small disconnected regions and holes in masks
|
| 166 |
+
if self.min_mask_region_area > 0:
|
| 167 |
+
mask_data = self.postprocess_small_regions(
|
| 168 |
+
mask_data,
|
| 169 |
+
self.min_mask_region_area,
|
| 170 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Encode masks
|
| 174 |
+
if self.output_mode == "coco_rle":
|
| 175 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
| 176 |
+
elif self.output_mode == "binary_mask":
|
| 177 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 178 |
+
else:
|
| 179 |
+
mask_data["segmentations"] = mask_data["rles"]
|
| 180 |
+
|
| 181 |
+
# Write mask records
|
| 182 |
+
curr_anns = []
|
| 183 |
+
for idx in range(len(mask_data["segmentations"])):
|
| 184 |
+
ann = {
|
| 185 |
+
"segmentation": mask_data["segmentations"][idx],
|
| 186 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
| 187 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 188 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 189 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
| 190 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
| 191 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 192 |
+
}
|
| 193 |
+
curr_anns.append(ann)
|
| 194 |
+
|
| 195 |
+
return curr_anns
|
| 196 |
+
|
| 197 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 198 |
+
orig_size = image.shape[:2]
|
| 199 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 200 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Iterate over image crops
|
| 204 |
+
data = MaskData()
|
| 205 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 206 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
| 207 |
+
data.cat(crop_data)
|
| 208 |
+
|
| 209 |
+
# Remove duplicate masks between crops
|
| 210 |
+
if len(crop_boxes) > 1:
|
| 211 |
+
# Prefer masks from smaller crops
|
| 212 |
+
scores = 1 / box_area(data["crop_boxes"])
|
| 213 |
+
scores = scores.to(data["boxes"].device)
|
| 214 |
+
keep_by_nms = batched_nms(
|
| 215 |
+
data["boxes"].float(),
|
| 216 |
+
scores,
|
| 217 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 218 |
+
iou_threshold=self.crop_nms_thresh,
|
| 219 |
+
)
|
| 220 |
+
data.filter(keep_by_nms)
|
| 221 |
+
|
| 222 |
+
data.to_numpy()
|
| 223 |
+
return data
|
| 224 |
+
|
| 225 |
+
def _process_crop(
|
| 226 |
+
self,
|
| 227 |
+
image: np.ndarray,
|
| 228 |
+
crop_box: List[int],
|
| 229 |
+
crop_layer_idx: int,
|
| 230 |
+
orig_size: Tuple[int, ...],
|
| 231 |
+
) -> MaskData:
|
| 232 |
+
# Crop the image and calculate embeddings
|
| 233 |
+
x0, y0, x1, y1 = crop_box
|
| 234 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
| 235 |
+
cropped_im_size = cropped_im.shape[:2]
|
| 236 |
+
self.predictor.set_image(cropped_im)
|
| 237 |
+
|
| 238 |
+
# Get points for this crop
|
| 239 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 240 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 241 |
+
|
| 242 |
+
# Generate masks for this crop in batches
|
| 243 |
+
data = MaskData()
|
| 244 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 245 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
| 246 |
+
data.cat(batch_data)
|
| 247 |
+
del batch_data
|
| 248 |
+
self.predictor.reset_image()
|
| 249 |
+
|
| 250 |
+
# Remove duplicates within this crop.
|
| 251 |
+
keep_by_nms = batched_nms(
|
| 252 |
+
data["boxes"].float(),
|
| 253 |
+
data["iou_preds"],
|
| 254 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 255 |
+
iou_threshold=self.box_nms_thresh,
|
| 256 |
+
)
|
| 257 |
+
data.filter(keep_by_nms)
|
| 258 |
+
|
| 259 |
+
# Return to the original image frame
|
| 260 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 261 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
| 262 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 263 |
+
|
| 264 |
+
return data
|
| 265 |
+
|
| 266 |
+
def _process_batch(
|
| 267 |
+
self,
|
| 268 |
+
points: np.ndarray,
|
| 269 |
+
im_size: Tuple[int, ...],
|
| 270 |
+
crop_box: List[int],
|
| 271 |
+
orig_size: Tuple[int, ...],
|
| 272 |
+
) -> MaskData:
|
| 273 |
+
orig_h, orig_w = orig_size
|
| 274 |
+
|
| 275 |
+
# Run model on this batch
|
| 276 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
| 277 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
| 278 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
| 279 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
| 280 |
+
in_points[:, None, :],
|
| 281 |
+
in_labels[:, None],
|
| 282 |
+
multimask_output=True,
|
| 283 |
+
return_logits=True,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Serialize predictions and store in MaskData
|
| 287 |
+
data = MaskData(
|
| 288 |
+
masks=masks.flatten(0, 1),
|
| 289 |
+
iou_preds=iou_preds.flatten(0, 1),
|
| 290 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
| 291 |
+
)
|
| 292 |
+
del masks
|
| 293 |
+
|
| 294 |
+
# Filter by predicted IoU
|
| 295 |
+
if self.pred_iou_thresh > 0.0:
|
| 296 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 297 |
+
data.filter(keep_mask)
|
| 298 |
+
|
| 299 |
+
# Calculate stability score
|
| 300 |
+
data["stability_score"] = calculate_stability_score(
|
| 301 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
| 302 |
+
)
|
| 303 |
+
if self.stability_score_thresh > 0.0:
|
| 304 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 305 |
+
data.filter(keep_mask)
|
| 306 |
+
|
| 307 |
+
# Threshold masks and calculate boxes
|
| 308 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
| 309 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 310 |
+
|
| 311 |
+
# Filter boxes that touch crop boundaries
|
| 312 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
| 313 |
+
if not torch.all(keep_mask):
|
| 314 |
+
data.filter(keep_mask)
|
| 315 |
+
|
| 316 |
+
# Compress to RLE
|
| 317 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 318 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 319 |
+
del data["masks"]
|
| 320 |
+
|
| 321 |
+
return data
|
| 322 |
+
|
| 323 |
+
@staticmethod
|
| 324 |
+
def postprocess_small_regions(
|
| 325 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 326 |
+
) -> MaskData:
|
| 327 |
+
"""
|
| 328 |
+
Removes small disconnected regions and holes in masks, then reruns
|
| 329 |
+
box NMS to remove any new duplicates.
|
| 330 |
+
|
| 331 |
+
Edits mask_data in place.
|
| 332 |
+
|
| 333 |
+
Requires open-cv as a dependency.
|
| 334 |
+
"""
|
| 335 |
+
if len(mask_data["rles"]) == 0:
|
| 336 |
+
return mask_data
|
| 337 |
+
|
| 338 |
+
# Filter small disconnected regions and holes
|
| 339 |
+
new_masks = []
|
| 340 |
+
scores = []
|
| 341 |
+
for rle in mask_data["rles"]:
|
| 342 |
+
mask = rle_to_mask(rle)
|
| 343 |
+
|
| 344 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 345 |
+
unchanged = not changed
|
| 346 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 347 |
+
unchanged = unchanged and not changed
|
| 348 |
+
|
| 349 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 350 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 351 |
+
# so NMS will prefer ones that didn't need postprocessing
|
| 352 |
+
scores.append(float(unchanged))
|
| 353 |
+
|
| 354 |
+
# Recalculate boxes and remove any new duplicates
|
| 355 |
+
masks = torch.cat(new_masks, dim=0)
|
| 356 |
+
boxes = batched_mask_to_box(masks)
|
| 357 |
+
keep_by_nms = batched_nms(
|
| 358 |
+
boxes.float(),
|
| 359 |
+
torch.as_tensor(scores),
|
| 360 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
| 361 |
+
iou_threshold=nms_thresh,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Only recalculate RLEs for masks that have changed
|
| 365 |
+
for i_mask in keep_by_nms:
|
| 366 |
+
if scores[i_mask] == 0.0:
|
| 367 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
| 368 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 369 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 370 |
+
mask_data.filter(keep_by_nms)
|
| 371 |
+
|
| 372 |
+
return mask_data
|
model/segment_anything_volumetric/build_sam.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
from functools import partial
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .modeling import (
|
| 12 |
+
ImageEncoderViT,
|
| 13 |
+
MaskDecoder,
|
| 14 |
+
PromptEncoder,
|
| 15 |
+
Sam,
|
| 16 |
+
TwoWayTransformer,
|
| 17 |
+
)
|
| 18 |
+
import numpy as np
|
| 19 |
+
from .modeling.image_encoder_swin import SwinTransformer
|
| 20 |
+
from monai.networks.nets import ViT
|
| 21 |
+
from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
|
| 22 |
+
|
| 23 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
Examples::
|
| 28 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
| 29 |
+
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
| 30 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
| 31 |
+
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
| 32 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
| 33 |
+
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def build_sam_vit_3d(checkpoint=None):
|
| 37 |
+
print('build_sam_vit_3d...')
|
| 38 |
+
return _build_sam(
|
| 39 |
+
image_encoder_type='vit',
|
| 40 |
+
embed_dim = 768,
|
| 41 |
+
patch_size=[4,16,16],
|
| 42 |
+
checkpoint=checkpoint,
|
| 43 |
+
image_size=[32,256,256],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
sam_model_registry = {
|
| 47 |
+
"vit": build_sam_vit_3d,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _build_sam(
|
| 52 |
+
image_encoder_type,
|
| 53 |
+
embed_dim,
|
| 54 |
+
patch_size,
|
| 55 |
+
checkpoint,
|
| 56 |
+
image_size,
|
| 57 |
+
):
|
| 58 |
+
mlp_dim = 3072
|
| 59 |
+
num_layers = 12
|
| 60 |
+
num_heads = 12
|
| 61 |
+
pos_embed = 'perceptron'
|
| 62 |
+
dropout_rate = 0.0
|
| 63 |
+
|
| 64 |
+
image_encoder=ViT(
|
| 65 |
+
in_channels=1,
|
| 66 |
+
img_size=image_size,
|
| 67 |
+
patch_size=patch_size,
|
| 68 |
+
hidden_size=embed_dim,
|
| 69 |
+
mlp_dim=mlp_dim,
|
| 70 |
+
num_layers=num_layers,
|
| 71 |
+
num_heads=num_heads,
|
| 72 |
+
pos_embed=pos_embed,
|
| 73 |
+
classification=False,
|
| 74 |
+
dropout_rate=dropout_rate,
|
| 75 |
+
)
|
| 76 |
+
image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
|
| 77 |
+
|
| 78 |
+
if checkpoint is not None:
|
| 79 |
+
with open(checkpoint, "rb") as f:
|
| 80 |
+
state_dict = torch.load(f, map_location='cpu')['state_dict']
|
| 81 |
+
encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
|
| 82 |
+
image_encoder.load_state_dict(encoder_dict)
|
| 83 |
+
print(f'===> image_encoder.load_param: {checkpoint}')
|
| 84 |
+
sam = Sam(
|
| 85 |
+
image_encoder=image_encoder,
|
| 86 |
+
prompt_encoder=PromptEncoder(
|
| 87 |
+
embed_dim=embed_dim,
|
| 88 |
+
image_embedding_size=image_embedding_size,
|
| 89 |
+
input_image_size=image_size,
|
| 90 |
+
mask_in_chans=16,
|
| 91 |
+
),
|
| 92 |
+
mask_decoder=MaskDecoder(
|
| 93 |
+
image_encoder_type=image_encoder_type,
|
| 94 |
+
num_multimask_outputs=3,
|
| 95 |
+
transformer=TwoWayTransformer(
|
| 96 |
+
depth=2,
|
| 97 |
+
embedding_dim=embed_dim,
|
| 98 |
+
mlp_dim=2048,
|
| 99 |
+
num_heads=8,
|
| 100 |
+
),
|
| 101 |
+
transformer_dim=embed_dim,
|
| 102 |
+
iou_head_depth=3,
|
| 103 |
+
iou_head_hidden_dim=256,
|
| 104 |
+
image_size=np.array(image_size),
|
| 105 |
+
patch_size=np.array(patch_size),
|
| 106 |
+
),
|
| 107 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
| 108 |
+
pixel_std=[58.395, 57.12, 57.375],
|
| 109 |
+
)
|
| 110 |
+
sam.eval()
|
| 111 |
+
return sam
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py
ADDED
|
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence, Tuple, Type, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.checkpoint as checkpoint
|
| 8 |
+
from torch.nn import LayerNorm
|
| 9 |
+
|
| 10 |
+
from monai.networks.blocks import MLPBlock as Mlp
|
| 11 |
+
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
|
| 12 |
+
from monai.networks.layers import DropPath, trunc_normal_
|
| 13 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
| 14 |
+
|
| 15 |
+
rearrange, _ = optional_import("einops", name="rearrange")
|
| 16 |
+
|
| 17 |
+
def window_partition(x, window_size):
|
| 18 |
+
"""window partition operation based on: "Liu et al.,
|
| 19 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 20 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 21 |
+
https://github.com/microsoft/Swin-Transformer
|
| 22 |
+
Args:
|
| 23 |
+
x: input tensor.
|
| 24 |
+
window_size: local window size.
|
| 25 |
+
"""
|
| 26 |
+
x_shape = x.size()
|
| 27 |
+
if len(x_shape) == 5:
|
| 28 |
+
b, d, h, w, c = x_shape
|
| 29 |
+
x = x.view(
|
| 30 |
+
b,
|
| 31 |
+
d // window_size[0],
|
| 32 |
+
window_size[0],
|
| 33 |
+
h // window_size[1],
|
| 34 |
+
window_size[1],
|
| 35 |
+
w // window_size[2],
|
| 36 |
+
window_size[2],
|
| 37 |
+
c,
|
| 38 |
+
)
|
| 39 |
+
windows = (
|
| 40 |
+
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
| 41 |
+
)
|
| 42 |
+
elif len(x_shape) == 4:
|
| 43 |
+
b, h, w, c = x.shape
|
| 44 |
+
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
| 45 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
| 46 |
+
return windows
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def window_reverse(windows, window_size, dims):
|
| 50 |
+
"""window reverse operation based on: "Liu et al.,
|
| 51 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 52 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 53 |
+
https://github.com/microsoft/Swin-Transformer
|
| 54 |
+
Args:
|
| 55 |
+
windows: windows tensor.
|
| 56 |
+
window_size: local window size.
|
| 57 |
+
dims: dimension values.
|
| 58 |
+
"""
|
| 59 |
+
if len(dims) == 4:
|
| 60 |
+
b, d, h, w = dims
|
| 61 |
+
x = windows.view(
|
| 62 |
+
b,
|
| 63 |
+
d // window_size[0],
|
| 64 |
+
h // window_size[1],
|
| 65 |
+
w // window_size[2],
|
| 66 |
+
window_size[0],
|
| 67 |
+
window_size[1],
|
| 68 |
+
window_size[2],
|
| 69 |
+
-1,
|
| 70 |
+
)
|
| 71 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
|
| 72 |
+
|
| 73 |
+
elif len(dims) == 3:
|
| 74 |
+
b, h, w = dims
|
| 75 |
+
x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
|
| 76 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_window_size(x_size, window_size, shift_size=None):
|
| 81 |
+
"""Computing window size based on: "Liu et al.,
|
| 82 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 83 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 84 |
+
https://github.com/microsoft/Swin-Transformer
|
| 85 |
+
Args:
|
| 86 |
+
x_size: input size.
|
| 87 |
+
window_size: local window size.
|
| 88 |
+
shift_size: window shifting size.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
use_window_size = list(window_size)
|
| 92 |
+
if shift_size is not None:
|
| 93 |
+
use_shift_size = list(shift_size)
|
| 94 |
+
for i in range(len(x_size)):
|
| 95 |
+
if x_size[i] <= window_size[i]:
|
| 96 |
+
use_window_size[i] = x_size[i]
|
| 97 |
+
if shift_size is not None:
|
| 98 |
+
use_shift_size[i] = 0
|
| 99 |
+
|
| 100 |
+
if shift_size is None:
|
| 101 |
+
return tuple(use_window_size)
|
| 102 |
+
else:
|
| 103 |
+
return tuple(use_window_size), tuple(use_shift_size)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class WindowAttention(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
|
| 109 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 110 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 111 |
+
https://github.com/microsoft/Swin-Transformer
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
dim: int,
|
| 117 |
+
num_heads: int,
|
| 118 |
+
window_size: Sequence[int],
|
| 119 |
+
qkv_bias: bool = False,
|
| 120 |
+
attn_drop: float = 0.0,
|
| 121 |
+
proj_drop: float = 0.0,
|
| 122 |
+
) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Args:
|
| 125 |
+
dim: number of feature channels.
|
| 126 |
+
num_heads: number of attention heads.
|
| 127 |
+
window_size: local window size.
|
| 128 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 129 |
+
attn_drop: attention dropout rate.
|
| 130 |
+
proj_drop: dropout rate of output.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.dim = dim
|
| 135 |
+
self.window_size = window_size
|
| 136 |
+
self.num_heads = num_heads
|
| 137 |
+
head_dim = dim // num_heads
|
| 138 |
+
self.scale = head_dim**-0.5
|
| 139 |
+
mesh_args = torch.meshgrid.__kwdefaults__
|
| 140 |
+
|
| 141 |
+
if len(self.window_size) == 3:
|
| 142 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 143 |
+
torch.zeros(
|
| 144 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
|
| 145 |
+
num_heads,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
coords_d = torch.arange(self.window_size[0])
|
| 149 |
+
coords_h = torch.arange(self.window_size[1])
|
| 150 |
+
coords_w = torch.arange(self.window_size[2])
|
| 151 |
+
if mesh_args is not None:
|
| 152 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
|
| 153 |
+
else:
|
| 154 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
|
| 155 |
+
coords_flatten = torch.flatten(coords, 1)
|
| 156 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 157 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 158 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 159 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 160 |
+
relative_coords[:, :, 2] += self.window_size[2] - 1
|
| 161 |
+
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
| 162 |
+
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
| 163 |
+
elif len(self.window_size) == 2:
|
| 164 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 165 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
| 166 |
+
)
|
| 167 |
+
coords_h = torch.arange(self.window_size[0])
|
| 168 |
+
coords_w = torch.arange(self.window_size[1])
|
| 169 |
+
if mesh_args is not None:
|
| 170 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
|
| 171 |
+
else:
|
| 172 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
|
| 173 |
+
coords_flatten = torch.flatten(coords, 1)
|
| 174 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 175 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 176 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 177 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 178 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 179 |
+
|
| 180 |
+
relative_position_index = relative_coords.sum(-1)
|
| 181 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 182 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 183 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 184 |
+
self.proj = nn.Linear(dim, dim)
|
| 185 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 186 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
| 187 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 188 |
+
|
| 189 |
+
def forward(self, x, mask):
|
| 190 |
+
b, n, c = x.shape
|
| 191 |
+
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 192 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 193 |
+
q = q * self.scale
|
| 194 |
+
attn = q @ k.transpose(-2, -1)
|
| 195 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 196 |
+
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
| 197 |
+
].reshape(n, n, -1)
|
| 198 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 199 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 200 |
+
if mask is not None:
|
| 201 |
+
nw = mask.shape[0]
|
| 202 |
+
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
| 203 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
| 204 |
+
attn = self.softmax(attn)
|
| 205 |
+
else:
|
| 206 |
+
attn = self.softmax(attn)
|
| 207 |
+
|
| 208 |
+
attn = self.attn_drop(attn)
|
| 209 |
+
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
| 210 |
+
x = self.proj(x)
|
| 211 |
+
x = self.proj_drop(x)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class SwinTransformerBlock(nn.Module):
|
| 216 |
+
"""
|
| 217 |
+
Swin Transformer block based on: "Liu et al.,
|
| 218 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 219 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 220 |
+
https://github.com/microsoft/Swin-Transformer
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
dim: int,
|
| 226 |
+
num_heads: int,
|
| 227 |
+
window_size: Sequence[int],
|
| 228 |
+
shift_size: Sequence[int],
|
| 229 |
+
mlp_ratio: float = 4.0,
|
| 230 |
+
qkv_bias: bool = True,
|
| 231 |
+
drop: float = 0.0,
|
| 232 |
+
attn_drop: float = 0.0,
|
| 233 |
+
drop_path: float = 0.0,
|
| 234 |
+
act_layer: str = "GELU",
|
| 235 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 236 |
+
use_checkpoint: bool = False,
|
| 237 |
+
) -> None:
|
| 238 |
+
"""
|
| 239 |
+
Args:
|
| 240 |
+
dim: number of feature channels.
|
| 241 |
+
num_heads: number of attention heads.
|
| 242 |
+
window_size: local window size.
|
| 243 |
+
shift_size: window shift size.
|
| 244 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 245 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 246 |
+
drop: dropout rate.
|
| 247 |
+
attn_drop: attention dropout rate.
|
| 248 |
+
drop_path: stochastic depth rate.
|
| 249 |
+
act_layer: activation layer.
|
| 250 |
+
norm_layer: normalization layer.
|
| 251 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.dim = dim
|
| 256 |
+
self.num_heads = num_heads
|
| 257 |
+
self.window_size = window_size
|
| 258 |
+
self.shift_size = shift_size
|
| 259 |
+
self.mlp_ratio = mlp_ratio
|
| 260 |
+
self.use_checkpoint = use_checkpoint
|
| 261 |
+
self.norm1 = norm_layer(dim)
|
| 262 |
+
self.attn = WindowAttention(
|
| 263 |
+
dim,
|
| 264 |
+
window_size=self.window_size,
|
| 265 |
+
num_heads=num_heads,
|
| 266 |
+
qkv_bias=qkv_bias,
|
| 267 |
+
attn_drop=attn_drop,
|
| 268 |
+
proj_drop=drop,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 272 |
+
self.norm2 = norm_layer(dim)
|
| 273 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 274 |
+
self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
|
| 275 |
+
|
| 276 |
+
def forward_part1(self, x, mask_matrix):
|
| 277 |
+
x_shape = x.size()
|
| 278 |
+
x = self.norm1(x)
|
| 279 |
+
if len(x_shape) == 5:
|
| 280 |
+
b, d, h, w, c = x.shape
|
| 281 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 282 |
+
pad_l = pad_t = pad_d0 = 0
|
| 283 |
+
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
|
| 284 |
+
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
|
| 285 |
+
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
|
| 286 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
| 287 |
+
_, dp, hp, wp, _ = x.shape
|
| 288 |
+
dims = [b, dp, hp, wp]
|
| 289 |
+
|
| 290 |
+
elif len(x_shape) == 4:
|
| 291 |
+
b, h, w, c = x.shape
|
| 292 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 293 |
+
pad_l = pad_t = 0
|
| 294 |
+
pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
|
| 295 |
+
pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
|
| 296 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 297 |
+
_, hp, wp, _ = x.shape
|
| 298 |
+
dims = [b, hp, wp]
|
| 299 |
+
|
| 300 |
+
if any(i > 0 for i in shift_size):
|
| 301 |
+
if len(x_shape) == 5:
|
| 302 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
| 303 |
+
elif len(x_shape) == 4:
|
| 304 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
| 305 |
+
attn_mask = mask_matrix
|
| 306 |
+
else:
|
| 307 |
+
shifted_x = x
|
| 308 |
+
attn_mask = None
|
| 309 |
+
x_windows = window_partition(shifted_x, window_size)
|
| 310 |
+
attn_windows = self.attn(x_windows, mask=attn_mask)
|
| 311 |
+
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
|
| 312 |
+
shifted_x = window_reverse(attn_windows, window_size, dims)
|
| 313 |
+
if any(i > 0 for i in shift_size):
|
| 314 |
+
if len(x_shape) == 5:
|
| 315 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
| 316 |
+
elif len(x_shape) == 4:
|
| 317 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
| 318 |
+
else:
|
| 319 |
+
x = shifted_x
|
| 320 |
+
|
| 321 |
+
if len(x_shape) == 5:
|
| 322 |
+
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
|
| 323 |
+
x = x[:, :d, :h, :w, :].contiguous()
|
| 324 |
+
elif len(x_shape) == 4:
|
| 325 |
+
if pad_r > 0 or pad_b > 0:
|
| 326 |
+
x = x[:, :h, :w, :].contiguous()
|
| 327 |
+
|
| 328 |
+
return x
|
| 329 |
+
|
| 330 |
+
def forward_part2(self, x):
|
| 331 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
| 332 |
+
|
| 333 |
+
def load_from(self, weights, n_block, layer):
|
| 334 |
+
root = f"module.{layer}.0.blocks.{n_block}."
|
| 335 |
+
block_names = [
|
| 336 |
+
"norm1.weight",
|
| 337 |
+
"norm1.bias",
|
| 338 |
+
"attn.relative_position_bias_table",
|
| 339 |
+
"attn.relative_position_index",
|
| 340 |
+
"attn.qkv.weight",
|
| 341 |
+
"attn.qkv.bias",
|
| 342 |
+
"attn.proj.weight",
|
| 343 |
+
"attn.proj.bias",
|
| 344 |
+
"norm2.weight",
|
| 345 |
+
"norm2.bias",
|
| 346 |
+
"mlp.fc1.weight",
|
| 347 |
+
"mlp.fc1.bias",
|
| 348 |
+
"mlp.fc2.weight",
|
| 349 |
+
"mlp.fc2.bias",
|
| 350 |
+
]
|
| 351 |
+
with torch.no_grad():
|
| 352 |
+
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
| 353 |
+
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
| 354 |
+
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
| 355 |
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
| 356 |
+
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
| 357 |
+
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
| 358 |
+
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
| 359 |
+
self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
|
| 360 |
+
self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
|
| 361 |
+
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
|
| 362 |
+
self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
|
| 363 |
+
self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
|
| 364 |
+
self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
|
| 365 |
+
self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
|
| 366 |
+
|
| 367 |
+
def forward(self, x, mask_matrix):
|
| 368 |
+
shortcut = x
|
| 369 |
+
if self.use_checkpoint:
|
| 370 |
+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
| 371 |
+
else:
|
| 372 |
+
x = self.forward_part1(x, mask_matrix)
|
| 373 |
+
x = shortcut + self.drop_path(x)
|
| 374 |
+
if self.use_checkpoint:
|
| 375 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
| 376 |
+
else:
|
| 377 |
+
x = x + self.forward_part2(x)
|
| 378 |
+
return x
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class PatchMerging(nn.Module):
|
| 382 |
+
"""
|
| 383 |
+
Patch merging layer based on: "Liu et al.,
|
| 384 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 385 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 386 |
+
https://github.com/microsoft/Swin-Transformer
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(
|
| 390 |
+
self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
|
| 391 |
+
) -> None: # type: ignore
|
| 392 |
+
"""
|
| 393 |
+
Args:
|
| 394 |
+
dim: number of feature channels.
|
| 395 |
+
norm_layer: normalization layer.
|
| 396 |
+
spatial_dims: number of spatial dims.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.dim = dim
|
| 401 |
+
if spatial_dims == 3:
|
| 402 |
+
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
|
| 403 |
+
self.norm = norm_layer(8 * dim)
|
| 404 |
+
elif spatial_dims == 2:
|
| 405 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 406 |
+
self.norm = norm_layer(4 * dim)
|
| 407 |
+
|
| 408 |
+
def forward(self, x):
|
| 409 |
+
|
| 410 |
+
x_shape = x.size()
|
| 411 |
+
if len(x_shape) == 5:
|
| 412 |
+
b, d, h, w, c = x_shape
|
| 413 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
|
| 414 |
+
if pad_input:
|
| 415 |
+
x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
|
| 416 |
+
x0 = x[:, 0::2, 0::2, 0::2, :]
|
| 417 |
+
x1 = x[:, 1::2, 0::2, 0::2, :]
|
| 418 |
+
x2 = x[:, 0::2, 1::2, 0::2, :]
|
| 419 |
+
x3 = x[:, 0::2, 0::2, 1::2, :]
|
| 420 |
+
x4 = x[:, 1::2, 0::2, 1::2, :]
|
| 421 |
+
x5 = x[:, 0::2, 1::2, 0::2, :]
|
| 422 |
+
x6 = x[:, 0::2, 0::2, 1::2, :]
|
| 423 |
+
x7 = x[:, 1::2, 1::2, 1::2, :]
|
| 424 |
+
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
|
| 425 |
+
|
| 426 |
+
elif len(x_shape) == 4:
|
| 427 |
+
b, h, w, c = x_shape
|
| 428 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1)
|
| 429 |
+
if pad_input:
|
| 430 |
+
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
|
| 431 |
+
x0 = x[:, 0::2, 0::2, :]
|
| 432 |
+
x1 = x[:, 1::2, 0::2, :]
|
| 433 |
+
x2 = x[:, 0::2, 1::2, :]
|
| 434 |
+
x3 = x[:, 1::2, 1::2, :]
|
| 435 |
+
x = torch.cat([x0, x1, x2, x3], -1)
|
| 436 |
+
|
| 437 |
+
x = self.norm(x)
|
| 438 |
+
x = self.reduction(x)
|
| 439 |
+
return x
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def compute_mask(dims, window_size, shift_size, device):
|
| 443 |
+
"""Computing region masks based on: "Liu et al.,
|
| 444 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 445 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 446 |
+
https://github.com/microsoft/Swin-Transformer
|
| 447 |
+
Args:
|
| 448 |
+
dims: dimension values.
|
| 449 |
+
window_size: local window size.
|
| 450 |
+
shift_size: shift size.
|
| 451 |
+
device: device.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
cnt = 0
|
| 455 |
+
|
| 456 |
+
if len(dims) == 3:
|
| 457 |
+
d, h, w = dims
|
| 458 |
+
img_mask = torch.zeros((1, d, h, w, 1), device=device)
|
| 459 |
+
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 460 |
+
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 461 |
+
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
|
| 462 |
+
img_mask[:, d, h, w, :] = cnt
|
| 463 |
+
cnt += 1
|
| 464 |
+
|
| 465 |
+
elif len(dims) == 2:
|
| 466 |
+
h, w = dims
|
| 467 |
+
img_mask = torch.zeros((1, h, w, 1), device=device)
|
| 468 |
+
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 469 |
+
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 470 |
+
img_mask[:, h, w, :] = cnt
|
| 471 |
+
cnt += 1
|
| 472 |
+
|
| 473 |
+
mask_windows = window_partition(img_mask, window_size)
|
| 474 |
+
mask_windows = mask_windows.squeeze(-1)
|
| 475 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 476 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 477 |
+
|
| 478 |
+
return attn_mask
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class BasicLayer(nn.Module):
|
| 482 |
+
"""
|
| 483 |
+
Basic Swin Transformer layer in one stage based on: "Liu et al.,
|
| 484 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 485 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 486 |
+
https://github.com/microsoft/Swin-Transformer
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
dim: int,
|
| 492 |
+
depth: int,
|
| 493 |
+
num_heads: int,
|
| 494 |
+
window_size: Sequence[int],
|
| 495 |
+
drop_path: list,
|
| 496 |
+
mlp_ratio: float = 4.0,
|
| 497 |
+
qkv_bias: bool = False,
|
| 498 |
+
drop: float = 0.0,
|
| 499 |
+
attn_drop: float = 0.0,
|
| 500 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 501 |
+
downsample: isinstance = None, # type: ignore
|
| 502 |
+
use_checkpoint: bool = False,
|
| 503 |
+
) -> None:
|
| 504 |
+
"""
|
| 505 |
+
Args:
|
| 506 |
+
dim: number of feature channels.
|
| 507 |
+
depths: number of layers in each stage.
|
| 508 |
+
num_heads: number of attention heads.
|
| 509 |
+
window_size: local window size.
|
| 510 |
+
drop_path: stochastic depth rate.
|
| 511 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 512 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 513 |
+
drop: dropout rate.
|
| 514 |
+
attn_drop: attention dropout rate.
|
| 515 |
+
norm_layer: normalization layer.
|
| 516 |
+
downsample: downsample layer at the end of the layer.
|
| 517 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.window_size = window_size
|
| 522 |
+
self.shift_size = tuple(i // 2 for i in window_size)
|
| 523 |
+
self.no_shift = tuple(0 for i in window_size)
|
| 524 |
+
self.depth = depth
|
| 525 |
+
self.use_checkpoint = use_checkpoint
|
| 526 |
+
self.blocks = nn.ModuleList(
|
| 527 |
+
[
|
| 528 |
+
SwinTransformerBlock(
|
| 529 |
+
dim=dim,
|
| 530 |
+
num_heads=num_heads,
|
| 531 |
+
window_size=self.window_size,
|
| 532 |
+
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
|
| 533 |
+
mlp_ratio=mlp_ratio,
|
| 534 |
+
qkv_bias=qkv_bias,
|
| 535 |
+
drop=drop,
|
| 536 |
+
attn_drop=attn_drop,
|
| 537 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 538 |
+
norm_layer=norm_layer,
|
| 539 |
+
use_checkpoint=use_checkpoint,
|
| 540 |
+
)
|
| 541 |
+
for i in range(depth)
|
| 542 |
+
]
|
| 543 |
+
)
|
| 544 |
+
self.downsample = downsample
|
| 545 |
+
if self.downsample is not None:
|
| 546 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
|
| 547 |
+
|
| 548 |
+
def forward(self, x):
|
| 549 |
+
x_shape = x.size()
|
| 550 |
+
if len(x_shape) == 5:
|
| 551 |
+
b, c, d, h, w = x_shape
|
| 552 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 553 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
| 554 |
+
dp = int(np.ceil(d / window_size[0])) * window_size[0]
|
| 555 |
+
hp = int(np.ceil(h / window_size[1])) * window_size[1]
|
| 556 |
+
wp = int(np.ceil(w / window_size[2])) * window_size[2]
|
| 557 |
+
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
|
| 558 |
+
for blk in self.blocks:
|
| 559 |
+
x = blk(x, attn_mask)
|
| 560 |
+
x = x.view(b, d, h, w, -1)
|
| 561 |
+
if self.downsample is not None:
|
| 562 |
+
x = self.downsample(x)
|
| 563 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
| 564 |
+
|
| 565 |
+
elif len(x_shape) == 4:
|
| 566 |
+
b, c, h, w = x_shape
|
| 567 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 568 |
+
x = rearrange(x, "b c h w -> b h w c")
|
| 569 |
+
hp = int(np.ceil(h / window_size[0])) * window_size[0]
|
| 570 |
+
wp = int(np.ceil(w / window_size[1])) * window_size[1]
|
| 571 |
+
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
|
| 572 |
+
for blk in self.blocks:
|
| 573 |
+
x = blk(x, attn_mask)
|
| 574 |
+
x = x.view(b, h, w, -1)
|
| 575 |
+
if self.downsample is not None:
|
| 576 |
+
x = self.downsample(x)
|
| 577 |
+
x = rearrange(x, "b h w c -> b c h w")
|
| 578 |
+
return x
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class SwinTransformer(nn.Module):
|
| 582 |
+
"""
|
| 583 |
+
Swin Transformer based on: "Liu et al.,
|
| 584 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 585 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 586 |
+
https://github.com/microsoft/Swin-Transformer
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
def __init__(
|
| 590 |
+
self,
|
| 591 |
+
in_chans: int,
|
| 592 |
+
embed_dim: int,
|
| 593 |
+
window_size: Sequence[int],
|
| 594 |
+
patch_size: Sequence[int],
|
| 595 |
+
depths: Sequence[int],
|
| 596 |
+
num_heads: Sequence[int],
|
| 597 |
+
mlp_ratio: float = 4.0,
|
| 598 |
+
qkv_bias: bool = True,
|
| 599 |
+
drop_rate: float = 0.0,
|
| 600 |
+
attn_drop_rate: float = 0.0,
|
| 601 |
+
drop_path_rate: float = 0.0,
|
| 602 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 603 |
+
patch_norm: bool = False,
|
| 604 |
+
use_checkpoint: bool = False,
|
| 605 |
+
spatial_dims: int = 3,
|
| 606 |
+
) -> None:
|
| 607 |
+
"""
|
| 608 |
+
Args:
|
| 609 |
+
in_chans: dimension of input channels.
|
| 610 |
+
embed_dim: number of linear projection output channels.
|
| 611 |
+
window_size: local window size.
|
| 612 |
+
patch_size: patch size.
|
| 613 |
+
depths: number of layers in each stage.
|
| 614 |
+
num_heads: number of attention heads.
|
| 615 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 616 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 617 |
+
drop_rate: dropout rate.
|
| 618 |
+
attn_drop_rate: attention dropout rate.
|
| 619 |
+
drop_path_rate: stochastic depth rate.
|
| 620 |
+
norm_layer: normalization layer.
|
| 621 |
+
patch_norm: add normalization after patch embedding.
|
| 622 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 623 |
+
spatial_dims: spatial dimension.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
super().__init__()
|
| 627 |
+
self.num_layers = len(depths)
|
| 628 |
+
self.embed_dim = embed_dim
|
| 629 |
+
self.patch_norm = patch_norm
|
| 630 |
+
self.window_size = window_size
|
| 631 |
+
self.patch_size = patch_size
|
| 632 |
+
self.patch_embed = PatchEmbed(
|
| 633 |
+
patch_size=self.patch_size,
|
| 634 |
+
in_chans=in_chans,
|
| 635 |
+
embed_dim=embed_dim,
|
| 636 |
+
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
|
| 637 |
+
spatial_dims=spatial_dims,
|
| 638 |
+
)
|
| 639 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 640 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 641 |
+
# self.layers1 = nn.ModuleList()
|
| 642 |
+
# self.layers2 = nn.ModuleList()
|
| 643 |
+
# self.layers3 = nn.ModuleList()
|
| 644 |
+
# self.layers4 = nn.ModuleList()
|
| 645 |
+
self.layers = nn.ModuleList()
|
| 646 |
+
for i_layer in range(self.num_layers):
|
| 647 |
+
layer = BasicLayer(
|
| 648 |
+
dim=int(embed_dim * 2**i_layer),
|
| 649 |
+
depth=depths[i_layer],
|
| 650 |
+
num_heads=num_heads[i_layer],
|
| 651 |
+
window_size=self.window_size,
|
| 652 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
| 653 |
+
mlp_ratio=mlp_ratio,
|
| 654 |
+
qkv_bias=qkv_bias,
|
| 655 |
+
drop=drop_rate,
|
| 656 |
+
attn_drop=attn_drop_rate,
|
| 657 |
+
norm_layer=norm_layer,
|
| 658 |
+
downsample=PatchMerging,
|
| 659 |
+
use_checkpoint=use_checkpoint,
|
| 660 |
+
)
|
| 661 |
+
self.layers.append(layer)
|
| 662 |
+
# if i_layer == 0:
|
| 663 |
+
# self.layers1.append(layer)
|
| 664 |
+
# elif i_layer == 1:
|
| 665 |
+
# self.layers2.append(layer)
|
| 666 |
+
# elif i_layer == 2:
|
| 667 |
+
# self.layers3.append(layer)
|
| 668 |
+
# elif i_layer == 3:
|
| 669 |
+
# self.layers4.append(layer)
|
| 670 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 671 |
+
|
| 672 |
+
def proj_out(self, x, normalize=False):
|
| 673 |
+
if normalize:
|
| 674 |
+
x_shape = x.size()
|
| 675 |
+
if len(x_shape) == 5:
|
| 676 |
+
n, ch, d, h, w = x_shape
|
| 677 |
+
x = rearrange(x, "n c d h w -> n d h w c")
|
| 678 |
+
x = F.layer_norm(x, [ch])
|
| 679 |
+
x = rearrange(x, "n d h w c -> n c d h w")
|
| 680 |
+
elif len(x_shape) == 4:
|
| 681 |
+
n, ch, h, w = x_shape
|
| 682 |
+
x = rearrange(x, "n c h w -> n h w c")
|
| 683 |
+
x = F.layer_norm(x, [ch])
|
| 684 |
+
x = rearrange(x, "n h w c -> n c h w")
|
| 685 |
+
return x
|
| 686 |
+
|
| 687 |
+
def forward(self, x, normalize=True):
|
| 688 |
+
# x input: [B*sample, C(1), H, W, D]
|
| 689 |
+
# x = rearrange(x, "b c h w d -> b c d h w")
|
| 690 |
+
# print('>> input: ', x.shape)
|
| 691 |
+
x = self.patch_embed(x)
|
| 692 |
+
# print('>> patch_embed: ', x.shape)
|
| 693 |
+
x = self.pos_drop(x)
|
| 694 |
+
for layer in self.layers:
|
| 695 |
+
x = layer(x.contiguous())
|
| 696 |
+
# print('>> layer: ', x.shape)
|
| 697 |
+
return x
|
| 698 |
+
# # x0_out = self.proj_out(x0, normalize)
|
| 699 |
+
# x1 = self.layers1[0](x0.contiguous())
|
| 700 |
+
# # x1_out = self.proj_out(x1, normalize)
|
| 701 |
+
# x2 = self.layers2[0](x1.contiguous())
|
| 702 |
+
# # x2_out = self.proj_out(x2, normalize)
|
| 703 |
+
# x3 = self.layers3[0](x2.contiguous())
|
| 704 |
+
# # x3_out = self.proj_out(x3, normalize)
|
| 705 |
+
# x4 = self.layers4[0](x3.contiguous())
|
| 706 |
+
# # x4_out = self.proj_out(x4, normalize)
|
| 707 |
+
# # return [x0_out, x1_out, x2_out, x3_out, x4_out]
|
| 708 |
+
|
| 709 |
+
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from typing import Any, Optional, Tuple, Type
|
| 12 |
+
|
| 13 |
+
from .common import LayerNorm2d
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
class PromptEncoder(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
embed_dim: int,
|
| 20 |
+
image_embedding_size: Tuple[int, int, int],
|
| 21 |
+
input_image_size: Tuple[int, int, int],
|
| 22 |
+
mask_in_chans: int,
|
| 23 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 27 |
+
|
| 28 |
+
Arguments:
|
| 29 |
+
embed_dim (int): The prompts' embedding dimension
|
| 30 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 31 |
+
image embedding, as (H, W).
|
| 32 |
+
input_image_size (int): The padded size of the image as input
|
| 33 |
+
to the image encoder, as (H, W).
|
| 34 |
+
mask_in_chans (int): The number of hidden channels used for
|
| 35 |
+
encoding input masks.
|
| 36 |
+
activation (nn.Module): The activation to use when encoding
|
| 37 |
+
input masks.
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.embed_dim = embed_dim
|
| 41 |
+
self.input_image_size = input_image_size
|
| 42 |
+
self.image_embedding_size = image_embedding_size
|
| 43 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 44 |
+
|
| 45 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 46 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
| 47 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 48 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 49 |
+
|
| 50 |
+
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
|
| 51 |
+
self.mask_downscaling = nn.Sequential(
|
| 52 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 53 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 54 |
+
activation(),
|
| 55 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 56 |
+
LayerNorm2d(mask_in_chans),
|
| 57 |
+
activation(),
|
| 58 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 59 |
+
)
|
| 60 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 61 |
+
|
| 62 |
+
def get_dense_pe(self) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Returns the positional encoding used to encode point prompts,
|
| 65 |
+
applied to a dense set of points the shape of the image encoding.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.Tensor: Positional encoding with shape
|
| 69 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 70 |
+
"""
|
| 71 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 72 |
+
|
| 73 |
+
def _embed_points(
|
| 74 |
+
self,
|
| 75 |
+
points: torch.Tensor,
|
| 76 |
+
labels: torch.Tensor,
|
| 77 |
+
pad: bool,
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
"""Embeds point prompts."""
|
| 80 |
+
points = points + 0.5 # Shift to center of pixel
|
| 81 |
+
if pad:
|
| 82 |
+
padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
|
| 83 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 84 |
+
points = torch.cat([points, padding_point], dim=1)
|
| 85 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
| 86 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
| 87 |
+
point_embedding[labels == -1] = 0.0
|
| 88 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 89 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
| 90 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
| 91 |
+
return point_embedding
|
| 92 |
+
|
| 93 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Embeds box prompts."""
|
| 95 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 96 |
+
coords = boxes.reshape(-1, 2, 3)
|
| 97 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
| 98 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 99 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 100 |
+
return corner_embedding
|
| 101 |
+
|
| 102 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""Embeds mask inputs."""
|
| 104 |
+
mask_embedding = self.mask_downscaling(masks)
|
| 105 |
+
return mask_embedding
|
| 106 |
+
|
| 107 |
+
def _get_batch_size(
|
| 108 |
+
self,
|
| 109 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 110 |
+
boxes: Optional[torch.Tensor],
|
| 111 |
+
masks: Optional[torch.Tensor],
|
| 112 |
+
text_embedding: Optional[torch.Tensor],
|
| 113 |
+
) -> int:
|
| 114 |
+
"""
|
| 115 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 116 |
+
"""
|
| 117 |
+
if points is not None:
|
| 118 |
+
return points[0].shape[0]
|
| 119 |
+
elif boxes is not None:
|
| 120 |
+
return boxes.shape[0]
|
| 121 |
+
elif masks is not None:
|
| 122 |
+
return masks.shape[0]
|
| 123 |
+
elif text_embedding is not None:
|
| 124 |
+
return text_embedding.shape[0]
|
| 125 |
+
else:
|
| 126 |
+
return 1
|
| 127 |
+
|
| 128 |
+
def _get_device(self) -> torch.device:
|
| 129 |
+
return self.point_embeddings[0].weight.device
|
| 130 |
+
|
| 131 |
+
def forward(
|
| 132 |
+
self,
|
| 133 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 134 |
+
boxes: Optional[torch.Tensor],
|
| 135 |
+
masks: Optional[torch.Tensor],
|
| 136 |
+
text_embedding: Optional[torch.Tensor],
|
| 137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 138 |
+
"""
|
| 139 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 140 |
+
embeddings.
|
| 141 |
+
|
| 142 |
+
Arguments:
|
| 143 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 144 |
+
and labels to embed.
|
| 145 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 146 |
+
masks (torch.Tensor or none): masks to embed
|
| 147 |
+
text: test prompt (B, 768)
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 151 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 152 |
+
and boxes.
|
| 153 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 154 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 155 |
+
"""
|
| 156 |
+
# print('prompt encoder here...')
|
| 157 |
+
|
| 158 |
+
bs = self._get_batch_size(points, boxes, masks, text_embedding)
|
| 159 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
| 160 |
+
# print('sparse_embeddings ', sparse_embeddings.shape)
|
| 161 |
+
if points is not None:
|
| 162 |
+
coords, labels = points
|
| 163 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 164 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 165 |
+
|
| 166 |
+
if boxes is not None:
|
| 167 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 168 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 169 |
+
|
| 170 |
+
if text_embedding is not None:
|
| 171 |
+
sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
|
| 172 |
+
|
| 173 |
+
# print('box_embeddings ', box_embeddings.shape)
|
| 174 |
+
# print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
|
| 175 |
+
|
| 176 |
+
if masks is not None:
|
| 177 |
+
dense_embeddings = self._embed_masks(masks)
|
| 178 |
+
else:
|
| 179 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
|
| 180 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]
|
| 181 |
+
)
|
| 182 |
+
# print('dense_embeddings ', dense_embeddings.shape)
|
| 183 |
+
return sparse_embeddings, dense_embeddings
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
Positional encoding using random spatial frequencies.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 192 |
+
super().__init__()
|
| 193 |
+
if scale is None or scale <= 0.0:
|
| 194 |
+
scale = 1.0
|
| 195 |
+
self.register_buffer(
|
| 196 |
+
"positional_encoding_gaussian_matrix",
|
| 197 |
+
scale * torch.randn((3, num_pos_feats)),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 201 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 202 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 203 |
+
coords = 2 * coords - 1
|
| 204 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 205 |
+
coords = 2 * np.pi * coords
|
| 206 |
+
# outputs d_1 x ... x d_n x C shape
|
| 207 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 208 |
+
|
| 209 |
+
def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
|
| 210 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 211 |
+
h, w, d = size
|
| 212 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 213 |
+
grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
|
| 214 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 215 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 216 |
+
z_embed = grid.cumsum(dim=2) - 0.5
|
| 217 |
+
y_embed = y_embed / h
|
| 218 |
+
x_embed = x_embed / w
|
| 219 |
+
z_embed = z_embed / d
|
| 220 |
+
|
| 221 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
|
| 222 |
+
return pe.permute(3, 0, 1, 2) # C x H x W x D
|
| 223 |
+
|
| 224 |
+
def forward_with_coords(
|
| 225 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 228 |
+
coords = coords_input.clone()
|
| 229 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 230 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 231 |
+
coords[:, :, 2] = coords[:, :, 2] / image_size[2]
|
| 232 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
model/segment_anything_volumetric/modeling/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .sam import Sam
|
| 8 |
+
from .image_encoder import ImageEncoderViT
|
| 9 |
+
from .mask_decoder import MaskDecoder
|
| 10 |
+
from .prompt_encoder import PromptEncoder
|
| 11 |
+
from .transformer import TwoWayTransformer
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (394 Bytes). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (424 Bytes). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc
ADDED
|
Binary file (1.75 kB). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc
ADDED
|
Binary file (5.5 kB). View file
|
|
|