Spaces:
bla
/
Runtime error

File size: 31,099 Bytes
9bc4638
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
 
 
9bc4638
 
 
 
 
5bc3a57
 
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e60611
 
5bc3a57
eafda84
5bc3a57
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
5bc3a57
 
 
 
9bc4638
 
5bc3a57
9bc4638
 
e508568
9bc4638
 
e508568
 
 
 
 
 
 
 
 
5bc3a57
e508568
 
5bc3a57
 
e508568
 
5bc3a57
e508568
 
 
 
 
 
 
5bc3a57
e508568
 
 
 
 
 
5bc3a57
 
 
e508568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
 
 
 
 
 
 
 
 
e508568
 
5bc3a57
 
 
e508568
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
9bc4638
1affb38
 
 
 
 
e508568
9bc4638
e508568
 
 
 
 
1affb38
9bc4638
 
 
 
1affb38
9bc4638
e508568
 
1affb38
e508568
 
1affb38
9bc4638
 
 
 
 
 
 
 
 
 
 
1affb38
e508568
9bc4638
 
 
1affb38
 
 
e508568
 
1affb38
e508568
 
1affb38
 
 
5bc3a57
 
1affb38
e508568
9bc4638
 
5bc3a57
2a466e4
5bc3a57
9bc4638
e508568
9bc4638
 
5bc3a57
 
 
 
 
 
e508568
9bc4638
 
 
5bc3a57
9bc4638
 
 
 
 
1affb38
e508568
1affb38
 
5bc3a57
1affb38
e508568
 
1affb38
 
 
e508568
1affb38
 
 
 
9bc4638
 
 
 
 
 
1affb38
 
 
9bc4638
1affb38
 
 
9bc4638
1affb38
 
9bc4638
1affb38
9bc4638
1affb38
 
e508568
9bc4638
e508568
9bc4638
e508568
1affb38
 
 
 
fa0b563
1affb38
 
9bc4638
 
5bc3a57
9bc4638
 
 
5bc3a57
 
e508568
 
1affb38
5bc3a57
fa0b563
1affb38
fa0b563
1affb38
 
 
 
 
 
 
9bc4638
1affb38
e508568
 
 
1affb38
 
 
 
 
 
fa0b563
1affb38
 
 
 
fa0b563
1affb38
5bc3a57
 
 
1affb38
 
9bc4638
 
 
1affb38
 
 
 
2a466e4
 
1affb38
9bc4638
1affb38
9bc4638
 
1affb38
 
 
 
 
2a466e4
1affb38
 
 
2a466e4
1affb38
2a466e4
1affb38
2a466e4
1affb38
 
 
 
 
 
fa0b563
 
 
1affb38
 
 
 
9bc4638
1affb38
 
 
 
9bc4638
 
5bc3a57
9bc4638
e508568
9bc4638
 
1affb38
 
 
e508568
 
9bc4638
1affb38
 
9bc4638
e508568
9bc4638
1affb38
9bc4638
1affb38
9bc4638
 
 
1affb38
 
 
2a466e4
1affb38
 
 
 
 
 
2a466e4
 
e508568
2a466e4
1affb38
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
1affb38
2a466e4
 
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
 
 
1affb38
 
 
 
 
 
 
9bc4638
 
e508568
 
1affb38
 
 
 
 
 
 
 
 
 
9bc4638
1affb38
 
 
 
 
 
 
 
 
9bc4638
1affb38
5bc3a57
1affb38
5bc3a57
 
 
 
 
 
 
 
 
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
5bc3a57
1affb38
9bc4638
 
 
 
1affb38
9bc4638
 
1affb38
 
 
 
 
 
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
1affb38
9bc4638
 
 
 
 
 
 
1affb38
9bc4638
1affb38
 
 
 
9bc4638
1affb38
 
9bc4638
fa0b563
1affb38
5bc3a57
1affb38
 
 
 
9bc4638
 
 
 
 
 
1affb38
9bc4638
1affb38
 
5bc3a57
1affb38
 
 
fa0b563
1affb38
 
 
 
 
 
9bc4638
 
1affb38
 
 
 
 
9bc4638
1affb38
e508568
9bc4638
 
1affb38
9bc4638
e508568
 
9bc4638
1affb38
9bc4638
 
1affb38
e508568
9bc4638
 
1affb38
e508568
 
 
9bc4638
1affb38
9bc4638
 
