andreped commited on
Commit
c0d9a8c
·
1 Parent(s): ade0ab5

Major refactor on deployment

Browse files
.gitignore CHANGED
@@ -2,3 +2,8 @@
2
  .vs/
3
  *__pycache__/
4
  venv/
 
 
 
 
 
 
2
  .vs/
3
  *__pycache__/
4
  venv/
5
+ *.nii
6
+ *.nii.gz
7
+ *.egg-info/
8
+ *.h5
9
+ *.log
DeepDeformationMapRegistration/main.py CHANGED
@@ -1,16 +1,10 @@
1
  import datetime
2
  import os, sys
3
  import shutil
4
- import re
5
  import argparse
6
  import subprocess
7
  import logging
8
  import time
9
- import warnings
10
-
11
- # currentdir = os.path.dirname(os.path.realpath(__file__))
12
- # parentdir = os.path.dirname(currentdir)
13
- # sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
14
 
15
  import tensorflow as tf
16
 
@@ -20,13 +14,10 @@ from scipy.ndimage import gaussian_filter, zoom
20
  from skimage.measure import regionprops
21
  import SimpleITK as sitk
22
 
23
- import voxelmorph as vxm
24
  from voxelmorph.tf.layers import SpatialTransformer
25
 
26
  import DeepDeformationMapRegistration.utils.constants as C
27
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
28
- from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC
29
- from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
30
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
31
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
32
  from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
@@ -191,7 +182,7 @@ def main():
191
  parser.add_argument('--gpu', type=int,
192
  help='In case of multi-GPU systems, limits the execution to the defined GPU number',
193
  default=None)
194
- parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
195
  default='UW-NSD')
196
  parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
197
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
@@ -206,17 +197,7 @@ def main():
206
  assert os.path.exists(args.moving), 'Moving image not found'
207
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
208
  assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
209
- if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
210
- if args.clear_outputdir:
211
- erase = 'y'
212
- else:
213
- erase = input('Output directory is not empty, erase content? (y/n)')
214
- if erase.lower() in ['y', 'yes']:
215
- shutil.rmtree(args.outputdir, ignore_errors=True)
216
- print('Erased directory: ' + args.outputdir)
217
- elif erase.lower() in ['n', 'no']:
218
- args.outputdir = os.path.join(args.outputdir, datetime.datetime.now().strftime('%H%M%S_%Y%m%d'))
219
- print('New output directory: ' + args.outputdir)
220
  os.makedirs(args.outputdir, exist_ok=True)
221
 
222
  log_format = '%(asctime)s [%(levelname)s]:\t%(message)s'
@@ -296,20 +277,11 @@ def main():
296
 
297
  # 3. Build the whole graph
298
  LOGGER.info('Building TF graph')
299
- ### METRICS GRAPH ###
300
- fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
301
- pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
302
-
303
- ssim_tf = StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric(fix_img_ph, pred_img_ph)
304
- ncc_tf = NCC(image_shape_or).metric(fix_img_ph, pred_img_ph)
305
- mse_tf = vxm.losses.MSE().loss(fix_img_ph, pred_img_ph)
306
- ms_ssim_tf = MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric(fix_img_ph, pred_img_ph)
307
 
308
  LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
309
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
310
 
311
  network, registration_model = load_model(MODEL_FILE, False, True)
312
- deb_model = network.apply_transform
313
 
314
  LOGGER.info('Computing registration')
315
  with sess.as_default():
@@ -317,20 +289,17 @@ def main():
317
  registration_model.summary(line_length=C.SUMMARY_LINE_LENGTH)
318
  LOGGER.info('Computing displacement map...')
319
  time_disp_map_start = time.time()
320
- # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
321
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
322
  time_disp_map_end = time.time()
323
  LOGGER.info(f'\t... done ({time_disp_map_end - time_disp_map_start})')
324
  disp_map = np.squeeze(disp_map)
325
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
326
  debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
327
- # pred_image = min_max_norm(pred_image)
328
- # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
329
- # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
330
 
331
  LOGGER.info('Applying displacement map...')
332
  time_pred_img_start = time.time()
333
- pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
 
334
  time_pred_img_end = time.time()
335
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
336
  pred_image = pred_image[0, ...]
@@ -340,31 +309,10 @@ def main():
340
  moving_image = moving_image_or
341
  fixed_image = fixed_image_or
342
  # disp_map = disp_map_or
343
- pred_image = zoom(pred_image, 1/zoom_factors)
344
  pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
