import os import random import json import torch import torch.nn as nn import torchvision.transforms.functional as F import numpy as np from torch.utils.data.dataset import Dataset from packaging import version as pver from decord import VideoReader from safetensors.torch import load_file class RealEstate10KPCDRenderLatentCapEmbDataset(Dataset): def __init__( self, video_root_dir, text_embedding_path ): root_path = video_root_dir self.root_path = root_path self.latent_root = os.path.join(self.root_path, 'joint_latents') self.source_video_root = os.path.join(self.root_path, 'videos') self.captions_root = os.path.join(self.root_path, 'captions') self.dataset = sorted([n.replace('.safetensors','') for n in os.listdir(self.latent_root)]) self.length = len(self.dataset) self.text_embedding_path = text_embedding_path self.mask_root = os.path.join(self.root_path, 'masks') def get_batch(self, idx): clip_name = self.dataset[idx] cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt') video_caption_emb = torch.load(cap_emb_path, weights_only=True) joint_latent_path = os.path.join(self.latent_root, clip_name + '.safetensors') joint_latent = load_file(joint_latent_path, device='cpu')['joint_latent'] video_reader = VideoReader(os.path.join(self.source_video_root, clip_name + '.mp4')) indices = [0] first_frame = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() first_frame = (first_frame / 255.)*2-1 T = joint_latent.shape[2] // 2 source_latent = joint_latent[:, :, :T] anchor_latent = joint_latent[:, :, T:] masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0 masks = torch.from_numpy(masks).unsqueeze(1) return source_latent, anchor_latent, first_frame, masks, video_caption_emb, clip_name def __len__(self): return self.length def __getitem__(self, idx): while True: try: source_latent, anchor_latent, image, mask, video_caption_emb, clip_name = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length - 1) data = { 'source_latent': source_latent, 'anchor_latent': anchor_latent, 'image': image, 'caption_emb': video_caption_emb, 'mask': mask } return data