1affb38
 
9bc4638
 
 
1affb38
 
9bc4638
 
1affb38
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
e508568
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
9bc4638
1affb38
 
 
9bc4638
 
2a466e4
 
9bc4638
 
1affb38
e508568
9bc4638
5bc3a57
 
 
 
9bc4638
 
 
1affb38
5bc3a57
1affb38
5bc3a57
9bc4638
1affb38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
from datetime import datetime

import gradio as gr

# Removed GPU-specific environment variable setting
# os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"

import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
# Removed spaces decorator import for CPU-only demo
# import spaces # Removed spaces import
import torch

from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

# Description
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"

description_p = """# Instructions
                <ol>
                <li> Upload one video or click one example video</li>
                <li> Click 'include' point type, select the object to segment and track</li>
                <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
                <li> Click the 'Track' button to obtain the masked video </li>
                </ol>
              """

# examples - Keep examples, they are input files
examples = [
    ["examples/01_dog.mp4"],
    ["examples/02_cups.mp4"],
    ["examples/03_blocks.mp4"],
    ["examples/04_coffee.mp4"],
    ["examples/05_default_juggle.mp4"],
    ["examples/01_breakdancer.mp4"],
    ["examples/02_hummingbird.mp4"],
    ["examples/03_skateboarder.mp4"],
    ["examples/04_octopus.mp4"],
    ["examples/05_landing_dog_soccer.mp4"],
    ["examples/06_pingpong.mp4"],
    ["examples/07_snowboarder.mp4"],
    ["examples/08_driving.mp4"],
    ["examples/09_birdcartoon.mp4"],
    ["examples/10_cloth_magic.mp4"],
    ["examples/11_polevault.mp4"],
    ["examples/12_hideandseek.mp4"],
    ["examples/13_butterfly.mp4"],
    ["examples/14_social_dog_training.mp4"],
    ["examples/15_cricket.mp4"],
    ["examples/16_robotarm.mp4"],
    ["examples/17_childrendancing.mp4"],
    ["examples/18_threedogs.mp4"],
    ["examples/19_cyclist.mp4"],
    ["examples/20_doughkneading.mp4"],
    ["examples/21_biker.mp4"],
    ["examples/22_dogskateboarder.mp4"],
    ["examples/23_racecar.mp4"],
    ["examples/24_clownfish.mp4"],
]

OBJ_ID = 0

sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"
# Ensure predictor is explicitly built for CPU
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
# Removed .to("cuda") - predictor is already on CPU from build_sam2_video_predictor
# predictor.to("cuda")
print("predictor loaded on CPU")

# Removed CUDA specific autocast and backend settings
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
#     torch.backends.cuda.matmul.allow_tf32 = True
#     torch.backends.cudnn.allow_tf32 = True
# elif not torch.cuda.is_available():
#     print("Warning: CUDA not available. Running on CPU.")


def get_video_fps(video_path):
    """Gets the frames per second of a video file."""
    if video_path is None or not os.path.exists(video_path):
         print(f"Warning: Video file not found at {video_path}")
         return None
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        return None
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release() # Release the capture object
    return fps


def reset(session_state):
    """Resets the UI and session state."""
    print("Resetting demo.")
    session_state["input_points"] = []
    session_state["input_labels"] = []
    # Reset the predictor state if it exists
    if session_state["inference_state"] is not None:
        try:
            # Assuming predictor.reset_state handles clearing current masks/features
            predictor.reset_state(session_state["inference_state"])
             # Explicitly delete or re-init the state object if a full reset is intended
             # This depends on how predictor.reset_state works. Setting to None is safest for a full reset.
            session_state["inference_state"] = None
        except Exception as e:
             print(f"Error resetting predictor state: {e}")
             session_state["inference_state"] = None # Force-clear on error

    session_state["first_frame"] = None
    session_state["all_frames"] = None
    session_state["inference_state"] = None # Ensure state is None after a full reset
    # Also reset video path if stored
    session_state["video_path"] = None

    # Resetting UI components and disabling buttons
    return (
        None, # video_in (clears the video player)
        gr.update(open=True), # video_in_drawer (opens accordion)
        None, # points_map (clears the image)
        None, # output_image (clears the image)
        gr.update(value=None, visible=False), # output_video (hides and clears)
        gr.update(interactive=False), # propagate_btn disabled
        gr.update(interactive=False), # clear_points_btn disabled
        gr.update(interactive=False), # reset_btn disabled
        session_state, # return updated session state
    )


