Spaces:
bla
/
Runtime error

File size: 20,590 Bytes
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
0b34400
9bc4638
 
 
 
 
 
0b34400
 
 
 
9bc4638
0b34400
9bc4638
 
 
 
 
 
 
 
 
 
0b34400
9bc4638
 
 
 
 
 
 
 
 
 
b950bc5
6e60611
 
b950bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc3a57
0b34400
9bc4638
 
 
0b34400
 
9bc4638
0b34400
9bc4638
 
e508568
 
 
 
 
0b34400
e508568
 
0b34400
e508568
0b34400
 
 
 
 
 
e508568
 
 
 
 
 
0b34400
 
e508568
0b34400
 
 
 
e508568
 
 
9bc4638
0b34400
9bc4638
0b34400
 
 
 
 
9bc4638
 
0b34400
9bc4638
 
0b34400
9bc4638
0b34400
 
 
 
 
9bc4638
 
0b34400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
0b34400
 
 
 
 
 
9bc4638
 
 
 
0b34400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
0b34400
1affb38
0b34400
 
 
 
 
 
9bc4638
 
 
 
 
0b34400
 
 
 
9bc4638
 
 
 
 
 
 
 
0b34400
 
9bc4638
 
 
 
 
0b34400
9bc4638
0b34400
b950bc5
 
 
9bc4638
0b34400
 
 
9bc4638
0b34400
 
9bc4638
 
 
0b34400
9bc4638
0b34400
 
 
 
 
 
9bc4638
 
0b34400
9bc4638
0b34400
9bc4638
0b34400
b950bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1affb38
0b34400
9bc4638
 
 
 
0b34400
9bc4638
 
0b34400
 
b950bc5
 
 
 
 
 
 
 
 
 
 
9bc4638
b950bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
0b34400
9bc4638
 
 
0b34400
 
9bc4638
 
1affb38
0b34400
1affb38
 
9bc4638
0b34400
b950bc5
0b34400
b950bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b34400
b950bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b34400
b950bc5
 
 
 
0b34400
b950bc5
 
 
 
 
 
9bc4638
 
0b34400
9bc4638
 
 
 
 
 
0b34400
 
 
 
 
 
 
 
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
0b34400
9bc4638
 
 
 
 
 
 
 
0b34400
 
 
9bc4638
 
0b34400
9bc4638
 
 
 
 
 
0b34400
 
 
 
9bc4638
0b34400
 
 
9bc4638
0b34400
9bc4638
 
0b34400
 
 
 
9bc4638
0b34400
 
 
 
 
9bc4638
0b34400
9bc4638
 
 
 
0b34400
 
 
 
e508568
0b34400
 
 
 
 
9bc4638
0b34400
9bc4638
 
0b34400
9bc4638
 
 
0b34400
 
9bc4638
 
0b34400
 
 
9bc4638
0b34400
9bc4638
 
0b34400
9bc4638
 
0b34400
9bc4638
0b34400
 
 
 
9bc4638
0b34400
9bc4638
 
 
 
0b34400
9bc4638
0b34400
 
 
 
 
 
9bc4638
0b34400
9bc4638
 
 
0b34400
9bc4638
0b34400
 
 
9bc4638
 
0b34400
 
9bc4638
 
0b34400
 
9bc4638
0b34400
9bc4638
 
 
0b34400
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
from datetime import datetime
import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import torch

from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

# Remove CUDA environment variables
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
    del os.environ["TORCH_CUDNN_SDPA_ENABLED"]

# Description
title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"

description_p = """# Instructions
                <ol>
                <li> Upload one video or click one example video</li>
                <li> Click 'include' point type, select the object to segment and track</li>
                <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
                <li> Click the 'Track' button to obtain the masked video </li>
                </ol>
              """

# examples - keeping fewer examples to reduce memory footprint
examples = [
    ["examples/01_dog.mp4"],
    ["examples/02_cups.mp4"],
    ["examples/03_blocks.mp4"],
    ["examples/04_coffee.mp4"],
    ["examples/05_default_juggle.mp4"],
]

OBJ_ID = 0

# Initialize model on CPU - add error handling for file paths
sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"

