roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for the inference libraries."""
import os
from glob import glob
from typing import Any
import mediapy as media
import numpy as np
import torch
from cosmos_predict1.tokenizer.networks import TokenizerModels
_DTYPE, _DEVICE = torch.bfloat16, "cuda"
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
_SPATIAL_ALIGN = 16
_TEMPORAL_ALIGN = 8
def load_model(
jit_filepath: str = None,
tokenizer_config: dict[str, Any] = None,
device: str = "cuda",
) -> torch.nn.Module | torch.jit.ScriptModule:
"""Loads a torch.nn.Module from a filepath.
Args:
jit_filepath: The filepath to the JIT-compiled model.
device: The device to load the model onto, default=cuda.
Returns:
The JIT compiled model loaded to device and on eval mode.
"""
if tokenizer_config is None:
return load_jit_model(jit_filepath, device)
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
full_model.load_state_dict(ckpts.state_dict(), strict=True)
return full_model.eval().to(device)
def load_encoder_model(
jit_filepath: str = None,
tokenizer_config: dict[str, Any] = None,
device: str = "cuda",
) -> torch.nn.Module | torch.jit.ScriptModule:
"""Loads a torch.nn.Module from a filepath.
Args:
jit_filepath: The filepath to the JIT-compiled model.
device: The device to load the model onto, default=cuda.
Returns:
The JIT compiled model loaded to device and on eval mode.
"""
if tokenizer_config is None:
return load_jit_model(jit_filepath, device)
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
encoder_model = full_model.encoder_jit()
encoder_model.load_state_dict(ckpts.state_dict(), strict=True)
return encoder_model.eval().to(device)
def load_decoder_model(
jit_filepath: str = None,
tokenizer_config: dict[str, Any] = None,
device: str = "cuda",
) -> torch.nn.Module | torch.jit.ScriptModule:
"""Loads a torch.nn.Module from a filepath.
Args:
jit_filepath: The filepath to the JIT-compiled model.
device: The device to load the model onto, default=cuda.
Returns:
The JIT compiled model loaded to device and on eval mode.
"""
if tokenizer_config is None:
return load_jit_model(jit_filepath, device)
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
decoder_model = full_model.decoder_jit()
decoder_model.load_state_dict(ckpts.state_dict(), strict=True)
return decoder_model.eval().to(device)
def _load_pytorch_model(
jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda"
) -> torch.nn.Module:
"""Loads a torch.nn.Module from a filepath.
Args:
jit_filepath: The filepath to the JIT-compiled model.
device: The device to load the model onto, default=cuda.
Returns:
The JIT compiled model loaded to device and on eval mode.
"""
tokenizer_name = tokenizer_config["name"]
model = TokenizerModels[tokenizer_name].value(**tokenizer_config)
ckpts = torch.jit.load(jit_filepath, map_location=device)
return model, ckpts
def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule:
"""Loads a torch.jit.ScriptModule from a filepath.
Args:
jit_filepath: The filepath to the JIT-compiled model.
device: The device to load the model onto, default=cuda.
Returns:
The JIT compiled model loaded to device and on eval mode.
"""
model = torch.jit.load(jit_filepath, map_location=device)
return model.eval().to(device)
def save_jit_model(
model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None,
jit_filepath: str = None,
) -> None:
"""Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file.
Args:
model: JIT compiled model loaded onto `config.checkpoint.jit.device`.
jit_filepath: The filepath to the JIT-compiled model.
"""
torch.jit.save(model, jit_filepath)
def get_filepaths(input_pattern) -> list[str]:
"""Returns a list of filepaths from a pattern."""
filepaths = sorted(glob(str(input_pattern)))
return list(set(filepaths))
def get_output_filepath(filepath: str, output_dir: str = None) -> str:
"""Returns the output filepath for the given input filepath."""
output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions"
output_filepath = f"{output_dir}/{os.path.basename(filepath)}"
os.makedirs(output_dir, exist_ok=True)
return output_filepath
def read_image(filepath: str) -> np.ndarray:
"""Reads an image from a filepath.
Args:
filepath: The filepath to the image.
Returns:
The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype.
"""
image = media.read_image(filepath)
# convert the grey scale image to RGB
# since our tokenizers always assume 3-channel RGB image
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
# convert RGBA to RGB
if image.shape[-1] == 4:
image = image[..., :3]
return image
def read_video(filepath: str) -> np.ndarray:
"""Reads a video from a filepath.
Args:
filepath: The filepath to the video.
Returns:
The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype.
"""
video = media.read_video(filepath)
# convert the grey scale frame to RGB
# since our tokenizers always assume 3-channel video
if video.ndim == 3:
video = np.stack([video] * 3, axis=-1)
# convert RGBA to RGB
if video.shape[-1] == 4:
video = video[..., :3]
return video
def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray:
"""Resizes an image to have the short side of `short_size`.
Args:
image: The image to resize, layout HxWxC, of any range.
short_size: The size of the short side.
Returns:
The resized image.
"""
if short_size is None:
return image
height, width = image.shape[-3:-1]
if height <= width:
height_new, width_new = short_size, int(width * short_size / height + 0.5)
width_new = width_new if width_new % 2 == 0 else width_new + 1
else:
height_new, width_new = (
int(height * short_size / width + 0.5),
short_size,
)
height_new = height_new if height_new % 2 == 0 else height_new + 1
return media.resize_image(image, shape=(height_new, width_new))
def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray:
"""Resizes a video to have the short side of `short_size`.
Args:
video: The video to resize, layout TxHxWxC, of any range.
short_size: The size of the short side.
Returns:
The resized video.
"""
if short_size is None:
return video
height, width = video.shape[-3:-1]
if height <= width:
height_new, width_new = short_size, int(width * short_size / height + 0.5)
width_new = width_new if width_new % 2 == 0 else width_new + 1
else:
height_new, width_new = (
int(height * short_size / width + 0.5),
short_size,
)
height_new = height_new if height_new % 2 == 0 else height_new + 1
return media.resize_video(video, shape=(height_new, width_new))
def write_image(filepath: str, image: np.ndarray):
"""Writes an image to a filepath."""
return media.write_image(filepath, image)
def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None:
"""Writes a video to a filepath."""
return media.write_video(filepath, video, fps=fps)
def numpy2tensor(
input_image: np.ndarray,
dtype: torch.dtype = _DTYPE,
device: str = _DEVICE,
range_min: int = -1,
) -> torch.Tensor:
"""Converts image(dtype=np.uint8) to `dtype` in range [0..255].
Args:
input_image: A batch of images in range [0..255], BxHxWx3 layout.
Returns:
A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype.
"""
ndim = input_image.ndim
indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1]
image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F
if range_min == -1:
image = 2.0 * image - 1.0
return torch.from_numpy(image).to(dtype).to(device)
def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
"""Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
Args:
input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
Returns:
A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
"""
if range_min == -1:
input_tensor = (input_tensor.float() + 1.0) / 2.0
ndim = input_tensor.ndim
output_image = input_tensor.clamp(0, 1).cpu().numpy()
output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]:
"""Pads a batch of images to be divisible by `spatial_align`.
Args:
batch: The batch of images to pad, layout BxHxWx3, in any range.
align: The alignment to pad to.
Returns:
The padded batch and the crop region.
"""
height, width = batch.shape[1:3]
align = spatial_align
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
crop_region = [
height_to_pad >> 1,
width_to_pad >> 1,
height + (height_to_pad >> 1),
width + (width_to_pad >> 1),
]
batch = np.pad(
batch,
(
(0, 0),
(height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
(0, 0),
),
mode="constant",
)
return batch, crop_region
def pad_video_batch(
batch: np.ndarray,
temporal_align: int = _TEMPORAL_ALIGN,
spatial_align: int = _SPATIAL_ALIGN,
) -> tuple[np.ndarray, list[int]]:
"""Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`.
Zero pad spatially. Reflection pad temporally to handle causality better.
Args:
batch: The batch of videos to pad., layout BxFxHxWx3, in any range.
align: The alignment to pad to.
Returns:
The padded batch and the crop region.
"""
num_frames, height, width = batch.shape[-4:-1]
align = spatial_align
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
align = temporal_align
frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0
crop_region = [
frames_to_pad >> 1,
height_to_pad >> 1,
width_to_pad >> 1,
num_frames + (frames_to_pad >> 1),
height + (height_to_pad >> 1),
width + (width_to_pad >> 1),
]
batch = np.pad(
batch,
(
(0, 0),
(0, 0),
(height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
(0, 0),
),
mode="constant",
)
batch = np.pad(
batch,
(
(0, 0),
(frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)),
(0, 0),
(0, 0),
(0, 0),
),
mode="edge",
)
return batch, crop_region
def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
"""Unpads video with `crop_region`.
Args:
batch: A batch of numpy videos, layout BxFxHxWxC.
crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices.
Returns:
np.ndarray: Cropped numpy video, layout BxFxHxWxC.
"""
assert len(crop_region) == 6, "crop_region should be len of 6."
f1, y1, x1, f2, y2, x2 = crop_region
return batch[..., f1:f2, y1:y2, x1:x2, :]
def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
"""Unpads image with `crop_region`.
Args:
batch: A batch of numpy images, layout BxHxWxC.
crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices.
Returns:
np.ndarray: Cropped numpy image, layout BxHxWxC.
"""
assert len(crop_region) == 4, "crop_region should be len of 4."
y1, x1, y2, x2 = crop_region
return batch[..., y1:y2, x1:x2, :]