def clear_points(session_state):
    """Clears selected points and resets segmentation on the first frame."""
    print("Clearing points.")
    session_state["input_points"] = []
    session_state["input_labels"] = []

    # Reset the predictor state to clear internal masks/features
    # This typically doesn't remove the video context, just the mask predictions
    if session_state["inference_state"] is not None:
        try:
            # Assuming reset_state handles clearing current masks/features
            predictor.reset_state(session_state["inference_state"])
            print("Predictor state reset for clearing points.")
            # If you need to re-initialize the state for the *same* video after clearing points,
            # you might need to call predictor.init_state again here, using the stored video_path.
            # Since we are on CPU, device="cpu" is implicit now.
            if session_state["video_path"] is not None:
                 session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
                 print("Predictor state re-initialized after clearing points.")
            else:
                 print("Warning: Could not re-initialize state after clear_points (video_path missing).")
                 session_state["inference_state"] = None # Ensure state is None if video_path is gone


        except Exception as e:
             print(f"Error resetting predictor state during clear_points: {e}")
             # If reset fails, this might leave old masks. Force-clear state on error.
             session_state["inference_state"] = None


    # Return the original first frame image for points_map and clear the output_image
    first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None

    return (
        first_frame_img, # points_map shows original first frame (no points yet)
        None, # output_image cleared (no mask)
        gr.update(value=None, visible=False), # output_video hidden
        session_state, # return updated session state
    )


# Removed @spaces.GPU decorator
def preprocess_video_in(video_path, session_state):
    """Loads video frames and initializes the predictor state."""
    print(f"Processing video: {video_path}")
    if video_path is None or not os.path.exists(video_path):
        print("No video path provided or file not found.")
        # Reset state and UI elements if input is invalid
        # Need to return updates for the buttons as well
        return (
            gr.update(open=True), None, None, gr.update(value=None, visible=False),
            gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
             { # Reset session state
                "first_frame": None, "all_frames": None, "input_points": [],
                "input_labels": [], "inference_state": None, "video_path": None,
            }
        )

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        return (
            gr.update(open=True), None, None, gr.update(value=None, visible=False),
            gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
            { # Reset session state
                "first_frame": None, "all_frames": None, "input_points": [],
                "input_labels": [], "inference_state": None, "video_path": None,
            }
        )

    first_frame = None
    all_frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        all_frames.append(frame)
        if first_frame is None:
            first_frame = frame

    cap.release()

    if not all_frames:
        print(f"Error: No frames read from video file {video_path}.")
        return (
            gr.update(open=True), None, None, gr.update(value=None, visible=False),
            gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
            { # Reset session state
                "first_frame": None, "all_frames": None, "input_points": [],
                "input_labels": [], "inference_state": None, "video_path": None,
            }
        )

    # Update session state with frames and path
    session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
    session_state["all_frames"] = all_frames
    session_state["video_path"] = video_path # Store video path
    session_state["input_points"] = []
    session_state["input_labels"] = []
    # Initialize state WITHOUT the device argument (uses predictor's device, which is CPU)
    session_state["inference_state"] = predictor.init_state(video_path=video_path)
    print("Video loaded and predictor state initialized on CPU.")

    # Enable buttons after successful load
    return [
        gr.update(open=False),  # video_in_drawer
        first_frame,  # points_map (shows first frame)
        None,  # output_image (cleared initially)
        gr.update(value=None, visible=False),  # output_video (hidden initially)
        gr.update(interactive=True), # propagate_btn enabled
        gr.update(interactive=True), # clear_points_btn enabled
        gr.update(interactive=True), # reset_btn enabled
        session_state, # session_state
    ]


