# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The combined loss functions for continuous-space tokenizers training.""" import einops import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import torchvision.models.optical_flow as optical_flow from cosmos_predict1.tokenizer.modules.utils import batch2time, time2batch from cosmos_predict1.tokenizer.training.datasets.utils import INPUT_KEY, LATENT_KEY, MASK_KEY, RECON_KEY from cosmos_predict1.tokenizer.training.losses import ReduceMode from cosmos_predict1.tokenizer.training.losses.lpips import LPIPS from cosmos_predict1.utils.lazy_config import instantiate _VALID_LOSS_NAMES = ["color", "perceptual", "flow", "kl", "video_consistency"] VIDEO_CONSISTENCY_LOSS = "video_consistency" RECON_CONSISTENCY_KEY = f"{RECON_KEY}_consistency" class TokenizerLoss(nn.Module): def __init__(self, config) -> None: super().__init__() self.config = config _reduce = ReduceMode(config.reduce.upper()) if hasattr(config, "reduce") else None self.reduce = _reduce.function self.loss_modules = nn.ModuleDict() for key in _VALID_LOSS_NAMES: self.loss_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NullLoss() def forward(self, inputs, output_batch, iteration) -> tuple[dict[str, torch.Tensor], torch.Tensor]: loss = dict() total_loss = 0.0 inputs[MASK_KEY] = torch.ones_like(inputs[INPUT_KEY]) # Calculates reconstruction losses (`total_loss`). for key, module in self.loss_modules.items(): curr_loss = module(inputs, output_batch, iteration) loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) # Computes the overall loss as sum of the reconstruction losses and the generator loss. total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) return dict(loss=loss), total_loss class WeightScheduler(torch.nn.Module): def __init__(self, boundaries, values): super().__init__() self.boundaries = list(boundaries) self.values = list(values) def forward(self, iteration): for boundary, value in zip(self.boundaries, self.values): if iteration < boundary: return value return self.values[-1] class NullLoss(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, inputs, output_batch, iteration) -> dict[dict, torch.Tensor]: return dict() class ColorLoss(torch.nn.Module): def __init__(self, config) -> None: super().__init__() self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: reconstructions = output_batch[RECON_KEY] weights = inputs[MASK_KEY] recon = weights * torch.abs(inputs[INPUT_KEY].contiguous() - reconstructions.contiguous()) color_weighted = self.schedule(iteration) * recon if torch.isnan(color_weighted).any(): raise ValueError("[COLOR] NaN detected in loss") return dict(color=color_weighted) class KLLoss(torch.nn.Module): def __init__(self, config) -> None: super().__init__() self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) def kl(self, mean, logvar): _dims = [idx for idx in range(1, mean.ndim)] var = torch.exp(logvar) return 0.5 * (torch.pow(mean, 2) + var - 1.0 - logvar) def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: if "posteriors" not in output_batch: # No KL loss for discrete tokens. return dict() mean, logvar = output_batch["posteriors"] if mean.ndim == 1: # No KL if the mean is a scalar. return dict() kl = self.kl(mean, logvar) kl_weighted = self.schedule(iteration) * kl if torch.isnan(kl_weighted).any(): raise ValueError("[KL] NaN detected in loss") return dict(kl=kl_weighted) class PerceptualLoss(LPIPS): """Relevant changes that're internal to us: - Remove linear projection layers, simply use the raw pre-normalized features. - Use pyramid-layer weights: [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]. - Accepts pixel-wise masks and modulates the features before norm calculation. - Implements gram-matrix and correlation losses. """ def __init__(self, config): super(PerceptualLoss, self).__init__(config.checkpoint_activations) self.net = self.net.eval() self.gram_enabled = config.gram_enabled self.corr_enabled = config.corr_enabled self.layer_weights = list(config.layer_weights) self.lpips_schedule = WeightScheduler(config.lpips_boundaries, config.lpips_values) self.gram_schedule = WeightScheduler(config.gram_boundaries, config.gram_values) self.corr_schedule = WeightScheduler(config.corr_boundaries, config.corr_values) self.checkpoint_activations = config.checkpoint_activations def _temporal_gram_matrix(self, x, batch_size=None): x = batch2time(x, batch_size) c, t, h, w = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1] reshaped_x = torch.reshape(x, [-1, c, t * h * w]) return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(t * h * w) def _gram_matrix(self, x, batch_size=None): if batch_size is not None and x.shape[0] != batch_size: return self._temporal_gram_matrix(x, batch_size) c, h, w = x.shape[-3], x.shape[-2], x.shape[-1] reshaped_x = torch.reshape(x, [-1, c, h * w]) return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(h * w) def forward(self, inputs, output_batch, iteration): output_dict = dict() reconstructions = output_batch[RECON_KEY] weights = inputs[MASK_KEY] input_images = inputs[INPUT_KEY] if input_images.ndim == 5: input_images, batch_size = time2batch(input_images) reconstructions, _ = time2batch(reconstructions) weights, _ = time2batch(weights) else: batch_size = input_images.shape[0] in0_input, in1_input = (self.scaling_layer(input_images), self.scaling_layer(reconstructions)) outs0, outs1 = self.net(in0_input), self.net(in1_input) _layer_weights = self.layer_weights weights_map, res, diffs = {}, {}, {} for kk in range(len(self.chns)): weights_map[kk] = torch.nn.functional.interpolate(weights[:, :1, ...], outs0[kk].shape[-2:]) diffs[kk] = weights_map[kk] * torch.abs(outs0[kk] - outs1[kk]) res[kk] = _layer_weights[kk] * diffs[kk].mean([1, 2, 3], keepdim=True) val = res[0] for ll in range(1, len(self.chns)): val += res[ll] # Scale by number of pixels to match pixel-wise losses. val = val.expand(-1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1]) if batch_size != input_images.shape[0]: val = batch2time(val, batch_size) if torch.isnan(val).any(): raise ValueError("[LPIPS] NaN detected in loss") output_dict["lpips"] = self.lpips_schedule(iteration) * val if self.gram_enabled and self.gram_schedule(iteration) > 0.0: num_chans = len(self.chns) grams0 = [self._gram_matrix(weights_map[kk] * outs0[kk], batch_size) for kk in range(num_chans)] grams1 = [self._gram_matrix(weights_map[kk] * outs1[kk], batch_size) for kk in range(num_chans)] gram_diffs = [(grams0[kk] - grams1[kk]) ** 2 for kk in range(num_chans)] grams_res = [_layer_weights[kk] * gram_diffs[kk].mean([1, 2], keepdim=True) for kk in range(num_chans)] gram_val = grams_res[0] for ll in range(1, len(self.chns)): gram_val += grams_res[ll] # Scale by number of total pixels to match pixel-wise losses. gram_val = gram_val.unsqueeze(1).expand( -1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1] ) if batch_size != input_images.shape[0]: gram_val = batch2time(gram_val, batch_size) if torch.isnan(gram_val).any(): raise ValueError("[GRAM] NaN detected in loss") output_dict["gram"] = self.gram_schedule(iteration) * gram_val return output_dict def torch_compile(self): """ This method invokes torch.compile() on this loss """ # cuda-graphs crash after 1k iterations self.net = torch.compile(self.net, dynamic=False) class FlowLoss(torch.nn.Module): def __init__(self, config) -> None: super().__init__() self.schedule = WeightScheduler(config.boundaries, config.values) self.scale = config.scale self.dtype = getattr(torch, config.dtype) self.checkpoint_activations = config.checkpoint_activations self.enabled = config.enabled current_device = torch.device(torch.cuda.current_device()) # In order to be able to run model in bf16 we need to change make_coords_grid() # to allow it to return arbitrary type provided by us in argument # the line from orginal implementation that caused results to be only fp32 is commented # Additionally I've changed that function to run on GPU instead of CPU, which results in # less graph breaks when torch.compile() is used # This function is copied from # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/_utils.py#L22 # commit: b06ea39d5f0adbe949d08257837bda912339e415 def make_coords_grid( batch_size: int, h: int, w: int, device: torch.device = current_device, dtype: torch.dtype = self.dtype ): # Original: def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): device = torch.device(device) coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") coords = torch.stack(coords[::-1], dim=0).to(dtype) # Original: coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch_size, 1, 1, 1) # We also need to specify output dtype of torch.linspace() in index_pyramid() # method of CorrBlock, otherwise it uses default fp32 dtype as output. # Additionally I've changed that function to run on GPU instead of CPU, which results in # less graph breaks when torch.compile() is used # This function is copied from # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py#L394 # commit: b06ea39d5f0adbe949d08257837bda912339e415 def index_pyramid( self, centroids_coords, dtype: torch.dtype = self.dtype, device: torch.device = current_device ): # Original: def index_pyramid(self, centroids_coords): """Return correlation features by indexing from the pyramid.""" neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) # Original: di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) # Original: dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) batch_size, _, h, w = centroids_coords.shape # _ = 2 centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) indexed_pyramid = [] for corr_volume in self.corr_pyramid: sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) indexed_corr_volume = optical_flow.raft.grid_sample( corr_volume, sampling_coords, align_corners=True, mode="bilinear" ).view(batch_size, h, w, -1) indexed_pyramid.append(indexed_corr_volume) centroids_coords = centroids_coords / 2 corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() expected_output_shape = (batch_size, self.out_channels, h, w) if corr_features.shape != expected_output_shape: raise ValueError( f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" ) return corr_features optical_flow.raft.make_coords_grid = make_coords_grid optical_flow.raft.CorrBlock.index_pyramid = index_pyramid flow_model = optical_flow.raft_large(pretrained=True, progress=False) flow_model.requires_grad_(False) flow_model.eval() flow_model = flow_model.to(self.dtype) self.flow_model = flow_model def _run_model(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: """Runs flow_model in the forward mode on explicit dtype=float32. Args: input1: First video frames batch, layout (T, C, H, W), bfloat16. input2: Next video frames batch, layout (T, C, H, W), bfloat16. Returns: Forward optical flow, (T, 2, H, W), bfloat16. """ input_dtype = input1.dtype flow_output = self.flow_model.to(self.dtype)(input1.to(self.dtype), input2.to(self.dtype))[-1] return flow_output.to(input_dtype) def _run_model_fwd(self, input_video: torch.Tensor) -> torch.Tensor: """Runs foward flow on a batch of videos, one batch at a time. Args: input_video: The input batch of videos, layout (B, T, C, H, W). Returns: Forward optical flow, layout (B, 2, T-1, H, W). """ output_list = list() for fwd_input_frames in input_video: fwd_input_frames = fwd_input_frames.transpose(1, 0) fwd_flow_output = self._run_model(fwd_input_frames[:-1], fwd_input_frames[1:]) output_list.append(fwd_flow_output.transpose(1, 0)) return torch.stack(output_list, dim=0) def _bidirectional_flow(self, input_video: torch.Tensor) -> torch.Tensor: """The bidirectional optical flow on a batch of videos. The forward and backward flows are averaged to get the bidirectional flow. To reduce memory pressure, the input video is scaled down by a factor of `self.scale`, and rescaled back to match other pixel-wise losses. Args: input_video: The input batch of videos, layout (B, T, C, H, W). Returns: Biderectinoal flow, layout (B, 2, T-1, H, W). """ # scale down the input video to reduce memory pressure. t, h, w = input_video.shape[-3:] input_video_scaled = F.interpolate(input_video, (t, h // self.scale, w // self.scale), mode="trilinear") # forward flow. if self.checkpoint_activations: fwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) else: fwd_flow_output = self._run_model_fwd(input_video_scaled) # backward flow. input_video_scaled = input_video_scaled.flip([2]) if self.checkpoint_activations: bwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) else: bwd_flow_output = self._run_model_fwd(input_video_scaled) bwd_flow_output = bwd_flow_output.flip([2]) # bidirectional flow, concat fwd and bwd along temporal axis. flow_input = torch.cat([fwd_flow_output, bwd_flow_output], dim=2) return self.scale * F.interpolate(flow_input, (2 * (t - 1), h, w), mode="trilinear") def forward( self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int ) -> dict[str, torch.Tensor]: input_images = inputs[INPUT_KEY] if input_images.ndim == 4 or input_images.shape[2] == 1: return dict() if not self.enabled or self.schedule(iteration) == 0.0: return dict() # Biderectional flow (B, 2, 2*(T-1), H, W) flow_input = self._bidirectional_flow(input_images) flow_recon = self._bidirectional_flow(output_batch[RECON_KEY]) # L1 loss on the flow. (B, 1, 2*(T-1), H, W) flow_loss = torch.abs(flow_input - flow_recon).mean(dim=1, keepdim=True) flow_loss_weighted = self.schedule(iteration) * flow_loss if torch.isnan(flow_loss_weighted).any(): raise ValueError("[FLOW] NaN detected in loss") return dict(flow=flow_loss_weighted) def torch_compile(self): """ This method invokes torch.compile() on this loss """ self.flow_model = torch.compile(self.flow_model, dynamic=False) class VideoConsistencyLoss(torch.nn.Module): def __init__(self, config) -> None: super().__init__() self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) self.enabled = config.enabled self.num_frames = config.num_frames self.step = config.step self.num_windows = None def shuffle(self, inputs: torch.Tensor) -> torch.Tensor: """ For input video of [B, 3, T, H, W], this function will reshape the video to the shape of [B*(T-num_frames+1)//step, 3, num_frames, H, W] using a sliding window This function is used to compute the temporal consistency between overlapped frames to enable temporal consistency """ assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" B, C, T, H, W = inputs.shape assert T >= self.num_frames, f"inputs {T} should be greater than {self.num_frames}" # [B, C, num_windows, H, W, num_frames] outputs = inputs.unfold(dimension=2, size=self.num_frames, step=self.step) self.num_windows = outputs.shape[2] outputs = einops.rearrange(outputs, "b c m h w n -> (b m) c n h w") return outputs def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: if not self.enabled or self.num_windows is None: return dict() if self.schedule(iteration) == 0.0: return dict() # reshape output_batch to compute loss between overlapped frames reconstructions = output_batch[RECON_CONSISTENCY_KEY] B, C, T, H, W = reconstructions.shape assert T == self.num_frames, f"reconstruction shape invalid (shape[2] should be {self.num_frames})" assert ( B % self.num_windows == 0 ), f"reconstruction shape invalid (shape[0]={B} not dividable by {self.num_windows})" B = B // self.num_windows videos = reconstructions.view(B, self.num_windows, C, self.num_frames, H, W) # Compute the L1 distance between overlapped frames for all windows at once diff = torch.mean(torch.abs(videos[:, :-1, :, self.step :, :, :] - videos[:, 1:, :, : -self.step, :, :])) diff_weighted = self.schedule(iteration) * diff if LATENT_KEY not in output_batch: return dict(frame_consistency=diff_weighted) B_latent, C_latent, T_latent, H_latent, W_latent = output_batch["latent"].shape assert B_latent % self.num_windows == 0, f"latent batches should be divisible by {self.num_windows}" latents = output_batch[LATENT_KEY].view( B_latent // self.num_windows, self.num_windows, C_latent, T_latent, H_latent, W_latent ) temporal_rate = self.num_frames // T_latent spatial_rate = (H // H_latent) * (W // W_latent) step_latent = self.step // temporal_rate latent_diff = torch.mean( torch.abs(latents[:, :-1, :, step_latent:, :, :] - latents[:, 1:, :, :-step_latent, :, :]) ) latent_diff_weighted = self.schedule(iteration) * latent_diff * (C * temporal_rate * spatial_rate) / (C_latent) return dict(frame_consistency=diff_weighted, latent_consistency=latent_diff_weighted) def unshuffle(self, inputs: torch.Tensor) -> torch.Tensor: """ For input video of [B*num_windows, 3, num_frames, H, W], this function will undo the shuffle to a tensor of shape [B, 3, T, H, W] """ assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" B, C, T, H, W = inputs.shape assert T == self.num_frames, f"inputs shape invalid (shape[2] should be {self.num_frames})" assert B % self.num_windows == 0, f"inputs shape invalid (shape[0]={B} not dividable by {self.num_windows})" B = B // self.num_windows videos = inputs.view(B, self.num_windows, C, self.num_frames, H, W) T = self.num_frames + (self.num_windows - 1) * self.step current_device = torch.device(torch.cuda.current_device()) outputs = torch.zeros(B, C, T, H, W).to(inputs.dtype).to(current_device) counter = torch.zeros_like(outputs) for i in range(self.num_windows): outputs[:, :, i * self.step : i * self.step + self.num_frames, :, :] += videos[:, i, :, :, :, :] counter[:, :, i * self.step : i * self.step + self.num_frames, :, :] += 1 outputs = outputs / counter return outputs