Vemund Fredriksen commited on
Commit
c7c329a
·
1 Parent(s): 3ff57cd

Initial pre and post implementation

Browse files
Files changed (1) hide show
  1. lungtumormask/dataprocessing.py +252 -0
lungtumormask/dataprocessing.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lungmask
2
+ from lungmask import mask
3
+ from monai import transforms
4
+ from monai.transforms.intensity.array import ThresholdIntensity
5
+ from monai.transforms.spatial.array import Resize, Spacing
6
+ import torch
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ from monai.transforms import (Compose, LoadImaged, ToNumpyd, ThresholdIntensityd, AddChanneld, NormalizeIntensityd, SpatialCropd, DivisiblePadd, Spacingd, SqueezeDimd)
10
+
11
+ def mask_lung(scan_path, batch_size=20):
12
+ model = lungmask.mask.get_model('unet', 'R231')
13
+ device = torch.device('cuda')
14
+ model.to(device)
15
+
16
+ scan_dict = {
17
+ 'image' : scan_path
18
+ }
19
+
20
+ transformer = Compose(
21
+ [
22
+ LoadImaged(keys=['image']),
23
+ ToNumpyd(keys=['image']),
24
+ ]
25
+ )
26
+
27
+ scan_read = transformer(scan_dict)
28
+ inimg_raw = scan_read['image'].swapaxes(0,2)
29
+
30
+ tvolslices, xnew_box = lungmask.utils.preprocess(inimg_raw, resolution=[256, 256])
31
+ tvolslices[tvolslices > 600] = 600
32
+ tvolslices = np.divide((tvolslices + 1024), 1624)
33
+
34
+ torch_ds_val = lungmask.utils.LungLabelsDS_inf(tvolslices)
35
+ dataloader_val = torch.utils.data.DataLoader(torch_ds_val, batch_size=batch_size, shuffle=False, num_workers=1,
36
+ pin_memory=False)
37
+
38
+ timage_res = np.empty((np.append(0, tvolslices[0].shape)), dtype=np.uint8)
39
+
40
+ with torch.no_grad():
41
+ for X in tqdm(dataloader_val):
42
+ X = X.float().to(device)
43
+ prediction = model(X)
44
+ pls = torch.max(prediction, 1)[1].detach().cpu().numpy().astype(np.uint8)
45
+ timage_res = np.vstack((timage_res, pls))
46
+
47
+ outmask = lungmask.utils.postrocessing(timage_res)
48
+
49
+
50
+ outmask = np.asarray(
51
+ [lungmask.utils.reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:]) for i in range(outmask.shape[0])],
52
+ dtype=np.uint8)
53
+
54
+ outmask = np.swapaxes(outmask, 0, 2)
55
+ #outmask = np.flip(outmask, 0)
56
+
57
+
58
+ return outmask.astype(np.uint8), scan_read['image_meta_dict']['affine']
59
+
60
+ def calculate_extremes(image, annotation_value):
61
+
62
+ holder = np.copy(image)
63
+
64
+ x_min = float('inf')
65
+ x_max = 0
66
+ y_min = float('inf')
67
+ y_max = 0
68
+ z_min = -1
69
+ z_max = 0
70
+
71
+ holder[holder != annotation_value] = 0
72
+
73
+ holder = np.swapaxes(holder, 0, 2)
74
+ for i, layer in enumerate(holder):
75
+ if(np.amax(layer) < 1):
76
+ continue
77
+ if(z_min == -1):
78
+ z_min = i
79
+ z_max = i
80
+
81
+ y = np.any(layer, axis = 1)
82
+ x = np.any(layer, axis = 0)
83
+ y_minl, y_maxl = np.argmax(y) + 1, layer.shape[0] - np.argmax(np.flipud(y))
84
+ x_minl, x_maxl = np.argmax(x) + 1, layer.shape[1] - np.argmax(np.flipud(x))
85
+
86
+ if(y_minl < y_min):
87
+ y_min = y_minl
88
+ if(x_minl < x_min):
89
+ x_min = x_minl
90
+ if(y_maxl > y_max):
91
+ y_max = y_maxl
92
+ if(x_maxl > x_max):
93
+ x_max = x_maxl
94
+
95
+ return ((x_min, x_max), (y_min, y_max), (z_min, z_max))
96
+
97
+ def process_lung_scan(scan_dict, extremes):
98
+
99
+ load_transformer = Compose(
100
+ [
101
+ LoadImaged(keys=["image"]),
102
+ ThresholdIntensityd(keys=['image'], above = False, threshold = 1000, cval = 1000),
103
+ ThresholdIntensityd(keys=['image'], above = True, threshold = -1024, cval = -1024),
104
+ AddChanneld(keys=["image"]),
105
+ NormalizeIntensityd(keys=["image"]),
106
+ SpatialCropd(keys=["image"], roi_start=(extremes[0][0], extremes[1][0], extremes[2][0]), roi_end=(extremes[0][1], extremes[1][1], extremes[2][1])),
107
+ Spacingd(keys=["image"], pixdim=(1, 1, 1.5)),
108
+ ]
109
+ )
110
+
111
+ processed_1 = load_transformer(scan_dict)
112
+
113
+ transformer_1 = Compose(
114
+ [
115
+ DivisiblePadd(keys=["image"], k=16, mode='constant'),
116
+ SqueezeDimd(keys=["image"], dim = 0),
117
+ ToNumpyd(keys=["image"]),
118
+ ]
119
+ )
120
+
121
+ processed_2 = transformer_1(processed_1)
122
+
123
+ affine = processed_2['image_meta_dict']['affine']
124
+
125
+ normalized_image = processed_2['image']
126
+
127
+ return normalized_image, affine
128
+
129
+ def preprocess(image_path):
130
+
131
+ preprocess_dump = {}
132
+
133
+ scan_dict = {
134
+ 'image' : image_path
135
+ }
136
+
137
+ im = LoadImaged(keys=['image'])(scan_dict)
138
+ preprocess_dump['org_shape'] = im['image'].shape
139
+ preprocess_dump['pixdim'] = im['image_meta_dict']['pixdim'][1:4]
140
+ preprocess_dump['org_affine'] = im['image_meta_dict']['affine']
141
+
142
+ masked_lungs = mask_lung(image_path, 5)
143
+ right_lung_extreme = calculate_extremes(masked_lungs[0], 1)
144
+ preprocess_dump['right_extremes'] = right_lung_extreme
145
+ right_lung_processed = process_lung_scan(scan_dict, right_lung_extreme)
146
+
147
+ left_lung_extreme = calculate_extremes(masked_lungs[0], 2)
148
+ preprocess_dump['left_extremes'] = left_lung_extreme
149
+ left_lung_processed = process_lung_scan(scan_dict, left_lung_extreme)
150
+
151
+
152
+ preprocess_dump['affine'] = left_lung_processed[1]
153
+
154
+ preprocess_dump['right_lung'] = right_lung_processed[0]
155
+ preprocess_dump['left_lung'] = left_lung_processed[0]
156
+
157
+ return preprocess_dump
158
+
159
+ def find_pad_edge(original):
160
+ a_min = -1
161
+ a_max = original.shape[0]
162
+
163
+ for i in range(len(original)):
164
+ a_min = i
165
+ if(np.any(original[i])):
166
+ break
167
+
168
+ for i in range(len(original) - 1, 0, -1):
169
+ a_max = i
170
+ if(np.any(original[i])):
171
+ break
172
+
173
+ original = original.swapaxes(0,1)
174
+
175
+ b_min = -1
176
+ b_max = original.shape[0]
177
+
178
+ for i in range(len(original)):
179
+ b_min = i
180
+ if(np.any(original[i])):
181
+ break
182
+
183
+ for i in range(len(original) - 1, 0, -1):
184
+ b_max = i
185
+ if(np.any(original[i])):
186
+ break
187
+
188
+ original = original.swapaxes(0,1)
189
+ original = original.swapaxes(0,2)
190
+
191
+ c_min = -1
192
+ c_max = original.shape[0]
193
+
194
+ for i in range(len(original)):
195
+ c_min = i
196
+ if(np.any(original[i])):
197
+ break
198
+
199
+ for i in range(len(original) - 1, 0, -1):
200
+ c_max = i
201
+ if(np.any(original[i])):
202
+ break
203
+
204
+ return a_min, a_max + 1, b_min, b_max + 1, c_min, c_max + 1
205
+
206
+
207
+ def remove_pad(mask, original):
208
+ a_min, a_max, b_min, b_max, c_min, c_max = find_pad_edge(original)
209
+ return mask[a_min:a_max, b_min:b_max, c_min: c_max]
210
+
211
+ def voxel_space(image, target):
212
+ image = Resize((target[0][1]-target[0][0], target[1][1]-target[1][0], target[2][1]-target[2][0]), mode='trilinear')(np.expand_dims(image, 0))[0]
213
+ image = ThresholdIntensity(above = False, threshold = 0.5, cval = 1)(image)
214
+ image = ThresholdIntensity(above = True, threshold = 0.5, cval = 0)(image)
215
+
216
+ return image
217
+
218
+ def stitch(org_shape, cropped, roi):
219
+ holder = np.zeros(org_shape)
220
+
221
+ holder[roi[0][0]:roi[0][1], roi[1][0]:roi[1][1], roi[2][0]:roi[2][1]] = cropped
222
+
223
+ return holder
224
+
225
+ def post_process(left_mask, right_mask, preprocess_dump):
226
+ left = remove_pad(left_mask, preprocess_dump['left_lung'])
227
+ right = remove_pad(right_mask, preprocess_dump['right_lung'])
228
+
229
+ left = voxel_space(left, preprocess_dump['left_extremes'])
230
+ right = voxel_space(right, preprocess_dump['right_extremes'])
231
+
232
+ left = stitch(preprocess_dump['org_shape'], left, preprocess_dump['left_extremes'])
233
+ right = stitch(preprocess_dump['org_shape'], right, preprocess_dump['right_extremes'])
234
+
235
+ stitched = np.logical_or(left, right).astype(int)
236
+
237
+ return stitched
238
+
239
+
240
+ if __name__ == "__main__":
241
+ path = "D:\\Datasets\MSD\\Images\\lung_003.nii.gz"
242
+ preprocess_dump = preprocess(path)
243
+
244
+ unpad = post_process(preprocess_dump['left_lung'], preprocess_dump['right_lung'], preprocess_dump)
245
+
246
+ import nibabel
247
+
248
+ nimage = nibabel.Nifti1Image(unpad, preprocess_dump['org_affine'])
249
+ nibabel.save(nimage, "D:\\Datasets\\stitched.nii.gz")
250
+
251
+
252
+