from datetime import timedelta from pathlib import Path from typing import List, Optional, Tuple, Union from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm from torchvision.transforms.functional import center_crop, resize from torchvision.transforms import InterpolationMode import torchvision.transforms as TT import numpy as np import accelerate import torch import pandas as pd from pathlib import PosixPath import os from datetime import datetime try: import decord except ImportError: raise ImportError( "The `decord` package is required for loading the video dataset. Install with `pip install decord`" ) decord.bridge.set_bridge("torch") class Sakuga_Dataset(Dataset): def __init__( self, instance_data_root: Optional[str] = None, sketch_data_root: Optional[str] = None, dataset_name: Optional[str] = None, dataset_config_name: Optional[str] = None, caption_column: str = "text", video_column: str = "video", height: int = 480, width: int = 720, video_reshape_mode: str = "center", fps: int = 8, max_num_frames: int = 49, skip_frames_start: int = 0, skip_frames_end: int = 0, cache_dir: Optional[str] = None, id_token: Optional[str] = None, data_information: Optional[str] = None, stage: Optional[str] = "1", ) -> None: super().__init__() self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None self.sketch_data_root = Path(sketch_data_root) if sketch_data_root is not None else None self.dataset_name = dataset_name self.dataset_config_name = dataset_config_name self.caption_column = caption_column self.video_column = video_column self.height = height self.width = width self.video_reshape_mode = video_reshape_mode self.fps = fps self.max_num_frames = max_num_frames self.skip_frames_start = skip_frames_start self.skip_frames_end = skip_frames_end self.cache_dir = cache_dir self.id_token = id_token or "" self.stage=stage ''' if dataset_name is not None: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() else: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() ''' self.data_information=pd.read_parquet(data_information) self.num_instance_videos = self.data_information.shape[0] #self.detector = LineartDetector('cpu') #TODO: here just point the cuda maybe have some problem #we put the preprocess_data() in the get_item function #self.instance_videos = self._preprocess_data() #here, how to make it in the get_item? def __len__(self): return self.num_instance_videos def encode_video(self, video,vae,device): #vae,device video = video.to(device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] image = video[:, :, :1].clone() latent_dist = vae.encode(video).latent_dist image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] image_latent_dist = vae.encode(noisy_image).latent_dist return latent_dist, image_latent_dist def read_video(self,video_path): filename=PosixPath(video_path) #this part have some wrong things try: video_reader = decord.VideoReader(uri=filename.as_posix()) video_num_frames = len(video_reader) #需不需要这里强制一下从第10帧开始? start_frame = min(self.skip_frames_start, video_num_frames) end_frame = max(0, video_num_frames - self.skip_frames_end) if end_frame <= start_frame: frames = video_reader.get_batch([start_frame]) elif end_frame - start_frame <= self.max_num_frames: frames = video_reader.get_batch(list(range(start_frame, end_frame))) else: #this has problem #indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) indices=list(range(start_frame,self.max_num_frames)) frames = video_reader.get_batch(indices) # Ensure that we don't go over the limit frames = frames[: self.max_num_frames] selected_num_frames = frames.shape[0] # Choose first (4k + 1) frames as this is how many is required by the VAE remainder = (3 + (selected_num_frames % 4)) % 4 if remainder != 0: frames = frames[:-remainder] selected_num_frames = frames.shape[0] assert (selected_num_frames - 1) % 4 == 0 # Training transforms frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] #print("frame",frames.shape) frames = self._resize_for_rectangle_crop(frames) final_frames = frames.contiguous() if final_frames.dim()==3: final_frames=final_frames.unsqueeze(0) return final_frames except: return None def __getitem__(self, index): #output_video=self.encode_video(video,vae,device) #_encode_instance_video=self.encode_video(self.instance_prompts[index],device=) #处理selfinstance_videos folder_path=os.path.join(self.instance_data_root, str(self.data_information.iloc[index]['identifier_video'])) #sketch_path=os.path.join(self.sketch_data_root, str(self.data_information.iloc[index]['identifier_video'])) #注意这里的sketch存的是[0,255]的信息,不需要1-了,但是之后可能还是要变成2值的操作,得看一下如何2值化 #这里寻找id号的过程有点问题,不太对 #这个identifier是不是就是对应的片段名称?是的,identifier后面的对应的就是sence_number #indices = self.data_information.index[self.data_information['identifier_video'] == self.data_information.iloc[index]['identifier_video']].tolist() # video_name=self.data_information.iloc[index]['identifier_video'] # sence_number=self.data_information.iloc[index]["identifier"].split(":")[1] # data_name=f"{video_name}-Scene-{int(sence_number):03d}.mp4" frames=self.data_information.iloc[index]["start_frame"] video_name=self.data_information.iloc[index]["identifier"].split(':')[0] #print(frames) data_path_1=f'{video_name}-Scene-{frames}.mp4' data_path_2=f'{video_name}-Scene-{frames+1}.mp4' data_path_3=f'{video_name}-Scene-{frames-1}.mp4' fd1=os.path.join(folder_path,data_path_1) #sketch_fd1=os.path.join(sketch_path,data_path_1) fd2=os.path.join(folder_path,data_path_2) #sketch_fd2=os.path.join(sketch_path,data_path_2) fd3=os.path.join(folder_path,data_path_3) #sketch_fd3=os.path.join(sketch_path,data_path_2) #print(fd1) if os.path.exists(fd1): file_path=fd1 elif os.path.exists(fd2): file_path=fd2 elif os.path.exists(fd3): file_path=fd3 prompt=self.data_information.iloc[index]["text_description"] final_frames=self.read_video(PosixPath(file_path)) #final_sketch_frames=self.read_video(PosixPath(sketch_file_path)) final_sketch_frames=None instance_prompt = prompt + self.id_token return { "instance_prompt": instance_prompt, "instance_video": final_frames, "file_path":file_path, "sketch_video": final_sketch_frames, "instance_image": None, #"instance_sketch": final_sketch, } def _load_dataset_from_hub(self): try: from datasets import load_dataset except ImportError: raise ImportError( "You are trying to load your data using the datasets library. If you wish to train using custom " "captions please install the datasets library: `pip install datasets`. If you wish to load a " "local folder containing images only, specify --instance_data_root instead." ) # Downloading and loading a dataset from the hub. See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script dataset = load_dataset( self.dataset_name, self.dataset_config_name, cache_dir=self.cache_dir, ) column_names = dataset["train"].column_names if self.video_column is None: video_column = column_names[0] #logger.info(f"`video_column` defaulting to {video_column}") print(f"`video_column` defaulting to {video_column}") else: video_column = self.video_column if video_column not in column_names: raise ValueError( f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) if self.caption_column is None: caption_column = column_names[1] #logger.info(f"`caption_column` defaulting to {caption_column}") print(f"`caption_column` defaulting to {caption_column}") else: caption_column = self.caption_column if self.caption_column not in column_names: raise ValueError( f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) instance_prompts = dataset["train"][caption_column] instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] return instance_prompts, instance_videos def _load_dataset_from_local_path(self): if not self.instance_data_root.exists(): raise ValueError("Instance videos root folder does not exist") prompt_path = self.instance_data_root.joinpath(self.caption_column) video_path = self.instance_data_root.joinpath(self.video_column) if not prompt_path.exists() or not prompt_path.is_file(): raise ValueError( "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." ) if not video_path.exists() or not video_path.is_file(): raise ValueError( "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." ) with open(prompt_path, "r", encoding="utf-8") as file: instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] with open(video_path, "r", encoding="utf-8") as file: instance_videos = [ self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 ] if any(not path.is_file() for path in instance_videos): raise ValueError( "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." ) return instance_prompts, instance_videos def _resize_for_rectangle_crop(self, arr): image_size = self.height, self.width reshape_mode = self.video_reshape_mode if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], interpolation=InterpolationMode.BICUBIC, ) else: arr = resize( arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], interpolation=InterpolationMode.BICUBIC, ) h, w = arr.shape[2], arr.shape[3] arr = arr.squeeze(0) delta_h = h - image_size[0] delta_w = w - image_size[1] if reshape_mode == "random" or reshape_mode == "none": top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) elif reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: raise NotImplementedError arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) return arr # here process the all data, we should make these processed in the get_item or other position def _preprocess_data(self): decord.bridge.set_bridge("torch") progress_dataset_bar = tqdm( range(0, len(self.instance_video_paths)), desc="Loading progress resize and crop videos", ) videos = [] for filename in self.instance_video_paths: video_reader = decord.VideoReader(uri=filename.as_posix()) video_num_frames = len(video_reader) start_frame = min(self.skip_frames_start, video_num_frames) end_frame = max(0, video_num_frames - self.skip_frames_end) if end_frame <= start_frame: frames = video_reader.get_batch([start_frame]) elif end_frame - start_frame <= self.max_num_frames: frames = video_reader.get_batch(list(range(start_frame, end_frame))) else: indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) frames = video_reader.get_batch(indices) # Ensure that we don't go over the limit frames = frames[: self.max_num_frames] selected_num_frames = frames.shape[0] # Choose first (4k + 1) frames as this is how many is required by the VAE remainder = (3 + (selected_num_frames % 4)) % 4 if remainder != 0: frames = frames[:-remainder] selected_num_frames = frames.shape[0] assert (selected_num_frames - 1) % 4 == 0 # Training transforms frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] progress_dataset_bar.set_description( f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}" ) frames = self._resize_for_rectangle_crop(frames) #here the tensor should be processed to right size frames = (frames - 127.5) / 127.5 videos.append(frames.contiguous()) # [F, C, H, W] progress_dataset_bar.update(1) progress_dataset_bar.close() return videos if __name__=="__main__": train_dataset = Sakuga_Dataset( instance_data_root='', height= 480, width= 720, video_reshape_mode="center", fps=8, max_num_frames=49, skip_frames_start=0, skip_frames_end=0, cache_dir="~/.cache", id_token="", data_information="../../../Datasets/SakugaDataset/parquet/fliter_59_aesthetic_precise.parquet" ) data=train_dataset.__getitem__(0) print(data["instance_video"].shape)