audio-flamingo-3 / llava /model /liger /cross_entropy.py
SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import operator
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import compare_version, element_mul_kernel, is_hip
if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh
_TRUE = tl.constexpr(1)
_FALSE = tl.constexpr(0)
@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""
# https://github.com/triton-lang/triton/issues/1058
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
program_id = tl.program_id(0).to(tl.int64)
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
# 2. locate the start index
X_ptr += program_id * X_stride
if y == ignore_index:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)
# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# For label smoothing:
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
# With Z loss:
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
# dx_y = dx_i - (1 - label_smoothing) / N
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
# chain rule
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# = X_y - m - log d = X_y - lse
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss
# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
loss += z_loss
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
tl.store(z_loss_ptr, z_loss)
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
_bool_to_return_z_loss = {
True: _TRUE.value,
False: _FALSE.value,
}
def cross_entropy_forward(
_input,
target,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
):
if not isinstance(return_z_loss, int):
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
return_z_loss = _bool_to_return_z_loss[return_z_loss]
else:
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
if return_z_loss == _TRUE.value:
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
n_non_ignore = (target != ignore_index).sum().item()
# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32 if not is_hip() else 16,
)
loss = torch.sum(loss_1d)
if return_z_loss == _TRUE.value:
z_loss = torch.sum(z_loss_1d)
else:
z_loss = None
return loss, z_loss, _input
def cross_entropy_backward(_input, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return _input
class LigerCrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""
@staticmethod
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
"""
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
return loss, z_loss
@staticmethod
def backward(ctx, grad_output, grad_ouput2):
"""
The backward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tenosr): No use.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
_input,
None,
None,
None,
None,
None,
None,
None,
)
def liger_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
# loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
loss, _ = LigerCrossEntropyFunction.apply(source, target, ignore_index, 0.0, 0.0, reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def LigerForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = liger_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss