goryhon commited on
Commit
88237d7
·
verified ·
1 Parent(s): 2e83883

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +4 -2
web-demos/hugging_face/app.py CHANGED
@@ -257,7 +257,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
257
  template_mask[0][0]=1
258
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
259
  # return video_output, video_state, interactive_state, operation_error
260
- chunk_size = 5 # можно попробовать 3, если всё ещё вылетает
261
  masks = []
262
  logits = []
263
  painted_images = []
@@ -266,15 +266,17 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
266
  end = min(start + chunk_size, len(following_frames))
267
  chunk_frames = following_frames[start:end]
268
 
 
269
  chunk_masks, chunk_logits, chunk_painted = model.generator(
270
  images=chunk_frames,
271
- template_mask=template_mask if start == 0 else chunk_masks[-1]
272
  )
273
 
274
  masks.extend(chunk_masks)
275
  logits.extend(chunk_logits)
276
  painted_images.extend(chunk_painted)
277
 
 
278
  model.cutie.clear_memory()
279
 
280
  if interactive_state["track_end_number"]:
 
257
  template_mask[0][0]=1
258
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
259
  # return video_output, video_state, interactive_state, operation_error
260
+ chunk_size = 5
261
  masks = []
262
  logits = []
263
  painted_images = []
 
266
  end = min(start + chunk_size, len(following_frames))
267
  chunk_frames = following_frames[start:end]
268
 
269
+ # всегда используем одну и ту же template_mask
270
  chunk_masks, chunk_logits, chunk_painted = model.generator(
271
  images=chunk_frames,
272
+ template_mask=template_mask
273
  )
274
 
275
  masks.extend(chunk_masks)
276
  logits.extend(chunk_logits)
277
  painted_images.extend(chunk_painted)
278
 
279
+
280
  model.cutie.clear_memory()
281
 
282
  if interactive_state["track_end_number"]: