andreped commited on
Commit
49b38ba
·
unverified ·
1 Parent(s): bc39452

Added bug fixes to the demo

Browse files
Files changed (2) hide show
  1. demo/src/compute.py +2 -2
  2. demo/src/gui.py +50 -38
demo/src/compute.py CHANGED
@@ -1,6 +1,6 @@
1
  import subprocess as sp
2
 
3
 
4
- def run_model(fixed_path, moving_path, output_path):
5
  sp.check_call(["ddmr", "--fixed", fixed_path, "--moving", moving_path, \
6
- "-o", output_path, "-a", "B", "--model", "BL-NS", "--original-resolution"])
 
1
  import subprocess as sp
2
 
3
 
4
+ def run_model(fixed_path, moving_path, output_path, task):
5
  sp.check_call(["ddmr", "--fixed", fixed_path, "--moving", moving_path, \
6
+ "-o", output_path, "-a", task, "--model", "BL-NS", "--original-resolution"])
demo/src/gui.py CHANGED
@@ -28,8 +28,8 @@ class WebUI:
28
  self.share = share
29
 
30
  self.class_names = {
31
- "B": "Brain",
32
- "L": "Liver"
33
  }
34
 
35
  # define widgets not to be rendered immediantly, but later on
@@ -40,42 +40,49 @@ class WebUI:
40
  step=1,
41
  label="Which 2D slice to show",
42
  )
43
- self.volume_renderer = gr.Model3D(
44
- clear_color=[0.0, 0.0, 0.0, 0.0],
45
- label="3D Model",
46
- visible=True,
47
- elem_id="model-3d",
48
- ).style(height=512)
49
 
50
  def set_class_name(self, value):
51
  print("Changed task to:", value)
52
  self.class_name = value
53
 
54
- def combine_ct_and_seg(self, img, pred):
55
- return (img, [(pred, self.class_name)])
56
-
57
  def upload_file(self, file):
58
  return file.name
59
 
60
- def process(self, mesh_file_name):
61
- path = mesh_file_name.name
62
- run_model(
63
- path,
64
- model_path=os.path.join(self.cwd, "resources/models/"),
65
- task=self.class_names[self.class_name],
66
- name=self.result_names[self.class_name],
67
- )
68
- nifti_to_glb("prediction.nii.gz")
69
 
70
- self.images = load_ct_to_numpy(path)
71
- self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
72
- return "./prediction.obj"
73
 
74
- def get_img_pred_pair(self, k):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  k = int(k) - 1
76
- out = [gr.AnnotatedImage.update(visible=False)] * self.nb_slider_items
77
- out[k] = gr.AnnotatedImage.update(
78
- self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
 
 
 
 
 
 
 
 
 
79
  visible=True,
80
  )
81
  return out
