File size: 9,829 Bytes
226c7c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, Union

import matplotlib.pyplot as plt
import torch

from cosmos_transfer1.utils import log


class RegionalPromptProcessor:
    """
    Processes regional prompts and creates corresponding masks for attention.
    """

    def __init__(self, max_img_h, max_img_w, max_frames):
        self.max_img_h = max_img_h
        self.max_img_w = max_img_w
        self.max_frames = max_frames

    def create_region_masks_from_boxes(
        self,
        bounding_boxes: List[List[float]],
        batch_size: int,
        time_dim: int,
        height: int,
        width: int,
        device: torch.device,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Create region masks from bounding boxes [x1, y1, x2, y2] in normalized coordinates (0-1).

        Returns:
            region_masks: Tensor of shape (B, R, T, H, W) with values between 0 and 1
        """
        num_regions = len(bounding_boxes)
        region_masks = torch.zeros(
            batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16
        )

        for r, box in enumerate(bounding_boxes):
            # Convert normalized coordinates to pixel coordinates
            x1, y1, x2, y2 = box
            x1 = int(x1 * width)
            y1 = int(y1 * height)
            x2 = int(x2 * width)
            y2 = int(y2 * height)

            # Create mask for this region
            region_masks[:, r, :, y1:y2, x1:x2] = 1.0

        return region_masks

    def create_region_masks_from_segmentation(
        self,
        segmentation_maps: List[torch.Tensor],
        batch_size: int,
        time_dim: int,
        height: int,
        width: int,
        device: torch.device,
    ) -> torch.Tensor:
        """
        Create masks from binary segmentation maps.

        Args:
            segmentation_maps: List of Tensors, each of shape (T, H, W) with binary values

        Returns:
            region_masks: Tensor of shape (B, R, T, H, W) with binary values
        """
        num_regions = len(segmentation_maps)
        region_masks = torch.zeros(
            batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16
        )

        for r, seg_map in enumerate(segmentation_maps):
            # Clip to 121 frames if longer
            if seg_map.shape[0] > time_dim:
                log.info(f"clipping segmentation map to {time_dim} frames")
                seg_map = seg_map[:time_dim]
            region_masks[:, r] = seg_map.float()

        return region_masks

    def visualize_region_masks(
        self, region_masks: torch.Tensor, save_path: str, time_dim: int, height: int, width: int
    ) -> None:
        """
        Visualize region masks for debugging purposes.

        Args:
            region_masks: Tensor of shape (B, R, T*H*W)
            save_path: Path to save the visualization
            time_dim: Number of frames
            height: Height in latent space
            width: Width in latent space
        """

        B, R, T, H, W = region_masks.shape
        reshaped_masks = region_masks

        # Create figure
        fig, axes = plt.subplots(R, 1, figsize=(10, 3 * R))
        if R == 1:
            axes = [axes]
        for r in range(R):
            axes[r].imshow(reshaped_masks[r, time_dim // 2].cpu().numpy(), cmap="gray")
            axes[r].set_title(f"Region {r+1} Mask (Middle Frame)")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()


def compress_segmentation_map(segmentation_map, compression_factor):
    # Handle both [T,H,W] and [C,T,H,W] formats
    if len(segmentation_map.shape) == 4:  # [C,T,H,W] format
        C, T, H, W = segmentation_map.shape
        # Assuming first channel contains the main segmentation mask
        # Can be modified based on specific requirements
        segmentation_map = segmentation_map[0]  # Take first channel, now [T,H,W]

    # Add batch and channel dimensions [1, 1, T, H, W]
    expanded_map = segmentation_map.unsqueeze(0).unsqueeze(0)
    T, H, W = segmentation_map.shape
    new_H = H // compression_factor
    new_W = W // compression_factor

    compressed_map = torch.nn.functional.interpolate(
        expanded_map, size=(T, new_H, new_W), mode="trilinear", align_corners=False
    )

    return compressed_map.squeeze(0).squeeze(0)


def prepare_regional_prompts(
    model,
    global_prompt: Union[str, torch.Tensor],
    regional_prompts: torch.Tensor,
    region_definitions: List[Union[List[float], str]],
    batch_size: int,
    time_dim: int,
    height: int,
    width: int,
    device: torch.device,
    cache_dir: str = None,
    local_files_only: bool = False,
    visualize_masks: bool = False,
    visualization_path: str = None,
    compression_factor: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Prepare regional prompts and masks for inference.

    Args:
        model: DiT model
        global_prompt: Global text prompt or pre-computed embedding
        regional_prompts: List of regional text prompts
        region_definitions: List of bounding boxes [x1, y1, x2, y2] or segmentation map
        batch_size: Batch size
        time_dim: Number of frames
        height: Height in latent space
        width: Width in latent space
        device: Device to create tensors on
        cache_dir: Cache directory for text encoder
        local_files_only: Whether to use only local files for text encoder
        visualize_masks: Whether to visualize the region masks for debugging
        visualization_path: Path to save the visualization

    Returns:
        global_context: Global prompt embedding
        regional_contexts: List of regional prompt embeddings
        region_masks: Region masks tensor with values between 0 and 1
    """
    processor = RegionalPromptProcessor(max_img_h=height, max_img_w=width, max_frames=time_dim)

    # Validate that we have matching number of prompts and region definitions
    if len(regional_prompts) != len(region_definitions):
        raise ValueError(
            f"Number of regional prompts ({len(regional_prompts)}) must match "
            f"total number of region definitions ({len(region_definitions)})"
        )

    # Track which prompts correspond to which region types while maintaining order
    box_prompts = []
    seg_prompts = []
    prompt_idx = 0

    segmentation_maps: List[torch.Tensor] = []
    region_definitions_list: List[List[float]] = []
    # Maintain correspondence between prompts and region definitions
    for region_definition in region_definitions:
        if isinstance(region_definition, str):
            segmentation_map = torch.load(region_definition, weights_only=False)
            # Validate segmentation map dimensions
            if len(segmentation_map.shape) not in [3, 4]:
                raise ValueError(
                    f"Segmentation map should have shape [T,H,W] or [C,T,H,W], got shape {segmentation_map.shape}"
                )

            segmentation_map = compress_segmentation_map(segmentation_map, compression_factor)
            log.info(f"segmentation_map shape: {segmentation_map.shape}")
            segmentation_maps.append(segmentation_map)
            seg_prompts.append(regional_prompts[prompt_idx])
        elif isinstance(region_definition, list):
            region_definitions_list.append(region_definition)
            box_prompts.append(regional_prompts[prompt_idx])
        else:
            raise ValueError(f"Region definition format not recognized: {type(region_definition)}")
        prompt_idx += 1

    # Update regional_prompts to maintain correct ordering
    regional_prompts = box_prompts + seg_prompts
    region_masks_boxes = processor.create_region_masks_from_boxes(
        region_definitions_list, batch_size, time_dim, height, width, device
    )
    region_masks_segmentation = processor.create_region_masks_from_segmentation(
        segmentation_maps, batch_size, time_dim, height, width, device
    )
    region_masks = torch.cat([region_masks_boxes, region_masks_segmentation], dim=1)

    if visualize_masks and visualization_path:
        processor.visualize_region_masks(region_masks, visualization_path, time_dim, height, width)

    if isinstance(global_prompt, str):
        pass
    elif isinstance(global_prompt, torch.Tensor):
        global_context = global_prompt.to(dtype=torch.bfloat16)
    else:
        raise ValueError("Global prompt format not recognized.")

    regional_contexts = []
    for regional_prompt in regional_prompts:
        if isinstance(regional_prompt, str):
            raise ValueError(f"Regional prompt should be converted to embedding: {type(regional_prompt)}")
        elif isinstance(regional_prompt, torch.Tensor):
            regional_context = regional_prompt.to(dtype=torch.bfloat16)
        else:
            raise ValueError(f"Regional prompt format not recognized: {type(regional_prompt)}")

        regional_contexts.append(regional_context)

    regional_contexts = torch.stack(regional_contexts, dim=1)
    return global_context, regional_contexts, region_masks