Spaces:
bla
/
Runtime error

bla commited on
Commit
6e871ac
·
verified ·
1 Parent(s): b950bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -574
app.py CHANGED
@@ -1,615 +1,178 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
 
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import copy
8
  import os
9
- from datetime import datetime
10
  import tempfile
 
 
11
 
12
  import cv2
13
- import matplotlib.pyplot as plt
14
  import numpy as np
15
- import gradio as gr
 
16
  import torch
17
-
18
  from moviepy.editor import ImageSequenceClip
19
- from PIL import Image
20
- from sam2.build_sam import build_sam2_video_predictor
21
-
22
- # Remove CUDA environment variables
23
- if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
24
- del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
25
-
26
- # Description
27
- title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
28
 
29
- description_p = """# Instructions
30
- <ol>
31
- <li> Upload one video or click one example video</li>
32
- <li> Click 'include' point type, select the object to segment and track</li>
33
- <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
34
- <li> Click the 'Track' button to obtain the masked video </li>
35
- </ol>
36
- """
37
-
38
- # examples - keeping fewer examples to reduce memory footprint
39
- examples = [
40
- ["examples/01_dog.mp4"],
41
- ["examples/02_cups.mp4"],
42
- ["examples/03_blocks.mp4"],
43
- ["examples/04_coffee.mp4"],
44
- ["examples/05_default_juggle.mp4"],
45
- ]
46
 
47
- OBJ_ID = 0
 
48
 
49
- # Initialize model on CPU - add error handling for file paths
50
  sam2_checkpoint = "checkpoints/edgetam.pt"
51
  model_cfg = "edgetam.yaml"
 
 
52
 
53
- # Check if model files exist
54
- def check_file_exists(filepath):
55
- import os
56
- exists = os.path.exists(filepath)
57
- if not exists:
58
- print(f"WARNING: File not found: {filepath}")
59
- return exists
60
-
61
- # Verify files exist
62
- model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
63
- try:
64
- # Load model with more careful error handling
65
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
66
- print("predictor loaded on CPU")
67
- except Exception as e:
68
- print(f"Error loading model: {e}")
69
- import traceback
70
- traceback.print_exc()
71
- # Still create a predictor variable to avoid NameError
72
  predictor = None
73
 
74
- # Function to get video frame rate
75
- def get_video_fps(video_path):
76
  cap = cv2.VideoCapture(video_path)
77
- if not cap.isOpened():
78
- print("Error: Could not open video.")
79
- return 30.0 # Default fallback value
80
  fps = cap.get(cv2.CAP_PROP_FPS)
81
  cap.release()
82
  return fps
83
 
 
 
 
 
 
84
 
85
- def reset(session_state):
86
- session_state["input_points"] = []
87
- session_state["input_labels"] = []
88
- if session_state["inference_state"] is not None:
89
- predictor.reset_state(session_state["inference_state"])
90
- session_state["first_frame"] = None
91
- session_state["all_frames"] = None
92
- session_state["inference_state"] = None
93
- return (
94
- None,
95
- gr.update(open=True),
96
- None,
97
- None,
98
- gr.update(value=None, visible=False),
99
- session_state,
100
- )
101
-
102
-
103
- def clear_points(session_state):
104
- session_state["input_points"] = []
105
- session_state["input_labels"] = []
106
- if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
107
- predictor.reset_state(session_state["inference_state"])
108
- return (
109
- session_state["first_frame"],
110
- None,
111
- gr.update(value=None, visible=False),
112
- session_state,
113
- )
114
-
115
 
116
- def preprocess_video_in(video_path, session_state):
117
- if video_path is None:
118
- return (
119
- gr.update(open=True), # video_in_drawer
120
- None, # points_map
121
- None, # output_image
122
- gr.update(value=None, visible=False), # output_video
123
- session_state,
124
- )
125
-
126
- # Read the first frame
127
  cap = cv2.VideoCapture(video_path)