@@ -95,13 +102,13 @@ class WebUI:
95
  """
96
  with gr.Blocks(css=css) as demo:
97
  with gr.Row():
98
- file_output = gr.File(file_count="single", elem_id="upload")
99
  file_output.upload(self.upload_file, file_output, file_output)
100
 
101
  model_selector = gr.Dropdown(
102
  list(self.class_names.keys()),
103
  label="Task",
104
- info="Which task to perform registration for",
105
  multiselect=False,
106
  size="sm",
107
  )
@@ -117,7 +124,7 @@ class WebUI:
117
  run_btn.click(
118
  fn=lambda x: self.process(x),
119
  inputs=file_output,
120
- outputs=self.volume_renderer,
121
  )
122
 
123
  with gr.Row():
@@ -135,27 +142,32 @@ class WebUI:
135
  with gr.Row():
136
  with gr.Box():
137
  with gr.Column():
138
- image_boxes = []
139
  for i in range(self.nb_slider_items):
140
  visibility = True if i == 1 else False
141
- t = gr.AnnotatedImage(
142
  visible=visibility, elem_id="model-2d"
143
  ).style(
144
- color_map={self.class_name: "#ffae00"},
145
  height=512,
146
  width=512,
147
  )
148
- image_boxes.append(t)
 
 
 
149
 
150
  self.slider.input(
151
- self.get_img_pred_pair, self.slider, image_boxes
 
 
 
 
 
 
152
  )
153
 
154
  self.slider.render()
155
 
156
- with gr.Box():
157
- self.volume_renderer.render()
158
-
159
  # sharing app publicly -> share=True:
160
  # https://gradio.app/sharing-your-app/
161
  # inference times > 60 seconds -> need queue():
 
28
  self.share = share
29
 
30
  self.class_names = {
31
+ "Brain": "B",
32
+ "Liver": "L"
33
  }
34
 
35
  # define widgets not to be rendered immediantly, but later on
 
40
  step=1,
41
  label="Which 2D slice to show",
42
  )
 
 
 
 
 
 
43
 
44
  def set_class_name(self, value):
45
  print("Changed task to:", value)
46
  self.class_name = value
47
 
 
 
 
48
  def upload_file(self, file):
49
  return file.name
50
 
51
+ def process(self, mesh_file_names):
52
+ fixed_image_path = mesh_file_names[0].name
53
+ moving_image_path = mesh_file_names[1].name
 
 
 
 
 
 
54
 
55
+ run_model(fixed_path, moving_path, output_path, self.class_names[self.class_name])
 
 
56
 
57
+ self.fixed_images = load_ct_to_numpy(fixed_image_path)
58
+ self.moving_images = load_ct_to_numpy(moving_image_path)
59
+ #self.pred_images = load_ct_to_numpy("./prediction.nii.gz")
60
+ self.pred_images = np.ones_like(moving_images)
61
+ return None
62
+
63
+ def get_fixed_image(self, k):
64
+ k = int(k) - 1
65
+ out = [gr.Image.update(visible=False)] * self.nb_slider_items
66
+ out[k] = gr.Image.update(
67
+ self.fixed_images[k]
68
+ visible=True,
69
+ )
70
+ return out
71
+
72
+ def get_moving_image(self, k):
73
  k = int(k) - 1
74
+ out = [gr.Image.update(visible=False)] * self.nb_slider_items
75
+ out[k] = gr.Image.update(
76
+ self.moving_images[k]
77
+ visible=True,
78
+ )
79
+ return out
80
+
81
+ def get_pred_image(self, k):
82
+ k = int(k) - 1
83
+ out = [gr.Image.update(visible=False)] * self.nb_slider_items
84
+ out[k] = gr.Image.update(
85
+ self.pred_images[k]
86
  visible=True,
87
  )
88
  return out
 
102
  """
103
  with gr.Blocks(css=css) as demo:
104
  with gr.Row():
105
+ file_output = gr.File(file_count="multiple", elem_id="upload")
106
  file_output.upload(self.upload_file, file_output, file_output)
107
 
108
  model_selector = gr.Dropdown(
109
  list(self.class_names.keys()),
110
  label="Task",
111
+ info="Which task to perform image-to-registration on",
112
  multiselect=False,
113
  size="sm",
114
  )
 
124
  run_btn.click(
125
  fn=lambda x: self.process(x),
126
  inputs=file_output,
127
+ outputs=None,
128
  )
129
 
130
  with gr.Row():
 
142
  with gr.Row():
143
  with gr.Box():
144
  with gr.Column():
145
+ fixed_images = []
146
  for i in range(self.nb_slider_items):
147
  visibility = True if i == 1 else False
148
+ t = gr.Image(
149
  visible=visibility, elem_id="model-2d"
150
  ).style(
 
151
  height=512,
152
  width=512,
153
  )
154
+ fixed_images.append(t)
155
+
156
+ moving_images = fixed_images.copy()
157
+ pred_images = fixed_images.copy()
158
 
159
  self.slider.input(
160
+ self.get_fixed_image, self.slider, fixed_images
161
+ )
162
+ self.slider.input(
163
+ self.get_moving_image, self.slider, moving_images
164
+ )
165
+ self.slider.input(
166
+ self.get_pred_image, self.slider, pred_images
167
  )
168
 
169
  self.slider.render()
170
 
 
 
 
171
  # sharing app publicly -> share=True:
172
  # https://gradio.app/sharing-your-app/
173
  # inference times > 60 seconds -> need queue():