EPiC-fps / training /controlnet_datasets_camera_pcd_mask.py
roll-ai's picture
Upload 161 files
b14067d verified
import os
import random
import json
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import numpy as np
from decord import VideoReader
from torch.utils.data.dataset import Dataset
from packaging import version as pver
class RandomHorizontalFlipWithPose(nn.Module):
def __init__(self, p=0.5):
super(RandomHorizontalFlipWithPose, self).__init__()
self.p = p
def get_flip_flag(self, n_image):
return torch.rand(n_image) < self.p
def forward(self, image, flip_flag=None):
n_image = image.shape[0]
if flip_flag is not None:
assert n_image == flip_flag.shape[0]
else:
flip_flag = self.get_flip_flag(n_image)
ret_images = []
for fflag, img in zip(flip_flag, image):
if fflag:
ret_images.append(F.hflip(img))
else:
ret_images.append(img)
return torch.stack(ret_images, dim=0)
class RealEstate10KPCDRenderDataset(Dataset):
def __init__(
self,
video_root_dir,
sample_n_frames=49,
image_size=[480, 720],
shuffle_frames=False,
hflip_p=0.0,
):
if hflip_p != 0.0:
use_flip = True
else:
use_flip = False
root_path = video_root_dir
self.root_path = root_path
self.sample_n_frames = sample_n_frames
self.source_video_root = os.path.join(self.root_path, 'videos')
self.mask_video_root = os.path.join(self.root_path, 'masked_videos')
self.captions_root = os.path.join(self.root_path, 'captions')
self.dataset = sorted([n.replace('.mp4','') for n in os.listdir(self.source_video_root)])
self.length = len(self.dataset)
sample_size = image_size
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.sample_size = sample_size
if use_flip:
pixel_transforms = [transforms.Resize(sample_size),
RandomHorizontalFlipWithPose(hflip_p),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
else:
pixel_transforms = [transforms.Resize(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
self.sample_wh_ratio = sample_size[1] / sample_size[0]
self.pixel_transforms = pixel_transforms
self.shuffle_frames = shuffle_frames
self.use_flip = use_flip
def load_video_reader(self, idx):
clip_name = self.dataset[idx]
video_path = os.path.join(self.source_video_root, clip_name + '.mp4')
video_reader = VideoReader(video_path)
mask_video_path = os.path.join(self.mask_video_root, clip_name + '.mp4')
mask_video_reader = VideoReader(mask_video_path)
caption_path = os.path.join(self.captions_root, clip_name + '.txt')
if os.path.exists(caption_path):
caption = open(caption_path, 'r').read().strip()
else:
caption = ''
return clip_name, video_reader, mask_video_reader, caption
def get_batch(self, idx):
clip_name, video_reader, mask_video_reader, video_caption = self.load_video_reader(idx)
if self.use_flip:
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
else:
flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool)
indices = np.arange(self.sample_n_frames)
pixel_values = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
anchor_pixels = torch.from_numpy(mask_video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
anchor_pixels = anchor_pixels / 255.
return pixel_values, anchor_pixels, video_caption, flip_flag, clip_name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
video, anchor_video, video_caption, flip_flag, clip_name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length - 1)
if self.use_flip:
video = self.pixel_transforms[0](video)
video = self.pixel_transforms[1](video, flip_flag)
video = self.pixel_transforms[2](video)
anchor_video = self.pixel_transforms[0](anchor_video)
anchor_video = self.pixel_transforms[1](anchor_video, flip_flag)
anchor_video = self.pixel_transforms[2](anchor_video)
else:
for transform in self.pixel_transforms:
video = transform(video)
anchor_video = transform(anchor_video)
data = {
'video': video,
'anchor_video': anchor_video,
'caption': video_caption,
}
return data
class RealEstate10KPCDRenderCapEmbDataset(RealEstate10KPCDRenderDataset):
def __init__(
self,
video_root_dir,
text_embedding_path,
sample_n_frames=49,
image_size=[480, 720],
shuffle_frames=False,
hflip_p=0.0,
):
super().__init__(
video_root_dir,
sample_n_frames=sample_n_frames,
image_size=image_size,
shuffle_frames=shuffle_frames,
hflip_p=hflip_p,
)
self.text_embedding_path = text_embedding_path
self.mask_root = os.path.join(self.root_path, 'masks')
def get_batch(self, idx):
clip_name, video_reader, mask_video_reader, video_caption = self.load_video_reader(idx)
cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt')
video_caption_emb = torch.load(cap_emb_path, weights_only=True)
if self.use_flip:
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
else:
flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool)
indices = np.arange(self.sample_n_frames)
pixel_values = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
anchor_pixels = torch.from_numpy(mask_video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
anchor_pixels = anchor_pixels / 255.
try:
masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0
masks = torch.from_numpy(masks).unsqueeze(1)
except:
threshold = 0.1 # you can adjust this value
masks = (anchor_pixels.sum(dim=1, keepdim=True) < threshold).float()
return pixel_values, anchor_pixels, masks, video_caption_emb, flip_flag, clip_name
def __getitem__(self, idx):
while True:
try:
video, anchor_video, mask, video_caption_emb, flip_flag, clip_name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length - 1)
if self.use_flip:
video = self.pixel_transforms[0](video)
video = self.pixel_transforms[1](video, flip_flag)
video = self.pixel_transforms[2](video)
anchor_video = self.pixel_transforms[0](anchor_video)
anchor_video = self.pixel_transforms[1](anchor_video, flip_flag)
anchor_video = self.pixel_transforms[2](anchor_video)
else:
for transform in self.pixel_transforms:
video = transform(video)
anchor_video = transform(anchor_video)
data = {
'video': video,
'anchor_video': anchor_video,
'caption_emb': video_caption_emb,
'mask': mask
}
return data