Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import torch | |
| import random | |
| import torch.utils.data as data | |
| import numpy as np | |
| from glob import glob | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from opensora.dataset.transform import center_crop, RandomCropVideo | |
| from opensora.utils.dataset_utils import DecordInit | |
| class T2V_Feature_dataset(Dataset): | |
| def __init__(self, args, temporal_sample): | |
| self.video_folder = args.video_folder | |
| self.num_frames = args.video_length | |
| self.temporal_sample = temporal_sample | |
| print('Building dataset...') | |
| if os.path.exists('samples_430k.json'): | |
| with open('samples_430k.json', 'r') as f: | |
| self.samples = json.load(f) | |
| else: | |
| self.samples = self._make_dataset() | |
| with open('samples_430k.json', 'w') as f: | |
| json.dump(self.samples, f, indent=2) | |
| self.use_image_num = args.use_image_num | |
| self.use_img_from_vid = args.use_img_from_vid | |
| if self.use_image_num != 0 and not self.use_img_from_vid: | |
| self.img_cap_list = self.get_img_cap_list() | |
| def _make_dataset(self): | |
| all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) | |
| # all_mp4 = all_mp4[:1000] | |
| samples = [] | |
| for i in tqdm(all_mp4): | |
| video_id = os.path.basename(i).split('.')[0] | |
| ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature') | |
| ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy') | |
| if not os.path.exists(ae): | |
| continue | |
| t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature') | |
| cond_list = [] | |
| cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') | |
| mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') | |
| if os.path.exists(cond_llava) and os.path.exists(mask_llava): | |
| llava = dict(cond=cond_llava, mask=mask_llava) | |
| cond_list.append(llava) | |
| cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') | |
| mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') | |
| if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): | |
| sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) | |
| cond_list.append(sharegpt4v) | |
| if len(cond_list) > 0: | |
| sample = dict(ae=ae, t5=cond_list) | |
| samples.append(sample) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| # try: | |
| sample = self.samples[idx] | |
| ae, t5 = sample['ae'], sample['t5'] | |
| t5 = random.choice(t5) | |
| video_origin = np.load(ae)[0] # C T H W | |
| _, total_frames, _, _ = video_origin.shape | |
| # Sampling video frames | |
| start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| assert end_frame_ind - start_frame_ind >= self.num_frames | |
| select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) # start, stop, num=50 | |
| # print('select_video_idx', total_frames, select_video_idx) | |
| video = video_origin[:, select_video_idx] # C num_frames H W | |
| video = torch.from_numpy(video) | |
| cond = torch.from_numpy(np.load(t5['cond']))[0] # L | |
| cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D | |
| if self.use_image_num != 0 and self.use_img_from_vid: | |
| select_image_idx = np.random.randint(0, total_frames, self.use_image_num) | |
| # print('select_image_idx', total_frames, self.use_image_num, select_image_idx) | |
| images = video_origin[:, select_image_idx] # c, num_img, h, w | |
| images = torch.from_numpy(images) | |
| video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w | |
| cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
| cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
| elif self.use_image_num != 0 and not self.use_img_from_vid: | |
| images, captions = self.img_cap_list[idx] | |
| raise NotImplementedError | |
| else: | |
| pass | |
| return video, cond, cond_mask | |
| # except Exception as e: | |
| # print(f'Error with {e}, {sample}') | |
| # return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
| def get_img_cap_list(self): | |
| raise NotImplementedError | |
| class T2V_T5_Feature_dataset(Dataset): | |
| def __init__(self, args, transform, temporal_sample): | |
| self.video_folder = args.video_folder | |
| self.num_frames = args.num_frames | |
| self.transform = transform | |
| self.temporal_sample = temporal_sample | |
| self.v_decoder = DecordInit() | |
| print('Building dataset...') | |
| if os.path.exists('samples_430k.json'): | |
| with open('samples_430k.json', 'r') as f: | |
| self.samples = json.load(f) | |
| self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples] | |
| else: | |
| self.samples = self._make_dataset() | |
| with open('samples_430k.json', 'w') as f: | |
| json.dump(self.samples, f, indent=2) | |
| self.use_image_num = args.use_image_num | |
| self.use_img_from_vid = args.use_img_from_vid | |
| if self.use_image_num != 0 and not self.use_img_from_vid: | |
| self.img_cap_list = self.get_img_cap_list() | |
| def _make_dataset(self): | |
| all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) | |
| # all_mp4 = all_mp4[:1000] | |
| samples = [] | |
| for i in tqdm(all_mp4): | |
| video_id = os.path.basename(i).split('.')[0] | |
| # ae = os.path.split(i)[0].replace('data_split', 'lb_causalvideovae444_feature') | |
| # ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy') | |
| ae = i | |
| if not os.path.exists(ae): | |
| continue | |
| t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature') | |
| cond_list = [] | |
| cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') | |
| mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') | |
| if os.path.exists(cond_llava) and os.path.exists(mask_llava): | |
| llava = dict(cond=cond_llava, mask=mask_llava) | |
| cond_list.append(llava) | |
| cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') | |
| mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') | |
| if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): | |
| sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) | |
| cond_list.append(sharegpt4v) | |
| if len(cond_list) > 0: | |
| sample = dict(ae=ae, t5=cond_list) | |
| samples.append(sample) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| try: | |
| sample = self.samples[idx] | |
| ae, t5 = sample['ae'], sample['t5'] | |
| t5 = random.choice(t5) | |
| video = self.decord_read(ae) | |
| video = self.transform(video) # T C H W -> T C H W | |
| video = video.transpose(0, 1) # T C H W -> C T H W | |
| total_frames = video.shape[1] | |
| cond = torch.from_numpy(np.load(t5['cond']))[0] # L | |
| cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D | |
| if self.use_image_num != 0 and self.use_img_from_vid: | |
| select_image_idx = np.random.randint(0, total_frames, self.use_image_num) | |
| # print('select_image_idx', total_frames, self.use_image_num, select_image_idx) | |
| images = video.numpy()[:, select_image_idx] # c, num_img, h, w | |
| images = torch.from_numpy(images) | |
| video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w | |
| cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
| cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
| elif self.use_image_num != 0 and not self.use_img_from_vid: | |
| images, captions = self.img_cap_list[idx] | |
| raise NotImplementedError | |
| else: | |
| pass | |
| return video, cond, cond_mask | |
| except Exception as e: | |
| print(f'Error with {e}, {sample}') | |
| return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
| def decord_read(self, path): | |
| decord_vr = self.v_decoder(path) | |
| total_frames = len(decord_vr) | |
| # Sampling video frames | |
| start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| # assert end_frame_ind - start_frame_ind >= self.num_frames | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) | |
| video_data = decord_vr.get_batch(frame_indice).asnumpy() | |
| video_data = torch.from_numpy(video_data) | |
| video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) | |
| return video_data | |
| def get_img_cap_list(self): | |
| raise NotImplementedError |