andreped commited on
Commit
8ece154
·
unverified ·
2 Parent(s): 00ecfe4 d24d994

Merge pull request #46 from andreped/liver-support

Browse files
Files changed (2) hide show
  1. ddmr/main.py +1 -1
  2. demo/src/gui.py +71 -49
ddmr/main.py CHANGED
@@ -191,7 +191,7 @@ def main():
191
  parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
192
  default=False)
193
  args = parser.parse_args()
194
-
195
  assert os.path.exists(args.fixed), 'Fixed image not found'
196
  assert os.path.exists(args.moving), 'Moving image not found'
197
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
 
191
  parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
192
  default=False)
193
  args = parser.parse_args()
194
+
195
  assert os.path.exists(args.fixed), 'Fixed image not found'
196
  assert os.path.exists(args.moving), 'Moving image not found'
197
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
demo/src/gui.py CHANGED
@@ -28,6 +28,11 @@ class WebUI:
28
  "Liver": "L"
29
  }
30
 
 
 
 
 
 
31
  # define widgets not to be rendered immediantly, but later on
32
  self.slider = gr.Slider(
33
  1,
@@ -48,27 +53,38 @@ class WebUI:
48
  def upload_file(self, files):
49
  return [f.name for f in files]
50
 
51
- def process(self, mesh_file_names):
52
- if not (len(mesh_file_names) in [2, 4]):
53
- raise ValueError("Unsupported number of elements were provided as input to the DDMR CLI."
54
- "Either provided 2 or 4 elements, where the two first being the fixed"
55
- "and moving CT/MRIs and the two other being the binary segmentation"
56
- "which will be used for ROI filtering in preprocessing.")
57
- fixed_image_path = mesh_file_names[0].name
58
- moving_image_path = mesh_file_names[1].name
 
 
 
 
 
 
 
 
 
 
 
 
59
  output_path = self.cwd
 
 
60
 
61
- if len(mesh_file_names) == 2:
62
- run_model(fixed_image_path, moving_image_path, output_path, self.class_names[self.class_name])
63
- else:
64
- fixed_seg_path = mesh_file_names[2].name
65
- moving_seg_path = mesh_file_names[3].name
66
-
67
- run_model(fixed_image_path, moving_image_path, fixed_seg_path, moving_seg_path, output_path, self.class_names[self.class_name])
68
 
69
- self.fixed_images = load_ct_to_numpy(fixed_image_path)
70
- self.moving_images = load_ct_to_numpy(moving_image_path)
71
  self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz")
 
72
  return None
73
 
74
  def get_fixed_image(self, k):
@@ -100,41 +116,47 @@ class WebUI:
100
 
101
  def run(self):
102
  css = """
103
- #model-2d-fixed {
104
- height: 512px;
105
- margin: auto;
106
- }
107
- #model-2d-moving {
108
- height: 512px;
109
- margin: auto;
110
- }
111
- #model-2d-pred {
112
  height: 512px;
113
  margin: auto;
114
  }
115
  #upload {
116
- height: 120px;
117
  }
118
  """
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Row():
121
- file_output = gr.File(file_count="multiple", elem_id="upload")
122
- file_output.upload(self.upload_file, file_output, file_output)
123
-
124
- model_selector = gr.Dropdown(
125
- list(self.class_names.keys()),
126
- label="Task",
127
- info="Which task to perform image-to-registration on",
128
- multiselect=False,
129
- size="sm",
130
- )
131
- model_selector.input(
132
- fn=lambda x: self.set_class_name(x),
133
- inputs=model_selector,
134
- outputs=None,
135
- )
136
-
137
- self.run_btn.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  """
140
  with gr.Row():
@@ -159,7 +181,7 @@ class WebUI:
159
  for i in range(self.nb_slider_items):
160
  visibility = True if i == 1 else False
161
  t = gr.Image(
162
- visible=visibility, elem_id="model-2d-fixed", label="fixed image", show_label=True,
163
  ).style(
164
  height=512,
165
  width=512,
@@ -170,7 +192,7 @@ class WebUI:
170
  for i in range(self.nb_slider_items):
171
  visibility = True if i == 1 else False
172
  t = gr.Image(
173
- visible=visibility, elem_id="model-2d-moving", label="moving image", show_label=True,
174
  ).style(
175
  height=512,
176
  width=512,
@@ -181,7 +203,7 @@ class WebUI:
181
  for i in range(self.nb_slider_items):
182
  visibility = True if i == 1 else False
183
  t = gr.Image(
184
- visible=visibility, elem_id="model-2d-pred", label="predicted fixed image", show_label=True,
185
  ).style(
186
  height=512,
187
  width=512,
@@ -189,8 +211,8 @@ class WebUI:
189
  pred_images.append(t)
190
 
191
  self.run_btn.click(
192
- fn=lambda x: self.process(x),
193
- inputs=file_output,
194
  outputs=None,
195
  )
196
 
 
28
  "Liver": "L"
29
  }
30
 
31
+ self.fixed_image_path = None
32
+ self.moving_image_path = None
33
+ self.fixed_seg_path = None
34
+ self.moving_seg_path = None
35
+
36
  # define widgets not to be rendered immediantly, but later on
37
  self.slider = gr.Slider(
38
  1,
 
53
  def upload_file(self, files):
54
  return [f.name for f in files]
55
 
56
+ def update_fixed(self, cfile):
57
+ self.fixed_image_path = cfile.name
58
+ return self.fixed_image_path
59
+
60
+ def update_moving(self, cfile):
61
+ self.moving_image_path = cfile.name
62
+ return self.moving_image_path
63
+
64
+ def update_fixed_seg(self, cfile):
65
+ self.fixed_seg_path = cfile.name
66
+ return self.fixed_seg_path
67
+
68
+ def update_moving_seg(self, cfile):
69
+ self.moving_seg_path = cfile.name
70
+ return self.moving_seg_path
71
+
72
+ def process(self):
73
+ if (self.fixed_image_path is None) or (self.moving_image_path is None):
74
+ raise ValueError("Please, select both a fixed and moving image before running inference.")
75
+
76
  output_path = self.cwd
77
+
78
+ run_model(self.fixed_image_path, self.moving_image_path, self.fixed_seg_path, self.moving_seg_path, output_path, self.class_names[self.class_name])
79
 
80
+ # reset - to avoid using these segmentations again for new images
81
+ self.fixed_seg_path = None
82
+ self.moving_seg_path = None
 
 
 
 
83
 
84
+ self.fixed_images = load_ct_to_numpy(self.fixed_image_path)
85
+ self.moving_images = load_ct_to_numpy(self.moving_image_path)
86
  self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz")
87
+
88
  return None
89
 
90
  def get_fixed_image(self, k):
 
116
 
117
  def run(self):
118
  css = """
119
+ #model-2d {
 
 
 
 
 
 
 
 
120
  height: 512px;
121
  margin: auto;
122
  }
