|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
import torch |
|
|
|
from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel |
|
from cosmos_transfer1.auxiliary.human_keypoint.human_keypoint import HumanKeypointModel |
|
from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel |
|
from cosmos_transfer1.diffusion.inference.inference_utils import valid_hint_keys |
|
from cosmos_transfer1.utils import log |
|
from cosmos_transfer1.utils.video_utils import is_valid_video, video_to_tensor |
|
|
|
|
|
class Preprocessors: |
|
def __init__(self): |
|
self.depth_model = None |
|
self.seg_model = None |
|
self.keypoint_model = None |
|
|
|
def __call__(self, input_video, input_prompt, control_inputs, output_folder, regional_prompts=None): |
|
for hint_key in control_inputs: |
|
if hint_key in valid_hint_keys: |
|
if hint_key in ["depth", "seg", "keypoint"]: |
|
self.gen_input_control(input_video, input_prompt, hint_key, control_inputs[hint_key], output_folder) |
|
|
|
|
|
control_input = control_inputs[hint_key] |
|
|
|
|
|
|
|
|
|
if control_input.get("control_weight_prompt", None) is not None: |
|
prompt = control_input["control_weight_prompt"] |
|
log.info(f"{hint_key}: generating control weight tensor with SAM using {prompt=}") |
|
out_tensor = os.path.join(output_folder, f"{hint_key}_control_weight.pt") |
|
out_video = os.path.join(output_folder, f"{hint_key}_control_weight.mp4") |
|
weight_scaler = ( |
|
control_input["control_weight"] if isinstance(control_input["control_weight"], float) else 1.0 |
|
) |
|
self.segmentation( |
|
in_video=input_video, |
|
out_tensor=out_tensor, |
|
out_video=out_video, |
|
prompt=prompt, |
|
weight_scaler=weight_scaler, |
|
binarize_video=True, |
|
) |
|
control_input["control_weight"] = out_tensor |
|
if regional_prompts and len(regional_prompts): |
|
log.info(f"processing regional prompts: {regional_prompts}") |
|
for i, regional_prompt in enumerate(regional_prompts): |
|
log.info(f"generating regional context for {regional_prompt}") |
|
out_tensor = os.path.join(output_folder, f"regional_context_{i}.pt") |
|
if "mask_prompt" in regional_prompt: |
|
prompt = regional_prompt["mask_prompt"] |
|
out_video = os.path.join(output_folder, f"regional_context_{i}.mp4") |
|
self.segmentation( |
|
in_video=input_video, |
|
out_tensor=out_tensor, |
|
out_video=out_video, |
|
prompt=prompt, |
|
weight_scaler=1.0, |
|
legacy_mask=True, |
|
) |
|
if os.path.exists(out_tensor): |
|
regional_prompt["region_definitions_path"] = out_tensor |
|
elif "region_definitions_path" in regional_prompt and isinstance( |
|
regional_prompt["region_definitions_path"], str |
|
): |
|
if is_valid_video(regional_prompt["region_definitions_path"]): |
|
log.info(f"converting video to tensor: {regional_prompt['region_definitions_path']}") |
|
video_to_tensor(regional_prompt["region_definitions_path"], out_tensor) |
|
regional_prompt["region_definitions_path"] = out_tensor |
|
else: |
|
raise ValueError(f"Invalid video file: {regional_prompt['region_definitions_path']}") |
|
else: |
|
log.info("do nothing!") |
|
|
|
return control_inputs |
|
|
|
def gen_input_control(self, in_video, in_prompt, hint_key, control_input, output_folder): |
|
|
|
|
|
if control_input.get("input_control", None) is None: |
|
out_video = os.path.join(output_folder, f"{hint_key}_input_control.mp4") |
|
control_input["input_control"] = out_video |
|
if hint_key == "seg": |
|
prompt = control_input.get("input_control_prompt", in_prompt) |
|
prompt = " ".join(prompt.split()[:128]) |
|
log.info( |
|
f"no input_control provided for {hint_key}. generating input control video with SAM using {prompt=}" |
|
) |
|
self.segmentation( |
|
in_video=in_video, |
|
out_video=out_video, |
|
prompt=prompt, |
|
) |
|
elif hint_key == "depth": |
|
log.info( |
|
f"no input_control provided for {hint_key}. generating input control video with DepthAnythingModel" |
|
) |
|
self.depth( |
|
in_video=in_video, |
|
out_video=out_video, |
|
) |
|
else: |
|
log.info(f"no input_control provided for {hint_key}. generating input control video with Openpose") |
|
self.keypoint( |
|
in_video=in_video, |
|
out_video=out_video, |
|
) |
|
|
|
def depth(self, in_video, out_video): |
|
if self.depth_model is None: |
|
self.depth_model = DepthAnythingModel() |
|
|
|
self.depth_model(in_video, out_video) |
|
|
|
def keypoint(self, in_video, out_video): |
|
if self.keypoint_model is None: |
|
self.keypoint_model = HumanKeypointModel() |
|
|
|
self.keypoint_model(in_video, out_video) |
|
|
|
def segmentation( |
|
self, |
|
in_video, |
|
prompt, |
|
out_video=None, |
|
out_tensor=None, |
|
weight_scaler=None, |
|
binarize_video=False, |
|
legacy_mask=False, |
|
): |
|
if self.seg_model is None: |
|
self.seg_model = VideoSegmentationModel() |
|
self.seg_model( |
|
input_video=in_video, |
|
output_video=out_video, |
|
output_tensor=out_tensor, |
|
prompt=prompt, |
|
weight_scaler=weight_scaler, |
|
binarize_video=binarize_video, |
|
legacy_mask=legacy_mask, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
control_inputs = dict( |
|
{ |
|
"depth": { |
|
|
|
|
|
"control_weight_prompt": "a boy", |
|
}, |
|
"seg": { |
|
|
|
"input_control_prompt": "A boy", |
|
"control_weight_prompt": "A boy", |
|
}, |
|
}, |
|
) |
|
|
|
preprocessor = Preprocessors() |
|
input_video = "cosmos_transfer1/models/sam2/assets/input_video.mp4" |
|
|
|
preprocessor(input_video, control_inputs) |
|
print(json.dumps(control_inputs, indent=4)) |
|
|