# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import IO, Any, Union import cv2 import numpy as np import torch from einops import rearrange from PIL import Image as PILImage from torch import Tensor from cosmos_predict1.utils import log from cosmos_predict1.utils.easy_io import easy_io try: import ffmpegcv except Exception as e: # ImportError cannot catch all problems log.info(e) ffmpegcv = None def save_video(grid, video_name, fps=30): grid = (grid * 255).astype(np.uint8) grid = np.transpose(grid, (1, 2, 3, 0)) with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: for frame in grid: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) writer.write(frame) def save_img_or_video(sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24) -> None: """ Save a tensor as an image or video file based on shape Args: sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range. save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object. fps (int): Frames per second for video. Default is 24. """ assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor" assert isinstance(save_fp_wo_ext, str) or hasattr( save_fp_wo_ext, "write" ), "save_fp_wo_ext must be a string or file-like object" if torch.is_floating_point(sample_C_T_H_W_in01): sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1) else: assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor" sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255) if sample_C_T_H_W_in01.shape[1] == 1: save_obj = PILImage.fromarray( rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8), mode="RGB", ) ext = ".jpg" if isinstance(save_fp_wo_ext, str) else "" easy_io.dump( save_obj, f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, file_format="jpg", format="JPEG", quality=85, ) else: save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8) ext = ".mp4" if isinstance(save_fp_wo_ext, str) else "" easy_io.dump( save_obj, f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, file_format="mp4", format="mp4", fps=fps, )