# Removed @spaces.GPU decorator
def segment_with_points(
    point_type,
    session_state,
    evt: gr.SelectData,
):
    """Adds a point prompt and performs segmentation on the first frame."""
    # Ensure we have state and first frame
    if session_state["first_frame"] is None or session_state["inference_state"] is None:
         print("Error: Cannot segment. No video loaded or inference state missing.")
         # Return current states to avoid errors, without changing UI much
         return (
             session_state.get("first_frame"), # points_map (show first frame if exists)
             None, # output_image (keep cleared)
             session_state,
         )

    # evt.index is the (x, y) coordinate tuple
    click_coords = evt.index
    print(f"Clicked at: {click_coords} ({point_type})")

    session_state["input_points"].append(click_coords)

    if point_type == "include":
        session_state["input_labels"].append(1)
    elif point_type == "exclude":
        session_state["input_labels"].append(0)

    # Get the first frame as a PIL image for drawing
    first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA")
    w, h = first_frame_pil.size

    # Define the circle radius
    fraction = 0.01
    radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2

    # Create a transparent layer to draw points
    transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8)

    # Draw points on the transparent layer
    for index, track in enumerate(session_state["input_points"]):
        # Ensure coordinates are integers for cv2.circle
        point_coords = (int(track[0]), int(track[1]))
        # Ensure color is RGBA (0-255)
        if session_state["input_labels"][index] == 1:
            cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include
        else:
            cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude

    # Convert the transparent layer back to an image and composite onto the first frame
    transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
    # Combine the first frame image with the points layer for the points_map output
    # points_map shows the first frame *with the points you added*.
    selected_point_map_img = Image.alpha_composite(
        first_frame_pil.copy(), transparent_layer_points_pil
    )

    # Prepare points and labels as tensors on the correct device (CPU in this version)
    points = np.array(session_state["input_points"], dtype=np.float32)
    labels = np.array(session_state["input_labels"], np.int32)

    # Ensure tensors are on the correct device (CPU)
    device = next(predictor.parameters()).device # Get the device the model is on (should be "cpu")
    points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim
    labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim


    first_frame_output_img = None # Initialize output mask image as None in case of error
    try:
        # Note: predictor.add_new_points modifies the internal inference_state
        _, _, out_mask_logits = predictor.add_new_points(
            inference_state=session_state["inference_state"],
            frame_idx=0, # Always segment on the first frame initially
            obj_id=OBJ_ID,
            points=points_tensor,
            labels=labels_tensor,
        )

        # Process logits: detach from graph, move to CPU, apply threshold
        # out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id
        # Access the result for the first object (index 0) and the first item in batch (index 0)
        mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy
        mask_numpy = mask_tensor.numpy() # Convert to numpy

        # Get the mask image (RGBA)
        mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image

        # Composite the mask onto the first frame for the output_image
        # output_image shows the first frame *with the segmentation mask result*.
        first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)

    except Exception as e:
        print(f"Error during segmentation on first frame: {e}")
        # On error, first_frame_output_img remains None

    # Removed CUDA cache clearing call
    # if torch.cuda.is_available():
    #     torch.cuda.empty_cache()

    return selected_point_map_img, first_frame_output_img, session_state


def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
    """Helper function to visualize a mask."""
    # Ensure mask is a numpy array (and boolean)
    if isinstance(mask, torch.Tensor):
         mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
    # Convert potential float/int mask to boolean mask
    mask = mask.astype(bool)

    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors
        color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha

    # Ensure mask has H, W dimensions
    if mask.ndim == 3:
        mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
    if mask.ndim != 2:
        print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
        # Create an empty transparent image if mask shape is unexpected
        h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
        if convert_to_image:
             return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
        else:
             return np.zeros((h, w, 4), dtype=np.uint8)

    h, w = mask.shape
    # Create an RGBA image from the mask and color
    # Apply color where mask is True
    # Need to reshape color to be broadcastable [1, 1, 4]
    colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
    # Apply the color only where the mask is True.
    # This directly creates the colored overlay with transparency.
    colored_mask[mask] = color

    # Convert to uint8 [0-255]
    colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)

    if convert_to_image:
        mask_img = Image.fromarray(colored_mask_uint8, "RGBA")
        return mask_img
    else:
        return colored_mask_uint8


