andreped commited on
Commit
7914847
·
1 Parent(s): 545e839

One File importer for each input in demo

Browse files
Files changed (2) hide show
  1. ddmr/main.py +1 -1
  2. demo/src/gui.py +70 -49
ddmr/main.py CHANGED
@@ -191,7 +191,7 @@ def main():
191
  parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
192
  default=False)
193
  args = parser.parse_args()
194
-
195
  assert os.path.exists(args.fixed), 'Fixed image not found'
196
  assert os.path.exists(args.moving), 'Moving image not found'
197
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
 
191
  parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
192
  default=False)
193
  args = parser.parse_args()
194
+
195
  assert os.path.exists(args.fixed), 'Fixed image not found'
196
  assert os.path.exists(args.moving), 'Moving image not found'
197
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
demo/src/gui.py CHANGED
@@ -28,6 +28,11 @@ class WebUI:
28
  "Liver": "L"
29
  }
30
 
 
 
 
 
 
31
  # define widgets not to be rendered immediantly, but later on
32
  self.slider = gr.Slider(
33
  1,
@@ -48,26 +53,36 @@ class WebUI:
48
  def upload_file(self, files):
49
  return [f.name for f in files]
50
 
51
- def process(self, mesh_file_names):
52
- if not (len(mesh_file_names) in [2, 4]):
53
- raise ValueError("Unsupported number of elements were provided as input to the DDMR CLI."
54
- "Either provided 2 or 4 elements, where the two first being the fixed"
55
- "and moving CT/MRIs and the two other being the binary segmentation"
56
- "which will be used for ROI filtering in preprocessing.")
57
- fixed_image_path = mesh_file_names[0].name
58
- moving_image_path = mesh_file_names[1].name
 
 
 
 
 
 
 
 
 
 
 
 
59
  output_path = self.cwd
 
 
60
 
61
- if len(mesh_file_names) == 2:
62
- run_model(fixed_image_path, moving_image_path, output_path, self.class_names[self.class_name])
63
- else:
64
- fixed_seg_path = mesh_file_names[2].name
65
- moving_seg_path = mesh_file_names[3].name
66
-
67
- run_model(fixed_image_path, moving_image_path, fixed_seg_path, moving_seg_path, output_path, self.class_names[self.class_name])
68
 
69
- self.fixed_images = load_ct_to_numpy(fixed_image_path)
70
- self.moving_images = load_ct_to_numpy(moving_image_path)
71
  self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz")
72
  return None
73
 
@@ -100,41 +115,47 @@ class WebUI:
100
 
101
  def run(self):
102
  css = """
103
- #model-2d-fixed {
104
- height: 512px;
105
- margin: auto;
106
- }
107
- #model-2d-moving {
108
- height: 512px;
109
- margin: auto;
110
- }
111
- #model-2d-pred {
112
  height: 512px;
113
  margin: auto;
114
  }
115
  #upload {
116
- height: 120px;
117
  }
118
  """
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Row():
121
- file_output = gr.File(file_count="multiple", elem_id="upload")
122
- file_output.upload(self.upload_file, file_output, file_output)
123
-
124
- model_selector = gr.Dropdown(
125
- list(self.class_names.keys()),
126
- label="Task",
127
- info="Which task to perform image-to-registration on",
128
- multiselect=False,
129
- size="sm",
130
- )
131
- model_selector.input(
132
- fn=lambda x: self.set_class_name(x),
133
- inputs=model_selector,
134
- outputs=None,
135
- )
136
-
137
- self.run_btn.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  """
140
  with gr.Row():
@@ -159,7 +180,7 @@ class WebUI:
159
  for i in range(self.nb_slider_items):
160
  visibility = True if i == 1 else False
161
  t = gr.Image(
162
- visible=visibility, elem_id="model-2d-fixed", label="fixed image", show_label=True,
163
  ).style(
164
  height=512,
165
  width=512,
@@ -170,7 +191,7 @@ class WebUI:
170
  for i in range(self.nb_slider_items):
171
  visibility = True if i == 1 else False
172
  t = gr.Image(
173
- visible=visibility, elem_id="model-2d-moving", label="moving image", show_label=True,
174
  ).style(
175
  height=512,
176
  width=512,
@@ -181,7 +202,7 @@ class WebUI:
181
  for i in range(self.nb_slider_items):
182
  visibility = True if i == 1 else False
183
  t = gr.Image(
184
- visible=visibility, elem_id="model-2d-pred", label="predicted fixed image", show_label=True,
185
  ).style(
186
  height=512,
187
  width=512,
@@ -189,8 +210,8 @@ class WebUI:
189
  pred_images.append(t)
190
 
191
  self.run_btn.click(
192
- fn=lambda x: self.process(x),
193
- inputs=file_output,
194
  outputs=None,
195
  )
196
 
 
28
  "Liver": "L"
29
  }
30
 
31
+ self.fixed_image_path = None
32
+ self.moving_image_path = None
33
+ self.fixed_seg_path = None
34
+ self.moving_seg_path = None
35
+
36
  # define widgets not to be rendered immediantly, but later on
37
  self.slider = gr.Slider(
38
  1,
 
53
  def upload_file(self, files):
54
  return [f.name for f in files]
55
 
56
+ def update_fixed(self, cfile):
57
+ self.fixed_image_path = cfile.name
58
+ return self.fixed_image_path
59
+
60
+ def update_moving(self, cfile):
61
+ self.moving_image_path = cfile.name
62
+ return self.moving_image_path
63
+
64
+ def update_fixed_seg(self, cfile):
65
+ self.fixed_seg_path = cfile.name
66
+ return self.fixed_seg_path
67
+
68
+ def update_moving_seg(self, cfile):
69
+ self.moving_seg_path = cfile.name
70
+ return self.moving_seg_path
71
+
72
+ def process(self):
73
+ if (self.fixed_image_path is None) or (self.moving_image_path is None):
74
+ raise ValueError("Please, select both a fixed and moving image before running inference.")
75
+
76
  output_path = self.cwd
77
+
78
+ run_model(self.fixed_image_path, self.moving_image_path, self.fixed_seg_path, self.moving_seg_path, output_path, self.class_names[self.class_name])
79
 
80
+ # reset - to avoid using these segmentations again for new images
81
+ self.fixed_seg_path = None
82
+ self.moving_seg_path = None
 
 
 
 
83
 
84
+ self.fixed_images = load_ct_to_numpy(self.fixed_image_path)
85
+ self.moving_images = load_ct_to_numpy(self.moving_image_path)
86
  self.pred_images = load_ct_to_numpy(output_path + "pred_image.nii.gz")
87
  return None
88
 
 
115
 
116
  def run(self):
117
  css = """
118
+ #model-2d {
 
 
 
 
 
 
 
 
119
  height: 512px;
120
  margin: auto;
121
  }
