Spaces:
Paused
Paused
File size: 7,130 Bytes
59d751c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import hashlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import json
import random
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_images,
load_images_from_videos,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_buckets,
preprocess_video_with_resize,
load_binary_mask_compressed,
)
import pdb
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(LOG_NAME, LOG_LEVEL)
class I2VFlowDataset(Dataset):
"""
A dataset class for (image,flow)-to-video generation or image-to-flow_video that resizes inputs to fixed dimensions.
This class preprocesses videos and images by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
- Images are resized to height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(
self,
max_num_frames: int,
height: int,
width: int,
data_root: str,
caption_column: str,
video_column: str,
image_column: str | None,
device: torch.device,
trainer: "Trainer" = None,
*args,
**kwargs
) -> None:
data_root = Path(data_root)
metadata_path = data_root / "metadata_revised.jsonl"
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl in the root path"
# Load metadata
# metadata = {
# "video_path": ...,
# "hash_code": ...,
# "prompt": ...,
# }
metadata = []
with open(metadata_path, "r") as f:
for line in f:
metadata.append( json.loads(line) )
self.prompts = [x["prompt"] for x in metadata]
if 'curated' in str(data_root).lower():
self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata]
else:
self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
self.videos = [data_root / "video_latent" / "x".join(str(x) for x in trainer.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
self.images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
self.flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
# data_root = Path(data_root)
# self.prompts = load_prompts(data_root / caption_column)
# self.videos = load_videos(data_root / video_column)
self.trainer = trainer
self.device = device
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images) == len(self.flows)):
raise ValueError(
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=}, {len(self.images)=} and {len(self.flows)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
)
self.max_num_frames = max_num_frames
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
self.length = len(self.videos)
print(f"Dataset size: {self.length}")
def __len__(self) -> int:
return self.length
def load_data_pair(self, index):
# prompt = self.prompts[index]
prompt_embedding_path = self.prompt_embeddings[index]
encoded_video_path = self.videos[index]
encoded_flow_path = self.flows[index]
# mask_path = self.masks[index]
# image_path = self.images[index]
# train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
encoded_video = load_file(encoded_video_path)["encoded_video"] # CFHW
encoded_flow = load_file(encoded_flow_path)["encoded_flow_f"] # CFHW
return prompt_embedding, encoded_video, encoded_flow
def __getitem__(self, index: int) -> Dict[str, Any]:
while True:
try:
prompt_embedding, encoded_video, encoded_flow = self.load_data_pair(index)
break
except Exception as e:
print(f"Error loading {self.prompt_embeddings[index]}: {str(e)}")
index = random.randint(0, self.length - 1)
image_path = self.images[index]
prompt = self.prompts[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
_, image = self.preprocess(None, image_path)
image = self.image_transform(image)
# shape of encoded_video: [C, F, H, W]
# shape and scale of image: [C, H, W], [-1,1]
return {
"image": image,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"encoded_flow": encoded_flow,
"video_metadata": {
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
@override
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
else:
video = None
if image_path is not None:
image = preprocess_image_with_resize(image_path, self.height, self.width)
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image) |