# Removed @spaces.GPU decorator
def propagate_to_all(
    video_in, # Keep video_in path as in original
    session_state,
):
    """Runs mask propagation through the video and generates the output video."""
    print("Starting propagation...")
    # Ensure state is ready
    # Using session_state.get("video_path") is safer than video_in directly
    current_video_path = session_state.get("video_path")
    if (
        len(session_state["input_points"]) == 0 # Need at least one point
        or session_state["all_frames"] is None
        or session_state["inference_state"] is None
        or current_video_path is None # Ensure we have the original video path
    ):
        print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )

    # run propagation throughout the video and collect the results
    video_segments = {}
    try:
        # This loop performs the core tracking prediction frame by frame
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            session_state["inference_state"]
        ):
            # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
             # Ensure tensor is on CPU before converting to numpy
             video_segments[out_frame_idx] = {
                 # out_mask_logits is a list of tensors (one per object tracked in this frame)
                 # Each tensor is [batch_size, H, W]. Batch size is 1 here.
                 # Access the result for the first object (index i) and the first item in batch (index 0)
                 out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
                 for i, out_obj_id in enumerate(out_obj_ids)
             }
             # Optional: print progress
             # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}")

        print("Propagation finished.")
    except Exception as e:
        print(f"Error during propagation: {e}")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )


    output_frames = []
    # Iterate through all original frames to generate output video
    total_frames = len(session_state["all_frames"])
    for out_frame_idx in range(total_frames):
        original_frame_rgb = session_state["all_frames"][out_frame_idx]
        # Convert original frame to RGBA for compositing
        transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")

        # Check if we have a mask for this frame and object ID
        if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]:
            current_mask_numpy = video_segments[out_frame_idx][OBJ_ID]
            # Get the mask image (RGBA)
            mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID)
            # Composite the mask onto the frame
            output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil)
            # Convert back to numpy RGB (moviepy needs RGB or RGBA)
            output_frame_np = np.array(output_frame_img_rgba.convert("RGB"))
        else:
             # If no mask for this frame/object, just use the original frame (converted to RGB)
             # Note: all_frames are already RGB numpy arrays, so just use them directly.
             # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.")
             output_frame_np = original_frame_rgb # Already RGB numpy array

        output_frames.append(output_frame_np)

    # Removed CUDA cache clearing call
    # if torch.cuda.is_available():
    #     torch.cuda.empty_cache()

    # Define output path in a temporary directory
    unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
    final_vid_filename = f"output_video_{unique_id}.mp4"
    final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
    print(f"Output video path: {final_vid_output_path}")


    # Create a video clip from the image sequence
    # Get original FPS from the stored video path
    original_fps = get_video_fps(current_video_path)
    fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
    print(f"Creating output video with FPS: {fps}")

    # Check if there are frames to process
    if not output_frames:
         print("No output frames generated.")
         return (
            gr.update(value=None, visible=False), # Hide output video
            session_state,
         )

    # Create ImageSequenceClip from the list of numpy arrays
    try:
        clip = ImageSequenceClip(output_frames, fps=fps)
    except Exception as e:
        print(f"Error creating ImageSequenceClip: {e}")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )

    # Write the result to a file. Use 'libx264' codec for broad compatibility.
    # Added CPU optimization parameters for moviepy write
    try:
        print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'")
        clip.write_videofile(
            final_vid_output_path,
            codec="libx264",
            fps=fps, # Ensure correct FPS is used during writing
            preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
            threads="auto", # CPU optimization: Use multiple cores
            logger=None # Suppress moviepy output
        )
        print("Video writing complete.")
        # Return the path and make the video player visible
        return (
            gr.update(value=final_vid_output_path, visible=True),
            session_state,
        )
    except Exception as e:
        print(f"Error writing video file: {e}")
        # Clean up potentially created partial file
        if os.path.exists(final_vid_output_path):
             try:
                 os.remove(final_vid_output_path)
                 print(f"Removed partial video file: {final_vid_output_path}")
             except Exception as clean_e:
                 print(f"Error removing partial file: {clean_e}")

        # Return None if writing fails
        return (
            gr.update(value=None, visible=False),
            session_state,
        )


def update_output_video_visibility():
    """Simply returns a Gradio update to make the output video visible."""
    return gr.update(visible=True)


