import json import torch import cv2 from typing import Any, Dict, List, Optional, Tuple from torch.utils.data import DataLoader, Dataset import torchvision.transforms as TT from torchvision import transforms from torchvision.transforms.functional import center_crop, resize from torchvision.transforms import InterpolationMode import numpy as np import random, os try: import decord except ImportError: raise ImportError( "The `decord` package is required for loading the video dataset. Install with `pip install decord`" ) decord.bridge.set_bridge("torch") class ImageVideoDataset(Dataset): def __init__( self, root_path, annotation_json, tokenizer, max_sequence_length: int = 226, height: int = 480, width: int = 640, video_reshape_mode: str = "center", fps: int = 8, stripe: int = 2, max_num_frames: int = 49, skip_frames_start: int = 0, skip_frames_end: int = 0, random_flip: Optional[float] = None, ) -> None: super().__init__() self.root_path = root_path with open(annotation_json, 'r') as f: self.data_list = json.load(f) self.tokenizer = tokenizer self.max_sequence_length = max_sequence_length self.height = height self.width = width self.video_reshape_mode = video_reshape_mode self.fps = fps self.max_num_frames = max_num_frames self.skip_frames_start = skip_frames_start self.skip_frames_end = skip_frames_end self.stripe = stripe self.video_transforms = transforms.Compose( [ transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(lambda x: x), transforms.Lambda(lambda x: x / 255.0), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) def __len__(self): return len(self.data_list) def _resize_for_rectangle_crop(self, arr): image_size = self.height, self.width reshape_mode = self.video_reshape_mode if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], interpolation=InterpolationMode.BICUBIC, ) else: arr = resize( arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], interpolation=InterpolationMode.BICUBIC, ) h, w = arr.shape[2], arr.shape[3] arr = arr.squeeze(0) delta_h = h - image_size[0] delta_w = w - image_size[1] if reshape_mode == "random" or reshape_mode == "none": top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) elif reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: raise NotImplementedError arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) return arr def __getitem__(self, index): while True: try: video_path = os.path.join(self.root_path, self.data_list[index]['clip_path']) video_reader = decord.VideoReader(video_path, width=self.width, height=self.height) video_num_frames = len(video_reader) # print(video_num_frames, video_reader.get_avg_fps()) if self.stripe * self.max_num_frames > video_num_frames: stripe = 1 else: stripe = self.stripe random_range = video_num_frames - stripe * self.max_num_frames - 1 random_range = max(1, random_range) start_frame = random.randint(1, random_range) if random_range > 0 else 1 indices = list(range(start_frame, start_frame + stripe * self.max_num_frames, stripe)) # (end_frame - start_frame) // self.max_num_frames)) frames = video_reader.get_batch(indices) # Ensure that we don't go over the limit frames = frames[: self.max_num_frames] selected_num_frames = frames.shape[0] # Choose first (4k + 1) frames as this is how many is required by the VAE remainder = (3 + (selected_num_frames % 4)) % 4 if remainder != 0: frames = frames[:-remainder] selected_num_frames = frames.shape[0] assert (selected_num_frames - 1) % 4 == 0 if selected_num_frames == self.max_num_frames: break else: index = (index + 1) % len(self.data_list) continue except Exception as e: index = (index + 1) % len(self.data_list) print(video_num_frames, start_frame, indices) print( "Error encounter during audio feature extraction: ", e, ) continue # Training transforms # frames = (frames - 127.5) / 127.5 frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W] frames = self._resize_for_rectangle_crop(frames) frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) text_inputs = self.tokenizer( [self.data_list[index]['caption']], padding="max_length", max_length=self.max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids[0] return frames.contiguous(), text_input_ids class AutoEncoderDataset(ImageVideoDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __getitem__(self, index): while True: try: video_path = os.path.join(self.root_path, self.data_list[index]['clip_path']) video_reader = decord.VideoReader(video_path, width=self.width, height=self.height) video_num_frames = len(video_reader) # print(video_num_frames, video_reader.get_avg_fps()) if self.stripe * self.max_num_frames > video_num_frames: stripe = 1 else: stripe = self.stripe random_indice = [random.randint(1, video_num_frames - 1)] # random selects a frame from the video frames = video_reader.get_batch(random_indice) break except Exception as e: print("[WARN] Get problem when loading video: ", self.data_list[index]['clip_path']) print( "Error encounter during audio feature extraction: ", e, ) index = random.randint(0, len(self.data_list) - 1) continue return frames class LvisDataset(Dataset): def __init__( self, root_path, annotation_json, height: int = 480, width: int = 640, random_flip: Optional[float] = None, ) -> None: super().__init__() self.root_path = root_path with open(annotation_json, 'r') as f: self.data_list = json.load(f)['images'] self.height = height self.width = width self.width = width self.video_transforms = transforms.Compose( [ transforms.Lambda(lambda x: x / 255.0), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) def __len__(self): return len(self.data_list) def __getitem__(self, index): image_path = os.path.join(self.root_path, "unlabeled2017", self.data_list[index]['file_name']) image = cv2.imread(image_path) image = cv2.resize(image, (self.width, self.height)) image = self.video_transforms(torch.from_numpy(image).permute(2, 0, 1)) return image.contiguous()