123
  #upload {
124
+ height: 80px;
125
  }
126
  """
127
  with gr.Blocks(css=css) as demo:
128
  with gr.Row():
129
+
130
+ with gr.Column():
131
+ file_fixed = gr.File(file_count="single", elem_id="upload", label="Select Fixed Image", show_label=True)
132
+ file_fixed.upload(self.update_fixed, file_fixed, file_fixed)
133
+
134
+ file_moving = gr.File(file_count="single", elem_id="upload", label="Select Moving Image", show_label=True)
135
+ file_moving.upload(self.update_moving, file_moving, file_moving)
136
+
137
+ #with gr.Group():
138
+ with gr.Column():
139
+ file_fixed_seg = gr.File(file_count="single", elem_id="upload", label="Select Fixed Seg Image", show_label=True)
140
+ file_fixed_seg.upload(self.update_fixed_seg, file_fixed_seg, file_fixed_seg)
141
+
142
+ file_moving_seg = gr.File(file_count="single", elem_id="upload", label="Select Moving Seg Image", show_label=True)
143
+ file_moving_seg.upload(self.update_moving_seg, file_moving_seg, file_moving_seg)
144
+
145
+ with gr.Column():
146
+ model_selector = gr.Dropdown(
147
+ list(self.class_names.keys()),
148
+ label="Task",
149
+ info="Which task to perform image-to-registration on",
150
+ multiselect=False,
151
+ size="sm",
152
+ )
153
+ model_selector.input(
154
+ fn=lambda x: self.set_class_name(x),
155
+ inputs=model_selector,
156
+ outputs=None,
157
+ )
158
+
159
+ self.run_btn.render()
160
 
161
  """
162
  with gr.Row():
 
181
  for i in range(self.nb_slider_items):
182
  visibility = True if i == 1 else False
183
  t = gr.Image(
184
+ visible=visibility, elem_id="model-2d", label="fixed image", show_label=True,
185
  ).style(
186
  height=512,
187
  width=512,
 
192
  for i in range(self.nb_slider_items):
193
  visibility = True if i == 1 else False
194
  t = gr.Image(
195
+ visible=visibility, elem_id="model-2d", label="moving image", show_label=True,
196
  ).style(
197
  height=512,
198
  width=512,
 
203
  for i in range(self.nb_slider_items):
204
  visibility = True if i == 1 else False
205
  t = gr.Image(
206
+ visible=visibility, elem_id="model-2d", label="predicted fixed image", show_label=True,
207
  ).style(
208
  height=512,
209
  width=512,
 
211
  pred_images.append(t)
212
 
213
  self.run_btn.click(
214
+ fn=self.process,
215
+ inputs=None,
216
  outputs=None,
217
  )
218