File size: 6,066 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from megatron.core import parallel_state
from torch import Tensor

from cosmos_predict1.diffusion.conditioner import VideoExtendCondition
from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition


class DiffusionGen3CModel(DiffusionV2WModel):
    def __init__(self, config):
        super().__init__(config)
        self.frame_buffer_max = config.frame_buffer_max
        self.chunk_size = 121

    def encode_warped_frames(
        self, 
        condition_state: torch.Tensor, 
        condition_state_mask: torch.Tensor, 
        dtype: torch.dtype,
        ):

        assert condition_state.dim() == 6
        condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1)
        latent_condition = []
        for i in range(condition_state.shape[2]):
            current_video_latent = self.encode(
                condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
            ).contiguous()  # 1, 16, 8, 88, 160

            current_mask_latent = self.encode(
                condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
            ).contiguous()
            latent_condition.append(current_video_latent)
            latent_condition.append(current_mask_latent)
        for _ in range(self.frame_buffer_max - condition_state.shape[2]):
            latent_condition.append(torch.zeros_like(current_video_latent))
            latent_condition.append(torch.zeros_like(current_mask_latent))

        latent_condition = torch.cat(latent_condition, dim=1)
        return latent_condition

    def _get_conditions(
        self,
        data_batch: dict,
        is_negative_prompt: bool = False,
        condition_latent: Optional[torch.Tensor] = None,
        num_condition_t: Optional[int] = None,
        add_input_frames_guidance: bool = False,
    ):
        """Get the conditions for the model.

        Args:
            data_batch: Input data dictionary
            is_negative_prompt: Whether to use negative prompting
            condition_latent: Conditioning frames tensor (B,C,T,H,W)
            num_condition_t: Number of frames to condition on
            add_input_frames_guidance: Whether to apply guidance to input frames

        Returns:
            condition: Input conditions
            uncondition: Conditions removed/reduced to minimum (unconditioned)
        """
        if is_negative_prompt:
            condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
        else:
            condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)

        # encode warped frames
        condition_state, condition_state_mask = (
            data_batch["condition_state"],
            data_batch["condition_state_mask"],
        )
        latent_condition = self.encode_warped_frames(
            condition_state, condition_state_mask, self.tensor_kwargs["dtype"]
        )

        condition.video_cond_bool = True
        condition = self.add_condition_video_indicator_and_video_input_mask(
            condition_latent, condition, num_condition_t
        )
        condition = self.add_condition_pose(latent_condition, condition)

        uncondition.video_cond_bool = False if add_input_frames_guidance else True
        uncondition = self.add_condition_video_indicator_and_video_input_mask(
            condition_latent, uncondition, num_condition_t
        )
        uncondition = self.add_condition_pose(latent_condition, uncondition, drop_out_latent = True)
        assert condition.gt_latent.allclose(uncondition.gt_latent)

        # For inference, check if parallel_state is initialized
        to_cp = self.net.is_context_parallel_enabled
        if parallel_state.is_initialized():
            condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
            uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)

        return condition, uncondition

    def add_condition_pose(self, latent_condition: torch.Tensor, condition: VideoExtendCondition,
                           drop_out_latent: bool = False) -> VideoExtendCondition:
        """Add pose condition to the condition object. For camera control model
        Args:
            data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W
            latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W
            condition (VideoExtendCondition): condition object
            num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n"
        Returns:
            VideoExtendCondition: updated condition object
        """
        if drop_out_latent:
            condition.condition_video_pose = torch.zeros_like(latent_condition.contiguous())
        else:
            condition.condition_video_pose = latent_condition.contiguous()

        to_cp = self.net.is_context_parallel_enabled

        # For inference, check if parallel_state is initialized
        if parallel_state.is_initialized():
            condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp)
        else:
            assert not to_cp, "parallel_state is not initialized, context parallel should be turned off."

        return condition