andreped commited on
Commit
545e839
·
1 Parent(s): 80a22dd

Added support for seg-guiding during inference

Browse files
Files changed (2) hide show
  1. demo/src/compute.py +8 -3
  2. demo/src/gui.py +12 -4
demo/src/compute.py CHANGED
@@ -1,6 +1,11 @@
1
  import subprocess as sp
2
 
3
 
4
- def run_model(fixed_path, moving_path, output_path, task):
5
- sp.check_call(["ddmr", "--fixed", fixed_path, "--moving", moving_path, \
6
- "-o", output_path, "-a", task, "--model", "BL-NS", "--original-resolution"])
 
 
 
 
 
 
1
  import subprocess as sp
2
 
3
 
4
+ def run_model(fixed_path=None, moving_path=None, fixed_seg_path=None, moving_seg_path=None, output_path=None, task="B"):
5
+ if (fixed_seg_path is None) or (moving_seg_path is None):
6
+ print("The fixed or moving segmentation were not provided and are thus ignored for inference.")
7
+ sp.check_call(["ddmr", "-f", fixed_path, "-m", moving_path, \
8
+ "-o", output_path, "-a", task, "--model", "BL-NS", "--original-resolution"])
9
+ else:
10
+ sp.check_call(["ddmr", "-f", fixed_path, "-m", moving_path, "-fs", fixed_seg_path, "-ms", moving_seg_path,
11
+ "-o", output_path, "-a", task, "--model", "BL-NS", "--original-resolution"])
demo/src/gui.py CHANGED
@@ -1,7 +1,4 @@
1
- import os
2
-
3
  import gradio as gr
4
- import numpy as np
5
 
6
  from .compute import run_model
7
  from .utils import load_ct_to_numpy
@@ -52,11 +49,22 @@ class WebUI:
52
  return [f.name for f in files]
53
 
54
  def process(self, mesh_file_names):
 
 
 
 
 
55
  fixed_image_path = mesh_file_names[0].name
56
  moving_image_path = mesh_file_names[1].name
57
  output_path = self.cwd
58
 
59
- run_model(fixed_image_path, moving_image_path, output_path, self.class_names[self.class_name])
 
 
 
 
 
 
60
 
61
  self.fixed_images = load_ct_to_numpy(fixed_image_path)
62
  self.moving_images = load_ct_to_numpy(moving_image_path)
 
 
 
1
  import gradio as gr
 
2
 
3
  from .compute import run_model
4
  from .utils import load_ct_to_numpy
 
49
  return [f.name for f in files]
50
 
51
  def process(self, mesh_file_names):
52
+ if not (len(mesh_file_names) in [2, 4]):
53
+ raise ValueError("Unsupported number of elements were provided as input to the DDMR CLI."
54
+ "Either provided 2 or 4 elements, where the two first being the fixed"
55
+ "and moving CT/MRIs and the two other being the binary segmentation"
56
+ "which will be used for ROI filtering in preprocessing.")
57
  fixed_image_path = mesh_file_names[0].name
58
  moving_image_path = mesh_file_names[1].name
59
  output_path = self.cwd
60
 
61
+ if len(mesh_file_names) == 2:
62
+ run_model(fixed_image_path, moving_image_path, output_path, self.class_names[self.class_name])
63
+ else:
64
+ fixed_seg_path = mesh_file_names[2].name
65
+ moving_seg_path = mesh_file_names[3].name
66
+
67
+ run_model(fixed_image_path, moving_image_path, fixed_seg_path, moving_seg_path, output_path, self.class_names[self.class_name])
68
 
69
  self.fixed_images = load_ct_to_numpy(fixed_image_path)
70
  self.moving_images = load_ct_to_numpy(moving_image_path)