Spaces:
Build error
Build error
File size: 3,260 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import IO, Any, Union
import cv2
import numpy as np
import torch
from einops import rearrange
from PIL import Image as PILImage
from torch import Tensor
from cosmos_predict1.utils import log
from cosmos_predict1.utils.easy_io import easy_io
try:
import ffmpegcv
except Exception as e: # ImportError cannot catch all problems
log.info(e)
ffmpegcv = None
def save_video(grid, video_name, fps=30):
grid = (grid * 255).astype(np.uint8)
grid = np.transpose(grid, (1, 2, 3, 0))
with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer:
for frame in grid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
def save_img_or_video(sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24) -> None:
"""
Save a tensor as an image or video file based on shape
Args:
sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range.
save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object.
fps (int): Frames per second for video. Default is 24.
"""
assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor"
assert isinstance(save_fp_wo_ext, str) or hasattr(
save_fp_wo_ext, "write"
), "save_fp_wo_ext must be a string or file-like object"
if torch.is_floating_point(sample_C_T_H_W_in01):
sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1)
else:
assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor"
sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255)
if sample_C_T_H_W_in01.shape[1] == 1:
save_obj = PILImage.fromarray(
rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8),
mode="RGB",
)
ext = ".jpg" if isinstance(save_fp_wo_ext, str) else ""
easy_io.dump(
save_obj,
f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext,
file_format="jpg",
format="JPEG",
quality=85,
)
else:
save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8)
ext = ".mp4" if isinstance(save_fp_wo_ext, str) else ""
easy_io.dump(
save_obj,
f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext,
file_format="mp4",
format="mp4",
fps=fps,
)
|