jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import random
from typing import Any, Dict, Optional
import torch
import torch.distributed.checkpoint.stateful
from diffusers.video_processor import VideoProcessor
import finetrainers.functional as FF
from finetrainers.logging import get_logger
from finetrainers.processors import CannyProcessor, CopyProcessor
from .config import ControlType, FrameConditioningType
logger = get_logger()
class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
):
super().__init__()
self.dataset = dataset
self.control_type = control_type
self.control_processors = []
if control_type == ControlType.CANNY:
self.control_processors.append(
CannyProcessor(
output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device
)
)
elif control_type == ControlType.NONE:
self.control_processors.append(
CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"})
)
logger.info("Initialized IterableControlDataset")
def __iter__(self):
logger.info("Starting IterableControlDataset")
for data in iter(self.dataset):
control_augmented_data = self._run_control_processors(data)
yield control_augmented_data
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict)
def state_dict(self):
return self.dataset.state_dict()
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "control_image" in data:
if "image" in data:
data["control_image"] = FF.resize_to_nearest_bucket_image(
data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic"
)
if "video" in data:
batch_size, num_frames, num_channels, height, width = data["video"].shape
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
)
if _first_frame_only:
msg = (
"The number of frames in the control video is less than the minimum bucket size "
"specified. The first frame is being used as a single frame video. This "
"message is logged at the first occurence and for every 128th occurence "
"after that."
)
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
data["control_video"] = data["control_video"][0]
return data
if "control_video" in data:
if "image" in data:
data["control_image"] = FF.resize_to_nearest_bucket_image(
data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic"
)
if "video" in data:
batch_size, num_frames, num_channels, height, width = data["video"].shape
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
)
if _first_frame_only:
msg = (
"The number of frames in the control video is less than the minimum bucket size "
"specified. The first frame is being used as a single frame video. This "
"message is logged at the first occurence and for every 128th occurence "
"after that."
)
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
data["control_video"] = data["control_video"][0]
return data
if self.control_type == ControlType.CUSTOM:
return data
shallow_copy_data = dict(data.items())
is_image_control = "image" in shallow_copy_data
is_video_control = "video" in shallow_copy_data
if (is_image_control + is_video_control) != 1:
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
for processor in self.control_processors:
result = processor(**shallow_copy_data)
result_keys = set(result.keys())
repeat_keys = result_keys.intersection(shallow_copy_data.keys())
if repeat_keys:
logger.warning(
f"Processor {processor.__class__.__name__} returned keys that already exist in "
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
f"be intended. Please rename the keys in the processor to avoid conflicts."
)
shallow_copy_data.update(result)
if "control_output" in shallow_copy_data:
# Normalize to [-1, 1] range
control_output = shallow_copy_data.pop("control_output")
# TODO(aryan): need to specify a dim for normalize here across channels
control_output = FF.normalize(control_output, min=-1.0, max=1.0)
key = "control_image" if is_image_control else "control_video"
shallow_copy_data[key] = control_output
return shallow_copy_data
class ValidationControlDataset(torch.utils.data.IterableDataset):
def __init__(
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
):
super().__init__()
self.dataset = dataset
self.control_type = control_type
self.device = device
self._video_processor = VideoProcessor()
self.control_processors = []
if control_type == ControlType.CANNY:
self.control_processors.append(
CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device)
)
elif control_type == ControlType.NONE:
self.control_processors.append(
CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"})
)
logger.info("Initialized ValidationControlDataset")
def __iter__(self):
logger.info("Starting ValidationControlDataset")
for data in iter(self.dataset):
control_augmented_data = self._run_control_processors(data)
yield control_augmented_data
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict)
def state_dict(self):
return self.dataset.state_dict()
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.control_type == ControlType.CUSTOM:
return data
# These are already expected to be tensors
if "control_image" in data or "control_video" in data:
return data
shallow_copy_data = dict(data.items())
is_image_control = "image" in shallow_copy_data
is_video_control = "video" in shallow_copy_data
if (is_image_control + is_video_control) != 1:
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
for processor in self.control_processors:
result = processor(**shallow_copy_data)
result_keys = set(result.keys())
repeat_keys = result_keys.intersection(shallow_copy_data.keys())
if repeat_keys:
logger.warning(
f"Processor {processor.__class__.__name__} returned keys that already exist in "
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
f"be intended. Please rename the keys in the processor to avoid conflicts."
)
shallow_copy_data.update(result)
if "control_output" in shallow_copy_data:
# Normalize to [-1, 1] range
control_output = shallow_copy_data.pop("control_output")
if torch.is_tensor(control_output):
# TODO(aryan): need to specify a dim for normalize here across channels
control_output = FF.normalize(control_output, min=-1.0, max=1.0)
ndim = control_output.ndim
assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5"
if ndim == 5:
control_output = self._video_processor.postprocess_video(control_output, output_type="pil")
else:
if ndim == 3:
control_output = control_output.unsqueeze(0)
control_output = self._video_processor.postprocess(control_output, output_type="pil")[0]
key = "control_image" if is_image_control else "control_video"
shallow_copy_data[key] = control_output
return shallow_copy_data
# TODO(aryan): write a test for this function
def apply_frame_conditioning_on_latents(
latents: torch.Tensor,
expected_num_frames: int,
channel_dim: int,
frame_dim: int,
frame_conditioning_type: FrameConditioningType,
frame_conditioning_index: Optional[int] = None,
concatenate_mask: bool = False,
) -> torch.Tensor:
num_frames = latents.size(frame_dim)
mask = torch.zeros_like(latents)
if frame_conditioning_type == FrameConditioningType.INDEX:
frame_index = min(frame_conditioning_index, num_frames - 1)
indexing = [slice(None)] * latents.ndim
indexing[frame_dim] = frame_index
mask[tuple(indexing)] = 1
latents = latents * mask
elif frame_conditioning_type == FrameConditioningType.PREFIX:
frame_index = random.randint(1, num_frames)
indexing = [slice(None)] * latents.ndim
indexing[frame_dim] = slice(0, frame_index) # Keep frames 0 to frame_index-1
mask[tuple(indexing)] = 1
latents = latents * mask
elif frame_conditioning_type == FrameConditioningType.RANDOM:
# Zero or more random frames to keep
num_frames_to_keep = random.randint(1, num_frames)
frame_indices = random.sample(range(num_frames), num_frames_to_keep)
indexing = [slice(None)] * latents.ndim
indexing[frame_dim] = frame_indices
mask[tuple(indexing)] = 1
latents = latents * mask
elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST:
indexing = [slice(None)] * latents.ndim
indexing[frame_dim] = 0
mask[tuple(indexing)] = 1
indexing[frame_dim] = num_frames - 1
mask[tuple(indexing)] = 1
latents = latents * mask
elif frame_conditioning_type == FrameConditioningType.FULL:
indexing = [slice(None)] * latents.ndim
indexing[frame_dim] = slice(0, num_frames)
mask[tuple(indexing)] = 1
if latents.size(frame_dim) >= expected_num_frames:
slicing = [slice(None)] * latents.ndim
slicing[frame_dim] = slice(expected_num_frames)
latents = latents[tuple(slicing)]
mask = mask[tuple(slicing)]
else:
pad_size = expected_num_frames - num_frames
pad_shape = list(latents.shape)
pad_shape[frame_dim] = pad_size
padding = latents.new_zeros(pad_shape)
latents = torch.cat([latents, padding], dim=frame_dim)
mask = torch.cat([mask, padding], dim=frame_dim)
if concatenate_mask:
slicing = [slice(None)] * latents.ndim
slicing[channel_dim] = 0
latents = torch.cat([latents, mask], dim=channel_dim)
return latents