345
  LOGGER.info('Done...')
346
 
347
- LOGGER.info('Computing metrics...')
348
- if args.original_resolution:
349
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
350
- {'fix_img:0': fixed_image[np.newaxis,
351
- crop_min[0]: crop_max[0],
352
- crop_min[1]: crop_max[1],
353
- crop_min[2]: crop_max[2],
354
- ...],
355
- 'pred_img:0': pred_image[np.newaxis,
356
- crop_min[0]: crop_max[0],
357
- crop_min[1]: crop_max[1],
358
- crop_min[2]: crop_max[2],
359
- ...]}) # to only compare the deformed region!
360
- else:
361
-
362
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
363
- {'fix_img:0': fixed_image[np.newaxis, ...],
364
- 'pred_img:0': pred_image[np.newaxis, ...]})
365
- ssim = np.mean(ssim)
366
- ms_ssim = ms_ssim[0]
367
-
368
  if args.original_resolution:
369
  save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'), header=moving_image_header)
370
  else:
@@ -389,26 +337,11 @@ def main():
389
  np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
390
  else:
391
  np.savez_compressed(os.path.join(os.path.join(args.outputdir, 'debug'), 'displacement_map.npz'), disp_map)
 
392
  LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
393
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
394
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
395
 
396
- LOGGER.info('Similarity metrics\n------------------')
397
- LOGGER.info('SSIM: {:.03f}'.format(ssim))
398
- LOGGER.info('NCC: {:.03f}'.format(ncc))
399
- LOGGER.info('MSE: {:.03f}'.format(mse))
400
- LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
401
-
402
- # ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
403
- # {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
404
- # ssim = np.mean(ssim)
405
- # ms_ssim = ms_ssim[0]
406
- # LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
407
- # LOGGER.info('SSIM: {:.03f}'.format(ssim))
408
- # LOGGER.info('NCC: {:.03f}'.format(ncc))
409
- # LOGGER.info('MSE: {:.03f}'.format(mse))
410
- # LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
411
-
412
  del registration_model
413
  LOGGER.info('Done')
414
  exit(0)
 
1
  import datetime
2
  import os, sys
3
  import shutil
 
4
  import argparse
5
  import subprocess
6
  import logging
7
  import time
 
 
 
 
 
8
 
9
  import tensorflow as tf
10
 
 
14
  from skimage.measure import regionprops
15
  import SimpleITK as sitk
16
 
 
17
  from voxelmorph.tf.layers import SpatialTransformer
18
 
19
  import DeepDeformationMapRegistration.utils.constants as C
20
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
 
 
21
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
22
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
23
  from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
 
182
  parser.add_argument('--gpu', type=int,
183
  help='In case of multi-GPU systems, limits the execution to the defined GPU number',
184
  default=None)
185
+ parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, BL-NS, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
186
  default='UW-NSD')
187
  parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
188
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
 
197
  assert os.path.exists(args.moving), 'Moving image not found'
198
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
199
  assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
200
+
 
 
 
 
 
 
 
 
 
 
201
  os.makedirs(args.outputdir, exist_ok=True)
202
 
203
  log_format = '%(asctime)s [%(levelname)s]:\t%(message)s'
 
277
 
278
  # 3. Build the whole graph
279
  LOGGER.info('Building TF graph')
 
 
 
 
 
 
 
 
280
 
281
  LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
282
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
283
 
284
  network, registration_model = load_model(MODEL_FILE, False, True)
 
285
 
286
  LOGGER.info('Computing registration')
287
  with sess.as_default():
 
289
  registration_model.summary(line_length=C.SUMMARY_LINE_LENGTH)
290
  LOGGER.info('Computing displacement map...')
291
  time_disp_map_start = time.time()
 
292
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
293
  time_disp_map_end = time.time()
294
  LOGGER.info(f'\t... done ({time_disp_map_end - time_disp_map_start})')
295
  disp_map = np.squeeze(disp_map)
296
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
297
  debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
 
 
 
298
 
299
  LOGGER.info('Applying displacement map...')
300
  time_pred_img_start = time.time()
301
+ #pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
302
+ pred_image = np.zeros_like(moving_image[np.newaxis, ...]) # @TODO: Replace this with Keras' Model with SpatialTransformer Layer
303
  time_pred_img_end = time.time()
304
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
305
  pred_image = pred_image[0, ...]
 
309
  moving_image = moving_image_or
310
  fixed_image = fixed_image_or