122
  #upload {
123
+ height: 80px;
124
  }
125
  """
126
  with gr.Blocks(css=css) as demo:
127
  with gr.Row():
128
+
129
+ with gr.Column():
130
+ file_fixed = gr.File(file_count="single", elem_id="upload", label="Select Fixed Image", show_label=True)
131
+ file_fixed.upload(self.update_fixed, file_fixed, file_fixed)
132
+
133
+ file_moving = gr.File(file_count="single", elem_id="upload", label="Select Moving Image", show_label=True)
134
+ file_moving.upload(self.update_moving, file_moving, file_moving)
135
+
136
+ #with gr.Group():
137
+ with gr.Column():
138
+ file_fixed_seg = gr.File(file_count="single", elem_id="upload", label="Select Fixed Seg Image", show_label=True)
139
+ file_fixed_seg.upload(self.update_fixed_seg, file_fixed_seg, file_fixed_seg)
140
+
141
+ file_moving_seg = gr.File(file_count="single", elem_id="upload", label="Select Moving Seg Image", show_label=True)
142
+ file_moving_seg.upload(self.update_moving_seg, file_moving_seg, file_moving_seg)
143
+
144
+ with gr.Column():
145
+ model_selector = gr.Dropdown(
146
+ list(self.class_names.keys()),
147
+ label="Task",
148
+ info="Which task to perform image-to-registration on",
149
+ multiselect=False,
150
+ size="sm",
151
+ )
152
+ model_selector.input(
153
+ fn=lambda x: self.set_class_name(x),
154
+ inputs=model_selector,
155
+ outputs=None,
156
+ )
157
+
158
+ self.run_btn.render()
159
 
160
  """
161
  with gr.Row():
 
180
  for i in range(self.nb_slider_items):
181
  visibility = True if i == 1 else False
182
  t = gr.Image(
183
+ visible=visibility, elem_id="model-2d", label="fixed image", show_label=True,
184
  ).style(
185
  height=512,
186
  width=512,
 
191
  for i in range(self.nb_slider_items):
192
  visibility = True if i == 1 else False
193
  t = gr.Image(
194
+ visible=visibility, elem_id="model-2d", label="moving image", show_label=True,
195
  ).style(
196
  height=512,
197
  width=512,
 
202
  for i in range(self.nb_slider_items):
203
  visibility = True if i == 1 else False
204
  t = gr.Image(
205
+ visible=visibility, elem_id="model-2d", label="predicted fixed image", show_label=True,
206
  ).style(
207
  height=512,
208
  width=512,
 
210
  pred_images.append(t)
211
 
212
  self.run_btn.click(
213
+ fn=self.process,
214
+ inputs=None,
215
  outputs=None,
216
  )
217