roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The combined loss functions for continuous-space tokenizers training."""
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torchvision.models.optical_flow as optical_flow
from cosmos_predict1.tokenizer.modules.utils import batch2time, time2batch
from cosmos_predict1.tokenizer.training.datasets.utils import INPUT_KEY, LATENT_KEY, MASK_KEY, RECON_KEY
from cosmos_predict1.tokenizer.training.losses import ReduceMode
from cosmos_predict1.tokenizer.training.losses.lpips import LPIPS
from cosmos_predict1.utils.lazy_config import instantiate
_VALID_LOSS_NAMES = ["color", "perceptual", "flow", "kl", "video_consistency"]
VIDEO_CONSISTENCY_LOSS = "video_consistency"
RECON_CONSISTENCY_KEY = f"{RECON_KEY}_consistency"
class TokenizerLoss(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
_reduce = ReduceMode(config.reduce.upper()) if hasattr(config, "reduce") else None
self.reduce = _reduce.function
self.loss_modules = nn.ModuleDict()
for key in _VALID_LOSS_NAMES:
self.loss_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NullLoss()
def forward(self, inputs, output_batch, iteration) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
loss = dict()
total_loss = 0.0
inputs[MASK_KEY] = torch.ones_like(inputs[INPUT_KEY])
# Calculates reconstruction losses (`total_loss`).
for key, module in self.loss_modules.items():
curr_loss = module(inputs, output_batch, iteration)
loss.update({k: torch.mean(v) for k, v in curr_loss.items()})
total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()])
loss.update({k: torch.mean(v) for k, v in curr_loss.items()})
# Computes the overall loss as sum of the reconstruction losses and the generator loss.
total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()])
return dict(loss=loss), total_loss
class WeightScheduler(torch.nn.Module):
def __init__(self, boundaries, values):
super().__init__()
self.boundaries = list(boundaries)
self.values = list(values)
def forward(self, iteration):
for boundary, value in zip(self.boundaries, self.values):
if iteration < boundary:
return value
return self.values[-1]
class NullLoss(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, inputs, output_batch, iteration) -> dict[dict, torch.Tensor]:
return dict()
class ColorLoss(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values)
def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]:
reconstructions = output_batch[RECON_KEY]
weights = inputs[MASK_KEY]
recon = weights * torch.abs(inputs[INPUT_KEY].contiguous() - reconstructions.contiguous())
color_weighted = self.schedule(iteration) * recon
if torch.isnan(color_weighted).any():
raise ValueError("[COLOR] NaN detected in loss")
return dict(color=color_weighted)
class KLLoss(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values)
def kl(self, mean, logvar):
_dims = [idx for idx in range(1, mean.ndim)]
var = torch.exp(logvar)
return 0.5 * (torch.pow(mean, 2) + var - 1.0 - logvar)
def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]:
if "posteriors" not in output_batch: # No KL loss for discrete tokens.
return dict()
mean, logvar = output_batch["posteriors"]
if mean.ndim == 1: # No KL if the mean is a scalar.
return dict()
kl = self.kl(mean, logvar)
kl_weighted = self.schedule(iteration) * kl
if torch.isnan(kl_weighted).any():
raise ValueError("[KL] NaN detected in loss")
return dict(kl=kl_weighted)
class PerceptualLoss(LPIPS):
"""Relevant changes that're internal to us:
- Remove linear projection layers, simply use the raw pre-normalized features.
- Use pyramid-layer weights: [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5].
- Accepts pixel-wise masks and modulates the features before norm calculation.
- Implements gram-matrix and correlation losses.
"""
def __init__(self, config):
super(PerceptualLoss, self).__init__(config.checkpoint_activations)
self.net = self.net.eval()
self.gram_enabled = config.gram_enabled
self.corr_enabled = config.corr_enabled
self.layer_weights = list(config.layer_weights)
self.lpips_schedule = WeightScheduler(config.lpips_boundaries, config.lpips_values)
self.gram_schedule = WeightScheduler(config.gram_boundaries, config.gram_values)
self.corr_schedule = WeightScheduler(config.corr_boundaries, config.corr_values)
self.checkpoint_activations = config.checkpoint_activations
def _temporal_gram_matrix(self, x, batch_size=None):
x = batch2time(x, batch_size)
c, t, h, w = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1]
reshaped_x = torch.reshape(x, [-1, c, t * h * w])
return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(t * h * w)
def _gram_matrix(self, x, batch_size=None):
if batch_size is not None and x.shape[0] != batch_size:
return self._temporal_gram_matrix(x, batch_size)
c, h, w = x.shape[-3], x.shape[-2], x.shape[-1]
reshaped_x = torch.reshape(x, [-1, c, h * w])
return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(h * w)
def forward(self, inputs, output_batch, iteration):
output_dict = dict()
reconstructions = output_batch[RECON_KEY]
weights = inputs[MASK_KEY]
input_images = inputs[INPUT_KEY]
if input_images.ndim == 5:
input_images, batch_size = time2batch(input_images)
reconstructions, _ = time2batch(reconstructions)
weights, _ = time2batch(weights)
else:
batch_size = input_images.shape[0]
in0_input, in1_input = (self.scaling_layer(input_images), self.scaling_layer(reconstructions))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
_layer_weights = self.layer_weights
weights_map, res, diffs = {}, {}, {}
for kk in range(len(self.chns)):
weights_map[kk] = torch.nn.functional.interpolate(weights[:, :1, ...], outs0[kk].shape[-2:])
diffs[kk] = weights_map[kk] * torch.abs(outs0[kk] - outs1[kk])
res[kk] = _layer_weights[kk] * diffs[kk].mean([1, 2, 3], keepdim=True)
val = res[0]
for ll in range(1, len(self.chns)):
val += res[ll]
# Scale by number of pixels to match pixel-wise losses.
val = val.expand(-1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1])
if batch_size != input_images.shape[0]:
val = batch2time(val, batch_size)
if torch.isnan(val).any():
raise ValueError("[LPIPS] NaN detected in loss")
output_dict["lpips"] = self.lpips_schedule(iteration) * val
if self.gram_enabled and self.gram_schedule(iteration) > 0.0:
num_chans = len(self.chns)
grams0 = [self._gram_matrix(weights_map[kk] * outs0[kk], batch_size) for kk in range(num_chans)]
grams1 = [self._gram_matrix(weights_map[kk] * outs1[kk], batch_size) for kk in range(num_chans)]
gram_diffs = [(grams0[kk] - grams1[kk]) ** 2 for kk in range(num_chans)]
grams_res = [_layer_weights[kk] * gram_diffs[kk].mean([1, 2], keepdim=True) for kk in range(num_chans)]
gram_val = grams_res[0]
for ll in range(1, len(self.chns)):
gram_val += grams_res[ll]
# Scale by number of total pixels to match pixel-wise losses.
gram_val = gram_val.unsqueeze(1).expand(
-1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1]
)
if batch_size != input_images.shape[0]:
gram_val = batch2time(gram_val, batch_size)
if torch.isnan(gram_val).any():
raise ValueError("[GRAM] NaN detected in loss")
output_dict["gram"] = self.gram_schedule(iteration) * gram_val
return output_dict
def torch_compile(self):
"""
This method invokes torch.compile() on this loss
"""
# cuda-graphs crash after 1k iterations
self.net = torch.compile(self.net, dynamic=False)
class FlowLoss(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.schedule = WeightScheduler(config.boundaries, config.values)
self.scale = config.scale
self.dtype = getattr(torch, config.dtype)
self.checkpoint_activations = config.checkpoint_activations
self.enabled = config.enabled
current_device = torch.device(torch.cuda.current_device())
# In order to be able to run model in bf16 we need to change make_coords_grid()
# to allow it to return arbitrary type provided by us in argument
# the line from orginal implementation that caused results to be only fp32 is commented
# Additionally I've changed that function to run on GPU instead of CPU, which results in
# less graph breaks when torch.compile() is used
# This function is copied from
# https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/_utils.py#L22
# commit: b06ea39d5f0adbe949d08257837bda912339e415
def make_coords_grid(
batch_size: int, h: int, w: int, device: torch.device = current_device, dtype: torch.dtype = self.dtype
):
# Original: def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
device = torch.device(device)
coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).to(dtype)
# Original: coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)
# We also need to specify output dtype of torch.linspace() in index_pyramid()
# method of CorrBlock, otherwise it uses default fp32 dtype as output.
# Additionally I've changed that function to run on GPU instead of CPU, which results in
# less graph breaks when torch.compile() is used
# This function is copied from
# https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py#L394
# commit: b06ea39d5f0adbe949d08257837bda912339e415
def index_pyramid(
self, centroids_coords, dtype: torch.dtype = self.dtype, device: torch.device = current_device
):
# Original: def index_pyramid(self, centroids_coords):
"""Return correlation features by indexing from the pyramid."""
neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device)
dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device)
# Original: di = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
# Original: dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device)
delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2)
batch_size, _, h, w = centroids_coords.shape # _ = 2
centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2)
indexed_pyramid = []
for corr_volume in self.corr_pyramid:
sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2)
indexed_corr_volume = optical_flow.raft.grid_sample(
corr_volume, sampling_coords, align_corners=True, mode="bilinear"
).view(batch_size, h, w, -1)
indexed_pyramid.append(indexed_corr_volume)
centroids_coords = centroids_coords / 2
corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
expected_output_shape = (batch_size, self.out_channels, h, w)
if corr_features.shape != expected_output_shape:
raise ValueError(
f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}"
)
return corr_features
optical_flow.raft.make_coords_grid = make_coords_grid
optical_flow.raft.CorrBlock.index_pyramid = index_pyramid
flow_model = optical_flow.raft_large(pretrained=True, progress=False)
flow_model.requires_grad_(False)
flow_model.eval()
flow_model = flow_model.to(self.dtype)
self.flow_model = flow_model
def _run_model(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
"""Runs flow_model in the forward mode on explicit dtype=float32.
Args:
input1: First video frames batch, layout (T, C, H, W), bfloat16.
input2: Next video frames batch, layout (T, C, H, W), bfloat16.
Returns:
Forward optical flow, (T, 2, H, W), bfloat16.
"""
input_dtype = input1.dtype
flow_output = self.flow_model.to(self.dtype)(input1.to(self.dtype), input2.to(self.dtype))[-1]
return flow_output.to(input_dtype)
def _run_model_fwd(self, input_video: torch.Tensor) -> torch.Tensor:
"""Runs foward flow on a batch of videos, one batch at a time.
Args:
input_video: The input batch of videos, layout (B, T, C, H, W).
Returns:
Forward optical flow, layout (B, 2, T-1, H, W).
"""
output_list = list()
for fwd_input_frames in input_video:
fwd_input_frames = fwd_input_frames.transpose(1, 0)
fwd_flow_output = self._run_model(fwd_input_frames[:-1], fwd_input_frames[1:])
output_list.append(fwd_flow_output.transpose(1, 0))
return torch.stack(output_list, dim=0)
def _bidirectional_flow(self, input_video: torch.Tensor) -> torch.Tensor:
"""The bidirectional optical flow on a batch of videos.
The forward and backward flows are averaged to get the bidirectional flow.
To reduce memory pressure, the input video is scaled down by a factor of `self.scale`,
and rescaled back to match other pixel-wise losses.
Args:
input_video: The input batch of videos, layout (B, T, C, H, W).
Returns:
Biderectinoal flow, layout (B, 2, T-1, H, W).
"""
# scale down the input video to reduce memory pressure.
t, h, w = input_video.shape[-3:]
input_video_scaled = F.interpolate(input_video, (t, h // self.scale, w // self.scale), mode="trilinear")
# forward flow.
if self.checkpoint_activations:
fwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False)
else:
fwd_flow_output = self._run_model_fwd(input_video_scaled)
# backward flow.
input_video_scaled = input_video_scaled.flip([2])
if self.checkpoint_activations:
bwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False)
else:
bwd_flow_output = self._run_model_fwd(input_video_scaled)
bwd_flow_output = bwd_flow_output.flip([2])
# bidirectional flow, concat fwd and bwd along temporal axis.
flow_input = torch.cat([fwd_flow_output, bwd_flow_output], dim=2)
return self.scale * F.interpolate(flow_input, (2 * (t - 1), h, w), mode="trilinear")
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
input_images = inputs[INPUT_KEY]
if input_images.ndim == 4 or input_images.shape[2] == 1:
return dict()
if not self.enabled or self.schedule(iteration) == 0.0:
return dict()
# Biderectional flow (B, 2, 2*(T-1), H, W)
flow_input = self._bidirectional_flow(input_images)
flow_recon = self._bidirectional_flow(output_batch[RECON_KEY])
# L1 loss on the flow. (B, 1, 2*(T-1), H, W)
flow_loss = torch.abs(flow_input - flow_recon).mean(dim=1, keepdim=True)
flow_loss_weighted = self.schedule(iteration) * flow_loss
if torch.isnan(flow_loss_weighted).any():
raise ValueError("[FLOW] NaN detected in loss")
return dict(flow=flow_loss_weighted)
def torch_compile(self):
"""
This method invokes torch.compile() on this loss
"""
self.flow_model = torch.compile(self.flow_model, dynamic=False)
class VideoConsistencyLoss(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values)
self.enabled = config.enabled
self.num_frames = config.num_frames
self.step = config.step
self.num_windows = None
def shuffle(self, inputs: torch.Tensor) -> torch.Tensor:
"""
For input video of [B, 3, T, H, W], this function will reshape the video to
the shape of [B*(T-num_frames+1)//step, 3, num_frames, H, W] using a sliding window
This function is used to compute the temporal consistency between overlapped frames
to enable temporal consistency
"""
assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}"
B, C, T, H, W = inputs.shape
assert T >= self.num_frames, f"inputs {T} should be greater than {self.num_frames}"
# [B, C, num_windows, H, W, num_frames]
outputs = inputs.unfold(dimension=2, size=self.num_frames, step=self.step)
self.num_windows = outputs.shape[2]
outputs = einops.rearrange(outputs, "b c m h w n -> (b m) c n h w")
return outputs
def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]:
if not self.enabled or self.num_windows is None:
return dict()
if self.schedule(iteration) == 0.0:
return dict()
# reshape output_batch to compute loss between overlapped frames
reconstructions = output_batch[RECON_CONSISTENCY_KEY]
B, C, T, H, W = reconstructions.shape
assert T == self.num_frames, f"reconstruction shape invalid (shape[2] should be {self.num_frames})"
assert (
B % self.num_windows == 0
), f"reconstruction shape invalid (shape[0]={B} not dividable by {self.num_windows})"
B = B // self.num_windows
videos = reconstructions.view(B, self.num_windows, C, self.num_frames, H, W)
# Compute the L1 distance between overlapped frames for all windows at once
diff = torch.mean(torch.abs(videos[:, :-1, :, self.step :, :, :] - videos[:, 1:, :, : -self.step, :, :]))
diff_weighted = self.schedule(iteration) * diff
if LATENT_KEY not in output_batch:
return dict(frame_consistency=diff_weighted)
B_latent, C_latent, T_latent, H_latent, W_latent = output_batch["latent"].shape
assert B_latent % self.num_windows == 0, f"latent batches should be divisible by {self.num_windows}"
latents = output_batch[LATENT_KEY].view(
B_latent // self.num_windows, self.num_windows, C_latent, T_latent, H_latent, W_latent
)
temporal_rate = self.num_frames // T_latent
spatial_rate = (H // H_latent) * (W // W_latent)
step_latent = self.step // temporal_rate
latent_diff = torch.mean(
torch.abs(latents[:, :-1, :, step_latent:, :, :] - latents[:, 1:, :, :-step_latent, :, :])
)
latent_diff_weighted = self.schedule(iteration) * latent_diff * (C * temporal_rate * spatial_rate) / (C_latent)
return dict(frame_consistency=diff_weighted, latent_consistency=latent_diff_weighted)
def unshuffle(self, inputs: torch.Tensor) -> torch.Tensor:
"""
For input video of [B*num_windows, 3, num_frames, H, W], this function will
undo the shuffle to a tensor of shape [B, 3, T, H, W]
"""
assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}"
B, C, T, H, W = inputs.shape
assert T == self.num_frames, f"inputs shape invalid (shape[2] should be {self.num_frames})"
assert B % self.num_windows == 0, f"inputs shape invalid (shape[0]={B} not dividable by {self.num_windows})"
B = B // self.num_windows
videos = inputs.view(B, self.num_windows, C, self.num_frames, H, W)
T = self.num_frames + (self.num_windows - 1) * self.step
current_device = torch.device(torch.cuda.current_device())
outputs = torch.zeros(B, C, T, H, W).to(inputs.dtype).to(current_device)
counter = torch.zeros_like(outputs)
for i in range(self.num_windows):
outputs[:, :, i * self.step : i * self.step + self.num_frames, :, :] += videos[:, i, :, :, :, :]
counter[:, :, i * self.step : i * self.step + self.num_frames, :, :] += 1
outputs = outputs / counter
return outputs