Tournesol-Saturday commited on
Commit
c01dd14
·
verified ·
1 Parent(s): 0dc7c0f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +136 -0
  2. railnet_model.py +948 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import numpy as np
4
+ import gradio as gr
5
+ import plotly.graph_objects as go
6
+ from railnet_model import RailNetSystem
7
+
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
+
13
+ model = RailNetSystem.from_pretrained("Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image").cuda()
14
+
15
+ model.load_weights(from_hub=True, repo_id="Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image")
16
+
17
+ def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25):
18
+ downsample_factor = 2
19
+ pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor]
20
+
21
+ fig = go.Figure(data=go.Volume(
22
+ x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]),
23
+ y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]),
24
+ z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]),
25
+ value=pred_ds.flatten(),
26
+ isomin=0.5,
27
+ isomax=1.0,
28
+ opacity=0.1,
29
+ surface_count=1,
30
+ colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']],
31
+ showscale=False
32
+ ))
33
+
34
+ fig.update_layout(
35
+ scene=dict(
36
+ xaxis=dict(visible=False),
37
+ yaxis=dict(visible=False),
38
+ zaxis=dict(visible=False),
39
+ camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye))
40
+ ),
41
+ margin=dict(l=0, r=0, b=0, t=0)
42
+ )
43
+ return fig
44
+
45
+ def handle_example(filename):
46
+ repo_id = "Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image"
47
+ h5_path = hf_hub_download(repo_id=repo_id, filename=f"example_input_file/{filename}")
48
+
49
+ with h5py.File(h5_path, "r") as f:
50
+ image = f["image"][:]
51
+ label = f["label"][:]
52
+
53
+ name = filename.replace(".h5", "")
54
+ pred, dice, jc, hd, asd = model(image, label, "./output", name)
55
+
56
+ fig = render_plotly_volume(pred)
57
+
58
+ img_path = f"./output/{name}_img.nii.gz"
59
+ pred_path = f"./output/{name}_pred.nii.gz"
60
+
61
+ metrics = f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}"
62
+
63
+ return metrics, pred, fig, img_path, pred_path
64
+
65
+ def clear_all():
66
+ return "", None, None, None, None
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.HTML("<div style='text-align: center; font-size: 22px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>")
70
+ gr.HTML("<div style='text-align: center; font-size: 15px'>✅ Steps: Select a CBCT example file (.h5) → Automatic inference and metrics display → View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>")
71
+
72
+ gr.HTML("""
73
+ <style>
74
+ .code-style {
75
+ font-family: monospace;
76
+ background-color: #2f363d;
77
+ color: #ffffff;
78
+ padding: 2px 6px;
79
+ border-radius: 4px;
80
+ font-size: 90%;
81
+ }
82
+ </style>
83
+
84
+ <div style='font-size: 15px; font-weight: bold;'>
85
+ 📂 Step 1: Select a <span class='code-style'>.h5</span> example file from the <span class='code-style'>example_input_file</span> folder in our
86
+ <a href='https://huggingface.co/Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image' target='_blank' style='text-decoration: none; color: #1f6feb; font-weight: bold;'>
87
+ Hugging Face model
88
+ </a> repository.
89
+ </div>
90
+ """)
91
+
92
+ example_files = ["CBCT_01.h5", "CBCT_02.h5", "CBCT_03.h5", "CBCT_04.h5"]
93
+ dropdown = gr.Dropdown(choices=example_files, label="Example File", value=example_files[0])
94
+
95
+
96
+ with gr.Row():
97
+ clear_btn = gr.Button("清除", variant="secondary")
98
+ submit_btn = gr.Button("提交", variant="primary")
99
+
100
+ gr.HTML("<div style='font-size: 15px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>")
101
+ result_text = gr.Textbox()
102
+ hidden_pred = gr.State(value=None)
103
+
104
+ gr.HTML("<div style='font-size: 15px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>")
105
+ plot_output = gr.Plot()
106
+
107
+ # hidden_img_file = gr.File(visible=False)
108
+ # hidden_pred_file = gr.File(visible=False)
109
+
110
+ gr.HTML("<div style='font-size: 15px; font-weight: bold;'>⬇️ Step 4: Download <span class='code-style'>NIfTI</span> files for accurate 1:1 visualization using <span class='code-style'>ITK-SNAP</span> software</div>")
111
+ with gr.Row():
112
+ hidden_img_file = gr.File(label="Download Original Image", interactive=False)
113
+ hidden_pred_file = gr.File(label="Download Segmentation Result", interactive=False)
114
+
115
+ submit_btn.click(
116
+ fn=handle_example,
117
+ inputs=[dropdown],
118
+ outputs=[result_text, hidden_pred, plot_output, hidden_img_file, hidden_pred_file]
119
+ )
120
+
121
+ # def update_view(pred, x_eye, y_eye, z_eye):
122
+ # if pred is None:
123
+ # return gr.update()
124
+ # return render_plotly_volume(pred, x_eye, y_eye, z_eye)
125
+
126
+ clear_btn.click(
127
+ fn=clear_all,
128
+ inputs=[],
129
+ outputs=[result_text, hidden_pred, plot_output, hidden_img_file, hidden_pred_file]
130
+ )
131
+
132
+ # download_img_btn.click(fn=lambda f: f, inputs=[hidden_img_file], outputs=[hidden_img_file])
133
+ # download_pred_btn.click(fn=lambda f: f, inputs=[hidden_pred_file], outputs=[hidden_pred_file])
134
+
135
+ demo.launch()
136
+
railnet_model.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ import numpy as np
11
+ import nibabel as nib
12
+ from skimage import morphology
13
+
14
+ import math
15
+ from scipy import ndimage
16
+ from medpy import metric
17
+
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ class ConvBlock(nn.Module):
22
+ def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
23
+ super(ConvBlock, self).__init__()
24
+
25
+ ops = []
26
+ for i in range(n_stages):
27
+ if i == 0:
28
+ input_channel = n_filters_in
29
+ else:
30
+ input_channel = n_filters_out
31
+
32
+ ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
33
+ if normalization == 'batchnorm':
34
+ ops.append(nn.BatchNorm3d(n_filters_out))
35
+ elif normalization == 'groupnorm':
36
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
37
+ elif normalization == 'instancenorm':
38
+ ops.append(nn.InstanceNorm3d(n_filters_out))
39
+ elif normalization != 'none':
40
+ assert False
41
+ ops.append(nn.ReLU(inplace=True))
42
+
43
+ self.conv = nn.Sequential(*ops)
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class DownsamplingConvBlock(nn.Module):
51
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
52
+ super(DownsamplingConvBlock, self).__init__()
53
+
54
+ ops = []
55
+ if normalization != 'none':
56
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
57
+ if normalization == 'batchnorm':
58
+ ops.append(nn.BatchNorm3d(n_filters_out))
59
+ elif normalization == 'groupnorm':
60
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
61
+ elif normalization == 'instancenorm':
62
+ ops.append(nn.InstanceNorm3d(n_filters_out))
63
+ else:
64
+ assert False
65
+ else:
66
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
67
+
68
+ ops.append(nn.ReLU(inplace=True))
69
+
70
+ self.conv = nn.Sequential(*ops)
71
+
72
+ def forward(self, x):
73
+ x = self.conv(x)
74
+ return x
75
+
76
+
77
+ class UpsamplingDeconvBlock(nn.Module):
78
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
79
+ super(UpsamplingDeconvBlock, self).__init__()
80
+
81
+ ops = []
82
+ if normalization != 'none':
83
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84
+ if normalization == 'batchnorm':
85
+ ops.append(nn.BatchNorm3d(n_filters_out))
86
+ elif normalization == 'groupnorm':
87
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
88
+ elif normalization == 'instancenorm':
89
+ ops.append(nn.InstanceNorm3d(n_filters_out))
90
+ else:
91
+ assert False
92
+ else:
93
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
94
+
95
+ ops.append(nn.ReLU(inplace=True))
96
+
97
+ self.conv = nn.Sequential(*ops)
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ return x
102
+
103
+
104
+ class Upsampling(nn.Module):
105
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
106
+ super(Upsampling, self).__init__()
107
+
108
+ ops = []
109
+ ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
110
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
111
+ if normalization == 'batchnorm':
112
+ ops.append(nn.BatchNorm3d(n_filters_out))
113
+ elif normalization == 'groupnorm':
114
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
115
+ elif normalization == 'instancenorm':
116
+ ops.append(nn.InstanceNorm3d(n_filters_out))
117
+ elif normalization != 'none':
118
+ assert False
119
+ ops.append(nn.ReLU(inplace=True))
120
+
121
+ self.conv = nn.Sequential(*ops)
122
+
123
+ def forward(self, x):
124
+ x = self.conv(x)
125
+ return x
126
+
127
+
128
+ class ConnectNet(nn.Module):
129
+ def __init__(self, in_channels, out_channels, input_size):
130
+ super(ConnectNet, self).__init__()
131
+ self.encoder = nn.Sequential(
132
+ nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
133
+ nn.ReLU(),
134
+ nn.MaxPool3d(kernel_size=2, stride=2),
135
+ nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
136
+ nn.ReLU(),
137
+ nn.MaxPool3d(kernel_size=2, stride=2)
138
+ )
139
+
140
+ self.decoder = nn.Sequential(
141
+ nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
142
+ nn.ReLU(),
143
+ nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
144
+ nn.Sigmoid()
145
+ )
146
+
147
+ def forward(self, x):
148
+ encoded = self.encoder(x)
149
+ decoded = self.decoder(encoded)
150
+ return decoded
151
+
152
+
153
+ class VNet(nn.Module):
154
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
155
+ super(VNet, self).__init__()
156
+ self.has_dropout = has_dropout
157
+
158
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
159
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
160
+
161
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
162
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
163
+
164
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
165
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
166
+
167
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
168
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
169
+
170
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
171
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
172
+
173
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
174
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
175
+
176
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
177
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
178
+
179
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
180
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
181
+
182
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
183
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
184
+
185
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
186
+
187
+ self.__init_weight()
188
+
189
+ def encoder(self, input):
190
+ x1 = self.block_one(input)
191
+ x1_dw = self.block_one_dw(x1)
192
+
193
+ x2 = self.block_two(x1_dw)
194
+ x2_dw = self.block_two_dw(x2)
195
+
196
+ x3 = self.block_three(x2_dw)
197
+ x3_dw = self.block_three_dw(x3)
198
+
199
+ x4 = self.block_four(x3_dw)
200
+ x4_dw = self.block_four_dw(x4)
201
+
202
+ x5 = self.block_five(x4_dw)
203
+ if self.has_dropout:
204
+ x5 = self.dropout(x5)
205
+
206
+ res = [x1, x2, x3, x4, x5]
207
+
208
+ return res
209
+
210
+ def decoder(self, features):
211
+ x1 = features[0]
212
+ x2 = features[1]
213
+ x3 = features[2]
214
+ x4 = features[3]
215
+ x5 = features[4]
216
+
217
+ x5_up = self.block_five_up(x5)
218
+ x5_up = x5_up + x4
219
+
220
+ x6 = self.block_six(x5_up)
221
+ x6_up = self.block_six_up(x6)
222
+ x6_up = x6_up + x3
223
+
224
+ x7 = self.block_seven(x6_up)
225
+ x7_up = self.block_seven_up(x7)
226
+ x7_up = x7_up + x2
227
+
228
+ x8 = self.block_eight(x7_up)
229
+ x8_up = self.block_eight_up(x8)
230
+ x8_up = x8_up + x1
231
+ x9 = self.block_nine(x8_up)
232
+ if self.has_dropout:
233
+ x9 = self.dropout(x9)
234
+ out = self.out_conv(x9)
235
+ return out
236
+
237
+ def forward(self, input, turnoff_drop=False):
238
+ if turnoff_drop:
239
+ has_dropout = self.has_dropout
240
+ self.has_dropout = False
241
+ features = self.encoder(input)
242
+ out = self.decoder(features)
243
+ if turnoff_drop:
244
+ self.has_dropout = has_dropout
245
+ return out
246
+
247
+ def __init_weight(self):
248
+ for m in self.modules():
249
+ if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
250
+ torch.nn.init.kaiming_normal_(m.weight)
251
+ elif isinstance(m, nn.BatchNorm3d):
252
+ m.weight.data.fill_(1)
253
+ m.bias.data.zero_()
254
+
255
+
256
+ class VNet_roi(nn.Module):
257
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
258
+ super(VNet_roi, self).__init__()
259
+ self.has_dropout = has_dropout
260
+
261
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
262
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
263
+
264
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
265
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
266
+
267
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
268
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
269
+
270
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
271
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
272
+
273
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
274
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
275
+
276
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
277
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
278
+
279
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
280
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
281
+
282
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
283
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
284
+
285
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
286
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
287
+
288
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
289
+ # self.__init_weight()
290
+
291
+ def encoder(self, input):
292
+ x1 = self.block_one(input)
293
+ x1_dw = self.block_one_dw(x1)
294
+
295
+ x2 = self.block_two(x1_dw)
296
+ x2_dw = self.block_two_dw(x2)
297
+
298
+ x3 = self.block_three(x2_dw)
299
+ x3_dw = self.block_three_dw(x3)
300
+
301
+ x4 = self.block_four(x3_dw)
302
+ x4_dw = self.block_four_dw(x4)
303
+
304
+ x5 = self.block_five(x4_dw)
305
+ # x5 = F.dropout3d(x5, p=0.5, training=True)
306
+ if self.has_dropout:
307
+ x5 = self.dropout(x5)
308
+
309
+ res = [x1, x2, x3, x4, x5]
310
+
311
+ return res
312
+
313
+ def decoder(self, features):
314
+ x1 = features[0]
315
+ x2 = features[1]
316
+ x3 = features[2]
317
+ x4 = features[3]
318
+ x5 = features[4]
319
+
320
+ x5_up = self.block_five_up(x5)
321
+ x5_up = x5_up + x4
322
+
323
+ x6 = self.block_six(x5_up)
324
+ x6_up = self.block_six_up(x6)
325
+ x6_up = x6_up + x3
326
+
327
+ x7 = self.block_seven(x6_up)
328
+ x7_up = self.block_seven_up(x7)
329
+ x7_up = x7_up + x2
330
+
331
+ x8 = self.block_eight(x7_up)
332
+ x8_up = self.block_eight_up(x8)
333
+ x8_up = x8_up + x1
334
+ x9 = self.block_nine(x8_up)
335
+ # x9 = F.dropout3d(x9, p=0.5, training=True)
336
+ if self.has_dropout:
337
+ x9 = self.dropout(x9)
338
+ out = self.out_conv(x9)
339
+ return out
340
+
341
+
342
+ def forward(self, input, turnoff_drop=False):
343
+ if turnoff_drop:
344
+ has_dropout = self.has_dropout
345
+ self.has_dropout = False
346
+ features = self.encoder(input)
347
+ out = self.decoder(features)
348
+ if turnoff_drop:
349
+ self.has_dropout = has_dropout
350
+ return out
351
+
352
+
353
+ class ResVNet(nn.Module):
354
+ def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
355
+ super(ResVNet, self).__init__()
356
+ self.resencoder = resnet34()
357
+ self.has_dropout = has_dropout
358
+
359
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
360
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
361
+
362
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
363
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
364
+
365
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
366
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
367
+
368
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
369
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
370
+
371
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
372
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
373
+
374
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
375
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
376
+
377
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
378
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
379
+
380
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
381
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
382
+
383
+
384
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
385
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
386
+
387
+
388
+ if has_dropout:
389
+ self.dropout = nn.Dropout3d(p=0.5)
390
+ self.branchs = nn.ModuleList()
391
+ for i in range(1):
392
+ if has_dropout:
393
+ seq = nn.Sequential(
394
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
395
+ nn.Dropout3d(p=0.5),
396
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
397
+ )
398
+ else:
399
+ seq = nn.Sequential(
400
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
401
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
402
+ )
403
+ self.branchs.append(seq)
404
+
405
+ def encoder(self, input):
406
+ x1 = self.block_one(input)
407
+ x1_dw = self.block_one_dw(x1)
408
+
409
+ x2 = self.block_two(x1_dw)
410
+ x2_dw = self.block_two_dw(x2)
411
+
412
+ x3 = self.block_three(x2_dw)
413
+ x3_dw = self.block_three_dw(x3)
414
+
415
+ x4 = self.block_four(x3_dw)
416
+ x4_dw = self.block_four_dw(x4)
417
+
418
+ x5 = self.block_five(x4_dw)
419
+
420
+ if self.has_dropout:
421
+ x5 = self.dropout(x5)
422
+
423
+ res = [x1, x2, x3, x4, x5]
424
+
425
+ return res
426
+
427
+ def decoder(self, features):
428
+ x1 = features[0]
429
+ x2 = features[1]
430
+ x3 = features[2]
431
+ x4 = features[3]
432
+ x5 = features[4]
433
+
434
+ x5_up = self.block_five_up(x5)
435
+ x5_up = x5_up + x4
436
+
437
+ x6 = self.block_six(x5_up)
438
+ x6_up = self.block_six_up(x6)
439
+ x6_up = x6_up + x3
440
+
441
+ x7 = self.block_seven(x6_up)
442
+ x7_up = self.block_seven_up(x7)
443
+ x7_up = x7_up + x2
444
+
445
+ x8 = self.block_eight(x7_up)
446
+ x8_up = self.block_eight_up(x8)
447
+ x8_up = x8_up + x1
448
+
449
+
450
+ x9 = self.block_nine(x8_up)
451
+
452
+ out = self.out_conv(x9)
453
+
454
+
455
+ return out
456
+
457
+ def forward(self, input, turnoff_drop=False):
458
+ if turnoff_drop:
459
+ has_dropout = self.has_dropout
460
+ self.has_dropout = False
461
+ features = self.resencoder(input)
462
+ out = self.decoder(features)
463
+ if turnoff_drop:
464
+ self.has_dropout = has_dropout
465
+ return out
466
+
467
+
468
+ __all__ = ['ResNet', 'resnet34']
469
+
470
+
471
+ def conv3x3(in_planes, out_planes, stride=1):
472
+ """3x3 convolution with padding"""
473
+ return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
474
+
475
+
476
+ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
477
+ return nn.Sequential(
478
+ conv3x3(in_planes, out_planes, stride),
479
+ nn.InstanceNorm3d(out_planes),
480
+ nn.ReLU()
481
+ )
482
+
483
+
484
+ class BasicBlock(nn.Module):
485
+ expansion = 1
486
+
487
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
488
+ groups=1, base_width=64, dilation=-1):
489
+ super(BasicBlock, self).__init__()
490
+ if groups != 1 or base_width != 64:
491
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
492
+ self.conv1 = conv3x3(inplanes, planes, stride)
493
+ self.bn1 = nn.InstanceNorm3d(planes)
494
+ self.relu = nn.ReLU(inplace=True)
495
+ self.conv2 = conv3x3(planes, planes)
496
+ self.bn2 = nn.InstanceNorm3d(planes)
497
+ self.downsample = downsample
498
+ self.stride = stride
499
+
500
+ def forward(self, x):
501
+ residual = x
502
+
503
+ out = self.conv1(x)
504
+ out = self.bn1(out)
505
+ out = self.relu(out)
506
+
507
+ out = self.conv2(out)
508
+ out = self.bn2(out)
509
+
510
+ if self.downsample is not None:
511
+ residual = self.downsample(x)
512
+
513
+ out += residual
514
+ out = self.relu(out)
515
+
516
+ return out
517
+
518
+
519
+ class Bottleneck(nn.Module):
520
+ expansion = 4
521
+
522
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
523
+ groups=1, base_width=64, dilation=1):
524
+ super(Bottleneck, self).__init__()
525
+ width = int(planes * (base_width / 64.)) * groups
526
+ self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
527
+ self.bn1 = nn.InstanceNorm3d(width)
528
+ self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
529
+ padding=dilation, groups=groups, bias=False)
530
+ self.bn2 = nn.InstanceNorm3d(width)
531
+ self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
532
+ self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
533
+ self.relu = nn.ReLU(inplace=True)
534
+ self.downsample = downsample
535
+ self.stride = stride
536
+
537
+ def forward(self, x):
538
+ residual = x
539
+
540
+ out = self.conv1(x)
541
+ out = self.bn1(out)
542
+ out = self.relu(out)
543
+
544
+ out = self.conv2(out)
545
+ out = self.bn2(out)
546
+ out = self.relu(out)
547
+
548
+ out = self.conv3(out)
549
+ out = self.bn3(out)
550
+
551
+ if self.downsample is not None:
552
+ residual = self.downsample(x)
553
+
554
+ out += residual
555
+ out = self.relu(out)
556
+
557
+ return out
558
+
559
+
560
+ class ResNet(nn.Module):
561
+
562
+ def __init__(self, block, layers, in_channel=1, width=1,
563
+ groups=1, width_per_group=64,
564
+ mid_dim=1024, low_dim=128,
565
+ avg_down=False, deep_stem=False,
566
+ head_type='mlp_head', layer4_dilation=1):
567
+ super(ResNet, self).__init__()
568
+ self.avg_down = avg_down
569
+ self.inplanes = 16 * width
570
+ self.base = int(16 * width)
571
+ self.groups = groups
572
+ self.base_width = width_per_group
573
+
574
+ mid_dim = self.base * 8 * block.expansion
575
+
576
+ if deep_stem:
577
+ self.conv1 = nn.Sequential(
578
+ conv3x3_bn_relu(in_channel, 32, stride=2),
579
+ conv3x3_bn_relu(32, 32, stride=1),
580
+ conv3x3(32, 64, stride=1)
581
+ )
582
+ else:
583
+ self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
584
+
585
+ self.bn1 = nn.InstanceNorm3d(self.inplanes)
586
+ self.relu = nn.ReLU(inplace=True)
587
+
588
+ self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
589
+ self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
590
+ self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
591
+ self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
592
+ if layer4_dilation == 1:
593
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
594
+ elif layer4_dilation == 2:
595
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
596
+ else:
597
+ raise NotImplementedError
598
+ self.avgpool = nn.AvgPool3d(7, stride=1)
599
+
600
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
601
+ downsample = None
602
+ if stride != 1 or self.inplanes != planes * block.expansion:
603
+ if self.avg_down:
604
+ downsample = nn.Sequential(
605
+ nn.AvgPool3d(kernel_size=stride, stride=stride),
606
+ nn.Conv3d(self.inplanes, planes * block.expansion,
607
+ kernel_size=1, stride=1, bias=False),
608
+ nn.InstanceNorm3d(planes * block.expansion),
609
+ )
610
+ else:
611
+ downsample = nn.Sequential(
612
+ nn.Conv3d(self.inplanes, planes * block.expansion,
613
+ kernel_size=1, stride=stride, bias=False),
614
+ nn.InstanceNorm3d(planes * block.expansion),
615
+ )
616
+
617
+ layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
618
+ self.inplanes = planes * block.expansion
619
+ for _ in range(1, blocks):
620
+ layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
621
+
622
+ return nn.Sequential(*layers)
623
+
624
+ def forward(self, x):
625
+ x = self.conv1(x)
626
+ x = self.bn1(x)
627
+ x = self.relu(x)
628
+ #c2 = self.maxpool(x)
629
+ c2 = self.layer1(x)
630
+ c3 = self.layer2(c2)
631
+ c4 = self.layer3(c3)
632
+ c5 = self.layer4(c4)
633
+
634
+
635
+ return [x,c2,c3,c4,c5]
636
+
637
+
638
+ def resnet34(**kwargs):
639
+ return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
640
+
641
+
642
+ def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
643
+ w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
644
+ # resize label map (int)
645
+ if flag == 'trilinear':
646
+ teeth_ids = np.unique(image_label)
647
+ image_label_ori = np.zeros((w_ori, h_ori, z_ori))
648
+
649
+
650
+ image_label = torch.from_numpy(image_label).cuda(0)
651
+
652
+
653
+ for label_id in range(len(teeth_ids)):
654
+ image_label_bn = (image_label == teeth_ids[label_id]).float()
655
+ image_label_bn = image_label_bn[None, None, :, :, :]
656
+ image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
657
+ mode='trilinear', align_corners=False)
658
+ image_label_bn = image_label_bn[0, 0, :, :, :]
659
+ image_label_bn = image_label_bn.cpu().data.numpy()
660
+ image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
661
+ image_label = image_label_ori
662
+
663
+ if flag == 'nearest':
664
+
665
+
666
+ image_label = torch.from_numpy(image_label).cuda(0)
667
+
668
+
669
+ image_label = image_label[None, None, :, :, :].float()
670
+ image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
671
+ image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
672
+ return image_label
673
+
674
+
675
+ def img_crop(image_bbox):
676
+ if image_bbox.sum() > 0:
677
+
678
+ x_min = np.nonzero(image_bbox)[0].min() - 8
679
+ x_max = np.nonzero(image_bbox)[0].max() + 8
680
+
681
+ y_min = np.nonzero(image_bbox)[1].min() - 16
682
+ y_max = np.nonzero(image_bbox)[1].max() + 16
683
+
684
+ z_min = np.nonzero(image_bbox)[2].min() - 16
685
+ z_max = np.nonzero(image_bbox)[2].max() + 16
686
+
687
+ if x_min < 0:
688
+ x_min = 0
689
+ if y_min < 0:
690
+ y_min = 0
691
+ if z_min < 0:
692
+ z_min = 0
693
+ if x_max > image_bbox.shape[0]:
694
+ x_max = image_bbox.shape[0]
695
+ if y_max > image_bbox.shape[1]:
696
+ y_max = image_bbox.shape[1]
697
+ if z_max > image_bbox.shape[2]:
698
+ z_max = image_bbox.shape[2]
699
+
700
+ if (x_max - x_min) % 16 != 0:
701
+ x_max -= (x_max - x_min) % 16
702
+ if (y_max - y_min) % 16 != 0:
703
+ y_max -= (y_max - y_min) % 16
704
+ if (z_max - z_min) % 16 != 0:
705
+ z_max -= (z_max - z_min) % 16
706
+
707
+ if image_bbox.sum() == 0:
708
+ x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
709
+ 2]
710
+ return x_min, x_max, y_min, y_max, z_min, z_max
711
+
712
+
713
+ def roi_extraction(image, net_roi, ids):
714
+ w, h, d = image.shape
715
+ # roi binary segmentation parameters, the input spacing is 0.4 mm
716
+ print('---run the roi binary segmentation.')
717
+
718
+ stride_xy = 32
719
+ stride_z = 16
720
+ patch_size_roi_stage = (112, 112, 80)
721
+
722
+ label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
723
+ patch_size_roi_stage) # (400,400,200)
724
+ print(label_roi.shape, np.max(label_roi))
725
+ label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400)
726
+
727
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
728
+
729
+ label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
730
+
731
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
732
+ float)
733
+
734
+ label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
735
+
736
+ # crop image
737
+ x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
738
+ if x_min == -1: # non-foreground label
739
+ whole_label = np.zeros((w, h, d))
740
+ return whole_label
741
+ image = image[x_min:x_max, y_min:y_max, z_min:z_max]
742
+ print("image shape(after roi): ", image.shape)
743
+
744
+ return image, x_min, x_max, y_min, y_max, z_min, z_max
745
+
746
+
747
+ def roi_detection(net, image, stride_xy, stride_z, patch_size):
748
+ w, h, d = image.shape # (400,400,200)
749
+
750
+ # if the size of image is less than patch_size, then padding it
751
+ add_pad = False
752
+ if w < patch_size[0]:
753
+ w_pad = patch_size[0] - w
754
+ add_pad = True
755
+ else:
756
+ w_pad = 0
757
+ if h < patch_size[1]:
758
+ h_pad = patch_size[1] - h
759
+ add_pad = True
760
+ else:
761
+ h_pad = 0
762
+ if d < patch_size[2]:
763
+ d_pad = patch_size[2] - d
764
+ add_pad = True
765
+ else:
766
+ d_pad = 0
767
+ wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
768
+ hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
769
+ dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
770
+ if add_pad:
771
+ image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
772
+ constant_values=0)
773
+ ww, hh, dd = image.shape
774
+
775
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2
776
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2
777
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2
778
+ score_map = np.zeros((2,) + image.shape).astype(np.float32)
779
+ cnt = np.zeros(image.shape).astype(np.float32)
780
+ count = 0
781
+ for x in range(0, sx):
782
+ xs = min(stride_xy * x, ww - patch_size[0])
783
+ for y in range(0, sy):
784
+ ys = min(stride_xy * y, hh - patch_size[1])
785
+ for z in range(0, sz):
786
+ zs = min(stride_z * z, dd - patch_size[2])
787
+ test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
788
+ zs:zs + patch_size[2]]
789
+ test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
790
+ np.float32)
791
+
792
+
793
+ test_patch = torch.from_numpy(test_patch).cuda(0)
794
+
795
+
796
+ with torch.no_grad():
797
+ y1 = net(test_patch) # (1,2,256,256,160)
798
+ y = F.softmax(y1, dim=1) # (1,2,256,256,160)
799
+ y = y.cpu().data.numpy()
800
+ y = y[0, :, :, :, :] # (2,256,256,160)
801
+ score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
802
+ = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
803
+ zs:zs + patch_size[2]] + y # (2,400,400,200)
804
+ cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
805
+ = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200)
806
+ count = count + 1
807
+ score_map = score_map / np.expand_dims(cnt, axis=0)
808
+
809
+ label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1
810
+ if add_pad:
811
+ label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
812
+ score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
813
+ return label_map
814
+
815
+
816
+ def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
817
+ w, h, d = image.shape
818
+
819
+ # if the size of image is less than patch_size, then padding it
820
+ add_pad = False
821
+ if w < patch_size[0]:
822
+ w_pad = patch_size[0]-w
823
+ add_pad = True
824
+ else:
825
+ w_pad = 0
826
+ if h < patch_size[1]:
827
+ h_pad = patch_size[1]-h
828
+ add_pad = True
829
+ else:
830
+ h_pad = 0
831
+ if d < patch_size[2]:
832
+ d_pad = patch_size[2]-d
833
+ add_pad = True
834
+ else:
835
+ d_pad = 0
836
+ wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
837
+ hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
838
+ dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
839
+ if add_pad:
840
+ image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
841
+
842
+ ww,hh,dd = image.shape
843
+
844
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
845
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
846
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
847
+ score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
848
+ cnt = np.zeros(image.shape).astype(np.float32)
849
+
850
+ for x in range(0, sx):
851
+ xs = min(stride_xy*x, ww-patch_size[0])
852
+ for y in range(0, sy):
853
+ ys = min(stride_xy * y,hh-patch_size[1])
854
+ for z in range(0, sz):
855
+ zs = min(stride_z * z, dd-patch_size[2])
856
+ test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
857
+ test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
858
+
859
+
860
+ test_patch = torch.from_numpy(test_patch).cuda()
861
+
862
+
863
+ for model in model_array:
864
+ output = model(test_patch)
865
+ y_temp = F.softmax(output, dim=1)
866
+ y_temp = y_temp.cpu().data.numpy()
867
+ y += y_temp[0,:,:,:,:]
868
+ y /= len(model_array)
869
+ score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
870
+ = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
871
+ cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
872
+ = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
873
+ score_map = score_map/np.expand_dims(cnt,axis=0)
874
+
875
+ label_map = np.argmax(score_map, axis = 0)
876
+ if add_pad:
877
+ label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
878
+ score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
879
+ return label_map, score_map
880
+
881
+ def calculate_metric_percase(pred, gt):
882
+ dice = metric.binary.dc(pred, gt)
883
+ jc = metric.binary.jc(pred, gt)
884
+ hd = metric.binary.hd95(pred, gt)
885
+ asd = metric.binary.asd(pred, gt)
886
+
887
+ return dice, jc, hd, asd
888
+
889
+
890
+ class RailNetSystem(nn.Module, PyTorchModelHubMixin):
891
+ def __init__(self, n_channels: int, n_classes: int, normalization: str):
892
+ super().__init__()
893
+
894
+ self.num_classes = 2
895
+
896
+
897
+ self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
898
+
899
+
900
+ self.model_array = []
901
+ for i in range(4):
902
+ if i < 2:
903
+ model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
904
+ else:
905
+ model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
906
+ self.model_array.append(model)
907
+
908
+ def load_weights(self, weight_dir=".", from_hub=False, repo_id=None):
909
+ def load(file_name):
910
+ if from_hub:
911
+ return hf_hub_download(repo_id=repo_id, filename=f"model weights/{file_name}")
912
+ else:
913
+ return os.path.join(weight_dir, "model weights", file_name)
914
+
915
+ self.net_roi.load_state_dict(torch.load(load("roi_best_model.pth"), map_location="cuda", weights_only=True))
916
+ self.net_roi.eval()
917
+
918
+ model_files = [
919
+ "rail_0_iter_7995_best.pth",
920
+ "rail_1_iter_7995_best.pth",
921
+ "rail_2_iter_7995_best.pth",
922
+ "rail_3_iter_7995_best.pth",
923
+ ]
924
+ for i, file in enumerate(model_files):
925
+ self.model_array[i].load_state_dict(torch.load(load(file), map_location="cuda", weights_only=True))
926
+ self.model_array[i].eval()
927
+
928
+ def forward(self, image, label, save_path="./output", name="case"):
929
+ if not os.path.exists(save_path):
930
+ os.makedirs(save_path)
931
+ nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
932
+
933
+ w, h, d = image.shape
934
+
935
+ image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
936
+
937
+ prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
938
+
939
+ prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
940
+
941
+ new_prediction = np.zeros((w, h, d))
942
+ new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
943
+
944
+ dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
945
+
946
+ nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
947
+
948
+ return new_prediction, dice, jc, hd, asd