311
  # disp_map = disp_map_or
312
+ pred_image = zoom(pred_image, 1 / zoom_factors)
313
  pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
314
  LOGGER.info('Done...')
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  if args.original_resolution:
317
  save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'), header=moving_image_header)
318
  else:
 
337
  np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
338
  else:
339
  np.savez_compressed(os.path.join(os.path.join(args.outputdir, 'debug'), 'displacement_map.npz'), disp_map)
340
+
341
  LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
342
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
343
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  del registration_model
346
  LOGGER.info('Done')
347
  exit(0)
DeepDeformationMapRegistration/utils/logger.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ LOGGER = logging.getLogger(__name__)
Dockerfile ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+ FROM python:3.8-slim
4
+
5
+ # set language, format and stuff
6
+ ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
7
+
8
+ WORKDIR /code
9
+
10
+ RUN apt-get update -y
11
+ #RUN apt-get install -y python3 python3-pip
12
+ RUN apt install git --fix-missing -y
13
+ RUN apt install wget -y
14
+
15
+ # installing other libraries
16
+ RUN apt-get install python3-pip -y && \
17
+ apt-get -y install sudo
18
+ RUN apt-get install curl -y
19
+ RUN apt-get install nano -y
20
+ RUN apt-get update && apt-get install -y git
21
+ RUN apt-get install libblas-dev -y && apt-get install liblapack-dev -y
22
+ RUN apt-get install gfortran -y
23
+ RUN apt-get install libpng-dev -y
24
+ RUN apt-get install python3-dev -y
25
+
26
+ WORKDIR /code
27
+
28
+ # install dependencies
29
+ COPY ./demo/requirements.txt /code/demo/requirements.txt
30
+ RUN pip install --no-cache-dir --upgrade -r /code/demo/requirements.txt
31
+
32
+ # resolve issue with tf==2.4 and gradio dependency collision issue
33
+ RUN pip install --force-reinstall typing_extensions==4.7.1
34
+
35
+ # lower pydantic version to work with typing_extensions deprecation
36
+ RUN pip install --force-reinstall "pydantic<2.0.0"
37
+
38
+ # Install wget
39
+ RUN apt install wget -y && \
40
+ apt install unzip
41
+
42
+ # Set up a new user named "user" with user ID 1000
43
+ RUN useradd -m -u 1000 user
44
+
45
+ # Switch to the "user" user
46
+ USER user
47
+
48
+ # Set home to the user's home directory
49
+ ENV HOME=/home/user \
50
+ PATH=/home/user/.local/bin:$PATH
51
+
52
+ # Set the working directory to the user's home directory
53
+ WORKDIR $HOME/app
54
+
55
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
56
+ COPY --chown=user . $HOME/app
57
+
58
+ # download test data
59
+ RUN wget "https://github.com/jpdefrutos/DDMR/releases/download/test_data_brain_v0/ixi_image.nii.gz" && \
60
+ wget "https://github.com/jpdefrutos/DDMR/releases/download/test_data_brain_v0/ixi_image2.nii.gz"
61
+
62
+ # CMD ["/bin/bash"]
63
+ CMD ["python3", "app.py"]
README.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  <div align="center">
2
  <img src="https://user-images.githubusercontent.com/30429725/204778476-4d24c659-9287-48b8-b616-92016ffcf4f6.svg" alt="drawing" width="600">
3
  </div>
@@ -6,6 +18,8 @@
6
 
7
  <h1 align="center">DDMR: Deep Deformation Map Registration</h1>
8
  <h3 align="center">Learning deep abdominal CT registration through adaptive loss weighting and synthetic data generation</h3>
 
 
9
 
