harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig, random_blur_config
from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_AUG_KEYS, CTRL_HINT_KEYS, CTRL_HINT_KEYS_COMB
from cosmos_transfer1.diffusion.datasets.augmentors.basic_augmentors import (
ReflectionPadding,
ResizeLargestSideAspectPreserving,
)
from cosmos_transfer1.diffusion.datasets.augmentors.control_input import (
VIDEO_RES_SIZE_INFO,
AddControlInput,
AddControlInputComb,
)
from cosmos_transfer1.diffusion.datasets.augmentors.merge_datadict import DataDictMerger
from cosmos_transfer1.utils.lazy_config import LazyCall as L
AUGMENTOR_OPTIONS = {}
def augmentor_register(key):
def decorator(func):
AUGMENTOR_OPTIONS[key] = func
return func
return decorator
@augmentor_register("video_basic_augmentor")
def get_video_augmentor(
resolution: str,
blur_config=None,
):
return {
"merge_datadict": L(DataDictMerger)(
input_keys=["video"],
output_keys=[
"video",
"fps",
"num_frames",
"frame_start",
"frame_end",
"orig_num_frames",
],
),
"resize_largest_side_aspect_ratio_preserving": L(ResizeLargestSideAspectPreserving)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
"reflection_padding": L(ReflectionPadding)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
}
"""
register all the video ctrlnet augmentors for data loading
"""
for hint_key in CTRL_HINT_KEYS:
def get_video_ctrlnet_augmentor(hint_key, use_random=True):
def _get_video_ctrlnet_augmentor(
resolution: str,
blur_config: BlurAugmentorConfig = random_blur_config,
):
if hint_key == "control_input_keypoint":
add_control_input = L(AddControlInputComb)(
input_keys=["", "video"],
output_keys=[hint_key],
args={
"comb": CTRL_HINT_KEYS_COMB[hint_key],
"use_openpose_format": True,
"kpt_thr": 0.6,
"human_kpt_line_width": 4,
},
use_random=use_random,
blur_config=blur_config,
)
elif hint_key in CTRL_HINT_KEYS_COMB:
add_control_input = L(AddControlInputComb)(
input_keys=["", "video"],
output_keys=[hint_key],
args={"comb": CTRL_HINT_KEYS_COMB[hint_key]},
use_random=use_random,
blur_config=blur_config,
)
else:
add_control_input = L(AddControlInput)(
input_keys=["", "video"],
output_keys=[hint_key],
use_random=use_random,
blur_config=blur_config,
)
input_keys = ["video"]
output_keys = [
"video",
"fps",
"num_frames",
"frame_start",
"frame_end",
"orig_num_frames",
]
for key, value in CTRL_AUG_KEYS.items():
if key in hint_key:
input_keys.append(value)
output_keys.append(value)
augmentation = {
# "merge_datadict": L(DataDictMerger)(
# input_keys=input_keys,
# output_keys=output_keys,
# ),
# this addes the control input tensor to the data dict
"add_control_input": add_control_input,
# this resizes both the video and the control input to the model's required input size
"resize_largest_side_aspect_ratio_preserving": L(ResizeLargestSideAspectPreserving)(
input_keys=["video", hint_key],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
"reflection_padding": L(ReflectionPadding)(
input_keys=["video", hint_key],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
}
return augmentation
return _get_video_ctrlnet_augmentor
augmentor_register(f"video_ctrlnet_augmentor_{hint_key}")(get_video_ctrlnet_augmentor(hint_key))