RynnEC / rynnec /model /loss.py
lixin4ever's picture
Upload (#2)
372785b verified
# Adopted from https://github.com/magic-research/Sa2VA.
# Below is the original copyright:
# coding=utf-8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
linear_cross_entropy = None
import torch
import torch.nn.functional as F
import torch.nn as nn
from rynnec.constants import IGNORE_INDEX
from torch import Tensor
import logging
from huggingface_hub import hf_hub_download
import functools
from typing import Callable, Optional
def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss: Tensor,
weight: Optional[Tensor] = None,
reduction: str = 'mean',
avg_factor: Optional[float] = None) -> Tensor:
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Optional[Tensor], optional): Element-wise weights.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (Optional[float], optional): Average factor when
computing the mean of losses. Defaults to None.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def dice_loss(pred,
target,
weight=None,
eps=1e-3,
reduction='mean',
naive_dice=False,
avg_factor=None):
"""Calculate dice loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
if naive_dice:
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
else:
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=False,
loss_weight=1.0,
eps=1e-3):
"""Compute dice loss.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power. Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self,
pred,
target,
weight=None,
reduction_override=None,
avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
loss = self.loss_weight * dice_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
naive_dice=self.naive_dice,
avg_factor=avg_factor)
return loss
def cross_entropy_loss(
hidden_states,
lm_head,
position_ids,
labels,
reduction_scope="sequence",
**loss_kwargs
):
batch_size = hidden_states.size(0)
shift_hidden_states = hidden_states[..., :-1, :]
shift_labels = labels[..., 1:]
mask = shift_labels != IGNORE_INDEX
shift_hidden_states = shift_hidden_states[mask].contiguous()
shift_labels = shift_labels[mask].contiguous()
if mask.sum() == 0:
print(f"Get labels={labels}. Found no sample to calculate loss!")
pseudo_logits = lm_head(hidden_states[:, 0:1])
loss = 0.0 * pseudo_logits.mean()
return loss
if "num_items_in_batch" not in loss_kwargs:
reduction = "mean"
denominator = None
elif reduction_scope == "batch":
reduction = "sum"
denominator = loss_kwargs["num_items_in_batch"]
elif reduction_scope == "sequence":
reduction = "none"
if batch_size == 1:
# NOTE: packed sequence
start_indices = torch.nonzero(position_ids[0] == 0)[:, 0]
end_indices = F.pad(start_indices[1:], (0, 1), value=position_ids.size(1))
batch_indices = torch.cat(
[
torch.full((e - s,), fill_value=i, device=position_ids.device, dtype=torch.long)
for i, (s, e) in enumerate(zip(start_indices, end_indices))
],
).unsqueeze(0)
else:
batch_indices = torch.arange(batch_size, device=position_ids.device)
batch_indices = batch_indices.unsqueeze(1).expand(-1, hidden_states.size(1))
shift_batch_indices = batch_indices[..., :-1]
shift_batch_indices = shift_batch_indices[mask].contiguous()
num_tokens = F.one_hot(shift_batch_indices).sum(dim=0)
denominator = num_tokens[shift_batch_indices] * loss_kwargs["num_items_in_batch"]
else:
raise ValueError(f"Unknown reduction scope: {reduction_scope}")
if linear_cross_entropy is None:
shift_logits = lm_head(shift_hidden_states)
loss = torch.nn.functional.cross_entropy(
shift_logits,
shift_labels,
reduction=reduction,
)
else:
loss = linear_cross_entropy(
shift_hidden_states,
lm_head.weight,
shift_labels,
bias=lm_head.bias,
reduction=reduction,
accum_e_fp32=True,
accum_c_fp32=True,
)
if denominator is not None:
loss = loss / denominator
if loss.ndim > 0:
loss = loss.sum()
return loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(
valid_mask & (labels < label_channels), as_tuple=False)
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
label_channels).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if ignore_index is None:
ignore_index = self.ignore_index
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls