eawolf2357-git / Dataset /video_dataset.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
from datetime import timedelta
from pathlib import Path
from typing import List, Optional, Tuple, Union
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from lineart_extractor.annotator.lineart import LineartDetector
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np
import accelerate
import torch
try:
import decord
except ImportError:
raise ImportError(
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
)
decord.bridge.set_bridge("torch")
class VideoDataset(Dataset):
def __init__(
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
caption_column: str = "text",
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
) -> None:
super().__init__()
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.caption_column = caption_column
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
self.skip_frames_end = skip_frames_end
self.cache_dir = cache_dir
self.id_token = id_token or ""
if dataset_name is not None:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
else:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
self.num_instance_videos = len(self.instance_video_paths)
if self.num_instance_videos != len(self.instance_prompts):
raise ValueError(
f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
)
#self.detector = LineartDetector('cpu')
#TODO: here just point the cuda maybe have some problem
#we put the preprocess_data() in the get_item function
#self.instance_videos = self._preprocess_data()
#here, how to make it in the get_item?
def __len__(self):
return self.num_instance_videos
def encode_video(self, video,vae,device):
#vae,device
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
image = video[:, :, :1].clone()
latent_dist = vae.encode(video).latent_dist
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = vae.encode(noisy_image).latent_dist
return latent_dist, image_latent_dist
def __getitem__(self, index):
#output_video=self.encode_video(video,vae,device)
#_encode_instance_video=self.encode_video(self.instance_prompts[index],device=)
#处理selfinstance_videos
filename = self.instance_video_paths[index]
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - self.skip_frames_end)
if end_frame <= start_frame:
frames = video_reader.get_batch([start_frame])
elif end_frame - start_frame <= self.max_num_frames:
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
else:
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
frames = video_reader.get_batch(indices)
# Ensure that we don't go over the limit
frames = frames[: self.max_num_frames]
selected_num_frames = frames.shape[0]
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
frames = self._resize_for_rectangle_crop(frames)
final_frames = frames.contiguous()
# [F, C, H, W]
# with torch.no_grad():
# sketch = self.detector(final_frames,coarse=False)
# #sketch应该被增加成三通道的,方便后续的处理
# sketch=sketch.repeat(1,3,1,1)
# sketch = (sketch - 0.5) / 0.5
# final_sketch=sketch.contiguous()
#print("Frames is contiguous after arithmetic operations:", final_frames.is_contiguous())
# for i in range(selected_num_frames):
# np_img = np.array(Image.open(img_path).convert('RGB').resize((720,480)))
# with torch.no_grad():
# sketch = detector(np_img, coarse=False)
# sketch = (sketch - 127.5) / 127.5
# sketch = sketch.permute(0, 3, 1, 2) # [F, C, H, W]
# sketch = self._resize_for_rectangle_crop(sketch)
# final_sketch=final_sketch.contiguous() # [F, C, H, W]
#here is tensor framse
return {
"instance_prompt": self.instance_prompts[index],
"instance_video": final_frames,
#"instance_sketch": final_sketch,
}
def _load_dataset_from_hub(self):
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_root instead."
)
# Downloading and loading a dataset from the hub. See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
self.dataset_name,
self.dataset_config_name,
cache_dir=self.cache_dir,
)
column_names = dataset["train"].column_names
if self.video_column is None:
video_column = column_names[0]
#logger.info(f"`video_column` defaulting to {video_column}")
print(f"`video_column` defaulting to {video_column}")
else:
video_column = self.video_column
if video_column not in column_names:
raise ValueError(
f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if self.caption_column is None:
caption_column = column_names[1]
#logger.info(f"`caption_column` defaulting to {caption_column}")
print(f"`caption_column` defaulting to {caption_column}")
else:
caption_column = self.caption_column
if self.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_prompts = dataset["train"][caption_column]
instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]]
return instance_prompts, instance_videos
def _load_dataset_from_local_path(self):
if not self.instance_data_root.exists():
raise ValueError("Instance videos root folder does not exist")
prompt_path = self.instance_data_root.joinpath(self.caption_column)
video_path = self.instance_data_root.joinpath(self.video_column)
if not prompt_path.exists() or not prompt_path.is_file():
raise ValueError(
"Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
)
if not video_path.exists() or not video_path.is_file():
raise ValueError(
"Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
)
with open(prompt_path, "r", encoding="utf-8") as file:
instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
with open(video_path, "r", encoding="utf-8") as file:
instance_videos = [
self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
]
if any(not path.is_file() for path in instance_videos):
raise ValueError(
"Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
)
return instance_prompts, instance_videos
def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
# here process the all data, we should make these processed in the get_item or other position
def _preprocess_data(self):
decord.bridge.set_bridge("torch")
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []
for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - self.skip_frames_end)
if end_frame <= start_frame:
frames = video_reader.get_batch([start_frame])
elif end_frame - start_frame <= self.max_num_frames:
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
else:
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
frames = video_reader.get_batch(indices)
# Ensure that we don't go over the limit
frames = frames[: self.max_num_frames]
selected_num_frames = frames.shape[0]
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames) #here the tensor should be processed to right size
frames = (frames - 127.5) / 127.5
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)
progress_dataset_bar.close()
return videos