eawolf2357-git / Dataset /video_dataset.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
raw
history blame
13.6 kB
from datetime import timedelta
from pathlib import Path
from typing import List, Optional, Tuple, Union
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from lineart_extractor.annotator.lineart import LineartDetector
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np
import accelerate
import torch
try:
import decord
except ImportError:
raise ImportError(
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
)
decord.bridge.set_bridge("torch")
class VideoDataset(Dataset):
def __init__(
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
caption_column: str = "text",
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
) -> None:
super().__init__()
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.caption_column = caption_column
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
self.skip_frames_end = skip_frames_end
self.cache_dir = cache_dir
self.id_token = id_token or ""
if dataset_name is not None:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
else:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
self.num_instance_videos = len(self.instance_video_paths)
if self.num_instance_videos != len(self.instance_prompts):
raise ValueError(
f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
)
#self.detector = LineartDetector('cpu')
#TODO: here just point the cuda maybe have some problem
#we put the preprocess_data() in the get_item function
#self.instance_videos = self._preprocess_data()
#here, how to make it in the get_item?
def __len__(self):
return self.num_instance_videos
def encode_video(self, video,vae,device):
#vae,device
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
image = video[:, :, :1].clone()
latent_dist = vae.encode(video).latent_dist
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = vae.encode(noisy_image).latent_dist
return latent_dist, image_latent_dist
def __getitem__(self, index):
#output_video=self.encode_video(video,vae,device)
#_encode_instance_video=self.encode_video(self.instance_prompts[index],device=)
#处理selfinstance_videos
filename = self.instance_video_paths[index]
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - self.skip_frames_end)
if end_frame <= start_frame:
frames = video_reader.get_batch([start_frame])
elif end_frame - start_frame <= self.max_num_frames:
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
else:
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
frames = video_reader.get_batch(indices)
# Ensure that we don't go over the limit
frames = frames[: self.max_num_frames]
selected_num_frames = frames.shape[0]
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
frames = self._resize_for_rectangle_crop(frames)
final_frames = frames.contiguous()
# [F, C, H, W]
# with torch.no_grad():
# sketch = self.detector(final_frames,coarse=False)
# #sketch应该被增加成三通道的,方便后续的处理
# sketch=sketch.repeat(1,3,1,1)
# sketch = (sketch - 0.5) / 0.5
# final_sketch=sketch.contiguous()
#print("Frames is contiguous after arithmetic operations:", final_frames.is_contiguous())
# for i in range(selected_num_frames):
# np_img = np.array(Image.open(img_path).convert('RGB').resize((720,480)))
# with torch.no_grad():
# sketch = detector(np_img, coarse=False)
# sketch = (sketch - 127.5) / 127.5
# sketch = sketch.permute(0, 3, 1, 2) # [F, C, H, W]
# sketch = self._resize_for_rectangle_crop(sketch)
# final_sketch=final_sketch.contiguous() # [F, C, H, W]
#here is tensor framse
return {
"instance_prompt": self.instance_prompts[index],
"instance_video": final_frames,
#"instance_sketch": final_sketch,
}
def _load_dataset_from_hub(self):
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_root instead."
)
# Downloading and loading a dataset from the hub. See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
self.dataset_name,
self.dataset_config_name,
cache_dir=self.cache_dir,
)
column_names = dataset["train"].column_names
if self.video_column is None:
video_column = column_names[0]
#logger.info(f"`video_column` defaulting to {video_column}")
print(f"`video_column` defaulting to {video_column}")
else:
video_column = self.video_column
if video_column not in column_names:
raise ValueError(
f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if self.caption_column is None:
caption_column = column_names[1]
#logger.info(f"`caption_column` defaulting to {caption_column}")
print(f"`caption_column` defaulting to {caption_column}")
else:
caption_column = self.caption_column
if self.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_prompts = dataset["train"][caption_column]
instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]]
return instance_prompts, instance_videos
def _load_dataset_from_local_path(self):
if not self.instance_data_root.exists():
raise ValueError("Instance videos root folder does not exist")
prompt_path = self.instance_data_root.joinpath(self.caption_column)
video_path = self.instance_data_root.joinpath(self.video_column)
if not prompt_path.exists() or not prompt_path.is_file():
raise ValueError(
"Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
)
if not video_path.exists() or not video_path.is_file():
raise ValueError(
"Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
)
with open(prompt_path, "r", encoding="utf-8") as file:
instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
with open(video_path, "r", encoding="utf-8") as file:
instance_videos = [
self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
]
if any(not path.is_file() for path in instance_videos):
raise ValueError(
"Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
)
return instance_prompts, instance_videos
def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
# here process the all data, we should make these processed in the get_item or other position
def _preprocess_data(self):
decord.bridge.set_bridge("torch")
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []
for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - self.skip_frames_end)
if end_frame <= start_frame:
frames = video_reader.get_batch([start_frame])
elif end_frame - start_frame <= self.max_num_frames:
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
else:
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
frames = video_reader.get_batch(indices)
# Ensure that we don't go over the limit
frames = frames[: self.max_num_frames]
selected_num_frames = frames.shape[0]
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames) #here the tensor should be processed to right size
frames = (frames - 127.5) / 127.5
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)
progress_dataset_bar.close()
return videos