File size: 2,631 Bytes
b14067d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import random
import json
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
import numpy as np
from torch.utils.data.dataset import Dataset
from packaging import version as pver
from decord import VideoReader
from safetensors.torch import load_file
class RealEstate10KPCDRenderLatentCapEmbDataset(Dataset):
def __init__(
self,
video_root_dir,
text_embedding_path
):
root_path = video_root_dir
self.root_path = root_path
self.latent_root = os.path.join(self.root_path, 'joint_latents')
self.source_video_root = os.path.join(self.root_path, 'videos')
self.captions_root = os.path.join(self.root_path, 'captions')
self.dataset = sorted([n.replace('.safetensors','') for n in os.listdir(self.latent_root)])
self.length = len(self.dataset)
self.text_embedding_path = text_embedding_path
self.mask_root = os.path.join(self.root_path, 'masks')
def get_batch(self, idx):
clip_name = self.dataset[idx]
cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt')
video_caption_emb = torch.load(cap_emb_path, weights_only=True)
joint_latent_path = os.path.join(self.latent_root, clip_name + '.safetensors')
joint_latent = load_file(joint_latent_path, device='cpu')['joint_latent']
video_reader = VideoReader(os.path.join(self.source_video_root, clip_name + '.mp4'))
indices = [0]
first_frame = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
first_frame = (first_frame / 255.)*2-1
T = joint_latent.shape[2] // 2
source_latent = joint_latent[:, :, :T]
anchor_latent = joint_latent[:, :, T:]
masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0
masks = torch.from_numpy(masks).unsqueeze(1)
return source_latent, anchor_latent, first_frame, masks, video_caption_emb, clip_name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
source_latent, anchor_latent, image, mask, video_caption_emb, clip_name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length - 1)
data = {
'source_latent': source_latent,
'anchor_latent': anchor_latent,
'image': image,
'caption_emb': video_caption_emb,
'mask': mask
}
return data |