roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from typing import Tuple, Union
import einops
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as transforms_F
from matplotlib import pyplot as plt
from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel
from cosmos_predict1.utils import log
from cosmos_predict1.utils.easy_io import easy_io
"""This file contain functions needed for long video generation,
* function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20`
"""
@contextmanager
def switch_config_for_inference(model):
"""For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False.
This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context.
Args:
model (ExtendDiffusionModel): video generation model
"""
# Store the current condition_location
current_condition_location = model.config.conditioner.video_cond_bool.condition_location
if current_condition_location != "first_n" and current_condition_location != "first_and_last_1":
current_condition_location = "first_n"
current_apply_corruption_to_condition_region = (
model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region
)
try:
log.info(
"Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False"
)
# Change the condition_location to "first_n" for inference
model.config.conditioner.video_cond_bool.condition_location = current_condition_location
if current_apply_corruption_to_condition_region == "gaussian_blur":
model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean"
elif current_apply_corruption_to_condition_region == "noise_with_sigma":
model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed"
# Yield control back to the calling context
yield
finally:
# Restore the original condition_location after exiting the context
log.info(
f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}"
)
model.config.conditioner.video_cond_bool.condition_location = current_condition_location
model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = (
current_apply_corruption_to_condition_region
)
def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None):
"""Debug function to display a latent tensor as a grid of images.
Args:
tensor (torch.Tensor): tensor in shape BCTHW
nrow (int): number of images per row
show_norm (bool): whether to display the norm of the tensor
save_fig_path (str): path to save the visualization
"""
log.info(
f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}"
)
tensor = tensor.float().cpu().detach()
tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow) # .numpy()
# display the grid
tensor_mean = tensor.mean(-1)
tensor_norm = tensor.norm(dim=-1)
log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}")
plt.figure(figsize=(20, 20))
plt.imshow(tensor_mean)
plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}")
if save_fig_path:
os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
log.info(f"save to {os.path.abspath(save_fig_path)}")
plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0)
plt.show()
if show_norm:
plt.figure(figsize=(20, 20))
plt.imshow(tensor_norm)
plt.show()
def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None):
"""Debug function to display a tensor as a grid of images.
Args:
tensor (torch.Tensor): tensor in shape BCTHW
nrow (int): number of images per row
save_fig_path (str): path to save the visualization
"""
log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}")
assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong"
tensor = tensor.float().cpu().detach()
tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w")
# use torchvision to save the tensor as a grid of images
grid = torchvision.utils.make_grid(tensor, nrow=nrow)
if save_fig_path is not None:
os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
log.info(f"save to {os.path.abspath(save_fig_path)}")
torchvision.utils.save_image(tensor, save_fig_path)
# display the grid
plt.figure(figsize=(20, 20))
plt.imshow(grid.permute(1, 2, 0))
plt.show()
def compute_num_frames_condition(model: "ExtendDiffusionModel", num_of_latent_overlap: int, downsample_factor=8) -> int:
"""This function computes the number of condition pixel frames given the number of latent frames to overlap.
Args:
model (ExtendDiffusionModel): Video generation model
num_of_latent_overlap (int): Number of latent frames to overlap
downsample_factor (int): Downsample factor for temporal reduce
Returns:
int: Number of condition frames in output space
"""
# Access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly
vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer
# Check if the VAE is causal (default to True if attribute not found)
if getattr(vae, "is_casual", True):
# For causal model
num_frames_condition = num_of_latent_overlap // vae.latent_chunk_duration * vae.pixel_chunk_duration
if num_of_latent_overlap % vae.latent_chunk_duration == 1:
num_frames_condition += 1
elif num_of_latent_overlap % vae.latent_chunk_duration > 1:
num_frames_condition += 1 + (num_of_latent_overlap % vae.latent_chunk_duration - 1) * downsample_factor
else:
num_frames_condition = num_of_latent_overlap * downsample_factor
return num_frames_condition
def read_video_or_image_into_frames_BCTHW(
input_path: str,
input_path_format: str = None,
H: int = None,
W: int = None,
normalize: bool = True,
max_frames: int = -1,
also_return_fps: bool = False,
) -> torch.Tensor:
"""Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1].
Args:
input_path (str): path to the input video or image, end with .mp4 or .png or .jpg
H (int): height to resize the video
W (int): width to resize the video
Returns:
torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1]
"""
log.info(f"Reading video from {input_path}")
loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None)
if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"):
frames = np.array(loaded_data) # HWC, [0,255]
if frames.shape[-1] > 3: # RGBA, set the transparent to white
# Separate the RGB and Alpha channels
rgb_channels = frames[..., :3]
alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1]
# Create a white background
white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB
# Blend the RGB channels with the white background based on the alpha channel
frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype(
np.uint8
)
frames = [frames]
fps = 0
else:
frames, meta_data = loaded_data
fps = int(meta_data.get("fps"))
if max_frames != -1:
frames = frames[:max_frames]
input_tensor = np.stack(frames, axis=0)
input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w")
if normalize:
input_tensor = input_tensor / 128.0 - 1.0
input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW
log.info(f"Raw data shape: {input_tensor.shape}")
if H is not None and W is not None:
input_tensor = transforms_F.resize(
input_tensor,
size=(H, W), # type: ignore
interpolation=transforms_F.InterpolationMode.BICUBIC,
antialias=True,
)
input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1)
if normalize:
input_tensor = input_tensor.to("cuda")
log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}")
if also_return_fps:
return input_tensor, fps
return input_tensor
def create_condition_latent_from_input_frames(
model: ExtendDiffusionModel,
input_frames: torch.Tensor,
num_frames_condition: int = 25,
):
"""Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent.
Args:
model (ExtendDiffusionModel): Video generation model
input_frames (torch.Tensor): Video tensor in shape (1,C,T,H,W), range [-1, 1]
num_frames_condition (int): Number of condition frames
Returns:
torch.Tensor: Condition latent in shape B,C,T,H,W
"""
B, C, T, H, W = input_frames.shape
# Dynamically access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly
vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer
num_frames_encode = vae.pixel_chunk_duration # Access pixel_chunk_duration from the VAE
log.info(
f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}"
)
log.info(
f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}"
)
assert (
input_frames.shape[2] >= num_frames_condition
), f"input_frames not enough for condition, require at least {num_frames_condition}, got {input_frames.shape[2]}, {input_frames.shape}"
assert (
num_frames_encode >= num_frames_condition
), f"num_frames_encode should be larger than num_frames_condition, got {num_frames_encode}, {num_frames_condition}"
# Put the conditional frames at the beginning of the video, and pad the end with zeros
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1":
condition_frames_first = input_frames[:, :, :num_frames_condition]
condition_frames_last = input_frames[:, :, -num_frames_condition:]
padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W)
encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2)
else:
condition_frames = input_frames[:, :, -num_frames_condition:]
padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W)
encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2)
log.info(
f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end"
)
if hasattr(model, "n_views"):
encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views)
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1":
latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW
latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:])
latent = torch.cat([latent1, latent2], dim=2) # BCTHW
else:
latent = model.encode(encode_input_frames)
return latent, encode_input_frames
def get_condition_latent(
model: ExtendDiffusionModel,
conditioned_image_or_video_path: str,
num_of_latent_condition: int = 4,
state_shape: list[int] = None,
input_path_format: str = None,
frame_index: int = 0,
frame_stride: int = 1,
):
if state_shape is None:
state_shape = model.state_shape
if num_of_latent_condition == 0:
log.info("No condition latent needed, return empty latent")
condition_latent = (
torch.zeros(
[
1,
]
+ state_shape
)
.to(torch.bfloat16)
.cuda()
)
return condition_latent, None
H, W = (
state_shape[-2] * model.vae.spatial_compression_factor,
state_shape[-1] * model.vae.spatial_compression_factor,
)
input_frames = read_video_or_image_into_frames_BCTHW(
conditioned_image_or_video_path,
input_path_format=input_path_format,
H=H,
W=W,
)
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1":
start_frame = frame_index * frame_stride
end_frame = (frame_index + 1) * frame_stride
input_frames = torch.cat(
[input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2
).contiguous() # BCTHW
num_frames_condition = compute_num_frames_condition(
model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor
)
condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition)
condition_latent = condition_latent.to(torch.bfloat16)
return condition_latent, input_frames
def generate_video_from_batch_with_loop(
model: ExtendDiffusionModel,
state_shape: list[int],
is_negative_prompt: bool,
data_batch: dict,
condition_latent: torch.Tensor,
# hyper-parameters for inference
num_of_loops: int,
num_of_latent_overlap_list: list[int],
guidance: float,
num_steps: int,
seed: int,
add_input_frames_guidance: bool = False,
augment_sigma_list: list[float] = None,
data_batch_list: Union[None, list[dict]] = None,
visualize: bool = False,
save_fig_path: str = None,
skip_reencode: int = 0,
return_noise: bool = False,
) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]:
"""Generate video with loop, given data batch. The condition latent will be updated at each loop.
Args:
model (ExtendDiffusionModel)
state_shape (list): shape of the state tensor
is_negative_prompt (bool): whether to use negative prompt
data_batch (dict): data batch for video generation
condition_latent (torch.Tensor): condition latent in shape BCTHW
num_of_loops (int): number of loops to generate video
num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap
guidance (float): The guidance scale to use during sample generation; defaults to 5.0.
num_steps (int): number of steps for diffusion sampling
seed (int): random seed for sampling
add_input_frames_guidance (bool): whether to add image guidance, default is False
augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None
data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None
visualize (bool): whether to visualize the latent and grid, default is False
save_fig_path (str): path to save the visualization, default is None
skip_reencode (int): whether to skip re-encode the input frames, default is 0
return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False
Returns:
np.array: generated video in shape THWC, range [0, 255]
list: list of condition latent, each in shape BCTHW
list: list of sample latent, each in shape BCTHW
torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True)
"""
if data_batch_list is None:
data_batch_list = [data_batch for _ in range(num_of_loops)]
if visualize:
assert save_fig_path is not None, "save_fig_path should be set when visualize is True"
# Generate video with loop
condition_latent_list = []
decode_latent_list = [] # list collect the latent token to be decoded at the end
sample_latent = []
grid_list = []
augment_sigma_list = (
model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value
if augment_sigma_list is None
else augment_sigma_list
)
for i in range(num_of_loops):
num_of_latent_overlap_i = num_of_latent_overlap_list[i]
num_of_latent_overlap_i_plus_1 = (
num_of_latent_overlap_list[i + 1]
if i + 1 < len(num_of_latent_overlap_list)
else num_of_latent_overlap_list[-1]
)
if condition_latent.shape[2] < state_shape[1]:
# Padding condition latent to state shape
log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}")
b, c, t, h, w = condition_latent.shape
condition_latent = torch.cat(
[
condition_latent,
condition_latent.new_zeros(b, c, state_shape[1] - t, h, w),
],
dim=2,
).contiguous()
log.info(f"after padding, condition latent shape {condition_latent.shape}")
log.info(f"Generate video loop {i} / {num_of_loops}")
if visualize:
log.info(f"Visualize condition latent {i}")
visualize_latent_tensor_bcthw(
condition_latent[:, :, :4].float(),
nrow=4,
save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"),
) # BCTHW
condition_latent_list.append(condition_latent)
if i < len(augment_sigma_list):
condition_video_augment_sigma_in_inference = augment_sigma_list[i]
log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}")
else:
condition_video_augment_sigma_in_inference = augment_sigma_list[-1]
assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported"
sample = model.generate_samples_from_batch(
data_batch_list[i],
guidance=guidance,
state_shape=state_shape,
num_steps=num_steps,
is_negative_prompt=is_negative_prompt,
seed=seed + i,
condition_latent=condition_latent,
num_condition_t=num_of_latent_overlap_i,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
return_noise=return_noise,
)
if return_noise:
sample, noise = sample
if visualize:
log.info(f"Visualize sampled latent {i} 4-8 frames")
visualize_latent_tensor_bcthw(
sample[:, :, 4:8].float(),
nrow=4,
save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"),
) # BCTHW
diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i]
log.info(
f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}"
)
sample_latent.append(sample)
T = condition_latent.shape[2]
assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}"
if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i:
assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True"
if i == 0:
decode_latent_list.append(sample)
else:
decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:])
else:
# Interpolator mode. Decode the first and last as an image.
if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1":
grid_BCTHW_1 = (1.0 + model.decode(sample[:, :, :-1, ...])).clamp(0, 2) / 2 # [B, 3, T-1, H, W], [0, 1]
grid_BCTHW_2 = (1.0 + model.decode(sample[:, :, -1:, ...])).clamp(0, 2) / 2 # [B, 3, 1, H, W], [0, 1]
grid_BCTHW = torch.cat([grid_BCTHW_1, grid_BCTHW_2], dim=2) # [B, 3, T, H, W], [0, 1]
else:
grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1]
if visualize:
log.info(f"Visualize grid {i}")
visualize_tensor_bcthw(
grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png")
)
grid_np_THWC = (
(grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
) # THW3, range [0, 255]
# Post-process the output: cut the conditional frames from the output if it's not the first loop
num_cond_frames = compute_num_frames_condition(
model, num_of_latent_overlap_i_plus_1, downsample_factor=model.tokenizer.temporal_compression_factor
)
if i == 0:
new_grid_np_THWC = grid_np_THWC # First output, dont cut the conditional frames
else:
new_grid_np_THWC = grid_np_THWC[
num_cond_frames:
] # Remove the conditional frames from the output, since it's overlapped with previous loop
grid_list.append(new_grid_np_THWC)
# Prepare the next loop: re-compute the condition latent
if hasattr(model, "n_views"):
grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views)
condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1 # BCTHW, range [0, 1] to [-1, 1]
if skip_reencode:
# Use the last num_of_latent_overlap latent token as condition latent
log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token")
condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:]
else:
# Re-encode the condition frames to get the new condition latent
condition_latent, _ = create_condition_latent_from_input_frames(
model, condition_frame_input, num_frames_condition=num_cond_frames
) # BCTHW
condition_latent = condition_latent.to(torch.bfloat16)
# save videos
if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i:
# decode all video together
decode_latent_list = torch.cat(decode_latent_list, dim=2)
grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1]
video_THWC = (
(grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
) # THW3, range [0, 255]
else:
video_THWC = np.concatenate(grid_list, axis=0) # THW3, range [0, 255]
if return_noise:
return video_THWC, condition_latent_list, sample_latent, noise
return video_THWC, condition_latent_list, sample_latent