goryhon commited on
Commit
3fe708d
·
verified ·
1 Parent(s): 61a603e

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +27 -4
web-demos/hugging_face/app.py CHANGED
@@ -233,9 +233,33 @@ def show_mask(video_state, interactive_state, mask_dropdown):
233
  # tracking vos
234
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
235
  operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
237
  model.cutie.clear_memory()
238
-
239
 
240
  if interactive_state["track_end_number"]:
241
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
@@ -251,7 +275,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
251
  orig_h, orig_w = video_state["origin_images"][0].shape[:2]
252
  for mask in video_state["masks"]:
253
  mask_up = cv2.resize(mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
254
- binary_mask = np.where(mask_up > 0, 255, 0).astype(np.uint8)
255
  bw_frame = np.stack([binary_mask]*3, axis=-1) # RGB ч/б
256
  bw_mask_frames.append(bw_frame)
257
 
@@ -278,8 +302,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
278
  # save_mask(video_state["masks"], video_state["video_name"])
279
  #### shanggao code for mask save
280
  return video_output, video_state, interactive_state, operation_log, operation_log
281
-
282
- # inpaint
283
  def inpaint_video(video_state, *_args):
284
  operation_log = [("",""), ("Inpainting started in smooth-overlap mode.","Normal")]
285
 
 
233
  # tracking vos
234
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
235
  operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
236
+ model.cutie.clear_memory()
237
+ if interactive_state["track_end_number"]:
238
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
239
+ else:
240
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
241
+
242
+ if interactive_state["multi_mask"]["masks"]:
243
+ if len(mask_dropdown) == 0:
244
+ mask_dropdown = ["mask_001"]
245
+ mask_dropdown.sort()
246
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
247
+ for i in range(1,len(mask_dropdown)):
248
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
249
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
250
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
251
+ else:
252
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
253
+ fps = video_state["fps"]
254
+
255
+ # operation error
256
+ if len(np.unique(template_mask))==1:
257
+ template_mask[0][0]=1
258
+ operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
259
+ # return video_output, video_state, interactive_state, operation_error
260
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
261
+ # clear GPU memory
262
  model.cutie.clear_memory()
 
263
 
264
  if interactive_state["track_end_number"]:
265
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
 
275
  orig_h, orig_w = video_state["origin_images"][0].shape[:2]
276
  for mask in video_state["masks"]:
277
  mask_up = cv2.resize(mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
278
+ binary_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
279
  bw_frame = np.stack([binary_mask]*3, axis=-1) # RGB ч/б
280
  bw_mask_frames.append(bw_frame)
281
 
 
302
  # save_mask(video_state["masks"], video_state["video_name"])
303
  #### shanggao code for mask save
304
  return video_output, video_state, interactive_state, operation_log, operation_log
305
+
 
306
  def inpaint_video(video_state, *_args):
307
  operation_log = [("",""), ("Inpainting started in smooth-overlap mode.","Normal")]
308