Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Augmentations for tokenizer training (image and video)""" | |
from cosmos_predict1.tokenizer.training.datasets.augmentors import ( | |
CenterCrop, | |
CropResizeAugmentor, | |
HorizontalFlip, | |
Normalize, | |
RandomReverse, | |
ReflectionPadding, | |
ResizeSmallestSideAspectPreserving, | |
UnsqueezeImage, | |
) | |
from cosmos_predict1.tokenizer.training.datasets.utils import ( | |
VIDEO_KEY, | |
VIDEO_RES_SIZE_INFO, | |
VIDEO_VAL_CROP_SIZE_INFO, | |
get_crop_size_info, | |
) | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.lazy_config import LazyCall, LazyDict | |
_PROB_OF_CROP_ONLY: float = 0.1 | |
def video_train_augmentations( | |
input_keys: list[str], | |
resolution: str = "1080", | |
crop_height: int = 256, | |
) -> dict[str, LazyDict]: | |
[_video_key] = input_keys | |
crop_sizes = get_crop_size_info(crop_height) | |
log.info(f"[video] training crop_height={crop_height} and crop_sizes: {crop_sizes}.") | |
augmentations = { | |
"crop_resize": LazyCall(CropResizeAugmentor)( | |
input_keys=[_video_key], | |
output_keys=[VIDEO_KEY], | |
crop_args={"size": crop_sizes}, | |
resize_args={"size": VIDEO_RES_SIZE_INFO[resolution]}, | |
args={"prob": _PROB_OF_CROP_ONLY}, | |
), | |
"random_reverse": LazyCall(RandomReverse)( | |
input_keys=[VIDEO_KEY], | |
args={"prob": 0.5}, | |
), | |
"reflection_padding": LazyCall(ReflectionPadding)( | |
input_keys=[VIDEO_KEY], | |
args={"size": crop_sizes}, | |
), | |
"horizontal_flip": LazyCall(HorizontalFlip)( | |
input_keys=[VIDEO_KEY], | |
args={"size": crop_sizes}, | |
), | |
"normalize": LazyCall(Normalize)( | |
input_keys=[VIDEO_KEY], | |
args={"mean": 0.5, "std": 0.5}, | |
), | |
"unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), | |
} | |
return augmentations | |
def video_val_augmentations( | |
input_keys: list[str], resolution: str = "1080", crop_height: int = None | |
) -> dict[str, LazyDict]: | |
[_video_key] = input_keys | |
if crop_height is None: | |
crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] | |
else: | |
crop_sizes = get_crop_size_info(crop_height) | |
log.info(f"[video] validation crop_sizes: {crop_sizes}.") | |
augmenations = { | |
"resize_smallest_side_aspect_ratio_preserving": LazyCall(ResizeSmallestSideAspectPreserving)( | |
input_keys=[VIDEO_KEY], | |
args={"size": VIDEO_RES_SIZE_INFO[resolution]}, | |
), | |
"center_crop": LazyCall(CenterCrop)( | |
input_keys=[VIDEO_KEY], | |
args={"size": crop_sizes}, | |
), | |
"reflection_padding": LazyCall(ReflectionPadding)( | |
input_keys=[VIDEO_KEY], | |
args={"size": crop_sizes}, | |
), | |
"normalize": LazyCall(Normalize)( | |
input_keys=[VIDEO_KEY], | |
args={"mean": 0.5, "std": 0.5}, | |
), | |
"unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), | |
} | |
return augmenations | |