Spaces:
Sleeping
Sleeping
Commit
·
a37c14e
1
Parent(s):
9da2171
init
Browse files- app.py +53 -0
- unet.py +372 -0
- unet_model.pth +3 -0
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries and load the model
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
from unet import UNet # Assuming UNet is the model class
|
8 |
+
|
9 |
+
MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
|
10 |
+
STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)
|
11 |
+
|
12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
|
14 |
+
model = UNet(in_chns=3, class_num=2) # Initialize your model
|
15 |
+
model.load_state_dict(torch.load('unet_model.pth'))
|
16 |
+
|
17 |
+
model = model.to(device)
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
# Define the segmentation function
|
21 |
+
def segment(img):
|
22 |
+
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
23 |
+
original_size = img.size # Store the original size
|
24 |
+
|
25 |
+
img = img.resize((224, 224), Image.BILINEAR)
|
26 |
+
img = transforms.ToTensor()(img)
|
27 |
+
for i in range(3):
|
28 |
+
img[:, :, i] -= float(MEAN[i])
|
29 |
+
for i in range(3):
|
30 |
+
img[:, :, i] /= float(STD[i])
|
31 |
+
|
32 |
+
img = img.unsqueeze(0).to(device)
|
33 |
+
with torch.no_grad():
|
34 |
+
output = model(img)
|
35 |
+
output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
|
36 |
+
|
37 |
+
# Resize the mask back to the original image size
|
38 |
+
output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.BILINEAR)
|
39 |
+
|
40 |
+
# Convert the PIL Image back to a numpy array
|
41 |
+
output = np.array(output)
|
42 |
+
binary_mask = np.zeros_like(output)
|
43 |
+
binary_mask[output > 0] = 255
|
44 |
+
|
45 |
+
return binary_mask
|
46 |
+
|
47 |
+
# Create a Gradio interface
|
48 |
+
iface = gr.Interface(fn=segment, inputs="image", outputs="image", title="Segmentation Model",
|
49 |
+
description="Segment objects in an image.",
|
50 |
+
allow_flagging=False)
|
51 |
+
|
52 |
+
# Launch the interface
|
53 |
+
iface.launch()
|
unet.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
The implementation is borrowed from: https://github.com/HiLab-git/PyMIC
|
4 |
+
"""
|
5 |
+
from __future__ import division, print_function
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.distributions.uniform import Uniform
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBlock(nn.Module):
|
14 |
+
"""two convolution layers with batch norm and leaky relu"""
|
15 |
+
|
16 |
+
def __init__(self, in_channels, out_channels, dropout_p):
|
17 |
+
super(ConvBlock, self).__init__()
|
18 |
+
self.conv_conv = nn.Sequential(
|
19 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
20 |
+
nn.BatchNorm2d(out_channels),
|
21 |
+
nn.LeakyReLU(),
|
22 |
+
nn.Dropout(dropout_p),
|
23 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
24 |
+
nn.BatchNorm2d(out_channels),
|
25 |
+
nn.LeakyReLU()
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.conv_conv(x)
|
30 |
+
|
31 |
+
|
32 |
+
class DownBlock(nn.Module):
|
33 |
+
"""Downsampling followed by ConvBlock"""
|
34 |
+
|
35 |
+
def __init__(self, in_channels, out_channels, dropout_p):
|
36 |
+
super(DownBlock, self).__init__()
|
37 |
+
self.maxpool_conv = nn.Sequential(
|
38 |
+
nn.MaxPool2d(2),
|
39 |
+
ConvBlock(in_channels, out_channels, dropout_p)
|
40 |
+
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.maxpool_conv(x)
|
45 |
+
|
46 |
+
|
47 |
+
class UpBlock(nn.Module):
|
48 |
+
"""Upssampling followed by ConvBlock"""
|
49 |
+
|
50 |
+
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
|
51 |
+
bilinear=True):
|
52 |
+
super(UpBlock, self).__init__()
|
53 |
+
self.bilinear = bilinear
|
54 |
+
if self.bilinear != 'convtrans':
|
55 |
+
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
|
56 |
+
self.up = nn.Upsample(scale_factor=2, mode=self.bilinear)
|
57 |
+
if self.bilinear != 'nearest':
|
58 |
+
self.up = nn.Upsample(scale_factor=2, mode=self.bilinear, align_corners=True)
|
59 |
+
else:
|
60 |
+
self.up = nn.ConvTranspose2d(
|
61 |
+
in_channels1, in_channels2, kernel_size=2, stride=2)
|
62 |
+
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
|
63 |
+
|
64 |
+
def forward(self, x1, x2):
|
65 |
+
if self.bilinear != 'convtrans':
|
66 |
+
x1 = self.conv1x1(x1)
|
67 |
+
x1 = self.up(x1)
|
68 |
+
x = torch.cat([x2, x1], dim=1)
|
69 |
+
return self.conv(x)
|
70 |
+
|
71 |
+
|
72 |
+
class Encoder(nn.Module):
|
73 |
+
def __init__(self, params):
|
74 |
+
super(Encoder, self).__init__()
|
75 |
+
self.params = params
|
76 |
+
self.in_chns = self.params['in_chns']
|
77 |
+
self.ft_chns = self.params['feature_chns']
|
78 |
+
self.n_class = self.params['class_num']
|
79 |
+
self.bilinear = self.params['bilinear']
|
80 |
+
self.dropout = self.params['dropout']
|
81 |
+
assert (len(self.ft_chns) == 5)
|
82 |
+
self.in_conv = ConvBlock(
|
83 |
+
self.in_chns, self.ft_chns[0], self.dropout[0])
|
84 |
+
self.down1 = DownBlock(
|
85 |
+
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
|
86 |
+
self.down2 = DownBlock(
|
87 |
+
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
|
88 |
+
self.down3 = DownBlock(
|
89 |
+
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
|
90 |
+
self.down4 = DownBlock(
|
91 |
+
self.ft_chns[3], self.ft_chns[4], self.dropout[4])
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
x0 = self.in_conv(x)
|
95 |
+
x1 = self.down1(x0)
|
96 |
+
x2 = self.down2(x1)
|
97 |
+
x3 = self.down3(x2)
|
98 |
+
x4 = self.down4(x3)
|
99 |
+
return [x0, x1, x2, x3, x4]
|
100 |
+
|
101 |
+
|
102 |
+
class Decoder(nn.Module):
|
103 |
+
def __init__(self, params):
|
104 |
+
super(Decoder, self).__init__()
|
105 |
+
self.params = params
|
106 |
+
self.in_chns = self.params['in_chns']
|
107 |
+
self.ft_chns = self.params['feature_chns']
|
108 |
+
self.n_class = self.params['class_num']
|
109 |
+
self.bilinear = self.params['bilinear']
|
110 |
+
assert (len(self.ft_chns) == 5)
|
111 |
+
|
112 |
+
self.up1 = UpBlock(
|
113 |
+
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, bilinear=self.bilinear)
|
114 |
+
self.up2 = UpBlock(
|
115 |
+
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, bilinear=self.bilinear)
|
116 |
+
self.up3 = UpBlock(
|
117 |
+
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, bilinear=self.bilinear)
|
118 |
+
self.up4 = UpBlock(
|
119 |
+
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, bilinear=self.bilinear)
|
120 |
+
|
121 |
+
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
|
122 |
+
kernel_size=3, padding=1)
|
123 |
+
|
124 |
+
def forward(self, feature):
|
125 |
+
x0 = feature[0]
|
126 |
+
x1 = feature[1]
|
127 |
+
x2 = feature[2]
|
128 |
+
x3 = feature[3]
|
129 |
+
x4 = feature[4]
|
130 |
+
|
131 |
+
x = self.up1(x4, x3)
|
132 |
+
x = self.up2(x, x2)
|
133 |
+
x = self.up3(x, x1)
|
134 |
+
x = self.up4(x, x0)
|
135 |
+
output = self.out_conv(x)
|
136 |
+
return output
|
137 |
+
|
138 |
+
|
139 |
+
class Decoder_DS(nn.Module):
|
140 |
+
def __init__(self, params):
|
141 |
+
super(Decoder_DS, self).__init__()
|
142 |
+
self.params = params
|
143 |
+
self.in_chns = self.params['in_chns']
|
144 |
+
self.ft_chns = self.params['feature_chns']
|
145 |
+
self.n_class = self.params['class_num']
|
146 |
+
self.bilinear = self.params['bilinear']
|
147 |
+
assert (len(self.ft_chns) == 5)
|
148 |
+
|
149 |
+
self.up1 = UpBlock(
|
150 |
+
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
|
151 |
+
self.up2 = UpBlock(
|
152 |
+
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
|
153 |
+
self.up3 = UpBlock(
|
154 |
+
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
|
155 |
+
self.up4 = UpBlock(
|
156 |
+
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
|
157 |
+
|
158 |
+
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
|
159 |
+
kernel_size=3, padding=1)
|
160 |
+
self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,
|
161 |
+
kernel_size=3, padding=1)
|
162 |
+
self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,
|
163 |
+
kernel_size=3, padding=1)
|
164 |
+
self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,
|
165 |
+
kernel_size=3, padding=1)
|
166 |
+
self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,
|
167 |
+
kernel_size=3, padding=1)
|
168 |
+
|
169 |
+
def forward(self, feature, shape):
|
170 |
+
x0 = feature[0]
|
171 |
+
x1 = feature[1]
|
172 |
+
x2 = feature[2]
|
173 |
+
x3 = feature[3]
|
174 |
+
x4 = feature[4]
|
175 |
+
x = self.up1(x4, x3)
|
176 |
+
dp3_out_seg = self.out_conv_dp3(x)
|
177 |
+
dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)
|
178 |
+
|
179 |
+
x = self.up2(x, x2)
|
180 |
+
dp2_out_seg = self.out_conv_dp2(x)
|
181 |
+
dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)
|
182 |
+
|
183 |
+
x = self.up3(x, x1)
|
184 |
+
dp1_out_seg = self.out_conv_dp1(x)
|
185 |
+
dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)
|
186 |
+
|
187 |
+
x = self.up4(x, x0)
|
188 |
+
dp0_out_seg = self.out_conv(x)
|
189 |
+
return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
|
190 |
+
|
191 |
+
|
192 |
+
class Decoder_URDS(nn.Module):
|
193 |
+
def __init__(self, params):
|
194 |
+
super(Decoder_URDS, self).__init__()
|
195 |
+
self.params = params
|
196 |
+
self.in_chns = self.params['in_chns']
|
197 |
+
self.ft_chns = self.params['feature_chns']
|
198 |
+
self.n_class = self.params['class_num']
|
199 |
+
self.bilinear = self.params['bilinear']
|
200 |
+
assert (len(self.ft_chns) == 5)
|
201 |
+
|
202 |
+
self.up1 = UpBlock(
|
203 |
+
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
|
204 |
+
self.up2 = UpBlock(
|
205 |
+
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
|
206 |
+
self.up3 = UpBlock(
|
207 |
+
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
|
208 |
+
self.up4 = UpBlock(
|
209 |
+
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
|
210 |
+
|
211 |
+
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
|
212 |
+
kernel_size=3, padding=1)
|
213 |
+
self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,
|
214 |
+
kernel_size=3, padding=1)
|
215 |
+
self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,
|
216 |
+
kernel_size=3, padding=1)
|
217 |
+
self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,
|
218 |
+
kernel_size=3, padding=1)
|
219 |
+
self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,
|
220 |
+
kernel_size=3, padding=1)
|
221 |
+
self.feature_noise = FeatureNoise()
|
222 |
+
|
223 |
+
def forward(self, feature, shape):
|
224 |
+
x0 = feature[0]
|
225 |
+
x1 = feature[1]
|
226 |
+
x2 = feature[2]
|
227 |
+
x3 = feature[3]
|
228 |
+
x4 = feature[4]
|
229 |
+
x = self.up1(x4, x3)
|
230 |
+
if self.training:
|
231 |
+
dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5))
|
232 |
+
else:
|
233 |
+
dp3_out_seg = self.out_conv_dp3(x)
|
234 |
+
dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)
|
235 |
+
|
236 |
+
x = self.up2(x, x2)
|
237 |
+
if self.training:
|
238 |
+
dp2_out_seg = self.out_conv_dp2(FeatureDropout(x))
|
239 |
+
else:
|
240 |
+
dp2_out_seg = self.out_conv_dp2(x)
|
241 |
+
dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)
|
242 |
+
|
243 |
+
x = self.up3(x, x1)
|
244 |
+
if self.training:
|
245 |
+
dp1_out_seg = self.out_conv_dp1(self.feature_noise(x))
|
246 |
+
else:
|
247 |
+
dp1_out_seg = self.out_conv_dp1(x)
|
248 |
+
dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)
|
249 |
+
|
250 |
+
x = self.up4(x, x0)
|
251 |
+
dp0_out_seg = self.out_conv(x)
|
252 |
+
return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
|
253 |
+
|
254 |
+
|
255 |
+
def Dropout(x, p=0.5):
|
256 |
+
x = torch.nn.functional.dropout2d(x, p)
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
def FeatureDropout(x):
|
261 |
+
attention = torch.mean(x, dim=1, keepdim=True)
|
262 |
+
max_val, _ = torch.max(attention.view(
|
263 |
+
x.size(0), -1), dim=1, keepdim=True)
|
264 |
+
threshold = max_val * np.random.uniform(0.7, 0.9)
|
265 |
+
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
|
266 |
+
drop_mask = (attention < threshold).float()
|
267 |
+
x = x.mul(drop_mask)
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
class FeatureNoise(nn.Module):
|
272 |
+
def __init__(self, uniform_range=0.3):
|
273 |
+
super(FeatureNoise, self).__init__()
|
274 |
+
self.uni_dist = Uniform(-uniform_range, uniform_range)
|
275 |
+
|
276 |
+
def feature_based_noise(self, x):
|
277 |
+
noise_vector = self.uni_dist.sample(
|
278 |
+
x.shape[1:]).to(x.device).unsqueeze(0)
|
279 |
+
x_noise = x.mul(noise_vector) + x
|
280 |
+
return x_noise
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
x = self.feature_based_noise(x)
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
class UNet(nn.Module):
|
288 |
+
def __init__(self, in_chns, class_num):
|
289 |
+
super(UNet, self).__init__()
|
290 |
+
|
291 |
+
params = {'in_chns': in_chns,
|
292 |
+
'feature_chns': [16, 32, 64, 128, 256],
|
293 |
+
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
|
294 |
+
'class_num': class_num,
|
295 |
+
'bilinear': 'nearest',
|
296 |
+
'acti_func': 'relu'}
|
297 |
+
|
298 |
+
self.encoder = Encoder(params)
|
299 |
+
self.decoder = Decoder(params)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
feature = self.encoder(x)
|
303 |
+
output = self.decoder(feature)
|
304 |
+
return output
|
305 |
+
|
306 |
+
|
307 |
+
class UNet_DS(nn.Module):
|
308 |
+
def __init__(self, in_chns, class_num):
|
309 |
+
super(UNet_DS, self).__init__()
|
310 |
+
|
311 |
+
params = {'in_chns': in_chns,
|
312 |
+
'feature_chns': [16, 32, 64, 128, 256],
|
313 |
+
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
|
314 |
+
'class_num': class_num,
|
315 |
+
'bilinear': False,
|
316 |
+
'acti_func': 'relu'}
|
317 |
+
self.encoder = Encoder(params)
|
318 |
+
self.decoder = Decoder_DS(params)
|
319 |
+
|
320 |
+
def forward(self, x):
|
321 |
+
shape = x.shape[2:]
|
322 |
+
feature = self.encoder(x)
|
323 |
+
dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg = self.decoder(
|
324 |
+
feature, shape)
|
325 |
+
return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
|
326 |
+
|
327 |
+
|
328 |
+
class UNet_CCT(nn.Module):
|
329 |
+
def __init__(self, in_chns, class_num):
|
330 |
+
super(UNet_CCT, self).__init__()
|
331 |
+
|
332 |
+
params = {'in_chns': in_chns,
|
333 |
+
'feature_chns': [16, 32, 64, 128, 256],
|
334 |
+
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
|
335 |
+
'class_num': class_num,
|
336 |
+
'bilinear': 'nearest',
|
337 |
+
'acti_func': 'relu'}
|
338 |
+
self.encoder = Encoder(params)
|
339 |
+
self.main_decoder = Decoder(params)
|
340 |
+
self.aux_decoder1 = Decoder(params)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
feature = self.encoder(x)
|
344 |
+
main_seg = self.main_decoder(feature)
|
345 |
+
aux1_feature = [Dropout(i) for i in feature]
|
346 |
+
aux_seg1 = self.aux_decoder1(aux1_feature)
|
347 |
+
return main_seg, aux_seg1
|
348 |
+
|
349 |
+
|
350 |
+
class UNet_CCT_3H(nn.Module):
|
351 |
+
def __init__(self, in_chns, class_num):
|
352 |
+
super(UNet_CCT_3H, self).__init__()
|
353 |
+
|
354 |
+
params = {'in_chns': in_chns,
|
355 |
+
'feature_chns': [16, 32, 64, 128, 256],
|
356 |
+
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
|
357 |
+
'class_num': class_num,
|
358 |
+
'bilinear': False,
|
359 |
+
'acti_func': 'relu'}
|
360 |
+
self.encoder = Encoder(params)
|
361 |
+
self.main_decoder = Decoder(params)
|
362 |
+
self.aux_decoder1 = Decoder(params)
|
363 |
+
self.aux_decoder2 = Decoder(params)
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
feature = self.encoder(x)
|
367 |
+
main_seg = self.main_decoder(feature)
|
368 |
+
aux1_feature = [Dropout(i) for i in feature]
|
369 |
+
aux_seg1 = self.aux_decoder1(aux1_feature)
|
370 |
+
aux2_feature = [FeatureNoise()(i) for i in feature]
|
371 |
+
aux_seg2 = self.aux_decoder1(aux2_feature)
|
372 |
+
return main_seg, aux_seg1, aux_seg2
|
unet_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a1747468feb2005f81ff818f234f8358553ec017c6c1603dc5e046f2fc6ea39
|
3 |
+
size 7316273
|