Spaces:
Build error
Build error
File size: 6,066 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from megatron.core import parallel_state
from torch import Tensor
from cosmos_predict1.diffusion.conditioner import VideoExtendCondition
from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition
class DiffusionGen3CModel(DiffusionV2WModel):
def __init__(self, config):
super().__init__(config)
self.frame_buffer_max = config.frame_buffer_max
self.chunk_size = 121
def encode_warped_frames(
self,
condition_state: torch.Tensor,
condition_state_mask: torch.Tensor,
dtype: torch.dtype,
):
assert condition_state.dim() == 6
condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1)
latent_condition = []
for i in range(condition_state.shape[2]):
current_video_latent = self.encode(
condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
).contiguous() # 1, 16, 8, 88, 160
current_mask_latent = self.encode(
condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
).contiguous()
latent_condition.append(current_video_latent)
latent_condition.append(current_mask_latent)
for _ in range(self.frame_buffer_max - condition_state.shape[2]):
latent_condition.append(torch.zeros_like(current_video_latent))
latent_condition.append(torch.zeros_like(current_mask_latent))
latent_condition = torch.cat(latent_condition, dim=1)
return latent_condition
def _get_conditions(
self,
data_batch: dict,
is_negative_prompt: bool = False,
condition_latent: Optional[torch.Tensor] = None,
num_condition_t: Optional[int] = None,
add_input_frames_guidance: bool = False,
):
"""Get the conditions for the model.
Args:
data_batch: Input data dictionary
is_negative_prompt: Whether to use negative prompting
condition_latent: Conditioning frames tensor (B,C,T,H,W)
num_condition_t: Number of frames to condition on
add_input_frames_guidance: Whether to apply guidance to input frames
Returns:
condition: Input conditions
uncondition: Conditions removed/reduced to minimum (unconditioned)
"""
if is_negative_prompt:
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
else:
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
# encode warped frames
condition_state, condition_state_mask = (
data_batch["condition_state"],
data_batch["condition_state_mask"],
)
latent_condition = self.encode_warped_frames(
condition_state, condition_state_mask, self.tensor_kwargs["dtype"]
)
condition.video_cond_bool = True
condition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, condition, num_condition_t
)
condition = self.add_condition_pose(latent_condition, condition)
uncondition.video_cond_bool = False if add_input_frames_guidance else True
uncondition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, uncondition, num_condition_t
)
uncondition = self.add_condition_pose(latent_condition, uncondition, drop_out_latent = True)
assert condition.gt_latent.allclose(uncondition.gt_latent)
# For inference, check if parallel_state is initialized
to_cp = self.net.is_context_parallel_enabled
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)
return condition, uncondition
def add_condition_pose(self, latent_condition: torch.Tensor, condition: VideoExtendCondition,
drop_out_latent: bool = False) -> VideoExtendCondition:
"""Add pose condition to the condition object. For camera control model
Args:
data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W
latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W
condition (VideoExtendCondition): condition object
num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n"
Returns:
VideoExtendCondition: updated condition object
"""
if drop_out_latent:
condition.condition_video_pose = torch.zeros_like(latent_condition.contiguous())
else:
condition.condition_video_pose = latent_condition.contiguous()
to_cp = self.net.is_context_parallel_enabled
# For inference, check if parallel_state is initialized
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp)
else:
assert not to_cp, "parallel_state is not initialized, context parallel should be turned off."
return condition
|