128
- if not cap.isOpened():
129
- print("Error: Could not open video.")
130
- return (
131
- gr.update(open=True), # video_in_drawer
132
- None, # points_map
133
- None, # output_image
134
- gr.update(value=None, visible=False), # output_video
135
- session_state,
136
- )
137
-
138
- # For CPU optimization - determine video properties
139
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
140
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
141
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
142
-
143
- # Determine if we need to resize for CPU performance
144
- target_width = 640 # Target width for processing on CPU
145
- scale_factor = 1.0
146
-
147
- if frame_width > target_width:
148
- scale_factor = target_width / frame_width
149
- frame_width = target_width
150
- frame_height = int(frame_height * scale_factor)
151
 
152
- # Read frames - for CPU we'll be more selective about which frames to keep
153
- frame_number = 0
154
- first_frame = None
155
- all_frames = []
156
-
157
- # For CPU optimization, skip frames if video is too long
158
- frame_stride = 1
159
- if total_frames > 300: # If more than 300 frames
160
- frame_stride = max(1, int(total_frames / 300)) # Process at most ~300 frames
161
 
 
 
 
 
 
162
  while True:
163
  ret, frame = cap.read()
164
- if not ret:
165
- break
166
-
167
- if frame_number % frame_stride == 0: # Process every frame_stride frames
168
- # Resize the frame if needed
169
- if scale_factor != 1.0:
170
- frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
171
-
172
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
173
- frame = np.array(frame)
174
-
175
- # Store the first frame
176
- if first_frame is None:
177
- first_frame = frame
178
- all_frames.append(frame)
179
-
180
- frame_number += 1
181
-
182
  cap.release()
183
- session_state["first_frame"] = copy.deepcopy(first_frame)
184
- session_state["all_frames"] = all_frames
185
- session_state["frame_stride"] = frame_stride
186
- session_state["scale_factor"] = scale_factor
187
- session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
188
- int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
189
-
190
- session_state["inference_state"] = predictor.init_state(video_path=video_path)
191
- session_state["input_points"] = []
192
- session_state["input_labels"] = []
193
-
194
- return [
195
- gr.update(open=False), # video_in_drawer
196
- first_frame, # points_map
197
- None, # output_image
198
- gr.update(value=None, visible=False), # output_video
199
- session_state,
200
- ]
201
-
202
-
203
- def segment_with_points(
204
- point_type,
205
- session_state,
206
- evt: gr.SelectData,
207
- ):
208
- session_state["input_points"].append(evt.index)
209
- print(f"TRACKING INPUT POINT: {session_state['input_points']}")
210
-
211
- if point_type == "include":
212
- session_state["input_labels"].append(1)
213
- elif point_type == "exclude":
214
- session_state["input_labels"].append(0)
215
- print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
216
 
217
- # Open the image and get its dimensions
218
- first_frame = session_state["first_frame"]
219
- h, w = first_frame.shape[:2]
220
- transparent_background = Image.fromarray(first_frame).convert("RGBA")
221
-
222
- # Define the circle radius as a fraction of the smaller dimension
223
- fraction = 0.01 # You can adjust this value as needed
224
- radius = int(fraction * min(w, h))
225
-
226
- # Create a transparent layer to draw on
227
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
228
-
229
- for index, track in enumerate(session_state["input_points"]):
230
- if session_state["input_labels"][index] == 1:
231
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
232
- else:
233
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
234
-
235
- # Convert the transparent layer back to an image
236
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
237
- selected_point_map = Image.alpha_composite(
238
- transparent_background, transparent_layer
239
- )
240
-
241
- # Let's add a positive click at (x, y) = (210, 350) to get started
242
- points = np.array(session_state["input_points"], dtype=np.float32)
243
- # for labels, `1` means positive click and `0` means negative click
244
- labels = np.array(session_state["input_labels"], np.int32)
245
 