with gr.Blocks() as demo:
    # Session state dictionary to hold video frames, points, labels, and predictor state
    session_state = gr.State(
        {
            "first_frame": None, # numpy array (RGB)
            "all_frames": None,  # list of numpy arrays (RGB)
            "input_points": [],  # list of (x, y) tuples/lists
            "input_labels": [],  # list of 1s and 0s
            "inference_state": None, # EdgeTAM predictor state object
            "video_path": None, # Store the input video path
        }
    )

    with gr.Column():
        # Title
        gr.Markdown(title)
        with gr.Row():

            with gr.Column():
                # Instructions
                gr.Markdown(description_p)

                with gr.Accordion("Input Video", open=True) as video_in_drawer:
                    video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path

                with gr.Row():
                    point_type = gr.Radio(
                        label="point type",
                        choices=["include", "exclude"],
                        value="include",
                        scale=2,
                        interactive=True, # Make interactive
                    )
                    # Buttons are initially disabled until a video is loaded
                    propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False)
                    clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False)
                    reset_btn = gr.Button("Reset", scale=1, interactive=False)

                # points_map is where users click to add points. Needs to be interactive.
                # Shows the first frame with points drawn on it.
                points_map = gr.Image(
                    label="Click on the First Frame to Add Points", # Clearer label
                    type="numpy",
                    interactive=True, # <--- CHANGED TO True to enable clicking
                    height=400, # Set a fixed height for better UI
                    width="auto", # Let width adjust
                    show_share_button=False,
                    show_download_button=False,
                )

            with gr.Column():
                gr.Markdown("# Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[video_in],
                    examples_per_page=8,
                    cache_examples=False, # Do not cache processed examples, as state is involved
                )
                # Removed extra blank lines

                # output_image shows the segmentation mask prediction on the *first* frame
                output_image = gr.Image(
                    label="Segmentation Mask on First Frame", # Clearer label
                    type="numpy",
                    interactive=False, # Not interactive, just displays the mask
                    height=400, # Match height of points_map
                    width="auto", # Let width adjust
                    show_share_button=False,
                    show_download_button=False,
                )

                # output_video shows the final tracking result
                output_video = gr.Video(visible=False, label="Tracking Result")


    # --- Event Handlers ---

    # When a new video file is uploaded via the file browser
    # Added postprocess to update button interactivity based on whether video loaded
    video_in.upload(
        fn=preprocess_video_in,
        inputs=[video_in, session_state],
        outputs=[
            video_in_drawer, points_map, output_image, output_video,
            propagate_btn, clear_points_btn, reset_btn, session_state,
        ],
        queue=False, # Process immediately
    )

    # When an example video is selected (change event)
    # Added postprocess to update button interactivity
    video_in.change(
        fn=preprocess_video_in,
        inputs=[video_in, session_state],
        outputs=[
            video_in_drawer, points_map, output_image, output_video,
            propagate_btn, clear_points_btn, reset_btn, session_state,
        ],
        queue=False, # Process immediately
    )


    # Triggered when a user clicks on the points_map image
    points_map.select(
        fn=segment_with_points,
        inputs=[
            point_type,  # "include" or "exclude" radio button value
            session_state, # Pass session state
        ],
        outputs=[
            points_map,      # Updated image with points drawn
            output_image,    # Updated image with first frame segmentation mask
            session_state,   # Updated session state (points/labels added)
        ],
        queue=False, # Process clicks immediately
    )

    # Button to clear all selected points and reset the first frame mask
    clear_points_btn.click(
        fn=clear_points,
        inputs=[session_state], # Pass session state
        outputs=[
            points_map,    # points_map shows original first frame without points
            output_image,  # output_image cleared (or shows original first frame without mask)
            output_video,  # Hide output video
            session_state, # Updated session state (points/labels cleared, inference state reset)
        ],
        queue=False, # Process immediately
    )

    # Button to reset the entire demo state and UI
    reset_btn.click(
        fn=reset,
        inputs=[session_state], # Pass session state
        outputs=[
            video_in, video_in_drawer, points_map, output_image, output_video,
            propagate_btn, clear_points_btn, reset_btn, session_state,
        ],
        queue=False, # Process immediately
    )

    # Button to start mask propagation through the video
    propagate_btn.click(
        fn=update_output_video_visibility, # First, make the output video player visible
        inputs=[],
        outputs=[output_video],
        queue=False, # Process this UI update immediately
    ).then( # Then, run the propagation function
        fn=propagate_to_all,
        inputs=[
            video_in,      # Get the input video path (can also get from session_state["video_path"])
            session_state, # Pass session state (contains frames, points, inference_state, video_path)
        ],
        outputs=[
            output_video,  # Update output video player with result
            session_state, # Update session state
        ],
        # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion.
        # Queue=True ensures requests wait if another is processing.
        concurrency_limit=1,
        queue=True,
    )


# Launch the Gradio demo
demo.queue() # Enable queuing for sequential processing under concurrency limits
print("Gradio demo starting...")
# Removed share=True for local debugging unless you specifically need a public link
demo.launch()
print("Gradio demo launched.")