harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import torch
from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel
from cosmos_transfer1.auxiliary.human_keypoint.human_keypoint import HumanKeypointModel
from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel
from cosmos_transfer1.diffusion.inference.inference_utils import valid_hint_keys
from cosmos_transfer1.utils import log
from cosmos_transfer1.utils.video_utils import is_valid_video, video_to_tensor
class Preprocessors:
def __init__(self):
self.depth_model = None
self.seg_model = None
self.keypoint_model = None
def __call__(self, input_video, input_prompt, control_inputs, output_folder, regional_prompts=None):
for hint_key in control_inputs:
if hint_key in valid_hint_keys:
if hint_key in ["depth", "seg", "keypoint"]:
self.gen_input_control(input_video, input_prompt, hint_key, control_inputs[hint_key], output_folder)
# for all hints we need to create weight tensor if not present
control_input = control_inputs[hint_key]
# For each control input modality, compute a spatiotemporal weight tensor as long as
# the user provides "control_weight_prompt". The object specified in the
# control_weight_prompt will be treated as foreground and have control_weight for these locations.
# Everything else will be treated as background and have control weight 0 at those locations.
if control_input.get("control_weight_prompt", None) is not None:
prompt = control_input["control_weight_prompt"]
log.info(f"{hint_key}: generating control weight tensor with SAM using {prompt=}")
out_tensor = os.path.join(output_folder, f"{hint_key}_control_weight.pt")
out_video = os.path.join(output_folder, f"{hint_key}_control_weight.mp4")
weight_scaler = (
control_input["control_weight"] if isinstance(control_input["control_weight"], float) else 1.0
)
self.segmentation(
in_video=input_video,
out_tensor=out_tensor,
out_video=out_video,
prompt=prompt,
weight_scaler=weight_scaler,
binarize_video=True,
)
control_input["control_weight"] = out_tensor
if regional_prompts and len(regional_prompts):
log.info(f"processing regional prompts: {regional_prompts}")
for i, regional_prompt in enumerate(regional_prompts):
log.info(f"generating regional context for {regional_prompt}")
out_tensor = os.path.join(output_folder, f"regional_context_{i}.pt")
if "mask_prompt" in regional_prompt:
prompt = regional_prompt["mask_prompt"]
out_video = os.path.join(output_folder, f"regional_context_{i}.mp4")
self.segmentation(
in_video=input_video,
out_tensor=out_tensor,
out_video=out_video,
prompt=prompt,
weight_scaler=1.0,
legacy_mask=True,
)
if os.path.exists(out_tensor):
regional_prompt["region_definitions_path"] = out_tensor
elif "region_definitions_path" in regional_prompt and isinstance(
regional_prompt["region_definitions_path"], str
):
if is_valid_video(regional_prompt["region_definitions_path"]):
log.info(f"converting video to tensor: {regional_prompt['region_definitions_path']}")
video_to_tensor(regional_prompt["region_definitions_path"], out_tensor)
regional_prompt["region_definitions_path"] = out_tensor
else:
raise ValueError(f"Invalid video file: {regional_prompt['region_definitions_path']}")
else:
log.info("do nothing!")
return control_inputs
def gen_input_control(self, in_video, in_prompt, hint_key, control_input, output_folder):
# if input control isn't provided we need to run preprocessor to create input control tensor
# for depth no special params, for SAM we need to run with prompt
if control_input.get("input_control", None) is None:
out_video = os.path.join(output_folder, f"{hint_key}_input_control.mp4")
control_input["input_control"] = out_video
if hint_key == "seg":
prompt = control_input.get("input_control_prompt", in_prompt)
prompt = " ".join(prompt.split()[:128])
log.info(
f"no input_control provided for {hint_key}. generating input control video with SAM using {prompt=}"
)
self.segmentation(
in_video=in_video,
out_video=out_video,
prompt=prompt,
)
elif hint_key == "depth":
log.info(
f"no input_control provided for {hint_key}. generating input control video with DepthAnythingModel"
)
self.depth(
in_video=in_video,
out_video=out_video,
)
else:
log.info(f"no input_control provided for {hint_key}. generating input control video with Openpose")
self.keypoint(
in_video=in_video,
out_video=out_video,
)
def depth(self, in_video, out_video):
if self.depth_model is None:
self.depth_model = DepthAnythingModel()
self.depth_model(in_video, out_video)
def keypoint(self, in_video, out_video):
if self.keypoint_model is None:
self.keypoint_model = HumanKeypointModel()
self.keypoint_model(in_video, out_video)
def segmentation(
self,
in_video,
prompt,
out_video=None,
out_tensor=None,
weight_scaler=None,
binarize_video=False,
legacy_mask=False,
):
if self.seg_model is None:
self.seg_model = VideoSegmentationModel()
self.seg_model(
input_video=in_video,
output_video=out_video,
output_tensor=out_tensor,
prompt=prompt,
weight_scaler=weight_scaler,
binarize_video=binarize_video,
legacy_mask=legacy_mask,
)
if __name__ == "__main__":
control_inputs = dict(
{
"depth": {
# "input_control": "depth_control_input.mp4", # if empty we need to run depth
# "control_weight" : "0.1", # if empty we need to run SAM
"control_weight_prompt": "a boy", # SAM weights prompt
},
"seg": {
# "input_control": "seg_control_input.mp4", # if empty we need to run SAM
"input_control_prompt": "A boy",
"control_weight_prompt": "A boy", # if present we need to generate weight tensor
},
},
)
preprocessor = Preprocessors()
input_video = "cosmos_transfer1/models/sam2/assets/input_video.mp4"
preprocessor(input_video, control_inputs)
print(json.dumps(control_inputs, indent=4))