246
- try:
247
- # For CPU optimization, we'll process with smaller batch size
248
- _, _, out_mask_logits = predictor.add_new_points(
249
- inference_state=session_state["inference_state"],
250
- frame_idx=0,
251
- obj_id=OBJ_ID,
252
- points=points,
253
- labels=labels,
254
- )
255
-
256
- # Create the mask
257
- mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
258
-
259
- # Ensure the mask has the same size as the frame
260
- if mask_array.shape[:2] != (h, w):
261
- mask_array = cv2.resize(
262
- mask_array.astype(np.uint8),
263
- (w, h),
264
- interpolation=cv2.INTER_NEAREST
265
- ).astype(bool)
266
-
267
- mask_image = show_mask(mask_array)
268
-
269
- # Make sure mask_image has the same size as the background
270
- if mask_image.size != transparent_background.size:
271
- mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
272
-
273
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
274
- except Exception as e:
275
- print(f"Error in segmentation: {e}")
276
- # Return just the points as fallback
277
- first_frame_output = selected_point_map
278
-
279
- return selected_point_map, first_frame_output, session_state
280
-
281
-
282
- def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
283
- if random_color:
284
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
285
- else:
286
- cmap = plt.get_cmap("tab10")
287
- cmap_idx = 0 if obj_id is None else obj_id
288
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
289
 
290
- # Handle different mask shapes properly
291
- if len(mask.shape) == 2:
292
- h, w = mask.shape
293
- else:
294
- h, w = mask.shape[-2:]
295
-
296
- # Ensure correct reshaping based on mask dimensions
297
- mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
298
- mask_rgba = (mask_reshaped * 255).astype(np.uint8)
299
-
300
- if convert_to_image:
301
- try:
302
- # Ensure the mask has correct RGBA shape (h, w, 4)
303
- if mask_rgba.shape[2] != 4:
304
- # If not RGBA, create a proper RGBA array
305
- proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
306
- # Copy available channels
307
- proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
308
- mask_rgba = proper_mask
309
-
310
- # Create the PIL image
311
- return Image.fromarray(mask_rgba, "RGBA")
312
- except Exception as e:
313
- print(f"Error converting mask to image: {e}")
314
- # Fallback: create a blank transparent image of correct size
315
- blank = np.zeros((h, w, 4), dtype=np.uint8)
316
- return Image.fromarray(blank, "RGBA")
317
-
318
- return mask_rgba
319
-
320
 
321
- def propagate_to_all(
322
- video_in,
323
- session_state,
324
- ):
325
- if (
326
- len(session_state["input_points"]) == 0
327
- or video_in is None
328
- or session_state["inference_state"] is None
329
- ):
330
- return (
331
- None,
332
- session_state,
333
- )
334
-
335
- # For CPU optimization: process in smaller batches
336
- chunk_size = 3 # Process 3 frames at a time to avoid memory issues on CPU
337
-
338
  try:
339
- # run propagation throughout the video and collect the results in a dict
340
- video_segments = {} # video_segments contains the per-frame segmentation results
341
- print("starting propagate_in_video on CPU")
342
-
343
- # Get the frames in chunks for CPU memory optimization
344
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
345
- session_state["inference_state"]
346
- ):
347
- try:
348
- # Store the masks for each object ID
349
- video_segments[out_frame_idx] = {
350
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
351
- for i, out_obj_id in enumerate(out_obj_ids)
352
- }
353
-
354
- print(f"Processed frame {out_frame_idx}")
355
-
356
- # Release memory periodically
357
- if out_frame_idx % chunk_size == 0:
358
- # Explicitly clear any tensors
359
- del out_mask_logits
360
- import gc
361
- gc.collect()
362
- except Exception as e:
363
- print(f"Error processing frame {out_frame_idx}: {e}")
364
- continue
365
-
366
- # For CPU optimization: increase stride to reduce processing
367
- # Create a more aggressive stride to limit to fewer frames in output
368
- total_frames = len(video_segments)
369
- print(f"Total frames processed: {total_frames}")
370
-
371
- # Limit to max 50 frames for CPU processing
372
- max_output_frames = 50
373
- vis_frame_stride = max(1, total_frames // max_output_frames)
374
-
375
- # Get dimensions of the frames
376
- first_frame = session_state["all_frames"][0]
377
- h, w = first_frame.shape[:2]
378
-
379
- output_frames = []
380
- for out_frame_idx in range(0, total_frames, vis_frame_stride):
381
- if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
382
- continue
383
-
384
- try:
385
- frame = session_state["all_frames"][out_frame_idx]
386
- transparent_background = Image.fromarray(frame).convert("RGBA")
387
-
388
- # Get the mask and ensure it's the right size
389
- out_mask = video_segments[out_frame_idx][OBJ_ID]
390
-
391
- # Resize mask if dimensions don't match
392
- if out_mask.shape[:2] != (h, w):
393
- out_mask = cv2.resize(
394
- out_mask.astype(np.uint8),
395
- (w, h),
396
- interpolation=cv2.INTER_NEAREST
397
- ).astype(bool)
398
-
399
- mask_image = show_mask(out_mask)
400
-
401
- # Make sure mask has same dimensions as background
402
- if mask_image.size != transparent_background.size:
403
- mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
404
-
405
- output_frame = Image.alpha_composite(transparent_background, mask_image)
406
- output_frame = np.array(output_frame)
407
- output_frames.append(output_frame)
408
-
409
- # Clear memory periodically
410
- if len(output_frames) % 10 == 0:
411
- import gc
412
- gc.collect()
413
-
414
- except Exception as e:
415
- print(f"Error creating output frame {out_frame_idx}: {e}")
416
- continue
417
-
418
- # Create a video clip from the image sequence
419
- original_fps = get_video_fps(video_in)
420
- fps = original_fps
421
-
422
- # For CPU optimization - lower FPS if original is high
423
- if fps > 15:
424
- fps = 15 # Lower fps for CPU processing
425
-
426
- print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
427
- clip = ImageSequenceClip(output_frames, fps=fps)
428
-
429
- # Write the result to a file - use lower quality for CPU
430
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
431
- final_vid_output_path = f"output_video_{unique_id}.mp4"
432
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
433
-
434
- # Lower bitrate for CPU processing
435
- clip.write_videofile(
436
- final_vid_output_path,
437
- codec="libx264",
438
- bitrate="800k",
439
- threads=2, # Use fewer threads for CPU
440
- logger=None # Disable logger to reduce console output
441
- )
442
-
443
- # Free memory
444
- del video_segments
445
- del output_frames
446
- import gc
447
- gc.collect()
448
-
449
- return (
450
- gr.update(value=final_vid_output_path, visible=True),
451
- session_state,
452
- )
453
-
454
  except Exception as e:
455
- print(f"Error in propagate_to_all: {e}")
456
- return (
457
- gr.update(value=None, visible=False),
458
- session_state,
459
- )
460
 
 
 
461
 
462
- def update_ui():
463
- return gr.update(visible=True)
 
 
 
 
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  with gr.Blocks() as demo:
467
- session_state = gr.State(
468
- {
469
- "first_frame": None,
470
- "all_frames": None,
471
- "input_points": [],
472
- "input_labels": [],
473
- "inference_state": None,
474
- "frame_stride": 1,
475
- "scale_factor": 1.0,
476
- "original_dimensions": None,
477
- }
478
- )
479
-
480
- with gr.Column():
481
- # Title
482
- gr.Markdown(title)
483
- with gr.Row():
484
-
485
- with gr.Column():
486
- # Instructions
487
- gr.Markdown(description_p)
488
-
489
- with gr.Accordion("Input Video", open=True) as video_in_drawer:
490
- video_in = gr.Video(label="Input Video", format="mp4")
491
-
492
- with gr.Row():
493
- point_type = gr.Radio(
494
- label="point type",
495
- choices=["include", "exclude"],
496
- value="include",
497
- scale=2,
498
- )
499
- propagate_btn = gr.Button("Track", scale=1, variant="primary")
500
- clear_points_btn = gr.Button("Clear Points", scale=1)
501
- reset_btn = gr.Button("Reset", scale=1)
502
-
503
- points_map = gr.Image(
504
- label="Frame with Point Prompt", type="numpy", interactive=False
505
- )
506
-
507
- with gr.Column():
508
- gr.Markdown("# Try some of the examples below ⬇️")
509
- gr.Examples(
510
- examples=examples,
511
- inputs=[
512
- video_in,
513
- ],
514
- examples_per_page=5,
515
- )
516
-
517
- output_image = gr.Image(label="Reference Mask")
518
- output_video = gr.Video(visible=False)
519
-
520
- # When new video is uploaded
521
- video_in.upload(
522
- fn=preprocess_video_in,
523
- inputs=[
524
- video_in,
525
- session_state,
526
- ],
527
- outputs=[
528
- video_in_drawer, # Accordion to hide uploaded video player
529
- points_map, # Image component where we add new tracking points
530
- output_image,
531
- output_video,
532
- session_state,
533
- ],
534
- queue=False,
535
- )
536
-
537
- video_in.change(
538
- fn=preprocess_video_in,
539
- inputs=[
540
- video_in,
541
- session_state,
542
- ],
543
- outputs=[
544
- video_in_drawer, # Accordion to hide uploaded video player
545
- points_map, # Image component where we add new tracking points
546
- output_image,
547
- output_video,
548
- session_state,
549
- ],
550
- queue=False,
551
- )
552
-
553
- # triggered when we click on image to add new points
554
- points_map.select(
555
- fn=segment_with_points,
556
- inputs=[
557
- point_type, # "include" or "exclude"
558
- session_state,
559
- ],
560
- outputs=[
561
- points_map, # updated image with points
562
- output_image,
563
- session_state,
564
- ],
565
- queue=False,
566
- )
567
-
568
- # Clear every points clicked and added to the map
569
- clear_points_btn.click(
570
- fn=clear_points,
571
- inputs=session_state,
572
- outputs=[
573
- points_map,
574
- output_image,
575
- output_video,
576
- session_state,
577
- ],
578
- queue=False,
579
- )
580
-
581
- reset_btn.click(
582
- fn=reset,
583
- inputs=session_state,
584
- outputs=[
585
- video_in,
586
- video_in_drawer,
587
- points_map,
588
- output_image,
589
- output_video,
590
- session_state,
591
- ],
592
- queue=False,
593
- )
594
-
595
- propagate_btn.click(
596
- fn=update_ui,
597
- inputs=[],
598
- outputs=output_video,
599
- queue=False,
600
- ).then(
601
- fn=propagate_to_all,
602
- inputs=[
603
- video_in,
604
- session_state,
605
- ],
606
- outputs=[
607
- output_video,
608
- session_state,
609
- ],
610
- queue=True, # Use queue for CPU processing
611
- )
612
-
613
-
614
- demo.queue()
615
- demo.launch()
 
1
+ # The full rewritten version of the provided code with progress bar, error fixes, and proper Gradio integration
 
2
 
 
 
 
 
3
  import os
4
+ import copy
5
  import tempfile
6
+ from datetime import datetime
7
+ import gc
8
 
9
  import cv2
 
10
  import numpy as np
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
  import torch
14
+ import gradio as gr
15
  from moviepy.editor import ImageSequenceClip
 
 
 
 
 
 
 
 
 
16
 
17
+ from sam2.build_sam import build_sam2_video_predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Remove CUDA-related env var to force CPU-only mode
20
+ os.environ.pop("TORCH_CUDNN_SDPA_ENABLED", None)
21
 
22
+ # Config
23
  sam2_checkpoint = "checkpoints/edgetam.pt"
24
  model_cfg = "edgetam.yaml"
25
+ examples = [[f"examples/{vid}"] for vid in ["01_dog.mp4", "02_cups.mp4", "03_blocks.mp4", "04_coffee.mp4", "05_default_juggle.mp4"]]
26
+ OBJ_ID = 0
27
 
28
+ # Model loader
29
+ if os.path.exists(sam2_checkpoint) and os.path.exists(model_cfg):
30
+ try:
31
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
32
+ except Exception as e:
33
+ print("Error loading predictor:", e)
34
+ predictor = None
35
+ else:
36
+ print("Model files missing.")
 
 
 
 
 
 
 
 
 
 
37
  predictor = None
38
 
39
+ def get_fps(video_path):
 
40
  cap = cv2.VideoCapture(video_path)
41
+ if not cap.isOpened(): return 30.0
 
 
42
  fps = cap.get(cv2.CAP_PROP_FPS)
43
  cap.release()
44
  return fps
45
 
46
+ def reset(session):
47
+ if session["inference_state"]:
48
+ predictor.reset_state(session["inference_state"])
49
+ session.update({"input_points": [], "input_labels": [], "first_frame": None, "all_frames": None, "inference_state": None})
50
+ return None, gr.update(open=True), None, None, gr.update(value=None, visible=False), session
51
 
52
+ def clear_points(session):
53
+ session["input_points"] = []
54
+ session["input_labels"] = []
55
+ if session["inference_state"] and session["inference_state"].get("tracking_has_started"):
56
+ predictor.reset_state(session["inference_state"])
57
+ return session["first_frame"], None, gr.update(value=None, visible=False), session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def preprocess_video(video_path, session):
 
 
 
 
 
 
 
 
 
 
60
  cap = cv2.VideoCapture(video_path)
61
+ if not cap.isOpened(): return gr.update(open=True), None, None, gr.update(value=None, visible=False), session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
+ stride = max(1, total_frames // 300)
65
+ frames, first_frame = [], None
 
 
 
 
 
 
66
 
67
+ w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68
+ target_w = 640
69
+ scale = target_w / w if w > target_w else 1.0
70
+
71
+ frame_id = 0
72
  while True:
73
  ret, frame = cap.read()
74
+ if not ret: break
75
+ if frame_id % stride == 0:
76
+ if scale < 1.0:
77
+ frame = cv2.resize(frame, (int(w*scale), int(h*scale)))
 
 
 
 
78
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79
+ if first_frame is None: first_frame = frame
80
+ frames.append(frame)
81
+ frame_id += 1
 
 
 
 
 
 
82
  cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ session.update({"first_frame": first_frame, "all_frames": frames, "frame_stride": stride, "scale_factor": scale, "inference_state": predictor.init_state(video_path=video_path), "input_points": [], "input_labels": []})
85
+ return gr.update(open=False), first_frame, None, gr.update(value=None, visible=False), session
86
+
87
+ def show_mask(mask, obj_id=None):
88
+ cmap = plt.get_cmap("tab10")
89
+ color = np.array([*cmap(0 if obj_id is None else obj_id)[:3], 0.6])
90
+ h, w = mask.shape
91
+ mask_rgba = (mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255).astype(np.uint8)
92
+ proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
93
+ proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
94
+ return Image.fromarray(proper_mask, "RGBA")
95
+
96
+ def segment_with_points(ptype, session, evt):
97
+ session["input_points"].append(evt.index)
98
+ session["input_labels"].append(1 if ptype == "include" else 0)
99
+ first = session["first_frame"]
100
+ h, w = first.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ layer = np.zeros((h, w, 4), dtype=np.uint8)
103
+ for idx, pt in enumerate(session["input_points"]):
104
+ color = (0, 255, 0, 255) if session["input_labels"][idx] == 1 else (255, 0, 0, 255)
105
+ cv2.circle(layer, pt, int(min(w, h)*0.01), color, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ overlay = Image.alpha_composite(Image.fromarray(first).convert("RGBA"), Image.fromarray(layer, "RGBA"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  try:
110
+ _, _, logits = predictor.add_new_points(session["inference_state"], 0, OBJ_ID, np.array(session["input_points"]), np.array(session["input_labels"]))
111
+ mask = (logits[0] > 0.0).cpu().numpy()
112
+ mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
113
+ mask_img = show_mask(mask)
114
+ return overlay, Image.alpha_composite(Image.fromarray(first).convert("RGBA"), mask_img), session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  except Exception as e:
116
+ print("Segmentation error:", e)
117
+ return overlay, overlay, session
 
 
 
118
 
119
+ def propagate(video_in, session, progress=gr.Progress()):
120
+ if not session["input_points"] or not session["inference_state"]: return None, session
121
 
122
+ masks = {}
123
+ for i, (idxs, obj_ids, logits) in enumerate(predictor.propagate_in_video(session["inference_state"])):
124
+ try:
125
+ masks[idxs] = {oid: (logits[j] > 0.0).cpu().numpy() for j, oid in enumerate(obj_ids)}
126
+ progress(i / 300, desc=f"Tracking frame {idxs}")
127
+ except: continue
128
 
129
+ frames_out, stride = [], max(1, len(masks) // 50)
130
+ for i in range(0, len(masks), stride):
131
+ if i not in masks or OBJ_ID not in masks[i]: continue
132
+ try:
133
+ frame = session["all_frames"][i]
134
+ mask = masks[i][OBJ_ID]
135
+ h, w = frame.shape[:2]
136
+ mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
137
+ output = Image.alpha_composite(Image.fromarray(frame).convert("RGBA"), show_mask(mask))
138
+ frames_out.append(np.array(output))
139
+ except: continue
140
+
141
+ out_path = os.path.join(tempfile.gettempdir(), f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4")
142
+ fps = min(15, get_fps(video_in))
143
+ ImageSequenceClip(frames_out, fps=fps).write_videofile(out_path, codec="libx264", bitrate="800k", threads=2, logger=None)
144
+ gc.collect()
145
+ return gr.update(value=out_path, visible=True), session
146
 
147
  with gr.Blocks() as demo:
148
+ state = gr.State({"first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "frame_stride": 1, "scale_factor": 1.0, "original_dimensions": None})
149
+
150
+ gr.Markdown("<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>")
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ gr.Markdown("""<ol><li>Upload a video or use an example</li><li>Select 'include' or 'exclude' and click points</li><li>Click 'Track' to segment and track</li></ol>""")
155
+ drawer = gr.Accordion("Input Video", open=True)
156
+ with drawer:
157
+ video_in = gr.Video(label="Input Video", format="mp4")
158
+ ptype = gr.Radio(label="Point Type", choices=["include", "exclude"], value="include")
159
+ track_btn = gr.Button("Track", variant="primary")
160
+ clear_btn = gr.Button("Clear Points")
161
+ reset_btn = gr.Button("Reset")
162
+ points_map = gr.Image(label="Frame with Points", type="numpy", interactive=False)
163
+ with gr.Column():
164
+ gr.Markdown("# Try some examples ⬇️")
165
+ gr.Examples(examples, inputs=[video_in], examples_per_page=5)
166
+ output_img = gr.Image(label="Reference Mask")
167
+ output_vid = gr.Video(visible=False)
168
+
169
+ video_in.upload(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state])
170
+ video_in.change(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state])
171
+ points_map.select(segment_with_points, [ptype, state], [points_map, output_img, state])
172
+ clear_btn.click(clear_points, state, [points_map, output_img, output_vid, state])
173
+ reset_btn.click(reset, state, [video_in, drawer, points_map, output_img, output_vid, state])
174
+ track_btn.click(fn=propagate, inputs=[video_in, state], outputs=[output_vid, state])
175
+
176
+ if __name__ == '__main__':
177
+ demo.queue()
178
+ demo.launch()