# Check if model files exist
def check_file_exists(filepath):
    import os
    exists = os.path.exists(filepath)
    if not exists:
        print(f"WARNING: File not found: {filepath}")
    return exists

# Verify files exist
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
try:
    # Load model with more careful error handling
    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
    print("predictor loaded on CPU")
except Exception as e:
    print(f"Error loading model: {e}")
    import traceback
    traceback.print_exc()
    # Still create a predictor variable to avoid NameError
    predictor = None

# Function to get video frame rate
def get_video_fps(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return 30.0  # Default fallback value
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps


def reset(session_state):
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None:
        predictor.reset_state(session_state["inference_state"])
    session_state["first_frame"] = None
    session_state["all_frames"] = None
    session_state["inference_state"] = None
    return (
        None,
        gr.update(open=True),
        None,
        None,
        gr.update(value=None, visible=False),
        session_state,
    )


def clear_points(session_state):
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
        predictor.reset_state(session_state["inference_state"])
    return (
        session_state["first_frame"],
        None,
        gr.update(value=None, visible=False),
        session_state,
    )


def preprocess_video_in(video_path, session_state):
    if video_path is None:
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            session_state,
        )

    # Read the first frame
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            session_state,
        )

    # For CPU optimization - determine video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Determine if we need to resize for CPU performance
    target_width = 640  # Target width for processing on CPU
    scale_factor = 1.0
    
    if frame_width > target_width:
        scale_factor = target_width / frame_width
        frame_width = target_width
        frame_height = int(frame_height * scale_factor)
    
    # Read frames - for CPU we'll be more selective about which frames to keep
    frame_number = 0
    first_frame = None
    all_frames = []
    
    # For CPU optimization, skip frames if video is too long
    frame_stride = 1
    if total_frames > 300:  # If more than 300 frames
        frame_stride = max(1, int(total_frames / 300))  # Process at most ~300 frames
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        if frame_number % frame_stride == 0:  # Process every frame_stride frames
            # Resize the frame if needed
            if scale_factor != 1.0:
                frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
                
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = np.array(frame)
            
            # Store the first frame
            if first_frame is None:
                first_frame = frame
            all_frames.append(frame)
        
        frame_number += 1

    cap.release()
    session_state["first_frame"] = copy.deepcopy(first_frame)
    session_state["all_frames"] = all_frames
    session_state["frame_stride"] = frame_stride
    session_state["scale_factor"] = scale_factor
    session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), 
                                          int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

    session_state["inference_state"] = predictor.init_state(video_path=video_path)
    session_state["input_points"] = []
    session_state["input_labels"] = []

    return [
        gr.update(open=False),  # video_in_drawer
        first_frame,  # points_map
        None,  # output_image
        gr.update(value=None, visible=False),  # output_video
        session_state,
    ]


def segment_with_points(
    point_type,
    session_state,
    evt: gr.SelectData,
):
    session_state["input_points"].append(evt.index)
    print(f"TRACKING INPUT POINT: {session_state['input_points']}")

    if point_type == "include":
        session_state["input_labels"].append(1)
    elif point_type == "exclude":
        session_state["input_labels"].append(0)
    print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")

    # Open the image and get its dimensions
    first_frame = session_state["first_frame"]
    h, w = first_frame.shape[:2]
    transparent_background = Image.fromarray(first_frame).convert("RGBA")

    # Define the circle radius as a fraction of the smaller dimension
    fraction = 0.01  # You can adjust this value as needed
    radius = int(fraction * min(w, h))

    # Create a transparent layer to draw on
    transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)

    for index, track in enumerate(session_state["input_points"]):
        if session_state["input_labels"][index] == 1:
            cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
        else:
            cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)

    # Convert the transparent layer back to an image
    transparent_layer = Image.fromarray(transparent_layer, "RGBA")
    selected_point_map = Image.alpha_composite(
        transparent_background, transparent_layer
    )

    # Let's add a positive click at (x, y) = (210, 350) to get started
    points = np.array(session_state["input_points"], dtype=np.float32)
    # for labels, `1` means positive click and `0` means negative click
    labels = np.array(session_state["input_labels"], np.int32)
    
    try:
        # For CPU optimization, we'll process with smaller batch size
        _, _, out_mask_logits = predictor.add_new_points(
            inference_state=session_state["inference_state"],
            frame_idx=0,
            obj_id=OBJ_ID,
            points=points,
            labels=labels,
        )
        
        # Create the mask
        mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
        
        # Ensure the mask has the same size as the frame
        if mask_array.shape[:2] != (h, w):
            mask_array = cv2.resize(
                mask_array.astype(np.uint8), 
                (w, h), 
                interpolation=cv2.INTER_NEAREST
            ).astype(bool)
        
        mask_image = show_mask(mask_array)
        
        # Make sure mask_image has the same size as the background
        if mask_image.size != transparent_background.size:
            mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
            
        first_frame_output = Image.alpha_composite(transparent_background, mask_image)
    except Exception as e:
        print(f"Error in segmentation: {e}")
        # Return just the points as fallback
        first_frame_output = selected_point_map

    return selected_point_map, first_frame_output, session_state


