# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from contextlib import contextmanager from typing import Tuple, Union import einops import numpy as np import torch import torchvision import torchvision.transforms.functional as transforms_F from matplotlib import pyplot as plt from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel from cosmos_predict1.utils import log from cosmos_predict1.utils.easy_io import easy_io """This file contain functions needed for long video generation, * function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20` """ @contextmanager def switch_config_for_inference(model): """For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False. This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context. Args: model (ExtendDiffusionModel): video generation model """ # Store the current condition_location current_condition_location = model.config.conditioner.video_cond_bool.condition_location if current_condition_location != "first_n" and current_condition_location != "first_and_last_1": current_condition_location = "first_n" current_apply_corruption_to_condition_region = ( model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region ) try: log.info( "Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False" ) # Change the condition_location to "first_n" for inference model.config.conditioner.video_cond_bool.condition_location = current_condition_location if current_apply_corruption_to_condition_region == "gaussian_blur": model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean" elif current_apply_corruption_to_condition_region == "noise_with_sigma": model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed" # Yield control back to the calling context yield finally: # Restore the original condition_location after exiting the context log.info( f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}" ) model.config.conditioner.video_cond_bool.condition_location = current_condition_location model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = ( current_apply_corruption_to_condition_region ) def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None): """Debug function to display a latent tensor as a grid of images. Args: tensor (torch.Tensor): tensor in shape BCTHW nrow (int): number of images per row show_norm (bool): whether to display the norm of the tensor save_fig_path (str): path to save the visualization """ log.info( f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}" ) tensor = tensor.float().cpu().detach() tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow) # .numpy() # display the grid tensor_mean = tensor.mean(-1) tensor_norm = tensor.norm(dim=-1) log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}") plt.figure(figsize=(20, 20)) plt.imshow(tensor_mean) plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}") if save_fig_path: os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) log.info(f"save to {os.path.abspath(save_fig_path)}") plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0) plt.show() if show_norm: plt.figure(figsize=(20, 20)) plt.imshow(tensor_norm) plt.show() def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None): """Debug function to display a tensor as a grid of images. Args: tensor (torch.Tensor): tensor in shape BCTHW nrow (int): number of images per row save_fig_path (str): path to save the visualization """ log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}") assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong" tensor = tensor.float().cpu().detach() tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w") # use torchvision to save the tensor as a grid of images grid = torchvision.utils.make_grid(tensor, nrow=nrow) if save_fig_path is not None: os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) log.info(f"save to {os.path.abspath(save_fig_path)}") torchvision.utils.save_image(tensor, save_fig_path) # display the grid plt.figure(figsize=(20, 20)) plt.imshow(grid.permute(1, 2, 0)) plt.show() def compute_num_frames_condition(model: "ExtendDiffusionModel", num_of_latent_overlap: int, downsample_factor=8) -> int: """This function computes the number of condition pixel frames given the number of latent frames to overlap. Args: model (ExtendDiffusionModel): Video generation model num_of_latent_overlap (int): Number of latent frames to overlap downsample_factor (int): Downsample factor for temporal reduce Returns: int: Number of condition frames in output space """ # Access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer # Check if the VAE is causal (default to True if attribute not found) if getattr(vae, "is_casual", True): # For causal model num_frames_condition = num_of_latent_overlap // vae.latent_chunk_duration * vae.pixel_chunk_duration if num_of_latent_overlap % vae.latent_chunk_duration == 1: num_frames_condition += 1 elif num_of_latent_overlap % vae.latent_chunk_duration > 1: num_frames_condition += 1 + (num_of_latent_overlap % vae.latent_chunk_duration - 1) * downsample_factor else: num_frames_condition = num_of_latent_overlap * downsample_factor return num_frames_condition def read_video_or_image_into_frames_BCTHW( input_path: str, input_path_format: str = None, H: int = None, W: int = None, normalize: bool = True, max_frames: int = -1, also_return_fps: bool = False, ) -> torch.Tensor: """Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1]. Args: input_path (str): path to the input video or image, end with .mp4 or .png or .jpg H (int): height to resize the video W (int): width to resize the video Returns: torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1] """ log.info(f"Reading video from {input_path}") loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None) if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): frames = np.array(loaded_data) # HWC, [0,255] if frames.shape[-1] > 3: # RGBA, set the transparent to white # Separate the RGB and Alpha channels rgb_channels = frames[..., :3] alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] # Create a white background white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB # Blend the RGB channels with the white background based on the alpha channel frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( np.uint8 ) frames = [frames] fps = 0 else: frames, meta_data = loaded_data fps = int(meta_data.get("fps")) if max_frames != -1: frames = frames[:max_frames] input_tensor = np.stack(frames, axis=0) input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") if normalize: input_tensor = input_tensor / 128.0 - 1.0 input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW log.info(f"Raw data shape: {input_tensor.shape}") if H is not None and W is not None: input_tensor = transforms_F.resize( input_tensor, size=(H, W), # type: ignore interpolation=transforms_F.InterpolationMode.BICUBIC, antialias=True, ) input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) if normalize: input_tensor = input_tensor.to("cuda") log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") if also_return_fps: return input_tensor, fps return input_tensor def create_condition_latent_from_input_frames( model: ExtendDiffusionModel, input_frames: torch.Tensor, num_frames_condition: int = 25, ): """Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent. Args: model (ExtendDiffusionModel): Video generation model input_frames (torch.Tensor): Video tensor in shape (1,C,T,H,W), range [-1, 1] num_frames_condition (int): Number of condition frames Returns: torch.Tensor: Condition latent in shape B,C,T,H,W """ B, C, T, H, W = input_frames.shape # Dynamically access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer num_frames_encode = vae.pixel_chunk_duration # Access pixel_chunk_duration from the VAE log.info( f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" ) log.info( f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" ) assert ( input_frames.shape[2] >= num_frames_condition ), f"input_frames not enough for condition, require at least {num_frames_condition}, got {input_frames.shape[2]}, {input_frames.shape}" assert ( num_frames_encode >= num_frames_condition ), f"num_frames_encode should be larger than num_frames_condition, got {num_frames_encode}, {num_frames_condition}" # Put the conditional frames at the beginning of the video, and pad the end with zeros if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": condition_frames_first = input_frames[:, :, :num_frames_condition] condition_frames_last = input_frames[:, :, -num_frames_condition:] padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) else: condition_frames = input_frames[:, :, -num_frames_condition:] padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) log.info( f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" ) if hasattr(model, "n_views"): encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) latent = torch.cat([latent1, latent2], dim=2) # BCTHW else: latent = model.encode(encode_input_frames) return latent, encode_input_frames def get_condition_latent( model: ExtendDiffusionModel, conditioned_image_or_video_path: str, num_of_latent_condition: int = 4, state_shape: list[int] = None, input_path_format: str = None, frame_index: int = 0, frame_stride: int = 1, ): if state_shape is None: state_shape = model.state_shape if num_of_latent_condition == 0: log.info("No condition latent needed, return empty latent") condition_latent = ( torch.zeros( [ 1, ] + state_shape ) .to(torch.bfloat16) .cuda() ) return condition_latent, None H, W = ( state_shape[-2] * model.vae.spatial_compression_factor, state_shape[-1] * model.vae.spatial_compression_factor, ) input_frames = read_video_or_image_into_frames_BCTHW( conditioned_image_or_video_path, input_path_format=input_path_format, H=H, W=W, ) if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": start_frame = frame_index * frame_stride end_frame = (frame_index + 1) * frame_stride input_frames = torch.cat( [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 ).contiguous() # BCTHW num_frames_condition = compute_num_frames_condition( model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor ) condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition) condition_latent = condition_latent.to(torch.bfloat16) return condition_latent, input_frames def generate_video_from_batch_with_loop( model: ExtendDiffusionModel, state_shape: list[int], is_negative_prompt: bool, data_batch: dict, condition_latent: torch.Tensor, # hyper-parameters for inference num_of_loops: int, num_of_latent_overlap_list: list[int], guidance: float, num_steps: int, seed: int, add_input_frames_guidance: bool = False, augment_sigma_list: list[float] = None, data_batch_list: Union[None, list[dict]] = None, visualize: bool = False, save_fig_path: str = None, skip_reencode: int = 0, return_noise: bool = False, ) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]: """Generate video with loop, given data batch. The condition latent will be updated at each loop. Args: model (ExtendDiffusionModel) state_shape (list): shape of the state tensor is_negative_prompt (bool): whether to use negative prompt data_batch (dict): data batch for video generation condition_latent (torch.Tensor): condition latent in shape BCTHW num_of_loops (int): number of loops to generate video num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap guidance (float): The guidance scale to use during sample generation; defaults to 5.0. num_steps (int): number of steps for diffusion sampling seed (int): random seed for sampling add_input_frames_guidance (bool): whether to add image guidance, default is False augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None visualize (bool): whether to visualize the latent and grid, default is False save_fig_path (str): path to save the visualization, default is None skip_reencode (int): whether to skip re-encode the input frames, default is 0 return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False Returns: np.array: generated video in shape THWC, range [0, 255] list: list of condition latent, each in shape BCTHW list: list of sample latent, each in shape BCTHW torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True) """ if data_batch_list is None: data_batch_list = [data_batch for _ in range(num_of_loops)] if visualize: assert save_fig_path is not None, "save_fig_path should be set when visualize is True" # Generate video with loop condition_latent_list = [] decode_latent_list = [] # list collect the latent token to be decoded at the end sample_latent = [] grid_list = [] augment_sigma_list = ( model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value if augment_sigma_list is None else augment_sigma_list ) for i in range(num_of_loops): num_of_latent_overlap_i = num_of_latent_overlap_list[i] num_of_latent_overlap_i_plus_1 = ( num_of_latent_overlap_list[i + 1] if i + 1 < len(num_of_latent_overlap_list) else num_of_latent_overlap_list[-1] ) if condition_latent.shape[2] < state_shape[1]: # Padding condition latent to state shape log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}") b, c, t, h, w = condition_latent.shape condition_latent = torch.cat( [ condition_latent, condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), ], dim=2, ).contiguous() log.info(f"after padding, condition latent shape {condition_latent.shape}") log.info(f"Generate video loop {i} / {num_of_loops}") if visualize: log.info(f"Visualize condition latent {i}") visualize_latent_tensor_bcthw( condition_latent[:, :, :4].float(), nrow=4, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"), ) # BCTHW condition_latent_list.append(condition_latent) if i < len(augment_sigma_list): condition_video_augment_sigma_in_inference = augment_sigma_list[i] log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}") else: condition_video_augment_sigma_in_inference = augment_sigma_list[-1] assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported" sample = model.generate_samples_from_batch( data_batch_list[i], guidance=guidance, state_shape=state_shape, num_steps=num_steps, is_negative_prompt=is_negative_prompt, seed=seed + i, condition_latent=condition_latent, num_condition_t=num_of_latent_overlap_i, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, return_noise=return_noise, ) if return_noise: sample, noise = sample if visualize: log.info(f"Visualize sampled latent {i} 4-8 frames") visualize_latent_tensor_bcthw( sample[:, :, 4:8].float(), nrow=4, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"), ) # BCTHW diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i] log.info( f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}" ) sample_latent.append(sample) T = condition_latent.shape[2] assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}" if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True" if i == 0: decode_latent_list.append(sample) else: decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:]) else: # Interpolator mode. Decode the first and last as an image. if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": grid_BCTHW_1 = (1.0 + model.decode(sample[:, :, :-1, ...])).clamp(0, 2) / 2 # [B, 3, T-1, H, W], [0, 1] grid_BCTHW_2 = (1.0 + model.decode(sample[:, :, -1:, ...])).clamp(0, 2) / 2 # [B, 3, 1, H, W], [0, 1] grid_BCTHW = torch.cat([grid_BCTHW_1, grid_BCTHW_2], dim=2) # [B, 3, T, H, W], [0, 1] else: grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] if visualize: log.info(f"Visualize grid {i}") visualize_tensor_bcthw( grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png") ) grid_np_THWC = ( (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) ) # THW3, range [0, 255] # Post-process the output: cut the conditional frames from the output if it's not the first loop num_cond_frames = compute_num_frames_condition( model, num_of_latent_overlap_i_plus_1, downsample_factor=model.tokenizer.temporal_compression_factor ) if i == 0: new_grid_np_THWC = grid_np_THWC # First output, dont cut the conditional frames else: new_grid_np_THWC = grid_np_THWC[ num_cond_frames: ] # Remove the conditional frames from the output, since it's overlapped with previous loop grid_list.append(new_grid_np_THWC) # Prepare the next loop: re-compute the condition latent if hasattr(model, "n_views"): grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views) condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1 # BCTHW, range [0, 1] to [-1, 1] if skip_reencode: # Use the last num_of_latent_overlap latent token as condition latent log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token") condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:] else: # Re-encode the condition frames to get the new condition latent condition_latent, _ = create_condition_latent_from_input_frames( model, condition_frame_input, num_frames_condition=num_cond_frames ) # BCTHW condition_latent = condition_latent.to(torch.bfloat16) # save videos if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: # decode all video together decode_latent_list = torch.cat(decode_latent_list, dim=2) grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] video_THWC = ( (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) ) # THW3, range [0, 255] else: video_THWC = np.concatenate(grid_list, axis=0) # THW3, range [0, 255] if return_noise: return video_THWC, condition_latent_list, sample_latent, noise return video_THWC, condition_latent_list, sample_latent