lofrienger commited on
Commit
a37c14e
·
1 Parent(s): 9da2171
Files changed (3) hide show
  1. app.py +53 -0
  2. unet.py +372 -0
  3. unet_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries and load the model
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from unet import UNet # Assuming UNet is the model class
8
+
9
+ MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
10
+ STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ model = UNet(in_chns=3, class_num=2) # Initialize your model
15
+ model.load_state_dict(torch.load('unet_model.pth'))
16
+
17
+ model = model.to(device)
18
+ model.eval()
19
+
20
+ # Define the segmentation function
21
+ def segment(img):
22
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
23
+ original_size = img.size # Store the original size
24
+
25
+ img = img.resize((224, 224), Image.BILINEAR)
26
+ img = transforms.ToTensor()(img)
27
+ for i in range(3):
28
+ img[:, :, i] -= float(MEAN[i])
29
+ for i in range(3):
30
+ img[:, :, i] /= float(STD[i])
31
+
32
+ img = img.unsqueeze(0).to(device)
33
+ with torch.no_grad():
34
+ output = model(img)
35
+ output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
36
+
37
+ # Resize the mask back to the original image size
38
+ output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.BILINEAR)
39
+
40
+ # Convert the PIL Image back to a numpy array
41
+ output = np.array(output)
42
+ binary_mask = np.zeros_like(output)
43
+ binary_mask[output > 0] = 255
44
+
45
+ return binary_mask
46
+
47
+ # Create a Gradio interface
48
+ iface = gr.Interface(fn=segment, inputs="image", outputs="image", title="Segmentation Model",
49
+ description="Segment objects in an image.",
50
+ allow_flagging=False)
51
+
52
+ # Launch the interface
53
+ iface.launch()
unet.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ The implementation is borrowed from: https://github.com/HiLab-git/PyMIC
4
+ """
5
+ from __future__ import division, print_function
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributions.uniform import Uniform
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ """two convolution layers with batch norm and leaky relu"""
15
+
16
+ def __init__(self, in_channels, out_channels, dropout_p):
17
+ super(ConvBlock, self).__init__()
18
+ self.conv_conv = nn.Sequential(
19
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.LeakyReLU(),
22
+ nn.Dropout(dropout_p),
23
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
24
+ nn.BatchNorm2d(out_channels),
25
+ nn.LeakyReLU()
26
+ )
27
+
28
+ def forward(self, x):
29
+ return self.conv_conv(x)
30
+
31
+
32
+ class DownBlock(nn.Module):
33
+ """Downsampling followed by ConvBlock"""
34
+
35
+ def __init__(self, in_channels, out_channels, dropout_p):
36
+ super(DownBlock, self).__init__()
37
+ self.maxpool_conv = nn.Sequential(
38
+ nn.MaxPool2d(2),
39
+ ConvBlock(in_channels, out_channels, dropout_p)
40
+
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.maxpool_conv(x)
45
+
46
+
47
+ class UpBlock(nn.Module):
48
+ """Upssampling followed by ConvBlock"""
49
+
50
+ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
51
+ bilinear=True):
52
+ super(UpBlock, self).__init__()
53
+ self.bilinear = bilinear
54
+ if self.bilinear != 'convtrans':
55
+ self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
56
+ self.up = nn.Upsample(scale_factor=2, mode=self.bilinear)
57
+ if self.bilinear != 'nearest':
58
+ self.up = nn.Upsample(scale_factor=2, mode=self.bilinear, align_corners=True)
59
+ else:
60
+ self.up = nn.ConvTranspose2d(
61
+ in_channels1, in_channels2, kernel_size=2, stride=2)
62
+ self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
63
+
64
+ def forward(self, x1, x2):
65
+ if self.bilinear != 'convtrans':
66
+ x1 = self.conv1x1(x1)
67
+ x1 = self.up(x1)
68
+ x = torch.cat([x2, x1], dim=1)
69
+ return self.conv(x)
70
+
71
+
72
+ class Encoder(nn.Module):
73
+ def __init__(self, params):
74
+ super(Encoder, self).__init__()
75
+ self.params = params
76
+ self.in_chns = self.params['in_chns']
77
+ self.ft_chns = self.params['feature_chns']
78
+ self.n_class = self.params['class_num']
79
+ self.bilinear = self.params['bilinear']
80
+ self.dropout = self.params['dropout']
81
+ assert (len(self.ft_chns) == 5)
82
+ self.in_conv = ConvBlock(
83
+ self.in_chns, self.ft_chns[0], self.dropout[0])
84
+ self.down1 = DownBlock(
85
+ self.ft_chns[0], self.ft_chns[1], self.dropout[1])
86
+ self.down2 = DownBlock(
87
+ self.ft_chns[1], self.ft_chns[2], self.dropout[2])
88
+ self.down3 = DownBlock(
89
+ self.ft_chns[2], self.ft_chns[3], self.dropout[3])
90
+ self.down4 = DownBlock(
91
+ self.ft_chns[3], self.ft_chns[4], self.dropout[4])
92
+
93
+ def forward(self, x):
94
+ x0 = self.in_conv(x)
95
+ x1 = self.down1(x0)
96
+ x2 = self.down2(x1)
97
+ x3 = self.down3(x2)
98
+ x4 = self.down4(x3)
99
+ return [x0, x1, x2, x3, x4]
100
+
101
+
102
+ class Decoder(nn.Module):
103
+ def __init__(self, params):
104
+ super(Decoder, self).__init__()
105
+ self.params = params
106
+ self.in_chns = self.params['in_chns']
107
+ self.ft_chns = self.params['feature_chns']
108
+ self.n_class = self.params['class_num']
109
+ self.bilinear = self.params['bilinear']
110
+ assert (len(self.ft_chns) == 5)
111
+
112
+ self.up1 = UpBlock(
113
+ self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, bilinear=self.bilinear)
114
+ self.up2 = UpBlock(
115
+ self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, bilinear=self.bilinear)
116
+ self.up3 = UpBlock(
117
+ self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, bilinear=self.bilinear)
118
+ self.up4 = UpBlock(
119
+ self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, bilinear=self.bilinear)
120
+
121
+ self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
122
+ kernel_size=3, padding=1)
123
+
124
+ def forward(self, feature):
125
+ x0 = feature[0]
126
+ x1 = feature[1]
127
+ x2 = feature[2]
128
+ x3 = feature[3]
129
+ x4 = feature[4]
130
+
131
+ x = self.up1(x4, x3)
132
+ x = self.up2(x, x2)
133
+ x = self.up3(x, x1)
134
+ x = self.up4(x, x0)
135
+ output = self.out_conv(x)
136
+ return output
137
+
138
+
139
+ class Decoder_DS(nn.Module):
140
+ def __init__(self, params):
141
+ super(Decoder_DS, self).__init__()
142
+ self.params = params
143
+ self.in_chns = self.params['in_chns']
144
+ self.ft_chns = self.params['feature_chns']
145
+ self.n_class = self.params['class_num']
146
+ self.bilinear = self.params['bilinear']
147
+ assert (len(self.ft_chns) == 5)
148
+
149
+ self.up1 = UpBlock(
150
+ self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
151
+ self.up2 = UpBlock(
152
+ self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
153
+ self.up3 = UpBlock(
154
+ self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
155
+ self.up4 = UpBlock(
156
+ self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
157
+
158
+ self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
159
+ kernel_size=3, padding=1)
160
+ self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,
161
+ kernel_size=3, padding=1)
162
+ self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,
163
+ kernel_size=3, padding=1)
164
+ self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,
165
+ kernel_size=3, padding=1)
166
+ self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,
167
+ kernel_size=3, padding=1)
168
+
169
+ def forward(self, feature, shape):
170
+ x0 = feature[0]
171
+ x1 = feature[1]
172
+ x2 = feature[2]
173
+ x3 = feature[3]
174
+ x4 = feature[4]
175
+ x = self.up1(x4, x3)
176
+ dp3_out_seg = self.out_conv_dp3(x)
177
+ dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)
178
+
179
+ x = self.up2(x, x2)
180
+ dp2_out_seg = self.out_conv_dp2(x)
181
+ dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)
182
+
183
+ x = self.up3(x, x1)
184
+ dp1_out_seg = self.out_conv_dp1(x)
185
+ dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)
186
+
187
+ x = self.up4(x, x0)
188
+ dp0_out_seg = self.out_conv(x)
189
+ return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
190
+
191
+
192
+ class Decoder_URDS(nn.Module):
193
+ def __init__(self, params):
194
+ super(Decoder_URDS, self).__init__()
195
+ self.params = params
196
+ self.in_chns = self.params['in_chns']
197
+ self.ft_chns = self.params['feature_chns']
198
+ self.n_class = self.params['class_num']
199
+ self.bilinear = self.params['bilinear']
200
+ assert (len(self.ft_chns) == 5)
201
+
202
+ self.up1 = UpBlock(
203
+ self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
204
+ self.up2 = UpBlock(
205
+ self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
206
+ self.up3 = UpBlock(
207
+ self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
208
+ self.up4 = UpBlock(
209
+ self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
210
+
211
+ self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
212
+ kernel_size=3, padding=1)
213
+ self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,
214
+ kernel_size=3, padding=1)
215
+ self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,
216
+ kernel_size=3, padding=1)
217
+ self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,
218
+ kernel_size=3, padding=1)
219
+ self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,
220
+ kernel_size=3, padding=1)
221
+ self.feature_noise = FeatureNoise()
222
+
223
+ def forward(self, feature, shape):
224
+ x0 = feature[0]
225
+ x1 = feature[1]
226
+ x2 = feature[2]
227
+ x3 = feature[3]
228
+ x4 = feature[4]
229
+ x = self.up1(x4, x3)
230
+ if self.training:
231
+ dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5))
232
+ else:
233
+ dp3_out_seg = self.out_conv_dp3(x)
234
+ dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)
235
+
236
+ x = self.up2(x, x2)
237
+ if self.training:
238
+ dp2_out_seg = self.out_conv_dp2(FeatureDropout(x))
239
+ else:
240
+ dp2_out_seg = self.out_conv_dp2(x)
241
+ dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)
242
+
243
+ x = self.up3(x, x1)
244
+ if self.training:
245
+ dp1_out_seg = self.out_conv_dp1(self.feature_noise(x))
246
+ else:
247
+ dp1_out_seg = self.out_conv_dp1(x)
248
+ dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)
249
+
250
+ x = self.up4(x, x0)
251
+ dp0_out_seg = self.out_conv(x)
252
+ return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
253
+
254
+
255
+ def Dropout(x, p=0.5):
256
+ x = torch.nn.functional.dropout2d(x, p)
257
+ return x
258
+
259
+
260
+ def FeatureDropout(x):
261
+ attention = torch.mean(x, dim=1, keepdim=True)
262
+ max_val, _ = torch.max(attention.view(
263
+ x.size(0), -1), dim=1, keepdim=True)
264
+ threshold = max_val * np.random.uniform(0.7, 0.9)
265
+ threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
266
+ drop_mask = (attention < threshold).float()
267
+ x = x.mul(drop_mask)
268
+ return x
269
+
270
+
271
+ class FeatureNoise(nn.Module):
272
+ def __init__(self, uniform_range=0.3):
273
+ super(FeatureNoise, self).__init__()
274
+ self.uni_dist = Uniform(-uniform_range, uniform_range)
275
+
276
+ def feature_based_noise(self, x):
277
+ noise_vector = self.uni_dist.sample(
278
+ x.shape[1:]).to(x.device).unsqueeze(0)
279
+ x_noise = x.mul(noise_vector) + x
280
+ return x_noise
281
+
282
+ def forward(self, x):
283
+ x = self.feature_based_noise(x)
284
+ return x
285
+
286
+
287
+ class UNet(nn.Module):
288
+ def __init__(self, in_chns, class_num):
289
+ super(UNet, self).__init__()
290
+
291
+ params = {'in_chns': in_chns,
292
+ 'feature_chns': [16, 32, 64, 128, 256],
293
+ 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
294
+ 'class_num': class_num,
295
+ 'bilinear': 'nearest',
296
+ 'acti_func': 'relu'}
297
+
298
+ self.encoder = Encoder(params)
299
+ self.decoder = Decoder(params)
300
+
301
+ def forward(self, x):
302
+ feature = self.encoder(x)
303
+ output = self.decoder(feature)
304
+ return output
305
+
306
+
307
+ class UNet_DS(nn.Module):
308
+ def __init__(self, in_chns, class_num):
309
+ super(UNet_DS, self).__init__()
310
+
311
+ params = {'in_chns': in_chns,
312
+ 'feature_chns': [16, 32, 64, 128, 256],
313
+ 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
314
+ 'class_num': class_num,
315
+ 'bilinear': False,
316
+ 'acti_func': 'relu'}
317
+ self.encoder = Encoder(params)
318
+ self.decoder = Decoder_DS(params)
319
+
320
+ def forward(self, x):
321
+ shape = x.shape[2:]
322
+ feature = self.encoder(x)
323
+ dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg = self.decoder(
324
+ feature, shape)
325
+ return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
326
+
327
+
328
+ class UNet_CCT(nn.Module):
329
+ def __init__(self, in_chns, class_num):
330
+ super(UNet_CCT, self).__init__()
331
+
332
+ params = {'in_chns': in_chns,
333
+ 'feature_chns': [16, 32, 64, 128, 256],
334
+ 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
335
+ 'class_num': class_num,
336
+ 'bilinear': 'nearest',
337
+ 'acti_func': 'relu'}
338
+ self.encoder = Encoder(params)
339
+ self.main_decoder = Decoder(params)
340
+ self.aux_decoder1 = Decoder(params)
341
+
342
+ def forward(self, x):
343
+ feature = self.encoder(x)
344
+ main_seg = self.main_decoder(feature)
345
+ aux1_feature = [Dropout(i) for i in feature]
346
+ aux_seg1 = self.aux_decoder1(aux1_feature)
347
+ return main_seg, aux_seg1
348
+
349
+
350
+ class UNet_CCT_3H(nn.Module):
351
+ def __init__(self, in_chns, class_num):
352
+ super(UNet_CCT_3H, self).__init__()
353
+
354
+ params = {'in_chns': in_chns,
355
+ 'feature_chns': [16, 32, 64, 128, 256],
356
+ 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
357
+ 'class_num': class_num,
358
+ 'bilinear': False,
359
+ 'acti_func': 'relu'}
360
+ self.encoder = Encoder(params)
361
+ self.main_decoder = Decoder(params)
362
+ self.aux_decoder1 = Decoder(params)
363
+ self.aux_decoder2 = Decoder(params)
364
+
365
+ def forward(self, x):
366
+ feature = self.encoder(x)
367
+ main_seg = self.main_decoder(feature)
368
+ aux1_feature = [Dropout(i) for i in feature]
369
+ aux_seg1 = self.aux_decoder1(aux1_feature)
370
+ aux2_feature = [FeatureNoise()(i) for i in feature]
371
+ aux_seg2 = self.aux_decoder1(aux2_feature)
372
+ return main_seg, aux_seg1, aux_seg2
unet_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a1747468feb2005f81ff818f234f8358553ec017c6c1603dc5e046f2fc6ea39
3
+ size 7316273