Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import decord | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from decord import VideoReader, cpu | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import Compose, Lambda, ToTensor | |
| from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo | |
| from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample | |
| from torch.nn import functional as F | |
| import random | |
| from opensora.utils.dataset_utils import DecordInit | |
| class UCF101(Dataset): | |
| def __init__(self, args, transform, temporal_sample): | |
| self.data_path = args.data_path | |
| self.num_frames = args.num_frames | |
| self.transform = transform | |
| self.temporal_sample = temporal_sample | |
| self.v_decoder = DecordInit() | |
| self.classes = sorted(os.listdir(self.data_path)) | |
| self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} | |
| self.samples = self._make_dataset() | |
| def _make_dataset(self): | |
| dataset = [] | |
| for class_name in self.classes: | |
| class_path = os.path.join(self.data_path, class_name) | |
| for fname in os.listdir(class_path): | |
| if fname.endswith('.avi'): | |
| item = (os.path.join(class_path, fname), self.class_to_idx[class_name]) | |
| dataset.append(item) | |
| return dataset | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| video_path, label = self.samples[idx] | |
| try: | |
| video = self.tv_read(video_path) | |
| video = self.transform(video) # T C H W -> T C H W | |
| video = video.transpose(0, 1) # T C H W -> C T H W | |
| return video, label | |
| except Exception as e: | |
| print(f'Error with {e}, {video_path}') | |
| return self.__getitem__(random.randint(0, self.__len__()-1)) | |
| def tv_read(self, path): | |
| vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') | |
| total_frames = len(vframes) | |
| # Sampling video frames | |
| start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| # assert end_frame_ind - start_frame_ind >= self.num_frames | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) | |
| video = vframes[frame_indice] # (T, C, H, W) | |
| return video | |
| def decord_read(self, path): | |
| decord_vr = self.v_decoder(path) | |
| total_frames = len(decord_vr) | |
| # Sampling video frames | |
| start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| # assert end_frame_ind - start_frame_ind >= self.num_frames | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) | |
| video_data = decord_vr.get_batch(frame_indice).asnumpy() | |
| video_data = torch.from_numpy(video_data) | |
| video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) | |
| return video_data | |