import os import random import json import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.transforms.functional as F import numpy as np from decord import VideoReader from torch.utils.data.dataset import Dataset from packaging import version as pver class RandomHorizontalFlipWithPose(nn.Module): def __init__(self, p=0.5): super(RandomHorizontalFlipWithPose, self).__init__() self.p = p def get_flip_flag(self, n_image): return torch.rand(n_image) < self.p def forward(self, image, flip_flag=None): n_image = image.shape[0] if flip_flag is not None: assert n_image == flip_flag.shape[0] else: flip_flag = self.get_flip_flag(n_image) ret_images = [] for fflag, img in zip(flip_flag, image): if fflag: ret_images.append(F.hflip(img)) else: ret_images.append(img) return torch.stack(ret_images, dim=0) class RealEstate10KPCDRenderDataset(Dataset): def __init__( self, video_root_dir, sample_n_frames=49, image_size=[480, 720], shuffle_frames=False, hflip_p=0.0, ): if hflip_p != 0.0: use_flip = True else: use_flip = False root_path = video_root_dir self.root_path = root_path self.sample_n_frames = sample_n_frames self.source_video_root = os.path.join(self.root_path, 'videos') self.mask_video_root = os.path.join(self.root_path, 'masked_videos') self.captions_root = os.path.join(self.root_path, 'captions') self.dataset = sorted([n.replace('.mp4','') for n in os.listdir(self.source_video_root)]) self.length = len(self.dataset) sample_size = image_size sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) self.sample_size = sample_size if use_flip: pixel_transforms = [transforms.Resize(sample_size), RandomHorizontalFlipWithPose(hflip_p), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] else: pixel_transforms = [transforms.Resize(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] self.sample_wh_ratio = sample_size[1] / sample_size[0] self.pixel_transforms = pixel_transforms self.shuffle_frames = shuffle_frames self.use_flip = use_flip def load_video_reader(self, idx): clip_name = self.dataset[idx] video_path = os.path.join(self.source_video_root, clip_name + '.mp4') video_reader = VideoReader(video_path) mask_video_path = os.path.join(self.mask_video_root, clip_name + '.mp4') mask_video_reader = VideoReader(mask_video_path) caption_path = os.path.join(self.captions_root, clip_name + '.txt') if os.path.exists(caption_path): caption = open(caption_path, 'r').read().strip() else: caption = '' return clip_name, video_reader, mask_video_reader, caption def get_batch(self, idx): clip_name, video_reader, mask_video_reader, video_caption = self.load_video_reader(idx) if self.use_flip: flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames) else: flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool) indices = np.arange(self.sample_n_frames) pixel_values = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. anchor_pixels = torch.from_numpy(mask_video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() anchor_pixels = anchor_pixels / 255. return pixel_values, anchor_pixels, video_caption, flip_flag, clip_name def __len__(self): return self.length def __getitem__(self, idx): while True: try: video, anchor_video, video_caption, flip_flag, clip_name = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length - 1) if self.use_flip: video = self.pixel_transforms[0](video) video = self.pixel_transforms[1](video, flip_flag) video = self.pixel_transforms[2](video) anchor_video = self.pixel_transforms[0](anchor_video) anchor_video = self.pixel_transforms[1](anchor_video, flip_flag) anchor_video = self.pixel_transforms[2](anchor_video) else: for transform in self.pixel_transforms: video = transform(video) anchor_video = transform(anchor_video) data = { 'video': video, 'anchor_video': anchor_video, 'caption': video_caption, } return data class RealEstate10KPCDRenderCapEmbDataset(RealEstate10KPCDRenderDataset): def __init__( self, video_root_dir, text_embedding_path, sample_n_frames=49, image_size=[480, 720], shuffle_frames=False, hflip_p=0.0, ): super().__init__( video_root_dir, sample_n_frames=sample_n_frames, image_size=image_size, shuffle_frames=shuffle_frames, hflip_p=hflip_p, ) self.text_embedding_path = text_embedding_path self.mask_root = os.path.join(self.root_path, 'masks') def get_batch(self, idx): clip_name, video_reader, mask_video_reader, video_caption = self.load_video_reader(idx) cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt') video_caption_emb = torch.load(cap_emb_path, weights_only=True) if self.use_flip: flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames) else: flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool) indices = np.arange(self.sample_n_frames) pixel_values = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. anchor_pixels = torch.from_numpy(mask_video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() anchor_pixels = anchor_pixels / 255. try: masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0 masks = torch.from_numpy(masks).unsqueeze(1) except: threshold = 0.1 # you can adjust this value masks = (anchor_pixels.sum(dim=1, keepdim=True) < threshold).float() return pixel_values, anchor_pixels, masks, video_caption_emb, flip_flag, clip_name def __getitem__(self, idx): while True: try: video, anchor_video, mask, video_caption_emb, flip_flag, clip_name = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length - 1) if self.use_flip: video = self.pixel_transforms[0](video) video = self.pixel_transforms[1](video, flip_flag) video = self.pixel_transforms[2](video) anchor_video = self.pixel_transforms[0](anchor_video) anchor_video = self.pixel_transforms[1](anchor_video, flip_flag) anchor_video = self.pixel_transforms[2](anchor_video) else: for transform in self.pixel_transforms: video = transform(video) anchor_video = transform(anchor_video) data = { 'video': video, 'anchor_video': anchor_video, 'caption_emb': video_caption_emb, 'mask': mask } return data