def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    
    # Handle different mask shapes properly
    if len(mask.shape) == 2:
        h, w = mask.shape
    else:
        h, w = mask.shape[-2:]
    
    # Ensure correct reshaping based on mask dimensions
    mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    mask_rgba = (mask_reshaped * 255).astype(np.uint8)
    
    if convert_to_image:
        try:
            # Ensure the mask has correct RGBA shape (h, w, 4)
            if mask_rgba.shape[2] != 4:
                # If not RGBA, create a proper RGBA array
                proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
                # Copy available channels
                proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
                mask_rgba = proper_mask
            
            # Create the PIL image
            return Image.fromarray(mask_rgba, "RGBA")
        except Exception as e:
            print(f"Error converting mask to image: {e}")
            # Fallback: create a blank transparent image of correct size
            blank = np.zeros((h, w, 4), dtype=np.uint8)
            return Image.fromarray(blank, "RGBA")
    
    return mask_rgba


def propagate_to_all(
    video_in,
    session_state,
):
    if (
        len(session_state["input_points"]) == 0
        or video_in is None
        or session_state["inference_state"] is None
    ):
        return (
            None,
            session_state,
        )

    # For CPU optimization: process in smaller batches
    chunk_size = 3  # Process 3 frames at a time to avoid memory issues on CPU
    
    try:
        # run propagation throughout the video and collect the results in a dict
        video_segments = {}  # video_segments contains the per-frame segmentation results
        print("starting propagate_in_video on CPU")
        
        # Get the frames in chunks for CPU memory optimization
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            session_state["inference_state"]
        ):
            try:
                # Store the masks for each object ID
                video_segments[out_frame_idx] = {
                    out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                    for i, out_obj_id in enumerate(out_obj_ids)
                }
                
                print(f"Processed frame {out_frame_idx}")
                
                # Release memory periodically
                if out_frame_idx % chunk_size == 0:
                    # Explicitly clear any tensors
                    del out_mask_logits
                    import gc
                    gc.collect()
            except Exception as e:
                print(f"Error processing frame {out_frame_idx}: {e}")
                continue

        # For CPU optimization: increase stride to reduce processing
        # Create a more aggressive stride to limit to fewer frames in output
        total_frames = len(video_segments)
        print(f"Total frames processed: {total_frames}")
        
        # Limit to max 50 frames for CPU processing
        max_output_frames = 50
        vis_frame_stride = max(1, total_frames // max_output_frames)
        
        # Get dimensions of the frames
        first_frame = session_state["all_frames"][0]
        h, w = first_frame.shape[:2]
        
        output_frames = []
        for out_frame_idx in range(0, total_frames, vis_frame_stride):
            if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
                continue
                
            try:
                frame = session_state["all_frames"][out_frame_idx]
                transparent_background = Image.fromarray(frame).convert("RGBA")
                
                # Get the mask and ensure it's the right size
                out_mask = video_segments[out_frame_idx][OBJ_ID]
                
                # Resize mask if dimensions don't match
                if out_mask.shape[:2] != (h, w):
                    out_mask = cv2.resize(
                        out_mask.astype(np.uint8), 
                        (w, h), 
                        interpolation=cv2.INTER_NEAREST
                    ).astype(bool)
                
                mask_image = show_mask(out_mask)
                
                # Make sure mask has same dimensions as background
                if mask_image.size != transparent_background.size:
                    mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
                
                output_frame = Image.alpha_composite(transparent_background, mask_image)
                output_frame = np.array(output_frame)
                output_frames.append(output_frame)
                
                # Clear memory periodically
                if len(output_frames) % 10 == 0:
                    import gc
                    gc.collect()
                    
            except Exception as e:
                print(f"Error creating output frame {out_frame_idx}: {e}")
                continue

        # Create a video clip from the image sequence
        original_fps = get_video_fps(video_in)
        fps = original_fps
        
        # For CPU optimization - lower FPS if original is high
        if fps > 15:
            fps = 15  # Lower fps for CPU processing
        
        print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
        clip = ImageSequenceClip(output_frames, fps=fps)
        
        # Write the result to a file - use lower quality for CPU
        unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
        final_vid_output_path = f"output_video_{unique_id}.mp4"
        final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)

        # Lower bitrate for CPU processing
        clip.write_videofile(
            final_vid_output_path, 
            codec="libx264", 
            bitrate="800k",
            threads=2,  # Use fewer threads for CPU
            logger=None   # Disable logger to reduce console output
        )
        
        # Free memory
        del video_segments
        del output_frames
        import gc
        gc.collect()

        return (
            gr.update(value=final_vid_output_path, visible=True),
            session_state,
        )
    
    except Exception as e:
        print(f"Error in propagate_to_all: {e}")
        return (
            gr.update(value=None, visible=False),
            session_state,
        )


