Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Tuple, Union | |
import torch | |
from cosmos_predict1.diffusion.functional.batch_ops import batch_mul | |
def create_per_sample_loss_mask( | |
loss_masking_cfg: dict, | |
data_batch: dict, | |
x_shape: Tuple[int], | |
dtype: torch.dtype, | |
device: Union[str, torch.device] = "cuda", | |
): | |
""" | |
Creates a per-sample loss mask based on the given configuration and input data batch. | |
This function generates a dictionary of loss masks for each specified key in the loss masking configuration. | |
For keys present in both the configuration and the data batch, the corresponding data batch value is used. | |
For keys present only in the configuration, a tensor of zeros with the specified shape is created. | |
Additionally, it computes loss mask weights for each key based on the configuration values and adjusts them | |
based on the presence of certain keys in the data batch, such as "skip_face" and "object_loss_map". | |
Note: | |
- The original `loss_masking_cfg` and `data_batch` are not modified by this function. | |
- For image data, it is assumed that the channel is always the first dimension. | |
- `skip_face` is for face regions that should be skipped during training, the key is provided so that we can generate | |
diverse human and avoid collapse to a single face given certain prompts. The issue happens for getty projects, | |
where face distribution in the dataset is high unbalanced that single man face can be shown in more than 100+ images. | |
Parameters: | |
loss_masking_cfg (dict): Configuration for loss masking, specifying which keys to include and their weights. | |
data_batch (dict): The batch of data containing actual data points and potential mask indicators like "skip_face". | |
x_shape (tuple): The shape of the input data, used to initialize zero masks for keys not in the data batch. | |
dtype (torch.dtype): The data type for the tensors in the loss masks. | |
device (str, optional): The device on which to create the tensors. Defaults to 'cuda'. | |
Returns: | |
dict: A dictionary containing combined loss masks adjusted according to the `loss_masking_cfg` and `data_batch`. | |
Raises: | |
AssertionError: If "skip_face" is not present in `data_batch`. | |
Note: `create_combined_loss_mask` is assumed to be a separate function that combines individual loss masks into a | |
single mask or set of masks based on the given parameters. Its behavior should be documented separately. | |
""" | |
loss_mask_data: dict = {} | |
for key in loss_masking_cfg: | |
if key not in data_batch: | |
loss_mask_data[key] = torch.zeros((x_shape[0], 1, x_shape[2], x_shape[3]), device=device) | |
else: | |
loss_mask_data[key] = data_batch[key] | |
if "skip_face" not in data_batch: | |
# When skip_face is not there in data_dict, use 0 as default. This will not skip any sample. | |
data_batch["skip_face"] = torch.zeros((x_shape[0],), dtype=dtype, device=device) | |
loss_mask_weight: dict = {} | |
for k, v in loss_masking_cfg.items(): | |
loss_mask_weight[k] = torch.tensor(v, device=device).expand(data_batch["skip_face"].size()) | |
if "human_face_mask" in loss_mask_weight: | |
loss_mask_weight["human_face_mask"] = (1 - data_batch["skip_face"]) * loss_mask_weight["human_face_mask"] | |
if "object_loss_map" in data_batch: | |
loss_mask_weight["object_loss_map"] = torch.ones(data_batch["object_loss_map"].shape[0], device=device) | |
return create_combined_loss_mask(loss_mask_data, x_shape, dtype, device, loss_mask_weight) | |
def create_combined_loss_mask(data, x_shape, dtype, device="cuda", loss_masking=None): | |
""" | |
Creates a combined loss mask from multiple input masks. | |
This function combines several loss masks into a single mask. In regions where masks overlap, | |
the highest value is assigned. Non-overlapping regions are assigned a default value of 1. | |
Regions with a mask value of zero are explicitly zeroed out, which is essential for padded loss calculations. | |
Example: | |
Given the following masks and weights: | |
mask1: [0, 1, 1, 1, 0, 0], weight: 2 | |
mask2: [1, 0, 1, 0, 0, 0], weight: 4 | |
mask3: [0, 1, 0, 0, 0, 0], weight: 0 | |
The resulting combined loss mask would be: | |
[4, 0, 4, 2, 1, 1] | |
Parameters: | |
data (dict): Contains the loss masks and their weights. | |
x_shape (tuple): The shape of the output mask. | |
dtype: The data type for the output mask. | |
device: The device on which the output mask will be allocated. | |
loss_masking: The loss masking weight configuration. | |
Returns: | |
torch.Tensor: The combined loss mask. | |
""" | |
loss_mask = torch.ones(x_shape, dtype=dtype, device=device) | |
zero_mask = torch.ones(x_shape, dtype=dtype, device=device) | |
if loss_masking: | |
for key in loss_masking: | |
# Repeat mask along channel's dimension. ndim=4 for images. | |
repeat_dims = (1, x_shape[1]) + tuple([1] * (data[key].ndim - 2)) | |
mask_key = torch.tile(data[key], dims=repeat_dims) | |
weight_key = loss_masking[key] | |
# handle zero weight case | |
is_zero_weight = (weight_key == 0).float()[:, None, None, None] | |
zero_mask = zero_mask * ( | |
(1 - is_zero_weight) * torch.ones(x_shape, dtype=dtype, device=device) | |
+ is_zero_weight * (1 - mask_key.bool().float()) | |
) | |
# calculate weights | |
no_mask_region = (mask_key.bool() == 0).float() | |
loss_mask = batch_mul(mask_key, weight_key) + batch_mul(no_mask_region, loss_mask) | |
loss_mask_final = loss_mask * zero_mask | |
return loss_mask_final | |