File size: 4,899 Bytes
b14067d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import numpy as np
import math
import argparse
from decord import VideoReader
from diffusers import AutoencoderKLCogVideoX
from safetensors.torch import save_file
import tqdm
import random

def encode_video(video, vae):
    video = video[None].permute(0, 2, 1, 3, 4).contiguous()
    video = video.to(vae.device, dtype=vae.dtype)
    latent_dist = vae.encode(video).latent_dist
    latent = latent_dist.sample() * vae.config.scaling_factor
    return latent

def add_dashed_rays_to_video(video_tensor, num_perp_samples=50, density_decay=0.075):
    T, C, H, W = video_tensor.shape
    max_length = int((H**2 + W**2) ** 0.5) + 10
    center = torch.tensor([W / 2, H / 2])
    theta = torch.rand(1).item() * 2 * math.pi
    direction = torch.tensor([math.cos(theta), math.sin(theta)])
    direction = direction / direction.norm()
    d_perp = torch.tensor([-direction[1], direction[0]])
    half_len = max(H, W) // 2
    positions = torch.linspace(-half_len, half_len, num_perp_samples)
    perp_coords = center[None, :] + positions[:, None] * d_perp[None, :]
    x0, y0 = perp_coords[:, 0], perp_coords[:, 1]
    steps = []
    dist = 0
    while dist < max_length:
        steps.append(dist)
        dist += 1.0 + density_decay * dist
    steps = torch.tensor(steps)
    S = len(steps)
    dxdy = direction[None, :] * steps[:, None]
    all_xy = perp_coords[:, None, :] + dxdy[None, :, :]
    all_xy = all_xy.reshape(-1, 2)
    all_x = all_xy[:, 0].round().long()
    all_y = all_xy[:, 1].round().long()
    valid = (0 <= all_x) & (all_x < W) & (0 <= all_y) & (all_y < H)
    all_x = all_x[valid]
    all_y = all_y[valid]
    x0r = x0.round().long().clamp(0, W - 1)
    y0r = y0.round().long().clamp(0, H - 1)
    frame0 = video_tensor[0]
    base_colors = frame0[:, y0r, x0r]
    base_colors = base_colors.repeat_interleave(S, dim=1)[:, valid]
    video_out = video_tensor.clone()
    offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
    for dxo, dyo in offsets:
        ox = all_x + dxo
        oy = all_y + dyo
        inside = (0 <= ox) & (ox < W) & (0 <= oy) & (oy < H)
        ox = ox[inside]
        oy = oy[inside]
        colors = base_colors[:, inside]
        for c in range(C):
            video_out[1:, c, oy, ox] = colors[c][None, :].expand(T - 1, -1)
    return video_out

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = AutoencoderKLCogVideoX.from_pretrained(args.pretrained_model_path, subfolder="vae")
    vae.requires_grad_(False)
    vae = vae.to(device, dtype=torch.float16)

    masked_video_path = os.path.join(args.video_root, "masked_videos")
    source_video_path = os.path.join(args.video_root, "videos")
    joint_latent_path = os.path.join(args.video_root, "joint_latents")
    os.makedirs(joint_latent_path, exist_ok=True)

    all_video_names = sorted(os.listdir(source_video_path))
    video_names = all_video_names[args.start_idx : args.end_idx]

    for video_name in tqdm.tqdm(video_names, desc=f"GPU {args.gpu_id}"):
        masked_video_file = os.path.join(masked_video_path, video_name)
        source_video_file = os.path.join(source_video_path, video_name)
        output_file = os.path.join(joint_latent_path, video_name.replace('.mp4', '.safetensors'))

        if not os.path.exists(masked_video_file):
            print(f"Skipping {video_name}, masked video not found.")
            continue
        if os.path.exists(output_file):
            continue

        try:
            vr = VideoReader(source_video_file)
            video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
            video = (video / 255.0) * 2 - 1
            source_latent = encode_video(video, vae)

            vr = VideoReader(masked_video_file)
            video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
            video = (video / 255.0) * 2 - 1
            video = add_dashed_rays_to_video(video)
            masked_latent = encode_video(video, vae)

            source_latent = source_latent.to("cpu")
            masked_latent = masked_latent.to("cpu")
            cated_latent = torch.cat([source_latent, masked_latent], dim=2)
            save_file({'joint_latents': cated_latent}, output_file)

        except Exception as e:
            print(f"[GPU {args.gpu_id}] Error processing {video_name}: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--video_root", type=str, required=True)
    parser.add_argument("--pretrained_model_path", type=str, required=True)
    parser.add_argument("--start_idx", type=int, required=True)
    parser.add_argument("--end_idx", type=int, required=True)
    parser.add_argument("--gpu_id", type=int, required=True)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    main(args)