10
  **DDMR** was developed by SINTEF Health Research. The corresponding manuscript describing the framework has been published in [PLOS ONE](https://journals.plos.org/plosone/) and the manuscript is openly available [here](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0282110).
11
 
@@ -22,8 +36,27 @@ source venv/bin/activate
22
 
23
  2. Install requirements:
24
  ```
25
- pip install -r requirements.txt
 
 
 
 
 
 
26
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## 🏋️‍♂️ Training
29
 
 
1
+ ---
2
+ title: 'DDMR: Deep Deformation Map Registration of CT/MRIs'
3
+ colorFrom: indigo
4
+ colorTo: indigo
5
+ sdk: docker
6
+ app_port: 7860
7
+ emoji: 🧠
8
+ pinned: false
9
+ license: mit
10
+ app_file: demo/app.py
11
+ ---
12
+
13
  <div align="center">
14
  <img src="https://user-images.githubusercontent.com/30429725/204778476-4d24c659-9287-48b8-b616-92016ffcf4f6.svg" alt="drawing" width="600">
15
  </div>
 
18
 
19
  <h1 align="center">DDMR: Deep Deformation Map Registration</h1>
20
  <h3 align="center">Learning deep abdominal CT registration through adaptive loss weighting and synthetic data generation</h3>
21
+
22
+ # ⚠️***WARNING: Under construction***
23
 
24
  **DDMR** was developed by SINTEF Health Research. The corresponding manuscript describing the framework has been published in [PLOS ONE](https://journals.plos.org/plosone/) and the manuscript is openly available [here](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0282110).
25
 
 
36
 
37
  2. Install requirements:
38
  ```
39
+ pip install /path/to/clone/.
40
+ ```
41
+
42
+ ## 🤖 How to use
43
+ Use the following CLI command to register images
44
+ ```
45
+ ddmr --fixed path/to/fixed_image.nii.gz --moving path/to/moving_image.nii.gz --outputdir path/to/output/dir -a <anatomy> --model <model> --gpu <gpu-number> --original-resolution
46
  ```
47
+ where:
48
+ * anatomy: is the type of anatomy you want to register: B (brain) or L (liver)
49
+ * model: is the model you want to use:
50
+ + BL-N (baseline with NCC)
51
+ + BL-NS (baseline with NCC and SSIM)
52
+ + SG-ND (segmentation guided with NCC and DSC)
53
+ + SG-NSD (segmentation guided with NCC, SSIM, and DSC)
54
+ + UW-NSD (uncertainty weighted with NCC, SSIM, and DSC)
55
+ + UW-NSDH (uncertainty weighted with NCC, SSIM, DSC, and HD).
56
+ * gpu: is the GPU number you want to the model to run on, if you have multiple and want to use only one GPU
57
+ * original-resolution: (flag) whether to upsample the registered image to the fixed image resolution (disabled if the flag is not present)
58
+
59
+ Use ```ddmr --help``` to see additional options like using precomputed segmentations to crop the images to the desired ROI, or debugging.
60
 
61
  ## 🏋️‍♂️ Training
62
 
demo/app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+
4
+
5
+ def main():
6
+ parser = ArgumentParser()
7
+ parser.add_argument(
8
+ "--cwd",
9
+ type=str,
10
+ default="/home/user/app/",
11
+ help="Set current working directory (path to app.py).",
12
+ )
13
+ parser.add_argument(
14
+ "--share",
15
+ type=int,
16
+ default=1,
17
+ help="Whether to enable the app to be accessible online"
18
+ "-> setups a public link which requires internet access.",
19
+ )
20
+ args = parser.parse_args()
21
+
22
+ print("Current working directory:", args.cwd)
23
+
24
+ if not os.path.exists(args.cwd):
25
+ raise ValueError("Chosen 'cwd' is not a valid path!")
26
+ if args.share not in [0, 1]:
27
+ raise ValueError(
28
+ "The 'share' argument can only be set to 0 or 1, but was:",
29
+ args.share,
30
+ )
31
+
32
+ # initialize and run app
33
+ print("Launching demo...")
34
+ from src.gui import WebUI
35
+ app = WebUI(cwd=args.cwd, share=args.share)
36
+ app.run()
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
demo/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ddmr @ git+https://github.com/jpdefrutos/DDMR.git
2
+ gradio==3.44.4
demo/src/__init__.py ADDED
File without changes
demo/src/compute.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def run_model(input_path):
2
+ from lungtumormask import mask
3
+ mask.mask(input_path, "./prediction.nii.gz", lung_filter=True, threshold=0.5, radius=1, batch_size=1)
demo/src/gui.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from .inference import run_model
6
+ from .utils import load_ct_to_numpy
7
+ from .utils import load_pred_volume_to_numpy
8
+ from .utils import nifti_to_glb
9
+
10
+
11
+ class WebUI:
12
+ def __init__(
13
+ self,
14
+ model_name: str = None,
15
+ cwd: str = "/home/user/app/",
16
+ share: int = 1,
17
+ ):
18
+ # global states
19
+ self.images = []
20
+ self.pred_images = []
21
+
22
+ # @TODO: This should be dynamically set based on chosen volume size
23
+ self.nb_slider_items = 150
24
+
25
+ self.model_name = model_name
26
+ self.cwd = cwd
27
+ self.share = share
28
+
29
+ self.class_name = "meningioma" # default
30
+ self.class_names = {
31
+ "meningioma": "MRI_Meningioma",
32
+ "low-grade": "MRI_LGGlioma",
33
+ "metastasis": "MRI_Metastasis",
34
+ "high-grade": "MRI_GBM",
35
+ "brain": "MRI_Brain",
36
+ }
37
+
38
+ self.result_names = {
39
+ "meningioma": "Tumor",
40
+ "low-grade": "Tumor",
41
+ "metastasis": "Tumor",
42
+ "high-grade": "Tumor",
43
+ "brain": "Brain",
44
+ }
45
+
46
+ # define widgets not to be rendered immediantly, but later on
47
+ self.slider = gr.Slider(
48
+ 1,
49
+ self.nb_slider_items,
50
+ value=1,
51
+ step=1,
52
+ label="Which 2D slice to show",
53
+ )
54
+ self.volume_renderer = gr.Model3D(
55
+ clear_color=[0.0, 0.0, 0.0, 0.0],
56
+ label="3D Model",
57
+ visible=True,
58
+ elem_id="model-3d",
59
+ ).style(height=512)
60
+
61
+ def set_class_name(self, value):
62
+ print("Changed task to:", value)
63
+ self.class_name = value
64
+
65
+ def combine_ct_and_seg(self, img, pred):
66
+ return (img, [(pred, self.class_name)])
67
+
68
+ def upload_file(self, file):
69
+ return file.name
70
+
71
+ def process(self, mesh_file_name):
72
+ path = mesh_file_name.name
73
+ run_model(
74
+ path,
75
+ model_path=os.path.join(self.cwd, "resources/models/"),
76
+ task=self.class_names[self.class_name],
77
+ name=self.result_names[self.class_name],
78
+ )
79
+ nifti_to_glb("prediction.nii.gz")
80
+
81
+ self.images = load_ct_to_numpy(path)
82
+ self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
83
+ return "./prediction.obj"
84
+
85
+ def get_img_pred_pair(self, k):
86
+ k = int(k) - 1
87
+ out = [gr.AnnotatedImage.update(visible=False)] * self.nb_slider_items
88
+ out[k] = gr.AnnotatedImage.update(
89
+ self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
90
+ visible=True,
91
+ )
92
+ return out
93
+
94
+ def run(self):
95
+ css = """
96
+ #model-3d {
97
+ height: 512px;
98
+ }
99
+ #model-2d {
100
+ height: 512px;
101
+ margin: auto;
102
+ }
103
+ #upload {
104
+ height: 120px;
105
+ }
106
+ """
107
+ with gr.Blocks(css=css) as demo:
108
+ with gr.Row():
109
+ file_output = gr.File(file_count="single", elem_id="upload")
110
+ file_output.upload(self.upload_file, file_output, file_output)
111
+
112
+ model_selector = gr.Dropdown(
113
+ list(self.class_names.keys()),
114
+ label="Task",
115
+ info="Which task to perform - one model for"
116
+ "each brain tumor type and brain extraction",
117
+ multiselect=False,
118
+ size="sm",
119
+ )
120
+ model_selector.input(
121
+ fn=lambda x: self.set_class_name(x),
122
+ inputs=model_selector,
123
+ outputs=None,
124
+ )
125
+
126
+ run_btn = gr.Button("Run analysis").style(
127
+ full_width=False, size="lg"
128
+ )
129
+ run_btn.click(
130
+ fn=lambda x: self.process(x),
131
+ inputs=file_output,
132
+ outputs=self.volume_renderer,
133
+ )
134
+
135
+ with gr.Row():
136
+ gr.Examples(
137
+ examples=[
138
+ os.path.join(self.cwd, "RegLib_C01_1.nii"),
139
+ os.path.join(self.cwd, "RegLib_C01_2.nii"),
140
+ ],
141
+ inputs=file_output,
142
+ outputs=file_output,
143
+ fn=self.upload_file,
144
+ cache_examples=True,
145
+ )
146
+
147
+ with gr.Row():
148
+ with gr.Box():
149
+ with gr.Column():
150
+ image_boxes = []
151
+ for i in range(self.nb_slider_items):
152
+ visibility = True if i == 1 else False
153
+ t = gr.AnnotatedImage(
154
+ visible=visibility, elem_id="model-2d"
155
+ ).style(
156
+ color_map={self.class_name: "#ffae00"},
157
+ height=512,
158
+ width=512,
159
+ )
160
+ image_boxes.append(t)
161
+
162
+ self.slider.input(
163
+ self.get_img_pred_pair, self.slider, image_boxes
164
+ )
165
+
166
+ self.slider.render()
167
+
168
+ with gr.Box():
169
+ self.volume_renderer.render()
170
+
171
+ # sharing app publicly -> share=True:
172
+ # https://gradio.app/sharing-your-app/
173
+ # inference times > 60 seconds -> need queue():
174
+ # https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
175
+ demo.queue().launch(
176
+ server_name="0.0.0.0", server_port=7860, share=self.share
177
+ )
demo/src/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ import numpy as np
3
+ from nibabel.processing import resample_to_output
4
+ from skimage.measure import marching_cubes
5
+
6
+
7
+ def load_ct_to_numpy(data_path):
8
+ if type(data_path) != str:
9
+ data_path = data_path.name
10
+
11
+ image = nib.load(data_path)
12
+ resampled = resample_to_output(image, None, order=0)
13
+ data = resampled.get_fdata()
14
+
15
+ data = np.rot90(data, k=1, axes=(0, 1))
16
+
17
+ data[data < -150] = -150
18
+ data[data > 250] = 250
19
+
20
+ data = data - np.amin(data)
21
+ data = data / np.amax(data) * 255
22
+ data = data.astype("uint8")
23
+
24
+ print(data.shape)
25
+ return [data[..., i] for i in range(data.shape[-1])]
26
+
27
+
28
+ def load_pred_volume_to_numpy(data_path):
29
+ if type(data_path) != str:
30
+ data_path = data_path.name
31
+
32
+ image = nib.load(data_path)
33
+ resampled = resample_to_output(image, None, order=0)
34
+ data = resampled.get_fdata()
35
+
36
+ data = np.rot90(data, k=1, axes=(0, 1))
37
+
38
+ data[data > 0] = 1
39
+ data = data.astype("uint8")
40
+
41
+ print(data.shape)
42
+ return [data[..., i] for i in range(data.shape[-1])]
43
+
44
+
45
+ def nifti_to_glb(path, output="prediction.obj"):
46
+ # load NIFTI into numpy array
47
+ image = nib.load(path)
48
+ resampled = resample_to_output(image, [1, 1, 1], order=1)
49
+ data = resampled.get_fdata().astype("uint8")
50
+
51
+ # extract surface
52
+ verts, faces, normals, values = marching_cubes(data, 0)
53
+ faces += 1
54
+
55
+ with open(output, "w") as thefile:
56
+ for item in verts:
57
+ thefile.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
58
+
59
+ for item in normals:
60
+ thefile.write("vn {0} {1} {2}\n".format(item[0], item[1], item[2]))
61
+
62
+ for item in faces:
63
+ thefile.write(
64
+ "f {0}//{0} {1}//{1} {2}//{2}\n".format(
65
+ item[0], item[1], item[2]
66
+ )
67
+ )
setup.py CHANGED
@@ -23,21 +23,18 @@ setup(
23
  ],
24
  python_requires='>=3.6',
25
  install_requires=[
26
- 'fastrlock>=0.3', # required by cupy-cuda110
27
- 'testresources', # required by launchpadlib
28
  'scipy',
29
  'scikit-image',
30
  'simpleITK',
31
  'voxelmorph==0.1',
32
  'pystrum==0.1',
33
- 'tensorflow-gpu==1.14.0',
34
  'tensorflow-addons',
35
  'tensorflow-datasets',
36
  'tensorflow-metadata',
37
- 'tensorboard==1.14.0',
38
  'nibabel==3.2.1',
39
- 'numpy==1.18.5',
40
- 'h5py==2.10'
41
  ],
42
  entry_points={
43
  'console_scripts': ['ddmr=DeepDeformationMapRegistration.main:main']
 
23
  ],
24
  python_requires='>=3.6',
25
  install_requires=[
 
 
26
  'scipy',
27
  'scikit-image',
28
  'simpleITK',
29
  'voxelmorph==0.1',
30
  'pystrum==0.1',
31
+ 'tensorflow==2.13',
32
  'tensorflow-addons',
33
  'tensorflow-datasets',
34
  'tensorflow-metadata',
 
35
  'nibabel==3.2.1',
36
+ 'numpy',
37
+ 'h5py'
38
  ],
39
  entry_points={
40
  'console_scripts': ['ddmr=DeepDeformationMapRegistration.main:main']