EPiC-fps / preprocess /get_vae_latent.py
roll-ai's picture
Upload 161 files
b14067d verified
raw
history blame
4.9 kB
import os
import torch
import numpy as np
import math
import argparse
from decord import VideoReader
from diffusers import AutoencoderKLCogVideoX
from safetensors.torch import save_file
import tqdm
import random
def encode_video(video, vae):
video = video[None].permute(0, 2, 1, 3, 4).contiguous()
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
def add_dashed_rays_to_video(video_tensor, num_perp_samples=50, density_decay=0.075):
T, C, H, W = video_tensor.shape
max_length = int((H**2 + W**2) ** 0.5) + 10
center = torch.tensor([W / 2, H / 2])
theta = torch.rand(1).item() * 2 * math.pi
direction = torch.tensor([math.cos(theta), math.sin(theta)])
direction = direction / direction.norm()
d_perp = torch.tensor([-direction[1], direction[0]])
half_len = max(H, W) // 2
positions = torch.linspace(-half_len, half_len, num_perp_samples)
perp_coords = center[None, :] + positions[:, None] * d_perp[None, :]
x0, y0 = perp_coords[:, 0], perp_coords[:, 1]
steps = []
dist = 0
while dist < max_length:
steps.append(dist)
dist += 1.0 + density_decay * dist
steps = torch.tensor(steps)
S = len(steps)
dxdy = direction[None, :] * steps[:, None]
all_xy = perp_coords[:, None, :] + dxdy[None, :, :]
all_xy = all_xy.reshape(-1, 2)
all_x = all_xy[:, 0].round().long()
all_y = all_xy[:, 1].round().long()
valid = (0 <= all_x) & (all_x < W) & (0 <= all_y) & (all_y < H)
all_x = all_x[valid]
all_y = all_y[valid]
x0r = x0.round().long().clamp(0, W - 1)
y0r = y0.round().long().clamp(0, H - 1)
frame0 = video_tensor[0]
base_colors = frame0[:, y0r, x0r]
base_colors = base_colors.repeat_interleave(S, dim=1)[:, valid]
video_out = video_tensor.clone()
offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
for dxo, dyo in offsets:
ox = all_x + dxo
oy = all_y + dyo
inside = (0 <= ox) & (ox < W) & (0 <= oy) & (oy < H)
ox = ox[inside]
oy = oy[inside]
colors = base_colors[:, inside]
for c in range(C):
video_out[1:, c, oy, ox] = colors[c][None, :].expand(T - 1, -1)
return video_out
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKLCogVideoX.from_pretrained(args.pretrained_model_path, subfolder="vae")
vae.requires_grad_(False)
vae = vae.to(device, dtype=torch.float16)
masked_video_path = os.path.join(args.video_root, "masked_videos")
source_video_path = os.path.join(args.video_root, "videos")
joint_latent_path = os.path.join(args.video_root, "joint_latents")
os.makedirs(joint_latent_path, exist_ok=True)
all_video_names = sorted(os.listdir(source_video_path))
video_names = all_video_names[args.start_idx : args.end_idx]
for video_name in tqdm.tqdm(video_names, desc=f"GPU {args.gpu_id}"):
masked_video_file = os.path.join(masked_video_path, video_name)
source_video_file = os.path.join(source_video_path, video_name)
output_file = os.path.join(joint_latent_path, video_name.replace('.mp4', '.safetensors'))
if not os.path.exists(masked_video_file):
print(f"Skipping {video_name}, masked video not found.")
continue
if os.path.exists(output_file):
continue
try:
vr = VideoReader(source_video_file)
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = (video / 255.0) * 2 - 1
source_latent = encode_video(video, vae)
vr = VideoReader(masked_video_file)
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = (video / 255.0) * 2 - 1
video = add_dashed_rays_to_video(video)
masked_latent = encode_video(video, vae)
source_latent = source_latent.to("cpu")
masked_latent = masked_latent.to("cpu")
cated_latent = torch.cat([source_latent, masked_latent], dim=2)
save_file({'joint_latents': cated_latent}, output_file)
except Exception as e:
print(f"[GPU {args.gpu_id}] Error processing {video_name}: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_root", type=str, required=True)
parser.add_argument("--pretrained_model_path", type=str, required=True)
parser.add_argument("--start_idx", type=int, required=True)
parser.add_argument("--end_idx", type=int, required=True)
parser.add_argument("--gpu_id", type=int, required=True)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
main(args)