|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import json |
|
import os |
|
from contextlib import contextmanager |
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple |
|
|
|
import cv2 |
|
import einops |
|
import imageio |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms.functional as transforms_F |
|
from einops import rearrange |
|
|
|
from cosmos_transfer1.auxiliary.guardrail.common.io_utils import save_video |
|
from cosmos_transfer1.checkpoints import ( |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
) |
|
from cosmos_transfer1.diffusion.config.transfer.augmentors import BilateralOnlyBlurAugmentorConfig |
|
from cosmos_transfer1.diffusion.datasets.augmentors.control_input import get_augmentor_for_eval |
|
from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel |
|
from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel |
|
from cosmos_transfer1.diffusion.model.model_v2w_multiview import DiffusionV2WMultiviewModel |
|
from cosmos_transfer1.utils import log |
|
from cosmos_transfer1.utils.config_helper import get_config_module, override |
|
from cosmos_transfer1.utils.io import load_from_fileobj |
|
|
|
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
|
if TORCH_VERSION >= (1, 11): |
|
from torch.ao import quantization |
|
from torch.ao.quantization import FakeQuantizeBase, ObserverBase |
|
elif ( |
|
TORCH_VERSION >= (1, 8) |
|
and hasattr(torch.quantization, "FakeQuantizeBase") |
|
and hasattr(torch.quantization, "ObserverBase") |
|
): |
|
from torch import quantization |
|
from torch.quantization import FakeQuantizeBase, ObserverBase |
|
|
|
DEFAULT_AUGMENT_SIGMA = 0.001 |
|
NUM_MAX_FRAMES = 5000 |
|
VIDEO_RES_SIZE_INFO = { |
|
"1,1": (960, 960), |
|
"4,3": (960, 704), |
|
"3,4": (704, 960), |
|
"16,9": (1280, 704), |
|
"9,16": (704, 1280), |
|
} |
|
|
|
|
|
default_model_names = { |
|
"vis": VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"seg": SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"edge": EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"depth": DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"keypoint": KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"upscale": UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"hdmap": HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
"lidar": LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
} |
|
|
|
default_distilled_model_names = { |
|
"edge": EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, |
|
} |
|
|
|
|
|
class _IncompatibleKeys( |
|
NamedTuple( |
|
"IncompatibleKeys", |
|
[ |
|
("missing_keys", List[str]), |
|
("unexpected_keys", List[str]), |
|
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), |
|
], |
|
) |
|
): |
|
pass |
|
|
|
|
|
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: |
|
"""Load a model checkpoint with non-strict matching, handling shape mismatches. |
|
|
|
Args: |
|
model (torch.nn.Module): Model to load weights into |
|
checkpoint_state_dict (dict): State dict from checkpoint |
|
|
|
Returns: |
|
_IncompatibleKeys: Named tuple containing: |
|
- missing_keys: Keys present in model but missing from checkpoint |
|
- unexpected_keys: Keys present in checkpoint but not in model |
|
- incorrect_shapes: Keys with mismatched tensor shapes |
|
|
|
The function handles special cases like: |
|
- Uninitialized parameters |
|
- Quantization observers |
|
- TransformerEngine FP8 states |
|
""" |
|
|
|
model_state_dict = model.state_dict() |
|
incorrect_shapes = [] |
|
for k in list(checkpoint_state_dict.keys()): |
|
if k in model_state_dict: |
|
if "_extra_state" in k: |
|
log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") |
|
continue |
|
model_param = model_state_dict[k] |
|
|
|
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): |
|
continue |
|
if not isinstance(model_param, torch.Tensor): |
|
raise ValueError( |
|
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." |
|
) |
|
|
|
shape_model = tuple(model_param.shape) |
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
|
if shape_model != shape_checkpoint: |
|
has_observer_base_classes = ( |
|
TORCH_VERSION >= (1, 8) |
|
and hasattr(quantization, "ObserverBase") |
|
and hasattr(quantization, "FakeQuantizeBase") |
|
) |
|
if has_observer_base_classes: |
|
|
|
|
|
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: |
|
|
|
key_parts = key.split(".")[:-1] |
|
cur_module = model |
|
for key_part in key_parts: |
|
cur_module = getattr(cur_module, key_part) |
|
return cur_module |
|
|
|
cls_to_skip = ( |
|
ObserverBase, |
|
FakeQuantizeBase, |
|
) |
|
target_module = _get_module_for_key(model, k) |
|
if isinstance(target_module, cls_to_skip): |
|
|
|
|
|
|
|
continue |
|
|
|
incorrect_shapes.append((k, shape_checkpoint, shape_model)) |
|
checkpoint_state_dict.pop(k) |
|
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) |
|
|
|
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] |
|
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] |
|
return _IncompatibleKeys( |
|
missing_keys=missing_keys, |
|
unexpected_keys=unexpected_keys, |
|
incorrect_shapes=incorrect_shapes, |
|
) |
|
|
|
|
|
@contextmanager |
|
def skip_init_linear(): |
|
|
|
orig_reset_parameters = torch.nn.Linear.reset_parameters |
|
torch.nn.Linear.reset_parameters = lambda x: x |
|
xavier_uniform_ = torch.nn.init.xavier_uniform_ |
|
torch.nn.init.xavier_uniform_ = lambda x: x |
|
yield |
|
torch.nn.Linear.reset_parameters = orig_reset_parameters |
|
torch.nn.init.xavier_uniform_ = xavier_uniform_ |
|
|
|
|
|
def load_model_by_config( |
|
config_job_name, |
|
config_file="projects/cosmos_video/config/config.py", |
|
model_class=DiffusionT2WModel, |
|
base_checkpoint_dir="", |
|
): |
|
config_module = get_config_module(config_file) |
|
config = importlib.import_module(config_module).make_config() |
|
config = override(config, ["--", f"experiment={config_job_name}"]) |
|
if base_checkpoint_dir != "" and hasattr(config.model, "base_load_from"): |
|
if hasattr(config.model.base_load_from, "load_path"): |
|
if config.model.base_load_from.load_path != "": |
|
config.model.base_load_from.load_path = config.model.base_load_from.load_path.replace( |
|
"checkpoints", base_checkpoint_dir |
|
) |
|
log.info( |
|
f"Model need to load a base model weight, change the loading path from default folder to the {base_checkpoint_dir}" |
|
) |
|
|
|
|
|
config.validate() |
|
|
|
config.freeze() |
|
|
|
|
|
with skip_init_linear(): |
|
model = model_class(config.model) |
|
return model |
|
|
|
|
|
def load_network_model(model: DiffusionT2WModel, ckpt_path: str): |
|
if ckpt_path: |
|
with skip_init_linear(): |
|
model.set_up_model() |
|
net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
non_strict_load_model(model.model, net_state_dict) |
|
else: |
|
model.set_up_model() |
|
model.cuda() |
|
|
|
|
|
def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): |
|
with skip_init_linear(): |
|
model.set_up_tokenizer(tokenizer_dir) |
|
model.cuda() |
|
|
|
|
|
def prepare_data_batch( |
|
height: int, |
|
width: int, |
|
num_frames: int, |
|
fps: int, |
|
prompt_embedding: torch.Tensor, |
|
negative_prompt_embedding: Optional[torch.Tensor] = None, |
|
): |
|
"""Prepare input batch tensors for video generation. |
|
|
|
Args: |
|
height (int): Height of video frames |
|
width (int): Width of video frames |
|
num_frames (int): Number of frames to generate |
|
fps (int): Frames per second |
|
prompt_embedding (torch.Tensor): Encoded text prompt embeddings |
|
negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings |
|
|
|
Returns: |
|
dict: Batch dictionary containing: |
|
- video: Zero tensor of target video shape |
|
- t5_text_mask: Attention mask for text embeddings |
|
- image_size: Target frame dimensions |
|
- fps: Target frame rate |
|
- num_frames: Number of frames |
|
- padding_mask: Frame padding mask |
|
- t5_text_embeddings: Prompt embeddings |
|
- neg_t5_text_embeddings: Negative prompt embeddings (if provided) |
|
- neg_t5_text_mask: Mask for negative embeddings (if provided) |
|
""" |
|
|
|
data_batch = { |
|
"video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), |
|
"t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), |
|
"image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), |
|
"fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), |
|
"num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), |
|
"padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), |
|
} |
|
|
|
|
|
|
|
t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() |
|
data_batch["t5_text_embeddings"] = t5_embed |
|
|
|
if negative_prompt_embedding is not None: |
|
neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() |
|
data_batch["neg_t5_text_embeddings"] = neg_t5_embed |
|
data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() |
|
|
|
return data_batch |
|
|
|
|
|
def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): |
|
"""Prepare complete input batch for video generation including latent dimensions. |
|
|
|
Args: |
|
model: Diffusion model instance |
|
prompt_embedding (torch.Tensor): Text prompt embeddings |
|
negative_prompt_embedding (torch.Tensor): Negative prompt embeddings |
|
height (int): Output video height |
|
width (int): Output video width |
|
fps (int): Output video frame rate |
|
num_video_frames (int): Number of frames to generate |
|
|
|
Returns: |
|
tuple: |
|
- data_batch (dict): Complete model input batch |
|
- state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression |
|
""" |
|
raw_video_batch = prepare_data_batch( |
|
height=height, |
|
width=width, |
|
num_frames=num_video_frames, |
|
fps=fps, |
|
prompt_embedding=prompt_embedding, |
|
negative_prompt_embedding=negative_prompt_embedding, |
|
) |
|
state_shape = [ |
|
model.tokenizer.channel, |
|
model.tokenizer.get_latent_num_frames(num_video_frames), |
|
height // model.tokenizer.spatial_compression_factor, |
|
width // model.tokenizer.spatial_compression_factor, |
|
] |
|
return raw_video_batch, state_shape |
|
|
|
|
|
def resize_video(video_np, h, w, interpolation=cv2.INTER_AREA): |
|
"""Resize video frames to the specified height and width.""" |
|
video_np = video_np[0].transpose((1, 2, 3, 0)) |
|
t = video_np.shape[0] |
|
resized_video = np.zeros((t, h, w, 3), dtype=np.uint8) |
|
for i in range(t): |
|
resized_video[i] = cv2.resize(video_np[i], (w, h), interpolation=interpolation) |
|
return resized_video.transpose((3, 0, 1, 2))[None] |
|
|
|
|
|
def detect_aspect_ratio(img_size: tuple[int]): |
|
"""Function for detecting the closest aspect ratio.""" |
|
|
|
_aspect_ratios = np.array([(16 / 9), (4 / 3), 1, (3 / 4), (9 / 16)]) |
|
_aspect_ratio_keys = ["16,9", "4,3", "1,1", "3,4", "9,16"] |
|
w, h = img_size |
|
current_ratio = w / h |
|
closest_aspect_ratio = np.argmin((_aspect_ratios - current_ratio) ** 2) |
|
return _aspect_ratio_keys[closest_aspect_ratio] |
|
|
|
|
|
def get_upscale_size(orig_size: tuple[int], aspect_ratio: str, upscale_factor: int = 3, patch_overlap: int = 256): |
|
patch_w, patch_h = orig_size |
|
if aspect_ratio == "16,9" or aspect_ratio == "4,3": |
|
ratio = int(aspect_ratio.split(",")[1]) / int(aspect_ratio.split(",")[0]) |
|
target_w = patch_w * upscale_factor - patch_overlap |
|
target_h = patch_h * upscale_factor - int(patch_overlap * ratio) |
|
elif aspect_ratio == "9,16" or aspect_ratio == "3,4": |
|
ratio = int(aspect_ratio.split(",")[0]) / int(aspect_ratio.split(",")[1]) |
|
target_h = patch_h * upscale_factor - patch_overlap |
|
target_w = patch_w * upscale_factor - int(patch_overlap * ratio) |
|
else: |
|
target_h = patch_h * upscale_factor - patch_overlap |
|
target_w = patch_w * upscale_factor - patch_overlap |
|
return target_w, target_h |
|
|
|
|
|
def read_and_resize_input(input_control_path, num_total_frames, interpolation): |
|
control_input, fps = read_video_or_image_into_frames_BCTHW( |
|
input_control_path, |
|
normalize=False, |
|
max_frames=num_total_frames, |
|
also_return_fps=True, |
|
) |
|
aspect_ratio = detect_aspect_ratio((control_input.shape[-1], control_input.shape[-2])) |
|
w, h = VIDEO_RES_SIZE_INFO[aspect_ratio] |
|
control_input = resize_video(control_input, h, w, interpolation=interpolation) |
|
control_input = torch.from_numpy(control_input[0]) |
|
return control_input, fps, aspect_ratio |
|
|
|
|
|
def get_video_batch_for_multiview_model( |
|
model, prompt_embedding, height, width, fps, num_video_frames, frame_repeat_negative_condition |
|
): |
|
"""Prepare complete input batch for video generation including latent dimensions. |
|
|
|
Args: |
|
model: Diffusion model instance |
|
prompt_embedding list(torch.Tensor): Text prompt embeddings |
|
height (int): Output video height |
|
width (int): Output video width |
|
fps (int): Output video frame rate |
|
num_video_frames (int): Number of frames to generate |
|
frame_repeat_negative_condition (int): Number of frames to generate |
|
|
|
Returns: |
|
tuple: |
|
- data_batch (dict): Complete model input batch |
|
- state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression |
|
""" |
|
n_views = len(prompt_embedding) |
|
prompt_embedding = einops.rearrange(prompt_embedding, "n t d -> (n t) d").unsqueeze(0) |
|
raw_video_batch = prepare_data_batch( |
|
height=height, |
|
width=width, |
|
num_frames=num_video_frames, |
|
fps=fps, |
|
prompt_embedding=prompt_embedding, |
|
) |
|
if frame_repeat_negative_condition != -1: |
|
frame_repeat = torch.zeros(n_views) |
|
frame_repeat[-1] = frame_repeat_negative_condition |
|
frame_repeat[-2] = frame_repeat_negative_condition |
|
raw_video_batch["frame_repeat"] = frame_repeat.unsqueeze(0).to(dtype=torch.bfloat16).cuda() |
|
state_shape = [ |
|
model.tokenizer.channel, |
|
model.tokenizer.get_latent_num_frames(int(num_video_frames / n_views)) * n_views, |
|
height // model.tokenizer.spatial_compression_factor, |
|
width // model.tokenizer.spatial_compression_factor, |
|
] |
|
return raw_video_batch, state_shape |
|
|
|
|
|
def get_ctrl_batch_mv(H, W, data_batch, num_total_frames, control_inputs, num_views, num_video_frames): |
|
|
|
control_input_dict = {k: v for k, v in data_batch.items()} |
|
control_weights = [] |
|
hint_keys = [] |
|
for hint_key, control_info in control_inputs.items(): |
|
if hint_key not in valid_hint_keys: |
|
continue |
|
if "input_control" in control_info: |
|
cond_videos = [] |
|
for in_file in control_info["input_control"]: |
|
log.info(f"reading control input {in_file} for hint {hint_key}") |
|
cond_vid, fps = read_video_or_image_into_frames_BCTHW( |
|
in_file, |
|
normalize=False, |
|
max_frames=num_total_frames, |
|
also_return_fps=True, |
|
) |
|
cond_vid = resize_video(cond_vid, H, W, interpolation=cv2.INTER_LINEAR) |
|
cond_vid = torch.from_numpy(cond_vid[0]) |
|
|
|
cond_videos.append(cond_vid) |
|
|
|
input_frames = torch.cat(cond_videos, dim=1) |
|
control_input_dict[f"control_input_{hint_key}"] = input_frames |
|
hint_keys.append(hint_key) |
|
control_weights.append(control_info["control_weight"]) |
|
|
|
target_w, target_h = W, H |
|
hint_key = "control_input_" + "_".join(hint_keys) |
|
add_control_input = get_augmentor_for_eval(input_key="video", output_key=hint_key) |
|
|
|
if len(control_input_dict): |
|
control_input = add_control_input(control_input_dict)[hint_key] |
|
if control_input.ndim == 4: |
|
control_input = control_input[None] |
|
control_input = control_input.bfloat16() / 255 * 2 - 1 |
|
control_weights = load_spatial_temporal_weights( |
|
control_weights, B=1, T=num_total_frames, H=target_h, W=target_w, patch_h=H, patch_w=W |
|
) |
|
data_batch["control_weight"] = control_weights |
|
|
|
if len(control_inputs) > 1: |
|
data_batch["hint_key"] = "control_input_multi" |
|
data_batch["control_input_multi"] = control_input |
|
else: |
|
data_batch["hint_key"] = hint_key |
|
data_batch[hint_key] = control_input |
|
|
|
data_batch["target_h"], data_batch["target_w"] = target_h // 8, target_w // 8 |
|
data_batch["video"] = torch.zeros((1, 3, 57, H, W), dtype=torch.uint8).cuda() |
|
data_batch["image_size"] = torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda() |
|
data_batch["padding_mask"] = torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda() |
|
|
|
|
|
if num_views == 5: |
|
mapped_indices = [0, 1, 2, 4, 5] |
|
view_indices_conditioning = [] |
|
for v_index in mapped_indices: |
|
view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index) |
|
view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) |
|
data_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous() |
|
|
|
return data_batch |
|
|
|
|
|
def get_batched_ctrl_batch( |
|
model, |
|
prompt_embeddings, |
|
negative_prompt_embeddings, |
|
height, |
|
width, |
|
fps, |
|
num_video_frames, |
|
input_video_paths, |
|
control_inputs_list, |
|
blur_strength, |
|
canny_threshold, |
|
): |
|
""" |
|
Create a fully batched data_batch for video generation, including all control and video inputs. |
|
|
|
Args: |
|
model: The diffusion model instance. |
|
prompt_embeddings: [B, ...] tensor of prompt embeddings. |
|
negative_prompt_embeddings: [B, ...] tensor of negative prompt embeddings or None. |
|
height, width, fps, num_video_frames: Video parameters. |
|
input_video_paths: List of input video paths, length B. |
|
control_inputs_list: List of control input dicts, length B. |
|
blur_strength, canny_threshold: ControlNet augmentation parameters. |
|
|
|
Returns: |
|
data_batch: Dict with all fields batched along dim 0 (batch dimension). |
|
state_shape: List describing the latent state shape. |
|
""" |
|
B = len(input_video_paths) |
|
|
|
def prepare_single_data_batch(b): |
|
data_batch = { |
|
"video": torch.zeros((1, 3, num_video_frames, height, width), dtype=torch.uint8).cuda(), |
|
"t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), |
|
"image_size": torch.tensor([[height, width, height, width]], dtype=torch.bfloat16).cuda(), |
|
"fps": torch.tensor([fps], dtype=torch.bfloat16).cuda(), |
|
"num_frames": torch.tensor([num_video_frames], dtype=torch.bfloat16).cuda(), |
|
"padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), |
|
"t5_text_embeddings": prompt_embeddings[b : b + 1].to(dtype=torch.bfloat16).cuda(), |
|
} |
|
if negative_prompt_embeddings is not None: |
|
data_batch["neg_t5_text_embeddings"] = negative_prompt_embeddings[b : b + 1].to(dtype=torch.bfloat16).cuda() |
|
data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() |
|
return data_batch |
|
|
|
|
|
single_batches = [] |
|
for b in range(B): |
|
single_data_batch = prepare_single_data_batch(b) |
|
processed = get_ctrl_batch( |
|
model, |
|
single_data_batch, |
|
num_video_frames, |
|
input_video_paths[b], |
|
control_inputs_list[b], |
|
blur_strength, |
|
canny_threshold, |
|
) |
|
single_batches.append(processed) |
|
|
|
|
|
batched_data_batch = {} |
|
for k in single_batches[0]: |
|
if isinstance(single_batches[0][k], torch.Tensor): |
|
if k == "control_weight" and single_batches[0][k].ndim == 6: |
|
|
|
|
|
batched_data_batch[k] = torch.cat([d[k] for d in single_batches], dim=1) |
|
else: |
|
|
|
batched_data_batch[k] = torch.cat([d[k] for d in single_batches], dim=0) |
|
else: |
|
batched_data_batch[k] = single_batches[0][k] |
|
|
|
state_shape = [ |
|
model.tokenizer.channel, |
|
model.tokenizer.get_latent_num_frames(num_video_frames), |
|
height // model.tokenizer.spatial_compression_factor, |
|
width // model.tokenizer.spatial_compression_factor, |
|
] |
|
|
|
return batched_data_batch, state_shape |
|
|
|
|
|
def get_ctrl_batch( |
|
model, data_batch, num_video_frames, input_video_path, control_inputs, blur_strength, canny_threshold |
|
): |
|
"""Prepare complete input batch for video generation including latent dimensions. |
|
|
|
Args: |
|
model: Diffusion model instance |
|
|
|
Returns: |
|
- data_batch (dict): Complete model input batch |
|
""" |
|
state_shape = model.state_shape |
|
|
|
H, W = ( |
|
state_shape[-2] * model.tokenizer.spatial_compression_factor, |
|
state_shape[-1] * model.tokenizer.spatial_compression_factor, |
|
) |
|
|
|
|
|
control_input_dict = {k: v for k, v in data_batch.items()} |
|
num_total_frames = NUM_MAX_FRAMES |
|
if input_video_path: |
|
input_frames, fps, aspect_ratio = read_and_resize_input( |
|
input_video_path, num_total_frames=num_total_frames, interpolation=cv2.INTER_AREA |
|
) |
|
_, num_total_frames, H, W = input_frames.shape |
|
control_input_dict["video"] = input_frames.numpy() |
|
data_batch["input_video"] = input_frames.bfloat16()[None] / 255 * 2 - 1 |
|
else: |
|
data_batch["input_video"] = None |
|
target_w, target_h = W, H |
|
|
|
control_weights = [] |
|
for hint_key, control_info in control_inputs.items(): |
|
if "input_control" in control_info: |
|
in_file = control_info["input_control"] |
|
interpolation = cv2.INTER_NEAREST if hint_key == "seg" else cv2.INTER_LINEAR |
|
log.info(f"reading control input {in_file} for hint {hint_key}") |
|
control_input_dict[f"control_input_{hint_key}"], fps, aspect_ratio = read_and_resize_input( |
|
in_file, num_total_frames=num_total_frames, interpolation=interpolation |
|
) |
|
num_total_frames = min(num_total_frames, control_input_dict[f"control_input_{hint_key}"].shape[1]) |
|
target_h, target_w = H, W = control_input_dict[f"control_input_{hint_key}"].shape[2:] |
|
if hint_key == "upscale": |
|
orig_size = (W, H) |
|
target_w, target_h = get_upscale_size(orig_size, aspect_ratio, upscale_factor=3) |
|
input_resized = resize_video( |
|
input_frames[None].numpy(), |
|
target_h, |
|
target_w, |
|
interpolation=cv2.INTER_LINEAR, |
|
) |
|
control_input_dict["control_input_upscale"] = split_video_into_patches( |
|
torch.from_numpy(input_resized), H, W |
|
) |
|
data_batch["input_video"] = control_input_dict["control_input_upscale"].bfloat16() / 255 * 2 - 1 |
|
control_weights.append(control_info["control_weight"]) |
|
|
|
|
|
log.info(f"Making all control and input videos to be length of {num_total_frames} frames.") |
|
if len(control_inputs) > 1: |
|
for hint_key in control_inputs.keys(): |
|
cur_key = f"control_input_{hint_key}" |
|
if cur_key in control_input_dict: |
|
control_input_dict[cur_key] = control_input_dict[cur_key][:, :num_total_frames] |
|
if input_video_path: |
|
control_input_dict["video"] = control_input_dict["video"][:, :num_total_frames] |
|
data_batch["input_video"] = data_batch["input_video"][:, :, :num_total_frames] |
|
|
|
hint_key = "control_input_" + "_".join(control_inputs.keys()) |
|
add_control_input = get_augmentor_for_eval( |
|
input_key="video", |
|
output_key=hint_key, |
|
preset_blur_strength=blur_strength, |
|
preset_canny_threshold=canny_threshold, |
|
blur_config=BilateralOnlyBlurAugmentorConfig[blur_strength], |
|
) |
|
|
|
if len(control_input_dict): |
|
control_input = add_control_input(control_input_dict)[hint_key] |
|
if control_input.ndim == 4: |
|
control_input = control_input[None] |
|
control_input = control_input.bfloat16() / 255 * 2 - 1 |
|
control_weights = load_spatial_temporal_weights( |
|
control_weights, B=1, T=num_video_frames, H=target_h, W=target_w, patch_h=H, patch_w=W |
|
) |
|
data_batch["control_weight"] = control_weights |
|
|
|
if len(control_inputs) > 1: |
|
data_batch["hint_key"] = "control_input_multi" |
|
data_batch["control_input_multi"] = control_input |
|
else: |
|
data_batch["hint_key"] = hint_key |
|
data_batch[hint_key] = control_input |
|
|
|
data_batch["target_h"], data_batch["target_w"] = target_h // 8, target_w // 8 |
|
data_batch["video"] = torch.zeros((1, 3, 121, H, W), dtype=torch.uint8).cuda() |
|
data_batch["image_size"] = torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda() |
|
data_batch["padding_mask"] = torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda() |
|
|
|
return data_batch |
|
|
|
|
|
def generate_control_input(input_file_path, save_folder, hint_key, blur_strength, canny_threshold, num_total_frames=10): |
|
log.info( |
|
f"Generating control input for {hint_key} with blur strength {blur_strength} and canny threshold {canny_threshold}" |
|
) |
|
video_input = read_video_or_image_into_frames_BCTHW(input_file_path, normalize=False)[0, :, :num_total_frames] |
|
control_input = get_augmentor_for_eval( |
|
input_key="video", |
|
output_key=hint_key, |
|
preset_blur_strength=blur_strength, |
|
preset_canny_threshold=canny_threshold, |
|
blur_config=BilateralOnlyBlurAugmentorConfig[blur_strength], |
|
) |
|
control_input = control_input({"video": video_input})[hint_key] |
|
control_input = control_input.numpy().transpose((1, 2, 3, 0)) |
|
|
|
output_file_path = f"{save_folder}/{hint_key}_upsampler.mp4" |
|
log.info(f"Saving control input to {output_file_path}") |
|
save_video(frames=control_input, fps=24, filepath=output_file_path) |
|
return output_file_path |
|
|
|
|
|
def generate_world_from_control( |
|
model: DiffusionV2WModel, |
|
state_shape: list[int], |
|
is_negative_prompt: bool, |
|
data_batch: dict, |
|
guidance: float, |
|
num_steps: int, |
|
seed: int, |
|
condition_latent: torch.Tensor, |
|
num_input_frames: int, |
|
sigma_max: float, |
|
x_sigma_max=None, |
|
augment_sigma=None, |
|
use_batch_processing: bool = True, |
|
chunking: Optional[int] = None, |
|
) -> Tuple[np.array, list, list]: |
|
"""Generate video using a conditioning video/image input. |
|
|
|
Args: |
|
model (DiffusionV2WModel): The diffusion model instance |
|
state_shape (list[int]): Shape of the latent state [C,T,H,W] |
|
is_negative_prompt (bool): Whether negative prompt is provided |
|
data_batch (dict): Batch containing model inputs including text embeddings |
|
guidance (float): Classifier-free guidance scale for sampling |
|
num_steps (int): Number of diffusion sampling steps |
|
seed (int): Random seed for generation |
|
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file |
|
num_input_frames (int): Number of input frames |
|
chunking: Chunking size, if None, chunking is disabled |
|
|
|
Returns: |
|
np.array: Generated video frames in shape [T,H,W,C], range [0,255] |
|
""" |
|
assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" |
|
|
|
if augment_sigma is None: |
|
augment_sigma = DEFAULT_AUGMENT_SIGMA |
|
|
|
b, c, t, h, w = condition_latent.shape |
|
if condition_latent.shape[2] < state_shape[1]: |
|
|
|
condition_latent = torch.cat( |
|
[ |
|
condition_latent, |
|
condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), |
|
], |
|
dim=2, |
|
).contiguous() |
|
num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) |
|
|
|
sample = model.generate_samples_from_batch( |
|
data_batch, |
|
guidance=guidance, |
|
state_shape=[c, t, h, w], |
|
num_steps=num_steps, |
|
is_negative_prompt=is_negative_prompt, |
|
seed=seed, |
|
condition_latent=condition_latent, |
|
num_condition_t=num_of_latent_condition, |
|
condition_video_augment_sigma_in_inference=augment_sigma, |
|
x_sigma_max=x_sigma_max, |
|
sigma_max=sigma_max, |
|
target_h=data_batch["target_h"], |
|
target_w=data_batch["target_w"], |
|
patch_h=h, |
|
patch_w=w, |
|
use_batch_processing=use_batch_processing, |
|
chunking=chunking, |
|
) |
|
return sample |
|
|
|
|
|
def read_video_or_image_into_frames_BCTHW( |
|
input_path: str, |
|
input_path_format: str = "mp4", |
|
H: int = None, |
|
W: int = None, |
|
normalize: bool = True, |
|
max_frames: int = -1, |
|
also_return_fps: bool = False, |
|
) -> torch.Tensor: |
|
"""Read video or image file and convert to tensor format. |
|
|
|
Args: |
|
input_path (str): Path to input video/image file |
|
input_path_format (str): Format of input file (default: "mp4") |
|
H (int, optional): Height to resize frames to |
|
W (int, optional): Width to resize frames to |
|
normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) |
|
max_frames (int): Maximum number of frames to read (-1 for all frames) |
|
also_return_fps (bool): Whether to return fps along with frames |
|
|
|
Returns: |
|
torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested |
|
""" |
|
log.debug(f"Reading video from {input_path}") |
|
|
|
loaded_data = load_from_fileobj(input_path, format=input_path_format) |
|
frames, meta_data = loaded_data |
|
if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): |
|
frames = np.array(frames[0]) |
|
if frames.shape[-1] > 3: |
|
|
|
rgb_channels = frames[..., :3] |
|
alpha_channel = frames[..., 3] / 255.0 |
|
|
|
|
|
white_bg = np.ones_like(rgb_channels) * 255 |
|
|
|
|
|
frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( |
|
np.uint8 |
|
) |
|
frames = [frames] |
|
fps = 0 |
|
else: |
|
fps = int(meta_data.get("fps")) |
|
if max_frames != -1: |
|
frames = frames[:max_frames] |
|
input_tensor = np.stack(frames, axis=0) |
|
input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") |
|
if normalize: |
|
input_tensor = input_tensor / 128.0 - 1.0 |
|
input_tensor = torch.from_numpy(input_tensor).bfloat16() |
|
log.debug(f"Raw data shape: {input_tensor.shape}") |
|
if H is not None and W is not None: |
|
input_tensor = transforms_F.resize( |
|
input_tensor, |
|
size=(H, W), |
|
interpolation=transforms_F.InterpolationMode.BICUBIC, |
|
antialias=True, |
|
) |
|
input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) |
|
if normalize: |
|
input_tensor = input_tensor.to("cuda") |
|
log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") |
|
if also_return_fps: |
|
return input_tensor, fps |
|
return input_tensor |
|
|
|
|
|
def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: |
|
"""This function computes the number of latent frames given the number of input frames. |
|
Args: |
|
model (DiffusionV2WModel): video generation model |
|
num_input_frames (int): number of input frames |
|
downsample_factor (int): downsample factor for temporal reduce |
|
Returns: |
|
int: number of latent frames |
|
""" |
|
|
|
num_latent_frames = ( |
|
num_input_frames |
|
// model.tokenizer.video_vae.pixel_chunk_duration |
|
* model.tokenizer.video_vae.latent_chunk_duration |
|
) |
|
|
|
if num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration == 1: |
|
num_latent_frames += 1 |
|
elif num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration > 1: |
|
assert ( |
|
num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 |
|
) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" |
|
num_latent_frames += ( |
|
1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor |
|
) |
|
|
|
return num_latent_frames |
|
|
|
|
|
def create_condition_latent_from_input_frames( |
|
model: DiffusionV2WModel, |
|
input_frames: torch.Tensor, |
|
num_frames_condition: int = 25, |
|
from_back: bool = True, |
|
): |
|
"""Create condition latent for video generation from input frames. |
|
|
|
Takes the last num_frames_condition frames from input as conditioning. |
|
|
|
Args: |
|
model (DiffusionV2WModel): Video generation model |
|
input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] |
|
num_frames_condition (int): Number of frames to use for conditioning |
|
|
|
Returns: |
|
tuple: (condition_latent, encode_input_frames) where: |
|
- condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
|
- encode_input_frames (torch.Tensor): Padded input frames used for encoding |
|
""" |
|
B, C, T, H, W = input_frames.shape |
|
num_frames_encode = ( |
|
model.tokenizer.pixel_chunk_duration |
|
) |
|
log.debug( |
|
f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" |
|
) |
|
|
|
log.debug( |
|
f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" |
|
) |
|
|
|
assert ( |
|
input_frames.shape[2] >= num_frames_condition |
|
), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" |
|
assert ( |
|
num_frames_encode >= num_frames_condition |
|
), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" |
|
|
|
|
|
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": |
|
condition_frames_first = input_frames[:, :, :num_frames_condition] |
|
condition_frames_last = input_frames[:, :, -num_frames_condition:] |
|
padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) |
|
encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) |
|
elif not from_back: |
|
condition_frames = input_frames[:, :, :num_frames_condition] |
|
padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) |
|
encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) |
|
else: |
|
condition_frames = input_frames[:, :, -num_frames_condition:] |
|
padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) |
|
encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) |
|
|
|
log.info( |
|
f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" |
|
) |
|
if hasattr(model, "n_views") and encode_input_frames.shape[0] == model.n_views: |
|
encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) |
|
latent = model.encode(encode_input_frames) |
|
elif model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": |
|
latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) |
|
latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) |
|
latent = torch.cat([latent1, latent2], dim=2) |
|
elif encode_input_frames.shape[0] == 1: |
|
|
|
latent = model.tokenizer.encode(encode_input_frames) * model.sigma_data |
|
else: |
|
raise ValueError( |
|
f"First dimension of encode_input_frames {encode_input_frames.shape[0]} does not match " |
|
f"model.n_views or model.n_views is not defined and first dimension is not 1" |
|
) |
|
return latent, encode_input_frames |
|
|
|
|
|
def compute_num_frames_condition(model: DiffusionV2WModel, num_of_latent_overlap: int, downsample_factor=8) -> int: |
|
"""This function computes the number of condition pixel frames given the number of latent frames to overlap. |
|
Args: |
|
model (ExtendDiffusionModel): video generation model |
|
num_of_latent_overlap (int): number of latent frames to overlap |
|
downsample_factor (int): downsample factor for temporal reduce |
|
Returns: |
|
int: number of condition frames in output space |
|
""" |
|
if getattr(model.tokenizer.video_vae, "is_casual", True): |
|
|
|
num_frames_condition = ( |
|
num_of_latent_overlap |
|
// model.tokenizer.video_vae.latent_chunk_duration |
|
* model.tokenizer.video_vae.pixel_chunk_duration |
|
) |
|
if num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration == 1: |
|
num_frames_condition += 1 |
|
elif num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration > 1: |
|
num_frames_condition += ( |
|
1 + (num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration - 1) * downsample_factor |
|
) |
|
else: |
|
num_frames_condition = num_of_latent_overlap * downsample_factor |
|
|
|
return num_frames_condition |
|
|
|
|
|
def get_condition_latent( |
|
model: DiffusionV2WModel, |
|
input_image_or_video_path: str, |
|
num_input_frames: int = 1, |
|
state_shape: list[int] = None, |
|
frame_index: int = 0, |
|
frame_stride: int = 1, |
|
from_back: bool = True, |
|
start_frame: int = 0, |
|
) -> torch.Tensor: |
|
"""Get condition latent from input image/video file. |
|
|
|
Args: |
|
model (DiffusionV2WModel): Video generation model |
|
input_image_or_video_path (str): Path to conditioning image/video |
|
num_input_frames (int): Number of input frames for video2world prediction |
|
|
|
Returns: |
|
tuple: (condition_latent, input_frames) where: |
|
- condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
|
- input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] |
|
""" |
|
if state_shape is None: |
|
state_shape = model.state_shape |
|
assert num_input_frames > 0, "num_input_frames must be greater than 0" |
|
|
|
H, W = ( |
|
state_shape[-2] * model.tokenizer.spatial_compression_factor, |
|
state_shape[-1] * model.tokenizer.spatial_compression_factor, |
|
) |
|
|
|
input_path_format = input_image_or_video_path.split(".")[-1] |
|
input_frames = read_video_or_image_into_frames_BCTHW( |
|
input_image_or_video_path, |
|
input_path_format=input_path_format, |
|
H=H, |
|
W=W, |
|
) |
|
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": |
|
start_frame = frame_index * frame_stride |
|
end_frame = (frame_index + 1) * frame_stride |
|
curr_input_frames = torch.cat( |
|
[input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 |
|
).contiguous() |
|
num_of_latent_condition = 1 |
|
num_frames_condition = compute_num_frames_condition( |
|
model, num_of_latent_condition, downsample_factor=model.tokenizer.temporal_compression_factor |
|
) |
|
|
|
condition_latent, _ = create_condition_latent_from_input_frames(model, curr_input_frames, num_frames_condition) |
|
condition_latent = condition_latent.to(torch.bfloat16) |
|
return condition_latent |
|
input_frames = input_frames[:, :, start_frame:, :, :] |
|
condition_latent, _ = create_condition_latent_from_input_frames( |
|
model, input_frames, num_input_frames, from_back=from_back |
|
) |
|
condition_latent = condition_latent.to(torch.bfloat16) |
|
|
|
return condition_latent |
|
|
|
|
|
def check_input_frames(input_path: str, required_frames: int) -> bool: |
|
"""Check if input video/image has sufficient frames. |
|
|
|
Args: |
|
input_path: Path to input video or image |
|
required_frames: Number of required frames |
|
|
|
Returns: |
|
np.ndarray of frames if valid, None if invalid |
|
""" |
|
if input_path.endswith((".jpg", ".jpeg", ".png")): |
|
if required_frames > 1: |
|
log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") |
|
return False |
|
return True |
|
|
|
try: |
|
vid = imageio.get_reader(input_path, "ffmpeg") |
|
frame_count = vid.count_frames() |
|
|
|
if frame_count < required_frames: |
|
log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") |
|
return False |
|
else: |
|
return True |
|
except Exception as e: |
|
log.error(f"Error reading video file {input_path}: {e}") |
|
return False |
|
|
|
|
|
def load_spatial_temporal_weights(weight_paths, B, T, H, W, patch_h, patch_w): |
|
""" |
|
Load and process spatial-temporal weight maps from .pt files |
|
Args: |
|
weight_paths: List of weights that can be scalars, paths to .pt files, or empty strings |
|
B, T, H, W: Desired tensor dimensions |
|
patch_h, patch_w: Patch dimensions for splitting |
|
Returns: |
|
For all scalar weights: tensor of shape [num_controls] |
|
For any spatial maps: tensor of shape [num_controls, B, 1, T, H, W] |
|
""" |
|
|
|
weights = [] |
|
has_spatial_weights = False |
|
for path in weight_paths: |
|
if not path or (isinstance(path, str) and path.lower() == "none"): |
|
|
|
w = torch.ones((T, H, W), dtype=torch.bfloat16) |
|
else: |
|
try: |
|
|
|
scalar_value = float(path) |
|
w = torch.full((T, H, W), scalar_value, dtype=torch.bfloat16) |
|
except ValueError: |
|
|
|
has_spatial_weights = True |
|
w = torch.load(path, weights_only=False).to(dtype=torch.bfloat16) |
|
if w.ndim == 2: |
|
w = w.unsqueeze(0).repeat(T, 1, 1) |
|
elif w.ndim != 3: |
|
raise ValueError(f"Weight map must be 2D or 3D, got shape {w.shape}") |
|
|
|
if w.shape != (T, H, W): |
|
w = ( |
|
torch.nn.functional.interpolate( |
|
w.unsqueeze(0).unsqueeze(0), |
|
size=(T, H, W), |
|
mode="trilinear", |
|
align_corners=False, |
|
) |
|
.squeeze(0) |
|
.squeeze(0) |
|
) |
|
w = torch.clamp(w, min=0) |
|
w = w.unsqueeze(0).unsqueeze(1) |
|
w = w.expand(B, 1, -1, -1, -1) |
|
weights.append(w) |
|
|
|
if not has_spatial_weights: |
|
scalar_weights = [float(w) for w in weight_paths] |
|
weights_tensor = torch.tensor(scalar_weights, dtype=torch.bfloat16) |
|
weights_tensor = weights_tensor / (weights_tensor.sum().clip(1)) |
|
return weights_tensor.cuda() |
|
|
|
weights = torch.stack(weights, dim=0).cuda() |
|
weights = weights / (weights.sum(dim=0, keepdim=True).clip(1)) |
|
|
|
|
|
if patch_h != H or patch_w != W: |
|
num_controls = len(weights) |
|
weights = weights.reshape(num_controls * B, 1, T, H, W) |
|
weights = split_video_into_patches(weights, patch_h, patch_w) |
|
B_new = weights.shape[0] // num_controls |
|
weights = weights.reshape(num_controls, B_new, 1, T, H, W) |
|
|
|
return weights |
|
|
|
|
|
def resize_control_weight_map(control_weight_map, size): |
|
assert control_weight_map.shape[2] == 1 |
|
weight_map = control_weight_map.squeeze(2) |
|
T, H, W = size |
|
if weight_map.shape[2:5] != (T, H, W): |
|
assert (weight_map.shape[2] == T) or (weight_map.shape[2] == 8 * (T - 1) + 1) |
|
weight_map_i = [ |
|
torch.nn.functional.interpolate( |
|
weight_map[:, :, :1], |
|
size=(1, H, W), |
|
mode="trilinear", |
|
align_corners=False, |
|
) |
|
] |
|
weight_map_i += [ |
|
torch.nn.functional.interpolate( |
|
weight_map[:, :, 1:], |
|
size=(T - 1, H, W), |
|
mode="trilinear", |
|
align_corners=False, |
|
) |
|
] |
|
weight_map = torch.cat(weight_map_i, dim=2) |
|
return weight_map.unsqueeze(2) |
|
|
|
|
|
def split_video_into_patches(tensor, patch_h, patch_w): |
|
h, w = tensor.shape[-2:] |
|
n_img_w = (w - 1) // patch_w + 1 |
|
n_img_h = (h - 1) // patch_h + 1 |
|
overlap_size_h = overlap_size_w = 0 |
|
if n_img_w > 1: |
|
overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) |
|
assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w |
|
if n_img_h > 1: |
|
overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) |
|
assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h |
|
p_h = patch_h - overlap_size_h |
|
p_w = patch_w - overlap_size_w |
|
|
|
patches = [] |
|
for i in range(n_img_h): |
|
for j in range(n_img_w): |
|
patches += [tensor[:, :, :, p_h * i : (p_h * i + patch_h), p_w * j : (p_w * j + patch_w)]] |
|
return torch.cat(patches) |
|
|
|
|
|
def merge_patches_into_video(imgs, overlap_size_h, overlap_size_w, n_img_h, n_img_w): |
|
b, c, t, h, w = imgs.shape |
|
imgs = rearrange(imgs, "(b m n) c t h w -> m n b c t h w", m=n_img_h, n=n_img_w) |
|
H = n_img_h * h - (n_img_h - 1) * overlap_size_h |
|
W = n_img_w * w - (n_img_w - 1) * overlap_size_w |
|
img_sum = torch.zeros((b // (n_img_h * n_img_w), c, t, H, W)).to(imgs) |
|
mask_sum = torch.zeros((H, W)).to(imgs) |
|
|
|
|
|
def create_linear_gradient_tensor(H, W, overlap_size_h, overlap_size_w): |
|
y, x = torch.meshgrid( |
|
torch.minimum(torch.arange(H), H - torch.arange(H)) / (overlap_size_h + 1e-6), |
|
torch.minimum(torch.arange(W), W - torch.arange(W)) / (overlap_size_w + 1e-6), |
|
) |
|
return torch.clamp(y, 0.01, 1) * torch.clamp(x, 0.01, 1) |
|
|
|
mask_ij = create_linear_gradient_tensor(h, w, overlap_size_h, overlap_size_w).to(imgs) |
|
|
|
for i in range(n_img_h): |
|
for j in range(n_img_w): |
|
h_start = i * (h - overlap_size_h) |
|
w_start = j * (w - overlap_size_w) |
|
img_sum[:, :, :, h_start : h_start + h, w_start : w_start + w] += ( |
|
imgs[i, j] * mask_ij[None, None, None, :, :] |
|
) |
|
mask_sum[h_start : h_start + h, w_start : w_start + w] += mask_ij |
|
return img_sum / (mask_sum[None, None, None, :, :] + 1e-6) |
|
|
|
|
|
valid_hint_keys = {"vis", "seg", "edge", "depth", "keypoint", "upscale", "hdmap", "lidar"} |
|
|
|
|
|
def load_controlnet_specs(cfg) -> Dict[str, Any]: |
|
with open(cfg.controlnet_specs, "r") as f: |
|
controlnet_specs_in = json.load(f) |
|
|
|
controlnet_specs = {} |
|
args = {} |
|
|
|
for hint_key, config in controlnet_specs_in.items(): |
|
if hint_key in valid_hint_keys: |
|
controlnet_specs[hint_key] = config |
|
else: |
|
if type(config) == dict: |
|
raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") |
|
else: |
|
args[hint_key] = config |
|
continue |
|
return controlnet_specs, args |
|
|
|
|
|
def validate_controlnet_specs(cfg, controlnet_specs) -> Dict[str, Any]: |
|
""" |
|
Load and validate controlnet specifications from a JSON file. |
|
|
|
Args: |
|
json_path (str): Path to the JSON file containing controlnet specs. |
|
checkpoint_dir (str): Base directory for checkpoint files. |
|
|
|
Returns: |
|
Dict[str, Any]: Validated and processed controlnet specifications. |
|
""" |
|
checkpoint_dir = cfg.checkpoint_dir |
|
sigma_max = cfg.sigma_max |
|
input_video_path = cfg.input_video_path |
|
use_distilled = cfg.use_distilled |
|
|
|
for hint_key, config in controlnet_specs.items(): |
|
if hint_key not in valid_hint_keys: |
|
raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") |
|
|
|
if not input_video_path and sigma_max < 80: |
|
raise ValueError("Must have 'input_video' specified if sigma_max < 80") |
|
|
|
if not input_video_path and "input_control" not in config: |
|
raise ValueError( |
|
f"{hint_key} controlnet must have 'input_control' video specified if no 'input_video' specified." |
|
) |
|
|
|
if "ckpt_path" not in config: |
|
log.info(f"No checkpoint path specified for {hint_key}. Using default.") |
|
ckpt_path = os.path.join(checkpoint_dir, default_model_names[hint_key]) |
|
if use_distilled: |
|
if hint_key in default_distilled_model_names: |
|
ckpt_path = os.path.join(checkpoint_dir, default_distilled_model_names[hint_key]) |
|
else: |
|
log.info(f"No default distilled checkpoint for {hint_key}. Using full checkpoint") |
|
|
|
config["ckpt_path"] = ckpt_path |
|
log.info(f"Using default checkpoint path: {config['ckpt_path']}") |
|
|
|
|
|
|
|
if "control_weight" not in config: |
|
log.warning(f"No control weight specified for {hint_key}. Setting to 0.5.") |
|
config["control_weight"] = "0.5" |
|
else: |
|
|
|
weight = config["control_weight"] |
|
if not isinstance(weight, str) or not weight.endswith(".pt"): |
|
try: |
|
|
|
scalar_value = float(weight) |
|
if scalar_value < 0: |
|
raise ValueError(f"Control weight for {hint_key} must be non-negative.") |
|
except ValueError: |
|
raise ValueError( |
|
f"Control weight for {hint_key} must be a valid non-negative float or a path to a .pt file." |
|
) |
|
|
|
return controlnet_specs |
|
|