eawolf2357-git / Dataset /dummy_datasets.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
import os
import warnings
import glob
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision
import torch.distributed as dist
from decord import VideoReader
class DummyDataset(Dataset):
def __init__(
self,
# width=1024, height=576,
sample_frames=25,
base_folder='data/samples/',
file_list=None,
temporal_sample=None,
transform=None,
seed=42,
):
"""
Args:
num_samples (int): Number of samples in the dataset.
channels (int): Number of channels, default is 3 for RGB.
"""
# Define the path to the folder containing video frames
# self.base_folder = 'bdd100k/images/track/mini'
self.base_folder = base_folder
self.file_list = file_list
if file_list is None:
self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4'))
else:
# read from file_list.txt
self.video_lists = []
with open(file_list, 'r') as f:
for line in f:
video_path = line.strip()
self.video_lists.append(os.path.join(self.base_folder, video_path))
self.num_samples = len(self.video_lists)
self.channels = 3
# self.width = width
# self.height = height
self.sample_frames = sample_frames
self.temporal_sample = temporal_sample
self.transform = transform
self.seed = seed
def __len__(self):
return self.num_samples
def get_sample(self, idx):
"""
Args:
idx (int): Index of the sample to return.
Returns:
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
"""
# path = random.choice(self.video_lists)
path = self.video_lists[idx]
if self.file_list is not None: # read from pcache
with open(path, 'rb') as f:
vframes = VideoReader(f)
else:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
if not end_frame_ind - start_frame_ind >= self.sample_frames:
raise ValueError(f'video {path} does not have enough frames')
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.sample_frames, dtype=int)
if self.file_list is not None: # read from pcache
video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
else:
video = vframes[frame_indice]
# (f c h w)
pixel_values = self.transform(video)
return {'pixel_values': pixel_values}
def __getitem__(self, idx):
# return self.get_sample(idx)
while(True):
try:
# idx = np.random.randint(0, len(self.video_lists) - 1)
# idx = self.rng.integers(0, len(self.video_lists))
item = self.get_sample(idx)
return item
except:
warnings.warn(f'loading {idx} failed, retrying...')
idx = np.random.randint(0, len(self.video_lists) - 1)