cosmos_transfer1_av / cosmos_transfer1 /utils /regional_prompting_utils.py
harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union
import matplotlib.pyplot as plt
import torch
from cosmos_transfer1.utils import log
class RegionalPromptProcessor:
"""
Processes regional prompts and creates corresponding masks for attention.
"""
def __init__(self, max_img_h, max_img_w, max_frames):
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
def create_region_masks_from_boxes(
self,
bounding_boxes: List[List[float]],
batch_size: int,
time_dim: int,
height: int,
width: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create region masks from bounding boxes [x1, y1, x2, y2] in normalized coordinates (0-1).
Returns:
region_masks: Tensor of shape (B, R, T, H, W) with values between 0 and 1
"""
num_regions = len(bounding_boxes)
region_masks = torch.zeros(
batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16
)
for r, box in enumerate(bounding_boxes):
# Convert normalized coordinates to pixel coordinates
x1, y1, x2, y2 = box
x1 = int(x1 * width)
y1 = int(y1 * height)
x2 = int(x2 * width)
y2 = int(y2 * height)
# Create mask for this region
region_masks[:, r, :, y1:y2, x1:x2] = 1.0
return region_masks
def create_region_masks_from_segmentation(
self,
segmentation_maps: List[torch.Tensor],
batch_size: int,
time_dim: int,
height: int,
width: int,
device: torch.device,
) -> torch.Tensor:
"""
Create masks from binary segmentation maps.
Args:
segmentation_maps: List of Tensors, each of shape (T, H, W) with binary values
Returns:
region_masks: Tensor of shape (B, R, T, H, W) with binary values
"""
num_regions = len(segmentation_maps)
region_masks = torch.zeros(
batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16
)
for r, seg_map in enumerate(segmentation_maps):
# Clip to 121 frames if longer
if seg_map.shape[0] > time_dim:
log.info(f"clipping segmentation map to {time_dim} frames")
seg_map = seg_map[:time_dim]
region_masks[:, r] = seg_map.float()
return region_masks
def visualize_region_masks(
self, region_masks: torch.Tensor, save_path: str, time_dim: int, height: int, width: int
) -> None:
"""
Visualize region masks for debugging purposes.
Args:
region_masks: Tensor of shape (B, R, T*H*W)
save_path: Path to save the visualization
time_dim: Number of frames
height: Height in latent space
width: Width in latent space
"""
B, R, T, H, W = region_masks.shape
reshaped_masks = region_masks
# Create figure
fig, axes = plt.subplots(R, 1, figsize=(10, 3 * R))
if R == 1:
axes = [axes]
for r in range(R):
axes[r].imshow(reshaped_masks[r, time_dim // 2].cpu().numpy(), cmap="gray")
axes[r].set_title(f"Region {r+1} Mask (Middle Frame)")
plt.tight_layout()
plt.savefig(save_path)
plt.close()
def compress_segmentation_map(segmentation_map, compression_factor):
# Handle both [T,H,W] and [C,T,H,W] formats
if len(segmentation_map.shape) == 4: # [C,T,H,W] format
C, T, H, W = segmentation_map.shape
# Assuming first channel contains the main segmentation mask
# Can be modified based on specific requirements
segmentation_map = segmentation_map[0] # Take first channel, now [T,H,W]
# Add batch and channel dimensions [1, 1, T, H, W]
expanded_map = segmentation_map.unsqueeze(0).unsqueeze(0)
T, H, W = segmentation_map.shape
new_H = H // compression_factor
new_W = W // compression_factor
compressed_map = torch.nn.functional.interpolate(
expanded_map, size=(T, new_H, new_W), mode="trilinear", align_corners=False
)
return compressed_map.squeeze(0).squeeze(0)
def prepare_regional_prompts(
model,
global_prompt: Union[str, torch.Tensor],
regional_prompts: torch.Tensor,
region_definitions: List[Union[List[float], str]],
batch_size: int,
time_dim: int,
height: int,
width: int,
device: torch.device,
cache_dir: str = None,
local_files_only: bool = False,
visualize_masks: bool = False,
visualization_path: str = None,
compression_factor: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare regional prompts and masks for inference.
Args:
model: DiT model
global_prompt: Global text prompt or pre-computed embedding
regional_prompts: List of regional text prompts
region_definitions: List of bounding boxes [x1, y1, x2, y2] or segmentation map
batch_size: Batch size
time_dim: Number of frames
height: Height in latent space
width: Width in latent space
device: Device to create tensors on
cache_dir: Cache directory for text encoder
local_files_only: Whether to use only local files for text encoder
visualize_masks: Whether to visualize the region masks for debugging
visualization_path: Path to save the visualization
Returns:
global_context: Global prompt embedding
regional_contexts: List of regional prompt embeddings
region_masks: Region masks tensor with values between 0 and 1
"""
processor = RegionalPromptProcessor(max_img_h=height, max_img_w=width, max_frames=time_dim)
# Validate that we have matching number of prompts and region definitions
if len(regional_prompts) != len(region_definitions):
raise ValueError(
f"Number of regional prompts ({len(regional_prompts)}) must match "
f"total number of region definitions ({len(region_definitions)})"
)
# Track which prompts correspond to which region types while maintaining order
box_prompts = []
seg_prompts = []
prompt_idx = 0
segmentation_maps: List[torch.Tensor] = []
region_definitions_list: List[List[float]] = []
# Maintain correspondence between prompts and region definitions
for region_definition in region_definitions:
if isinstance(region_definition, str):
segmentation_map = torch.load(region_definition, weights_only=False)
# Validate segmentation map dimensions
if len(segmentation_map.shape) not in [3, 4]:
raise ValueError(
f"Segmentation map should have shape [T,H,W] or [C,T,H,W], got shape {segmentation_map.shape}"
)
segmentation_map = compress_segmentation_map(segmentation_map, compression_factor)
log.info(f"segmentation_map shape: {segmentation_map.shape}")
segmentation_maps.append(segmentation_map)
seg_prompts.append(regional_prompts[prompt_idx])
elif isinstance(region_definition, list):
region_definitions_list.append(region_definition)
box_prompts.append(regional_prompts[prompt_idx])
else:
raise ValueError(f"Region definition format not recognized: {type(region_definition)}")
prompt_idx += 1
# Update regional_prompts to maintain correct ordering
regional_prompts = box_prompts + seg_prompts
region_masks_boxes = processor.create_region_masks_from_boxes(
region_definitions_list, batch_size, time_dim, height, width, device
)
region_masks_segmentation = processor.create_region_masks_from_segmentation(
segmentation_maps, batch_size, time_dim, height, width, device
)
region_masks = torch.cat([region_masks_boxes, region_masks_segmentation], dim=1)
if visualize_masks and visualization_path:
processor.visualize_region_masks(region_masks, visualization_path, time_dim, height, width)
if isinstance(global_prompt, str):
pass
elif isinstance(global_prompt, torch.Tensor):
global_context = global_prompt.to(dtype=torch.bfloat16)
else:
raise ValueError("Global prompt format not recognized.")
regional_contexts = []
for regional_prompt in regional_prompts:
if isinstance(regional_prompt, str):
raise ValueError(f"Regional prompt should be converted to embedding: {type(regional_prompt)}")
elif isinstance(regional_prompt, torch.Tensor):
regional_context = regional_prompt.to(dtype=torch.bfloat16)
else:
raise ValueError(f"Regional prompt format not recognized: {type(regional_prompt)}")
regional_contexts.append(regional_context)
regional_contexts = torch.stack(regional_contexts, dim=1)
return global_context, regional_contexts, region_masks