Spaces:
Running
Running
import random | |
from typing import Any, Dict, Optional | |
import torch | |
import torch.distributed.checkpoint.stateful | |
from diffusers.video_processor import VideoProcessor | |
import finetrainers.functional as FF | |
from finetrainers.logging import get_logger | |
from finetrainers.processors import CannyProcessor, CopyProcessor | |
from .config import ControlType, FrameConditioningType | |
logger = get_logger() | |
class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): | |
def __init__( | |
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None | |
): | |
super().__init__() | |
self.dataset = dataset | |
self.control_type = control_type | |
self.control_processors = [] | |
if control_type == ControlType.CANNY: | |
self.control_processors.append( | |
CannyProcessor( | |
output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device | |
) | |
) | |
elif control_type == ControlType.NONE: | |
self.control_processors.append( | |
CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"}) | |
) | |
logger.info("Initialized IterableControlDataset") | |
def __iter__(self): | |
logger.info("Starting IterableControlDataset") | |
for data in iter(self.dataset): | |
control_augmented_data = self._run_control_processors(data) | |
yield control_augmented_data | |
def load_state_dict(self, state_dict): | |
self.dataset.load_state_dict(state_dict) | |
def state_dict(self): | |
return self.dataset.state_dict() | |
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
if "control_image" in data: | |
if "image" in data: | |
data["control_image"] = FF.resize_to_nearest_bucket_image( | |
data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic" | |
) | |
if "video" in data: | |
batch_size, num_frames, num_channels, height, width = data["video"].shape | |
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( | |
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" | |
) | |
if _first_frame_only: | |
msg = ( | |
"The number of frames in the control video is less than the minimum bucket size " | |
"specified. The first frame is being used as a single frame video. This " | |
"message is logged at the first occurence and for every 128th occurence " | |
"after that." | |
) | |
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) | |
data["control_video"] = data["control_video"][0] | |
return data | |
if "control_video" in data: | |
if "image" in data: | |
data["control_image"] = FF.resize_to_nearest_bucket_image( | |
data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic" | |
) | |
if "video" in data: | |
batch_size, num_frames, num_channels, height, width = data["video"].shape | |
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( | |
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" | |
) | |
if _first_frame_only: | |
msg = ( | |
"The number of frames in the control video is less than the minimum bucket size " | |
"specified. The first frame is being used as a single frame video. This " | |
"message is logged at the first occurence and for every 128th occurence " | |
"after that." | |
) | |
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) | |
data["control_video"] = data["control_video"][0] | |
return data | |
if self.control_type == ControlType.CUSTOM: | |
return data | |
shallow_copy_data = dict(data.items()) | |
is_image_control = "image" in shallow_copy_data | |
is_video_control = "video" in shallow_copy_data | |
if (is_image_control + is_video_control) != 1: | |
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") | |
for processor in self.control_processors: | |
result = processor(**shallow_copy_data) | |
result_keys = set(result.keys()) | |
repeat_keys = result_keys.intersection(shallow_copy_data.keys()) | |
if repeat_keys: | |
logger.warning( | |
f"Processor {processor.__class__.__name__} returned keys that already exist in " | |
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " | |
f"be intended. Please rename the keys in the processor to avoid conflicts." | |
) | |
shallow_copy_data.update(result) | |
if "control_output" in shallow_copy_data: | |
# Normalize to [-1, 1] range | |
control_output = shallow_copy_data.pop("control_output") | |
# TODO(aryan): need to specify a dim for normalize here across channels | |
control_output = FF.normalize(control_output, min=-1.0, max=1.0) | |
key = "control_image" if is_image_control else "control_video" | |
shallow_copy_data[key] = control_output | |
return shallow_copy_data | |
class ValidationControlDataset(torch.utils.data.IterableDataset): | |
def __init__( | |
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None | |
): | |
super().__init__() | |
self.dataset = dataset | |
self.control_type = control_type | |
self.device = device | |
self._video_processor = VideoProcessor() | |
self.control_processors = [] | |
if control_type == ControlType.CANNY: | |
self.control_processors.append( | |
CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device) | |
) | |
elif control_type == ControlType.NONE: | |
self.control_processors.append( | |
CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"}) | |
) | |
logger.info("Initialized ValidationControlDataset") | |
def __iter__(self): | |
logger.info("Starting ValidationControlDataset") | |
for data in iter(self.dataset): | |
control_augmented_data = self._run_control_processors(data) | |
yield control_augmented_data | |
def load_state_dict(self, state_dict): | |
self.dataset.load_state_dict(state_dict) | |
def state_dict(self): | |
return self.dataset.state_dict() | |
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
if self.control_type == ControlType.CUSTOM: | |
return data | |
# These are already expected to be tensors | |
if "control_image" in data or "control_video" in data: | |
return data | |
shallow_copy_data = dict(data.items()) | |
is_image_control = "image" in shallow_copy_data | |
is_video_control = "video" in shallow_copy_data | |
if (is_image_control + is_video_control) != 1: | |
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") | |
for processor in self.control_processors: | |
result = processor(**shallow_copy_data) | |
result_keys = set(result.keys()) | |
repeat_keys = result_keys.intersection(shallow_copy_data.keys()) | |
if repeat_keys: | |
logger.warning( | |
f"Processor {processor.__class__.__name__} returned keys that already exist in " | |
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " | |
f"be intended. Please rename the keys in the processor to avoid conflicts." | |
) | |
shallow_copy_data.update(result) | |
if "control_output" in shallow_copy_data: | |
# Normalize to [-1, 1] range | |
control_output = shallow_copy_data.pop("control_output") | |
if torch.is_tensor(control_output): | |
# TODO(aryan): need to specify a dim for normalize here across channels | |
control_output = FF.normalize(control_output, min=-1.0, max=1.0) | |
ndim = control_output.ndim | |
assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5" | |
if ndim == 5: | |
control_output = self._video_processor.postprocess_video(control_output, output_type="pil") | |
else: | |
if ndim == 3: | |
control_output = control_output.unsqueeze(0) | |
control_output = self._video_processor.postprocess(control_output, output_type="pil")[0] | |
key = "control_image" if is_image_control else "control_video" | |
shallow_copy_data[key] = control_output | |
return shallow_copy_data | |
# TODO(aryan): write a test for this function | |
def apply_frame_conditioning_on_latents( | |
latents: torch.Tensor, | |
expected_num_frames: int, | |
channel_dim: int, | |
frame_dim: int, | |
frame_conditioning_type: FrameConditioningType, | |
frame_conditioning_index: Optional[int] = None, | |
concatenate_mask: bool = False, | |
) -> torch.Tensor: | |
num_frames = latents.size(frame_dim) | |
mask = torch.zeros_like(latents) | |
if frame_conditioning_type == FrameConditioningType.INDEX: | |
frame_index = min(frame_conditioning_index, num_frames - 1) | |
indexing = [slice(None)] * latents.ndim | |
indexing[frame_dim] = frame_index | |
mask[tuple(indexing)] = 1 | |
latents = latents * mask | |
elif frame_conditioning_type == FrameConditioningType.PREFIX: | |
frame_index = random.randint(1, num_frames) | |
indexing = [slice(None)] * latents.ndim | |
indexing[frame_dim] = slice(0, frame_index) # Keep frames 0 to frame_index-1 | |
mask[tuple(indexing)] = 1 | |
latents = latents * mask | |
elif frame_conditioning_type == FrameConditioningType.RANDOM: | |
# Zero or more random frames to keep | |
num_frames_to_keep = random.randint(1, num_frames) | |
frame_indices = random.sample(range(num_frames), num_frames_to_keep) | |
indexing = [slice(None)] * latents.ndim | |
indexing[frame_dim] = frame_indices | |
mask[tuple(indexing)] = 1 | |
latents = latents * mask | |
elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST: | |
indexing = [slice(None)] * latents.ndim | |
indexing[frame_dim] = 0 | |
mask[tuple(indexing)] = 1 | |
indexing[frame_dim] = num_frames - 1 | |
mask[tuple(indexing)] = 1 | |
latents = latents * mask | |
elif frame_conditioning_type == FrameConditioningType.FULL: | |
indexing = [slice(None)] * latents.ndim | |
indexing[frame_dim] = slice(0, num_frames) | |
mask[tuple(indexing)] = 1 | |
if latents.size(frame_dim) >= expected_num_frames: | |
slicing = [slice(None)] * latents.ndim | |
slicing[frame_dim] = slice(expected_num_frames) | |
latents = latents[tuple(slicing)] | |
mask = mask[tuple(slicing)] | |
else: | |
pad_size = expected_num_frames - num_frames | |
pad_shape = list(latents.shape) | |
pad_shape[frame_dim] = pad_size | |
padding = latents.new_zeros(pad_shape) | |
latents = torch.cat([latents, padding], dim=frame_dim) | |
mask = torch.cat([mask, padding], dim=frame_dim) | |
if concatenate_mask: | |
slicing = [slice(None)] * latents.ndim | |
slicing[channel_dim] = 0 | |
latents = torch.cat([latents, mask], dim=channel_dim) | |
return latents | |