Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import IO, Any, Union | |
import cv2 | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from PIL import Image as PILImage | |
from torch import Tensor | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.easy_io import easy_io | |
try: | |
import ffmpegcv | |
except Exception as e: # ImportError cannot catch all problems | |
log.info(e) | |
ffmpegcv = None | |
def save_video(grid, video_name, fps=30): | |
grid = (grid * 255).astype(np.uint8) | |
grid = np.transpose(grid, (1, 2, 3, 0)) | |
with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: | |
for frame in grid: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
writer.write(frame) | |
def save_img_or_video(sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24) -> None: | |
""" | |
Save a tensor as an image or video file based on shape | |
Args: | |
sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range. | |
save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object. | |
fps (int): Frames per second for video. Default is 24. | |
""" | |
assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor" | |
assert isinstance(save_fp_wo_ext, str) or hasattr( | |
save_fp_wo_ext, "write" | |
), "save_fp_wo_ext must be a string or file-like object" | |
if torch.is_floating_point(sample_C_T_H_W_in01): | |
sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1) | |
else: | |
assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor" | |
sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255) | |
if sample_C_T_H_W_in01.shape[1] == 1: | |
save_obj = PILImage.fromarray( | |
rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8), | |
mode="RGB", | |
) | |
ext = ".jpg" if isinstance(save_fp_wo_ext, str) else "" | |
easy_io.dump( | |
save_obj, | |
f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, | |
file_format="jpg", | |
format="JPEG", | |
quality=85, | |
) | |
else: | |
save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8) | |
ext = ".mp4" if isinstance(save_fp_wo_ext, str) else "" | |
easy_io.dump( | |
save_obj, | |
f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, | |
file_format="mp4", | |
format="mp4", | |
fps=fps, | |
) | |