|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple, Union |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
from cosmos_transfer1.utils import log |
|
|
|
|
|
class RegionalPromptProcessor: |
|
""" |
|
Processes regional prompts and creates corresponding masks for attention. |
|
""" |
|
|
|
def __init__(self, max_img_h, max_img_w, max_frames): |
|
self.max_img_h = max_img_h |
|
self.max_img_w = max_img_w |
|
self.max_frames = max_frames |
|
|
|
def create_region_masks_from_boxes( |
|
self, |
|
bounding_boxes: List[List[float]], |
|
batch_size: int, |
|
time_dim: int, |
|
height: int, |
|
width: int, |
|
device: torch.device, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Create region masks from bounding boxes [x1, y1, x2, y2] in normalized coordinates (0-1). |
|
|
|
Returns: |
|
region_masks: Tensor of shape (B, R, T, H, W) with values between 0 and 1 |
|
""" |
|
num_regions = len(bounding_boxes) |
|
region_masks = torch.zeros( |
|
batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16 |
|
) |
|
|
|
for r, box in enumerate(bounding_boxes): |
|
|
|
x1, y1, x2, y2 = box |
|
x1 = int(x1 * width) |
|
y1 = int(y1 * height) |
|
x2 = int(x2 * width) |
|
y2 = int(y2 * height) |
|
|
|
|
|
region_masks[:, r, :, y1:y2, x1:x2] = 1.0 |
|
|
|
return region_masks |
|
|
|
def create_region_masks_from_segmentation( |
|
self, |
|
segmentation_maps: List[torch.Tensor], |
|
batch_size: int, |
|
time_dim: int, |
|
height: int, |
|
width: int, |
|
device: torch.device, |
|
) -> torch.Tensor: |
|
""" |
|
Create masks from binary segmentation maps. |
|
|
|
Args: |
|
segmentation_maps: List of Tensors, each of shape (T, H, W) with binary values |
|
|
|
Returns: |
|
region_masks: Tensor of shape (B, R, T, H, W) with binary values |
|
""" |
|
num_regions = len(segmentation_maps) |
|
region_masks = torch.zeros( |
|
batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16 |
|
) |
|
|
|
for r, seg_map in enumerate(segmentation_maps): |
|
|
|
if seg_map.shape[0] > time_dim: |
|
log.info(f"clipping segmentation map to {time_dim} frames") |
|
seg_map = seg_map[:time_dim] |
|
region_masks[:, r] = seg_map.float() |
|
|
|
return region_masks |
|
|
|
def visualize_region_masks( |
|
self, region_masks: torch.Tensor, save_path: str, time_dim: int, height: int, width: int |
|
) -> None: |
|
""" |
|
Visualize region masks for debugging purposes. |
|
|
|
Args: |
|
region_masks: Tensor of shape (B, R, T*H*W) |
|
save_path: Path to save the visualization |
|
time_dim: Number of frames |
|
height: Height in latent space |
|
width: Width in latent space |
|
""" |
|
|
|
B, R, T, H, W = region_masks.shape |
|
reshaped_masks = region_masks |
|
|
|
|
|
fig, axes = plt.subplots(R, 1, figsize=(10, 3 * R)) |
|
if R == 1: |
|
axes = [axes] |
|
for r in range(R): |
|
axes[r].imshow(reshaped_masks[r, time_dim // 2].cpu().numpy(), cmap="gray") |
|
axes[r].set_title(f"Region {r+1} Mask (Middle Frame)") |
|
plt.tight_layout() |
|
plt.savefig(save_path) |
|
plt.close() |
|
|
|
|
|
def compress_segmentation_map(segmentation_map, compression_factor): |
|
|
|
if len(segmentation_map.shape) == 4: |
|
C, T, H, W = segmentation_map.shape |
|
|
|
|
|
segmentation_map = segmentation_map[0] |
|
|
|
|
|
expanded_map = segmentation_map.unsqueeze(0).unsqueeze(0) |
|
T, H, W = segmentation_map.shape |
|
new_H = H // compression_factor |
|
new_W = W // compression_factor |
|
|
|
compressed_map = torch.nn.functional.interpolate( |
|
expanded_map, size=(T, new_H, new_W), mode="trilinear", align_corners=False |
|
) |
|
|
|
return compressed_map.squeeze(0).squeeze(0) |
|
|
|
|
|
def prepare_regional_prompts( |
|
model, |
|
global_prompt: Union[str, torch.Tensor], |
|
regional_prompts: torch.Tensor, |
|
region_definitions: List[Union[List[float], str]], |
|
batch_size: int, |
|
time_dim: int, |
|
height: int, |
|
width: int, |
|
device: torch.device, |
|
cache_dir: str = None, |
|
local_files_only: bool = False, |
|
visualize_masks: bool = False, |
|
visualization_path: str = None, |
|
compression_factor: int = 1, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Prepare regional prompts and masks for inference. |
|
|
|
Args: |
|
model: DiT model |
|
global_prompt: Global text prompt or pre-computed embedding |
|
regional_prompts: List of regional text prompts |
|
region_definitions: List of bounding boxes [x1, y1, x2, y2] or segmentation map |
|
batch_size: Batch size |
|
time_dim: Number of frames |
|
height: Height in latent space |
|
width: Width in latent space |
|
device: Device to create tensors on |
|
cache_dir: Cache directory for text encoder |
|
local_files_only: Whether to use only local files for text encoder |
|
visualize_masks: Whether to visualize the region masks for debugging |
|
visualization_path: Path to save the visualization |
|
|
|
Returns: |
|
global_context: Global prompt embedding |
|
regional_contexts: List of regional prompt embeddings |
|
region_masks: Region masks tensor with values between 0 and 1 |
|
""" |
|
processor = RegionalPromptProcessor(max_img_h=height, max_img_w=width, max_frames=time_dim) |
|
|
|
|
|
if len(regional_prompts) != len(region_definitions): |
|
raise ValueError( |
|
f"Number of regional prompts ({len(regional_prompts)}) must match " |
|
f"total number of region definitions ({len(region_definitions)})" |
|
) |
|
|
|
|
|
box_prompts = [] |
|
seg_prompts = [] |
|
prompt_idx = 0 |
|
|
|
segmentation_maps: List[torch.Tensor] = [] |
|
region_definitions_list: List[List[float]] = [] |
|
|
|
for region_definition in region_definitions: |
|
if isinstance(region_definition, str): |
|
segmentation_map = torch.load(region_definition, weights_only=False) |
|
|
|
if len(segmentation_map.shape) not in [3, 4]: |
|
raise ValueError( |
|
f"Segmentation map should have shape [T,H,W] or [C,T,H,W], got shape {segmentation_map.shape}" |
|
) |
|
|
|
segmentation_map = compress_segmentation_map(segmentation_map, compression_factor) |
|
log.info(f"segmentation_map shape: {segmentation_map.shape}") |
|
segmentation_maps.append(segmentation_map) |
|
seg_prompts.append(regional_prompts[prompt_idx]) |
|
elif isinstance(region_definition, list): |
|
region_definitions_list.append(region_definition) |
|
box_prompts.append(regional_prompts[prompt_idx]) |
|
else: |
|
raise ValueError(f"Region definition format not recognized: {type(region_definition)}") |
|
prompt_idx += 1 |
|
|
|
|
|
regional_prompts = box_prompts + seg_prompts |
|
region_masks_boxes = processor.create_region_masks_from_boxes( |
|
region_definitions_list, batch_size, time_dim, height, width, device |
|
) |
|
region_masks_segmentation = processor.create_region_masks_from_segmentation( |
|
segmentation_maps, batch_size, time_dim, height, width, device |
|
) |
|
region_masks = torch.cat([region_masks_boxes, region_masks_segmentation], dim=1) |
|
|
|
if visualize_masks and visualization_path: |
|
processor.visualize_region_masks(region_masks, visualization_path, time_dim, height, width) |
|
|
|
if isinstance(global_prompt, str): |
|
pass |
|
elif isinstance(global_prompt, torch.Tensor): |
|
global_context = global_prompt.to(dtype=torch.bfloat16) |
|
else: |
|
raise ValueError("Global prompt format not recognized.") |
|
|
|
regional_contexts = [] |
|
for regional_prompt in regional_prompts: |
|
if isinstance(regional_prompt, str): |
|
raise ValueError(f"Regional prompt should be converted to embedding: {type(regional_prompt)}") |
|
elif isinstance(regional_prompt, torch.Tensor): |
|
regional_context = regional_prompt.to(dtype=torch.bfloat16) |
|
else: |
|
raise ValueError(f"Regional prompt format not recognized: {type(regional_prompt)}") |
|
|
|
regional_contexts.append(regional_context) |
|
|
|
regional_contexts = torch.stack(regional_contexts, dim=1) |
|
return global_context, regional_contexts, region_masks |
|
|