def update_ui():
    return gr.update(visible=True)


with gr.Blocks() as demo:
    session_state = gr.State(
        {
            "first_frame": None,
            "all_frames": None,
            "input_points": [],
            "input_labels": [],
            "inference_state": None,
            "frame_stride": 1,
            "scale_factor": 1.0,
            "original_dimensions": None,
        }
    )

    with gr.Column():
        # Title
        gr.Markdown(title)
        with gr.Row():

            with gr.Column():
                # Instructions
                gr.Markdown(description_p)

                with gr.Accordion("Input Video", open=True) as video_in_drawer:
                    video_in = gr.Video(label="Input Video", format="mp4")

                with gr.Row():
                    point_type = gr.Radio(
                        label="point type",
                        choices=["include", "exclude"],
                        value="include",
                        scale=2,
                    )
                    propagate_btn = gr.Button("Track", scale=1, variant="primary")
                    clear_points_btn = gr.Button("Clear Points", scale=1)
                    reset_btn = gr.Button("Reset", scale=1)

                points_map = gr.Image(
                    label="Frame with Point Prompt", type="numpy", interactive=False
                )

            with gr.Column():
                gr.Markdown("# Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[
                        video_in,
                    ],
                    examples_per_page=5,
                )
                
                output_image = gr.Image(label="Reference Mask")
                output_video = gr.Video(visible=False)

    # When new video is uploaded
    video_in.upload(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            session_state,
        ],
        queue=False,
    )

    video_in.change(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            session_state,
        ],
        queue=False,
    )

    # triggered when we click on image to add new points
    points_map.select(
        fn=segment_with_points,
        inputs=[
            point_type,  # "include" or "exclude"
            session_state,
        ],
        outputs=[
            points_map,  # updated image with points
            output_image,
            session_state,
        ],
        queue=False,
    )

    # Clear every points clicked and added to the map
    clear_points_btn.click(
        fn=clear_points,
        inputs=session_state,
        outputs=[
            points_map,
            output_image,
            output_video,
            session_state,
        ],
        queue=False,
    )

    reset_btn.click(
        fn=reset,
        inputs=session_state,
        outputs=[
            video_in,
            video_in_drawer,
            points_map,
            output_image,
            output_video,
            session_state,
        ],
        queue=False,
    )

    propagate_btn.click(
        fn=update_ui,
        inputs=[],
        outputs=output_video,
        queue=False,
    ).then(
        fn=propagate_to_all,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            output_video,
            session_state,
        ],
        queue=True,  # Use queue for CPU processing
    )


demo.queue()
demo.launch()