jpdefrutos commited on
Commit
1fe3eab
·
2 Parent(s): 674737f ed3a7b8

Merge remote-tracking branch 'origin/HF_spatialtransformer' into HF_spatialtransformer

Browse files
DeepDeformationMapRegistration/main.py CHANGED
@@ -300,7 +300,6 @@ def main():
300
 
301
  LOGGER.info('Applying displacement map...')
302
  time_pred_img_start = time.time()
303
- # pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
304
  pred_image = spatialtransformer_model.predict([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]])
305
  time_pred_img_end = time.time()
306
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
@@ -313,6 +312,7 @@ def main():
313
  # disp_map = disp_map_or
314
  pred_image = zoom(pred_image, 1 / zoom_factors)
315
  pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
 
316
  LOGGER.info('Done...')
317
 
318
  if args.original_resolution:
 
300
 
301
  LOGGER.info('Applying displacement map...')
302
  time_pred_img_start = time.time()
 
303
  pred_image = spatialtransformer_model.predict([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]])
304
  time_pred_img_end = time.time()
305
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
 
312
  # disp_map = disp_map_or
313
  pred_image = zoom(pred_image, 1 / zoom_factors)
314
  pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
315
+ pred_image = np.squeeze(pred_image, axis=-1)
316
  LOGGER.info('Done...')
317
 
318
  if args.original_resolution:
demo/src/gui.py CHANGED
@@ -20,7 +20,7 @@ class WebUI:
20
  self.pred_images = []
21
 
22
  # @TODO: This should be dynamically set based on chosen volume size
23
- self.nb_slider_items = 150
24
 
25
  self.model_name = model_name
26
  self.cwd = cwd
@@ -60,8 +60,8 @@ class WebUI:
60
 
61
  self.fixed_images = load_ct_to_numpy(fixed_image_path)
62
  self.moving_images = load_ct_to_numpy(moving_image_path)
63
- self.pred_images = np.ones_like(self.moving_images)
64
- return self.pred_images
65
 
66
  def get_fixed_image(self, k):
67
  k = int(k) - 1
@@ -151,7 +151,7 @@ class WebUI:
151
  for i in range(self.nb_slider_items):
152
  visibility = True if i == 1 else False
153
  t = gr.Image(
154
- visible=visibility, elem_id="model-2d-fixed"
155
  ).style(
156
  height=512,
157
  width=512,
@@ -162,7 +162,7 @@ class WebUI:
162
  for i in range(self.nb_slider_items):
163
  visibility = True if i == 1 else False
164
  t = gr.Image(
165
- visible=visibility, elem_id="model-2d-moving"
166
  ).style(
167
  height=512,
168
  width=512,
@@ -173,7 +173,7 @@ class WebUI:
173
  for i in range(self.nb_slider_items):
174
  visibility = True if i == 1 else False
175
  t = gr.Image(
176
- visible=visibility, elem_id="model-2d-pred"
177
  ).style(
178
  height=512,
179
  width=512,
@@ -183,7 +183,7 @@ class WebUI:
183
  self.run_btn.click(
184
  fn=lambda x: self.process(x),
185
  inputs=file_output,
186
- outputs=t,
187
  )
188
 
189
  self.slider.input(
 
20
  self.pred_images = []
21
 
22
  # @TODO: This should be dynamically set based on chosen volume size
23
+ self.nb_slider_items = 128
24
 
25
  self.model_name = model_name
26
  self.cwd = cwd
 
60
 
61
  self.fixed_images = load_ct_to_numpy(fixed_image_path)
62
  self.moving_images = load_ct_to_numpy(moving_image_path)
63
+ self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz")
64
+ return None
65
 
66
  def get_fixed_image(self, k):
67
  k = int(k) - 1
 
151
  for i in range(self.nb_slider_items):
152
  visibility = True if i == 1 else False
153
  t = gr.Image(
154
+ visible=visibility, elem_id="model-2d-fixed", label="fixed image", show_label=True,
155
  ).style(
156
  height=512,
157
  width=512,
 
162
  for i in range(self.nb_slider_items):
163
  visibility = True if i == 1 else False
164
  t = gr.Image(
165
+ visible=visibility, elem_id="model-2d-moving", label="moving image", show_label=True,
166
  ).style(
167
  height=512,
168
  width=512,
 
173
  for i in range(self.nb_slider_items):
174
  visibility = True if i == 1 else False
175
  t = gr.Image(
176
+ visible=visibility, elem_id="model-2d-pred", label="predicted fixed image", show_label=True,
177
  ).style(
178
  height=512,
179
  width=512,
 
183
  self.run_btn.click(
184
  fn=lambda x: self.process(x),
185
  inputs=file_output,
186
+ outputs=None,
187
  )
188
 
189
  self.slider.input(
demo/src/utils.py CHANGED
@@ -21,7 +21,6 @@ def load_ct_to_numpy(data_path):
21
  data = data / np.amax(data) * 255
22
  data = data.astype("uint8")
23
 
24
- print(data.shape)
25
  return [data[..., i] for i in range(data.shape[-1])]
26
 
27
 
@@ -38,7 +37,6 @@ def load_pred_volume_to_numpy(data_path):
38
  data[data > 0] = 1
39
  data = data.astype("uint8")
40
 
41
- print(data.shape)
42
  return [data[..., i] for i in range(data.shape[-1])]
43
 
44
 
 
21
  data = data / np.amax(data) * 255
22
  data = data.astype("uint8")
23
 
 
24
  return [data[..., i] for i in range(data.shape[-1])]
25
 
26
 
 
37
  data[data > 0] = 1
38
  data = data.astype("uint8")
39
 
 
40
  return [data[..., i] for i in range(data.shape[-1])]
41
 
42