Tournesol-Saturday commited on
Commit
0dc7c0f
·
verified ·
1 Parent(s): 24c0f42

Delete railnet_model.py

Browse files
Files changed (1) hide show
  1. railnet_model.py +0 -975
railnet_model.py DELETED
@@ -1,975 +0,0 @@
1
- import os
2
- os.environ['KMP_DUPLICATE_LIB_OK']='True'
3
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from huggingface_hub import PyTorchModelHubMixin
9
-
10
- import numpy as np
11
- import nibabel as nib
12
- from skimage import morphology
13
-
14
- import math
15
- from scipy import ndimage
16
- from medpy import metric
17
-
18
- from huggingface_hub import hf_hub_download
19
-
20
-
21
- class ConvBlock(nn.Module):
22
- def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
23
- super(ConvBlock, self).__init__()
24
-
25
- ops = []
26
- for i in range(n_stages):
27
- if i == 0:
28
- input_channel = n_filters_in
29
- else:
30
- input_channel = n_filters_out
31
-
32
- ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
33
- if normalization == 'batchnorm':
34
- ops.append(nn.BatchNorm3d(n_filters_out))
35
- elif normalization == 'groupnorm':
36
- ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
37
- elif normalization == 'instancenorm':
38
- ops.append(nn.InstanceNorm3d(n_filters_out))
39
- elif normalization != 'none':
40
- assert False
41
- ops.append(nn.ReLU(inplace=True))
42
-
43
- self.conv = nn.Sequential(*ops)
44
-
45
- def forward(self, x):
46
- x = self.conv(x)
47
- return x
48
-
49
-
50
- class DownsamplingConvBlock(nn.Module):
51
- def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
52
- super(DownsamplingConvBlock, self).__init__()
53
-
54
- ops = []
55
- if normalization != 'none':
56
- ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
57
- if normalization == 'batchnorm':
58
- ops.append(nn.BatchNorm3d(n_filters_out))
59
- elif normalization == 'groupnorm':
60
- ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
61
- elif normalization == 'instancenorm':
62
- ops.append(nn.InstanceNorm3d(n_filters_out))
63
- else:
64
- assert False
65
- else:
66
- ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
67
-
68
- ops.append(nn.ReLU(inplace=True))
69
-
70
- self.conv = nn.Sequential(*ops)
71
-
72
- def forward(self, x):
73
- x = self.conv(x)
74
- return x
75
-
76
-
77
- class UpsamplingDeconvBlock(nn.Module):
78
- def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
79
- super(UpsamplingDeconvBlock, self).__init__()
80
-
81
- ops = []
82
- if normalization != 'none':
83
- ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84
- if normalization == 'batchnorm':
85
- ops.append(nn.BatchNorm3d(n_filters_out))
86
- elif normalization == 'groupnorm':
87
- ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
88
- elif normalization == 'instancenorm':
89
- ops.append(nn.InstanceNorm3d(n_filters_out))
90
- else:
91
- assert False
92
- else:
93
- ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
94
-
95
- ops.append(nn.ReLU(inplace=True))
96
-
97
- self.conv = nn.Sequential(*ops)
98
-
99
- def forward(self, x):
100
- x = self.conv(x)
101
- return x
102
-
103
-
104
- class Upsampling(nn.Module):
105
- def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
106
- super(Upsampling, self).__init__()
107
-
108
- ops = []
109
- ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
110
- ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
111
- if normalization == 'batchnorm':
112
- ops.append(nn.BatchNorm3d(n_filters_out))
113
- elif normalization == 'groupnorm':
114
- ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
115
- elif normalization == 'instancenorm':
116
- ops.append(nn.InstanceNorm3d(n_filters_out))
117
- elif normalization != 'none':
118
- assert False
119
- ops.append(nn.ReLU(inplace=True))
120
-
121
- self.conv = nn.Sequential(*ops)
122
-
123
- def forward(self, x):
124
- x = self.conv(x)
125
- return x
126
-
127
-
128
- class ConnectNet(nn.Module):
129
- def __init__(self, in_channels, out_channels, input_size):
130
- super(ConnectNet, self).__init__()
131
- self.encoder = nn.Sequential(
132
- nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
133
- nn.ReLU(),
134
- nn.MaxPool3d(kernel_size=2, stride=2),
135
- nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
136
- nn.ReLU(),
137
- nn.MaxPool3d(kernel_size=2, stride=2)
138
- )
139
-
140
- self.decoder = nn.Sequential(
141
- nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
142
- nn.ReLU(),
143
- nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
144
- nn.Sigmoid()
145
- )
146
-
147
- def forward(self, x):
148
- encoded = self.encoder(x)
149
- decoded = self.decoder(encoded)
150
- return decoded
151
-
152
-
153
- class VNet(nn.Module):
154
- def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
155
- super(VNet, self).__init__()
156
- self.has_dropout = has_dropout
157
-
158
- self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
159
- self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
160
-
161
- self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
162
- self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
163
-
164
- self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
165
- self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
166
-
167
- self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
168
- self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
169
-
170
- self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
171
- self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
172
-
173
- self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
174
- self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
175
-
176
- self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
177
- self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
178
-
179
- self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
180
- self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
181
-
182
- self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
183
- self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
184
-
185
- self.dropout = nn.Dropout3d(p=0.5, inplace=False)
186
-
187
- self.__init_weight()
188
-
189
- def encoder(self, input):
190
- x1 = self.block_one(input)
191
- x1_dw = self.block_one_dw(x1)
192
-
193
- x2 = self.block_two(x1_dw)
194
- x2_dw = self.block_two_dw(x2)
195
-
196
- x3 = self.block_three(x2_dw)
197
- x3_dw = self.block_three_dw(x3)
198
-
199
- x4 = self.block_four(x3_dw)
200
- x4_dw = self.block_four_dw(x4)
201
-
202
- x5 = self.block_five(x4_dw)
203
- if self.has_dropout:
204
- x5 = self.dropout(x5)
205
-
206
- res = [x1, x2, x3, x4, x5]
207
-
208
- return res
209
-
210
- def decoder(self, features):
211
- x1 = features[0]
212
- x2 = features[1]
213
- x3 = features[2]
214
- x4 = features[3]
215
- x5 = features[4]
216
-
217
- x5_up = self.block_five_up(x5)
218
- x5_up = x5_up + x4
219
-
220
- x6 = self.block_six(x5_up)
221
- x6_up = self.block_six_up(x6)
222
- x6_up = x6_up + x3
223
-
224
- x7 = self.block_seven(x6_up)
225
- x7_up = self.block_seven_up(x7)
226
- x7_up = x7_up + x2
227
-
228
- x8 = self.block_eight(x7_up)
229
- x8_up = self.block_eight_up(x8)
230
- x8_up = x8_up + x1
231
- x9 = self.block_nine(x8_up)
232
- if self.has_dropout:
233
- x9 = self.dropout(x9)
234
- out = self.out_conv(x9)
235
- return out
236
-
237
- def forward(self, input, turnoff_drop=False):
238
- if turnoff_drop:
239
- has_dropout = self.has_dropout
240
- self.has_dropout = False
241
- features = self.encoder(input)
242
- out = self.decoder(features)
243
- if turnoff_drop:
244
- self.has_dropout = has_dropout
245
- return out
246
-
247
- def __init_weight(self):
248
- for m in self.modules():
249
- if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
250
- torch.nn.init.kaiming_normal_(m.weight)
251
- elif isinstance(m, nn.BatchNorm3d):
252
- m.weight.data.fill_(1)
253
- m.bias.data.zero_()
254
-
255
-
256
- class VNet_roi(nn.Module):
257
- def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
258
- super(VNet_roi, self).__init__()
259
- self.has_dropout = has_dropout
260
-
261
- self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
262
- self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
263
-
264
- self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
265
- self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
266
-
267
- self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
268
- self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
269
-
270
- self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
271
- self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
272
-
273
- self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
274
- self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
275
-
276
- self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
277
- self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
278
-
279
- self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
280
- self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
281
-
282
- self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
283
- self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
284
-
285
- self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
286
- self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
287
-
288
- self.dropout = nn.Dropout3d(p=0.5, inplace=False)
289
- # self.__init_weight()
290
-
291
- def encoder(self, input):
292
- x1 = self.block_one(input)
293
- x1_dw = self.block_one_dw(x1)
294
-
295
- x2 = self.block_two(x1_dw)
296
- x2_dw = self.block_two_dw(x2)
297
-
298
- x3 = self.block_three(x2_dw)
299
- x3_dw = self.block_three_dw(x3)
300
-
301
- x4 = self.block_four(x3_dw)
302
- x4_dw = self.block_four_dw(x4)
303
-
304
- x5 = self.block_five(x4_dw)
305
- # x5 = F.dropout3d(x5, p=0.5, training=True)
306
- if self.has_dropout:
307
- x5 = self.dropout(x5)
308
-
309
- res = [x1, x2, x3, x4, x5]
310
-
311
- return res
312
-
313
- def decoder(self, features):
314
- x1 = features[0]
315
- x2 = features[1]
316
- x3 = features[2]
317
- x4 = features[3]
318
- x5 = features[4]
319
-
320
- x5_up = self.block_five_up(x5)
321
- x5_up = x5_up + x4
322
-
323
- x6 = self.block_six(x5_up)
324
- x6_up = self.block_six_up(x6)
325
- x6_up = x6_up + x3
326
-
327
- x7 = self.block_seven(x6_up)
328
- x7_up = self.block_seven_up(x7)
329
- x7_up = x7_up + x2
330
-
331
- x8 = self.block_eight(x7_up)
332
- x8_up = self.block_eight_up(x8)
333
- x8_up = x8_up + x1
334
- x9 = self.block_nine(x8_up)
335
- # x9 = F.dropout3d(x9, p=0.5, training=True)
336
- if self.has_dropout:
337
- x9 = self.dropout(x9)
338
- out = self.out_conv(x9)
339
- return out
340
-
341
-
342
- def forward(self, input, turnoff_drop=False):
343
- if turnoff_drop:
344
- has_dropout = self.has_dropout
345
- self.has_dropout = False
346
- features = self.encoder(input)
347
- out = self.decoder(features)
348
- if turnoff_drop:
349
- self.has_dropout = has_dropout
350
- return out
351
-
352
-
353
- class ResVNet(nn.Module):
354
- def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
355
- super(ResVNet, self).__init__()
356
- self.resencoder = resnet34()
357
- self.has_dropout = has_dropout
358
-
359
- self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
360
- self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
361
-
362
- self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
363
- self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
364
-
365
- self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
366
- self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
367
-
368
- self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
369
- self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
370
-
371
- self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
372
- self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
373
-
374
- self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
375
- self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
376
-
377
- self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
378
- self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
379
-
380
- self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
381
- self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
382
-
383
-
384
- self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
385
- self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
386
-
387
-
388
- if has_dropout:
389
- self.dropout = nn.Dropout3d(p=0.5)
390
- self.branchs = nn.ModuleList()
391
- for i in range(1):
392
- if has_dropout:
393
- seq = nn.Sequential(
394
- ConvBlock(1, n_filters, n_filters, normalization=normalization),
395
- nn.Dropout3d(p=0.5),
396
- nn.Conv3d(n_filters, n_classes, 1, padding=0)
397
- )
398
- else:
399
- seq = nn.Sequential(
400
- ConvBlock(1, n_filters, n_filters, normalization=normalization),
401
- nn.Conv3d(n_filters, n_classes, 1, padding=0)
402
- )
403
- self.branchs.append(seq)
404
-
405
- def encoder(self, input):
406
- x1 = self.block_one(input)
407
- x1_dw = self.block_one_dw(x1)
408
-
409
- x2 = self.block_two(x1_dw)
410
- x2_dw = self.block_two_dw(x2)
411
-
412
- x3 = self.block_three(x2_dw)
413
- x3_dw = self.block_three_dw(x3)
414
-
415
- x4 = self.block_four(x3_dw)
416
- x4_dw = self.block_four_dw(x4)
417
-
418
- x5 = self.block_five(x4_dw)
419
-
420
- if self.has_dropout:
421
- x5 = self.dropout(x5)
422
-
423
- res = [x1, x2, x3, x4, x5]
424
-
425
- return res
426
-
427
- def decoder(self, features):
428
- x1 = features[0]
429
- x2 = features[1]
430
- x3 = features[2]
431
- x4 = features[3]
432
- x5 = features[4]
433
-
434
- x5_up = self.block_five_up(x5)
435
- x5_up = x5_up + x4
436
-
437
- x6 = self.block_six(x5_up)
438
- x6_up = self.block_six_up(x6)
439
- x6_up = x6_up + x3
440
-
441
- x7 = self.block_seven(x6_up)
442
- x7_up = self.block_seven_up(x7)
443
- x7_up = x7_up + x2
444
-
445
- x8 = self.block_eight(x7_up)
446
- x8_up = self.block_eight_up(x8)
447
- x8_up = x8_up + x1
448
-
449
-
450
- x9 = self.block_nine(x8_up)
451
-
452
- out = self.out_conv(x9)
453
-
454
-
455
- return out
456
-
457
- def forward(self, input, turnoff_drop=False):
458
- if turnoff_drop:
459
- has_dropout = self.has_dropout
460
- self.has_dropout = False
461
- features = self.resencoder(input)
462
- out = self.decoder(features)
463
- if turnoff_drop:
464
- self.has_dropout = has_dropout
465
- return out
466
-
467
-
468
- __all__ = ['ResNet', 'resnet34']
469
-
470
-
471
- def conv3x3(in_planes, out_planes, stride=1):
472
- """3x3 convolution with padding"""
473
- return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
474
-
475
-
476
- def conv3x3_bn_relu(in_planes, out_planes, stride=1):
477
- return nn.Sequential(
478
- conv3x3(in_planes, out_planes, stride),
479
- nn.InstanceNorm3d(out_planes),
480
- nn.ReLU()
481
- )
482
-
483
-
484
- class BasicBlock(nn.Module):
485
- expansion = 1
486
-
487
- def __init__(self, inplanes, planes, stride=1, downsample=None,
488
- groups=1, base_width=64, dilation=-1):
489
- super(BasicBlock, self).__init__()
490
- if groups != 1 or base_width != 64:
491
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
492
- self.conv1 = conv3x3(inplanes, planes, stride)
493
- self.bn1 = nn.InstanceNorm3d(planes)
494
- self.relu = nn.ReLU(inplace=True)
495
- self.conv2 = conv3x3(planes, planes)
496
- self.bn2 = nn.InstanceNorm3d(planes)
497
- self.downsample = downsample
498
- self.stride = stride
499
-
500
- def forward(self, x):
501
- residual = x
502
-
503
- out = self.conv1(x)
504
- out = self.bn1(out)
505
- out = self.relu(out)
506
-
507
- out = self.conv2(out)
508
- out = self.bn2(out)
509
-
510
- if self.downsample is not None:
511
- residual = self.downsample(x)
512
-
513
- out += residual
514
- out = self.relu(out)
515
-
516
- return out
517
-
518
-
519
- class Bottleneck(nn.Module):
520
- expansion = 4
521
-
522
- def __init__(self, inplanes, planes, stride=1, downsample=None,
523
- groups=1, base_width=64, dilation=1):
524
- super(Bottleneck, self).__init__()
525
- width = int(planes * (base_width / 64.)) * groups
526
- self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
527
- self.bn1 = nn.InstanceNorm3d(width)
528
- self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
529
- padding=dilation, groups=groups, bias=False)
530
- self.bn2 = nn.InstanceNorm3d(width)
531
- self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
532
- self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
533
- self.relu = nn.ReLU(inplace=True)
534
- self.downsample = downsample
535
- self.stride = stride
536
-
537
- def forward(self, x):
538
- residual = x
539
-
540
- out = self.conv1(x)
541
- out = self.bn1(out)
542
- out = self.relu(out)
543
-
544
- out = self.conv2(out)
545
- out = self.bn2(out)
546
- out = self.relu(out)
547
-
548
- out = self.conv3(out)
549
- out = self.bn3(out)
550
-
551
- if self.downsample is not None:
552
- residual = self.downsample(x)
553
-
554
- out += residual
555
- out = self.relu(out)
556
-
557
- return out
558
-
559
-
560
- class ResNet(nn.Module):
561
-
562
- def __init__(self, block, layers, in_channel=1, width=1,
563
- groups=1, width_per_group=64,
564
- mid_dim=1024, low_dim=128,
565
- avg_down=False, deep_stem=False,
566
- head_type='mlp_head', layer4_dilation=1):
567
- super(ResNet, self).__init__()
568
- self.avg_down = avg_down
569
- self.inplanes = 16 * width
570
- self.base = int(16 * width)
571
- self.groups = groups
572
- self.base_width = width_per_group
573
-
574
- mid_dim = self.base * 8 * block.expansion
575
-
576
- if deep_stem:
577
- self.conv1 = nn.Sequential(
578
- conv3x3_bn_relu(in_channel, 32, stride=2),
579
- conv3x3_bn_relu(32, 32, stride=1),
580
- conv3x3(32, 64, stride=1)
581
- )
582
- else:
583
- self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
584
-
585
- self.bn1 = nn.InstanceNorm3d(self.inplanes)
586
- self.relu = nn.ReLU(inplace=True)
587
-
588
- self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
589
- self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
590
- self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
591
- self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
592
- if layer4_dilation == 1:
593
- self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
594
- elif layer4_dilation == 2:
595
- self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
596
- else:
597
- raise NotImplementedError
598
- self.avgpool = nn.AvgPool3d(7, stride=1)
599
-
600
- def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
601
- downsample = None
602
- if stride != 1 or self.inplanes != planes * block.expansion:
603
- if self.avg_down:
604
- downsample = nn.Sequential(
605
- nn.AvgPool3d(kernel_size=stride, stride=stride),
606
- nn.Conv3d(self.inplanes, planes * block.expansion,
607
- kernel_size=1, stride=1, bias=False),
608
- nn.InstanceNorm3d(planes * block.expansion),
609
- )
610
- else:
611
- downsample = nn.Sequential(
612
- nn.Conv3d(self.inplanes, planes * block.expansion,
613
- kernel_size=1, stride=stride, bias=False),
614
- nn.InstanceNorm3d(planes * block.expansion),
615
- )
616
-
617
- layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
618
- self.inplanes = planes * block.expansion
619
- for _ in range(1, blocks):
620
- layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
621
-
622
- return nn.Sequential(*layers)
623
-
624
- def forward(self, x):
625
- x = self.conv1(x)
626
- x = self.bn1(x)
627
- x = self.relu(x)
628
- #c2 = self.maxpool(x)
629
- c2 = self.layer1(x)
630
- c3 = self.layer2(c2)
631
- c4 = self.layer3(c3)
632
- c5 = self.layer4(c4)
633
-
634
-
635
- return [x,c2,c3,c4,c5]
636
-
637
-
638
- def resnet34(**kwargs):
639
- return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
640
-
641
-
642
- def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
643
- w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
644
- # resize label map (int)
645
- if flag == 'trilinear':
646
- teeth_ids = np.unique(image_label)
647
- image_label_ori = np.zeros((w_ori, h_ori, z_ori))
648
-
649
-
650
- image_label = torch.from_numpy(image_label).cuda(0)
651
-
652
-
653
- # image_label = torch.from_numpy(image_label).to("cpu")
654
- for label_id in range(len(teeth_ids)):
655
- image_label_bn = (image_label == teeth_ids[label_id]).float()
656
- image_label_bn = image_label_bn[None, None, :, :, :]
657
- image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
658
- mode='trilinear', align_corners=False)
659
- image_label_bn = image_label_bn[0, 0, :, :, :]
660
- image_label_bn = image_label_bn.cpu().data.numpy()
661
- image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
662
- image_label = image_label_ori
663
-
664
- if flag == 'nearest':
665
-
666
-
667
- image_label = torch.from_numpy(image_label).cuda(0)
668
-
669
-
670
- # image_label = torch.from_numpy(image_label).to("cpu")
671
- image_label = image_label[None, None, :, :, :].float()
672
- image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
673
- image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
674
- return image_label
675
-
676
-
677
- def img_crop(image_bbox):
678
- if image_bbox.sum() > 0:
679
-
680
- x_min = np.nonzero(image_bbox)[0].min() - 8
681
- x_max = np.nonzero(image_bbox)[0].max() + 8
682
-
683
- y_min = np.nonzero(image_bbox)[1].min() - 16
684
- y_max = np.nonzero(image_bbox)[1].max() + 16
685
-
686
- z_min = np.nonzero(image_bbox)[2].min() - 16
687
- z_max = np.nonzero(image_bbox)[2].max() + 16
688
-
689
- if x_min < 0:
690
- x_min = 0
691
- if y_min < 0:
692
- y_min = 0
693
- if z_min < 0:
694
- z_min = 0
695
- if x_max > image_bbox.shape[0]:
696
- x_max = image_bbox.shape[0]
697
- if y_max > image_bbox.shape[1]:
698
- y_max = image_bbox.shape[1]
699
- if z_max > image_bbox.shape[2]:
700
- z_max = image_bbox.shape[2]
701
-
702
- if (x_max - x_min) % 16 != 0:
703
- x_max -= (x_max - x_min) % 16
704
- if (y_max - y_min) % 16 != 0:
705
- y_max -= (y_max - y_min) % 16
706
- if (z_max - z_min) % 16 != 0:
707
- z_max -= (z_max - z_min) % 16
708
-
709
- if image_bbox.sum() == 0:
710
- x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
711
- 2]
712
- return x_min, x_max, y_min, y_max, z_min, z_max
713
-
714
-
715
- def roi_extraction(image, net_roi, ids):
716
- w, h, d = image.shape
717
- # roi binary segmentation parameters, the input spacing is 0.4 mm
718
- print('---run the roi binary segmentation.')
719
-
720
- stride_xy = 32
721
- stride_z = 16
722
- patch_size_roi_stage = (112, 112, 80)
723
-
724
- label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
725
- patch_size_roi_stage) # (400,400,200)
726
- print(label_roi.shape, np.max(label_roi))
727
- label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400)
728
-
729
- label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
730
-
731
- label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
732
-
733
- label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
734
- float)
735
-
736
- label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
737
-
738
- # crop image
739
- x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
740
- if x_min == -1: # non-foreground label
741
- whole_label = np.zeros((w, h, d))
742
- return whole_label
743
- image = image[x_min:x_max, y_min:y_max, z_min:z_max]
744
- print("image shape(after roi): ", image.shape)
745
-
746
- return image, x_min, x_max, y_min, y_max, z_min, z_max
747
-
748
-
749
- def roi_detection(net, image, stride_xy, stride_z, patch_size):
750
- w, h, d = image.shape # (400,400,200)
751
-
752
- # if the size of image is less than patch_size, then padding it
753
- add_pad = False
754
- if w < patch_size[0]:
755
- w_pad = patch_size[0] - w
756
- add_pad = True
757
- else:
758
- w_pad = 0
759
- if h < patch_size[1]:
760
- h_pad = patch_size[1] - h
761
- add_pad = True
762
- else:
763
- h_pad = 0
764
- if d < patch_size[2]:
765
- d_pad = patch_size[2] - d
766
- add_pad = True
767
- else:
768
- d_pad = 0
769
- wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
770
- hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
771
- dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
772
- if add_pad:
773
- image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
774
- constant_values=0)
775
- ww, hh, dd = image.shape
776
-
777
- sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2
778
- sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2
779
- sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2
780
- score_map = np.zeros((2,) + image.shape).astype(np.float32)
781
- cnt = np.zeros(image.shape).astype(np.float32)
782
- count = 0
783
- for x in range(0, sx):
784
- xs = min(stride_xy * x, ww - patch_size[0])
785
- for y in range(0, sy):
786
- ys = min(stride_xy * y, hh - patch_size[1])
787
- for z in range(0, sz):
788
- zs = min(stride_z * z, dd - patch_size[2])
789
- test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
790
- zs:zs + patch_size[2]]
791
- test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
792
- np.float32)
793
-
794
-
795
- test_patch = torch.from_numpy(test_patch).cuda(0)
796
-
797
-
798
- # test_patch = torch.from_numpy(test_patch).to("cpu")
799
- with torch.no_grad():
800
- y1 = net(test_patch) # (1,2,256,256,160)
801
- y = F.softmax(y1, dim=1) # (1,2,256,256,160)
802
- y = y.cpu().data.numpy()
803
- y = y[0, :, :, :, :] # (2,256,256,160)
804
- score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
805
- = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
806
- zs:zs + patch_size[2]] + y # (2,400,400,200)
807
- cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
808
- = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200)
809
- count = count + 1
810
- score_map = score_map / np.expand_dims(cnt, axis=0)
811
-
812
- label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1
813
- if add_pad:
814
- label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
815
- score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
816
- return label_map
817
-
818
-
819
- def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
820
- w, h, d = image.shape
821
-
822
- # if the size of image is less than patch_size, then padding it
823
- add_pad = False
824
- if w < patch_size[0]:
825
- w_pad = patch_size[0]-w
826
- add_pad = True
827
- else:
828
- w_pad = 0
829
- if h < patch_size[1]:
830
- h_pad = patch_size[1]-h
831
- add_pad = True
832
- else:
833
- h_pad = 0
834
- if d < patch_size[2]:
835
- d_pad = patch_size[2]-d
836
- add_pad = True
837
- else:
838
- d_pad = 0
839
- wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
840
- hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
841
- dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
842
- if add_pad:
843
- image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
844
-
845
- ww,hh,dd = image.shape
846
-
847
- sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
848
- sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
849
- sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
850
- score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
851
- cnt = np.zeros(image.shape).astype(np.float32)
852
-
853
- for x in range(0, sx):
854
- xs = min(stride_xy*x, ww-patch_size[0])
855
- for y in range(0, sy):
856
- ys = min(stride_xy * y,hh-patch_size[1])
857
- for z in range(0, sz):
858
- zs = min(stride_z * z, dd-patch_size[2])
859
- test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
860
- test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
861
-
862
-
863
- test_patch = torch.from_numpy(test_patch).cuda()
864
-
865
-
866
- # test_patch = torch.from_numpy(test_patch).to("cpu")
867
- for model in model_array:
868
- output = model(test_patch)
869
- y_temp = F.softmax(output, dim=1)
870
- y_temp = y_temp.cpu().data.numpy()
871
- y += y_temp[0,:,:,:,:]
872
- y /= len(model_array)
873
- score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
874
- = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
875
- cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
876
- = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
877
- score_map = score_map/np.expand_dims(cnt,axis=0)
878
-
879
- label_map = np.argmax(score_map, axis = 0)
880
- if add_pad:
881
- label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
882
- score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
883
- return label_map, score_map
884
-
885
- def calculate_metric_percase(pred, gt):
886
- dice = metric.binary.dc(pred, gt)
887
- jc = metric.binary.jc(pred, gt)
888
- hd = metric.binary.hd95(pred, gt)
889
- asd = metric.binary.asd(pred, gt)
890
-
891
- return dice, jc, hd, asd
892
-
893
-
894
- class RailNetSystem(nn.Module, PyTorchModelHubMixin):
895
- def __init__(self, n_channels: int, n_classes: int, normalization: str):
896
- super().__init__()
897
-
898
- self.num_classes = 2
899
-
900
-
901
- self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
902
-
903
-
904
- # self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).to("cpu")
905
-
906
- self.model_array = []
907
- for i in range(4):
908
- if i < 2:
909
-
910
-
911
- model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
912
-
913
-
914
- # model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).to("cpu")
915
- else:
916
-
917
-
918
- model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
919
-
920
-
921
- # model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).to("cpu")
922
- self.model_array.append(model)
923
-
924
- def load_weights(self, weight_dir=".", from_hub=False, repo_id=None):
925
- def load(file_name):
926
- if from_hub:
927
- return hf_hub_download(repo_id=repo_id, filename=f"model weights/{file_name}")
928
- else:
929
- return os.path.join(weight_dir, "model weights", file_name)
930
-
931
-
932
- # self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cuda", weights_only=True))
933
-
934
-
935
- # self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cpu", weights_only=True))
936
- self.net_roi.load_state_dict(torch.load(load("roi_best_model.pth"), map_location="cuda", weights_only=True))
937
- self.net_roi.eval()
938
-
939
- model_files = [
940
- "rail_0_iter_7995_best.pth",
941
- "rail_1_iter_7995_best.pth",
942
- "rail_2_iter_7995_best.pth",
943
- "rail_3_iter_7995_best.pth",
944
- ]
945
- for i, file in enumerate(model_files):
946
-
947
-
948
- # self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cuda", weights_only=True))
949
-
950
-
951
- # self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cpu", weights_only=True))
952
- self.model_array[i].load_state_dict(torch.load(load(file), map_location="cuda", weights_only=True))
953
- self.model_array[i].eval()
954
-
955
- def forward(self, image, label, save_path="./output", name="case"):
956
- if not os.path.exists(save_path):
957
- os.makedirs(save_path)
958
- nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
959
-
960
- w, h, d = image.shape
961
-
962
- image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
963
-
964
- prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
965
-
966
- prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
967
-
968
- new_prediction = np.zeros((w, h, d))
969
- new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
970
-
971
- dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
972
-
973
- nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
974
-
975
- return new_prediction, dice, jc, hd, asd