File size: 4,899 Bytes
b14067d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import torch
import numpy as np
import math
import argparse
from decord import VideoReader
from diffusers import AutoencoderKLCogVideoX
from safetensors.torch import save_file
import tqdm
import random
def encode_video(video, vae):
video = video[None].permute(0, 2, 1, 3, 4).contiguous()
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
def add_dashed_rays_to_video(video_tensor, num_perp_samples=50, density_decay=0.075):
T, C, H, W = video_tensor.shape
max_length = int((H**2 + W**2) ** 0.5) + 10
center = torch.tensor([W / 2, H / 2])
theta = torch.rand(1).item() * 2 * math.pi
direction = torch.tensor([math.cos(theta), math.sin(theta)])
direction = direction / direction.norm()
d_perp = torch.tensor([-direction[1], direction[0]])
half_len = max(H, W) // 2
positions = torch.linspace(-half_len, half_len, num_perp_samples)
perp_coords = center[None, :] + positions[:, None] * d_perp[None, :]
x0, y0 = perp_coords[:, 0], perp_coords[:, 1]
steps = []
dist = 0
while dist < max_length:
steps.append(dist)
dist += 1.0 + density_decay * dist
steps = torch.tensor(steps)
S = len(steps)
dxdy = direction[None, :] * steps[:, None]
all_xy = perp_coords[:, None, :] + dxdy[None, :, :]
all_xy = all_xy.reshape(-1, 2)
all_x = all_xy[:, 0].round().long()
all_y = all_xy[:, 1].round().long()
valid = (0 <= all_x) & (all_x < W) & (0 <= all_y) & (all_y < H)
all_x = all_x[valid]
all_y = all_y[valid]
x0r = x0.round().long().clamp(0, W - 1)
y0r = y0.round().long().clamp(0, H - 1)
frame0 = video_tensor[0]
base_colors = frame0[:, y0r, x0r]
base_colors = base_colors.repeat_interleave(S, dim=1)[:, valid]
video_out = video_tensor.clone()
offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
for dxo, dyo in offsets:
ox = all_x + dxo
oy = all_y + dyo
inside = (0 <= ox) & (ox < W) & (0 <= oy) & (oy < H)
ox = ox[inside]
oy = oy[inside]
colors = base_colors[:, inside]
for c in range(C):
video_out[1:, c, oy, ox] = colors[c][None, :].expand(T - 1, -1)
return video_out
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKLCogVideoX.from_pretrained(args.pretrained_model_path, subfolder="vae")
vae.requires_grad_(False)
vae = vae.to(device, dtype=torch.float16)
masked_video_path = os.path.join(args.video_root, "masked_videos")
source_video_path = os.path.join(args.video_root, "videos")
joint_latent_path = os.path.join(args.video_root, "joint_latents")
os.makedirs(joint_latent_path, exist_ok=True)
all_video_names = sorted(os.listdir(source_video_path))
video_names = all_video_names[args.start_idx : args.end_idx]
for video_name in tqdm.tqdm(video_names, desc=f"GPU {args.gpu_id}"):
masked_video_file = os.path.join(masked_video_path, video_name)
source_video_file = os.path.join(source_video_path, video_name)
output_file = os.path.join(joint_latent_path, video_name.replace('.mp4', '.safetensors'))
if not os.path.exists(masked_video_file):
print(f"Skipping {video_name}, masked video not found.")
continue
if os.path.exists(output_file):
continue
try:
vr = VideoReader(source_video_file)
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = (video / 255.0) * 2 - 1
source_latent = encode_video(video, vae)
vr = VideoReader(masked_video_file)
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = (video / 255.0) * 2 - 1
video = add_dashed_rays_to_video(video)
masked_latent = encode_video(video, vae)
source_latent = source_latent.to("cpu")
masked_latent = masked_latent.to("cpu")
cated_latent = torch.cat([source_latent, masked_latent], dim=2)
save_file({'joint_latents': cated_latent}, output_file)
except Exception as e:
print(f"[GPU {args.gpu_id}] Error processing {video_name}: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_root", type=str, required=True)
parser.add_argument("--pretrained_model_path", type=str, required=True)
parser.add_argument("--start_idx", type=int, required=True)
parser.add_argument("--end_idx", type=int, required=True)
parser.add_argument("--gpu_id", type=int, required=True)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
main(args)
|