# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for the inference libraries.""" import os from glob import glob from typing import Any import mediapy as media import numpy as np import torch from cosmos_predict1.tokenizer.networks import TokenizerModels _DTYPE, _DEVICE = torch.bfloat16, "cuda" _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) _SPATIAL_ALIGN = 16 _TEMPORAL_ALIGN = 8 def load_model( jit_filepath: str = None, tokenizer_config: dict[str, Any] = None, device: str = "cuda", ) -> torch.nn.Module | torch.jit.ScriptModule: """Loads a torch.nn.Module from a filepath. Args: jit_filepath: The filepath to the JIT-compiled model. device: The device to load the model onto, default=cuda. Returns: The JIT compiled model loaded to device and on eval mode. """ if tokenizer_config is None: return load_jit_model(jit_filepath, device) full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) full_model.load_state_dict(ckpts.state_dict(), strict=True) return full_model.eval().to(device) def load_encoder_model( jit_filepath: str = None, tokenizer_config: dict[str, Any] = None, device: str = "cuda", ) -> torch.nn.Module | torch.jit.ScriptModule: """Loads a torch.nn.Module from a filepath. Args: jit_filepath: The filepath to the JIT-compiled model. device: The device to load the model onto, default=cuda. Returns: The JIT compiled model loaded to device and on eval mode. """ if tokenizer_config is None: return load_jit_model(jit_filepath, device) full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) encoder_model = full_model.encoder_jit() encoder_model.load_state_dict(ckpts.state_dict(), strict=True) return encoder_model.eval().to(device) def load_decoder_model( jit_filepath: str = None, tokenizer_config: dict[str, Any] = None, device: str = "cuda", ) -> torch.nn.Module | torch.jit.ScriptModule: """Loads a torch.nn.Module from a filepath. Args: jit_filepath: The filepath to the JIT-compiled model. device: The device to load the model onto, default=cuda. Returns: The JIT compiled model loaded to device and on eval mode. """ if tokenizer_config is None: return load_jit_model(jit_filepath, device) full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) decoder_model = full_model.decoder_jit() decoder_model.load_state_dict(ckpts.state_dict(), strict=True) return decoder_model.eval().to(device) def _load_pytorch_model( jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" ) -> torch.nn.Module: """Loads a torch.nn.Module from a filepath. Args: jit_filepath: The filepath to the JIT-compiled model. device: The device to load the model onto, default=cuda. Returns: The JIT compiled model loaded to device and on eval mode. """ tokenizer_name = tokenizer_config["name"] model = TokenizerModels[tokenizer_name].value(**tokenizer_config) ckpts = torch.jit.load(jit_filepath, map_location=device) return model, ckpts def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: """Loads a torch.jit.ScriptModule from a filepath. Args: jit_filepath: The filepath to the JIT-compiled model. device: The device to load the model onto, default=cuda. Returns: The JIT compiled model loaded to device and on eval mode. """ model = torch.jit.load(jit_filepath, map_location=device) return model.eval().to(device) def save_jit_model( model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, jit_filepath: str = None, ) -> None: """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. Args: model: JIT compiled model loaded onto `config.checkpoint.jit.device`. jit_filepath: The filepath to the JIT-compiled model. """ torch.jit.save(model, jit_filepath) def get_filepaths(input_pattern) -> list[str]: """Returns a list of filepaths from a pattern.""" filepaths = sorted(glob(str(input_pattern))) return list(set(filepaths)) def get_output_filepath(filepath: str, output_dir: str = None) -> str: """Returns the output filepath for the given input filepath.""" output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" output_filepath = f"{output_dir}/{os.path.basename(filepath)}" os.makedirs(output_dir, exist_ok=True) return output_filepath def read_image(filepath: str) -> np.ndarray: """Reads an image from a filepath. Args: filepath: The filepath to the image. Returns: The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. """ image = media.read_image(filepath) # convert the grey scale image to RGB # since our tokenizers always assume 3-channel RGB image if image.ndim == 2: image = np.stack([image] * 3, axis=-1) # convert RGBA to RGB if image.shape[-1] == 4: image = image[..., :3] return image def read_video(filepath: str) -> np.ndarray: """Reads a video from a filepath. Args: filepath: The filepath to the video. Returns: The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. """ video = media.read_video(filepath) # convert the grey scale frame to RGB # since our tokenizers always assume 3-channel video if video.ndim == 3: video = np.stack([video] * 3, axis=-1) # convert RGBA to RGB if video.shape[-1] == 4: video = video[..., :3] return video def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: """Resizes an image to have the short side of `short_size`. Args: image: The image to resize, layout HxWxC, of any range. short_size: The size of the short side. Returns: The resized image. """ if short_size is None: return image height, width = image.shape[-3:-1] if height <= width: height_new, width_new = short_size, int(width * short_size / height + 0.5) width_new = width_new if width_new % 2 == 0 else width_new + 1 else: height_new, width_new = ( int(height * short_size / width + 0.5), short_size, ) height_new = height_new if height_new % 2 == 0 else height_new + 1 return media.resize_image(image, shape=(height_new, width_new)) def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: """Resizes a video to have the short side of `short_size`. Args: video: The video to resize, layout TxHxWxC, of any range. short_size: The size of the short side. Returns: The resized video. """ if short_size is None: return video height, width = video.shape[-3:-1] if height <= width: height_new, width_new = short_size, int(width * short_size / height + 0.5) width_new = width_new if width_new % 2 == 0 else width_new + 1 else: height_new, width_new = ( int(height * short_size / width + 0.5), short_size, ) height_new = height_new if height_new % 2 == 0 else height_new + 1 return media.resize_video(video, shape=(height_new, width_new)) def write_image(filepath: str, image: np.ndarray): """Writes an image to a filepath.""" return media.write_image(filepath, image) def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: """Writes a video to a filepath.""" return media.write_video(filepath, video, fps=fps) def numpy2tensor( input_image: np.ndarray, dtype: torch.dtype = _DTYPE, device: str = _DEVICE, range_min: int = -1, ) -> torch.Tensor: """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. Args: input_image: A batch of images in range [0..255], BxHxWx3 layout. Returns: A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. """ ndim = input_image.ndim indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F if range_min == -1: image = 2.0 * image - 1.0 return torch.from_numpy(image).to(dtype).to(device) def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. Args: input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. Returns: A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. """ if range_min == -1: input_tensor = (input_tensor.float() + 1.0) / 2.0 ndim = input_tensor.ndim output_image = input_tensor.clamp(0, 1).cpu().numpy() output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: """Pads a batch of images to be divisible by `spatial_align`. Args: batch: The batch of images to pad, layout BxHxWx3, in any range. align: The alignment to pad to. Returns: The padded batch and the crop region. """ height, width = batch.shape[1:3] align = spatial_align height_to_pad = (align - height % align) if height % align != 0 else 0 width_to_pad = (align - width % align) if width % align != 0 else 0 crop_region = [ height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1), ] batch = np.pad( batch, ( (0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0), ), mode="constant", ) return batch, crop_region def pad_video_batch( batch: np.ndarray, temporal_align: int = _TEMPORAL_ALIGN, spatial_align: int = _SPATIAL_ALIGN, ) -> tuple[np.ndarray, list[int]]: """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. Zero pad spatially. Reflection pad temporally to handle causality better. Args: batch: The batch of videos to pad., layout BxFxHxWx3, in any range. align: The alignment to pad to. Returns: The padded batch and the crop region. """ num_frames, height, width = batch.shape[-4:-1] align = spatial_align height_to_pad = (align - height % align) if height % align != 0 else 0 width_to_pad = (align - width % align) if width % align != 0 else 0 align = temporal_align frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 crop_region = [ frames_to_pad >> 1, height_to_pad >> 1, width_to_pad >> 1, num_frames + (frames_to_pad >> 1), height + (height_to_pad >> 1), width + (width_to_pad >> 1), ] batch = np.pad( batch, ( (0, 0), (0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0), ), mode="constant", ) batch = np.pad( batch, ( (0, 0), (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), (0, 0), (0, 0), (0, 0), ), mode="edge", ) return batch, crop_region def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: """Unpads video with `crop_region`. Args: batch: A batch of numpy videos, layout BxFxHxWxC. crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. Returns: np.ndarray: Cropped numpy video, layout BxFxHxWxC. """ assert len(crop_region) == 6, "crop_region should be len of 6." f1, y1, x1, f2, y2, x2 = crop_region return batch[..., f1:f2, y1:y2, x1:x2, :] def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: """Unpads image with `crop_region`. Args: batch: A batch of numpy images, layout BxHxWxC. crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. Returns: np.ndarray: Cropped numpy image, layout BxHxWxC. """ assert len(crop_region) == 4, "crop_region should be len of 4." y1, x1, y2, x2 = crop_region return batch[..., y1:y2, x1:x2, :]