File size: 7,930 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
"""
Misc Losses
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .builder import LOSSES
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(CrossEntropyLoss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.CrossEntropyLoss(
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, pred, target):
return self.loss(pred, target) * self.loss_weight
@LOSSES.register_module()
class SmoothCELoss(nn.Module):
def __init__(self, smoothing_ratio=0.1):
super(SmoothCELoss, self).__init__()
self.smoothing_ratio = smoothing_ratio
def forward(self, pred, target):
eps = self.smoothing_ratio
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).total(dim=1)
loss = loss[torch.isfinite(loss)].mean()
return loss
@LOSSES.register_module()
class BinaryFocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
"""Binary Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(BinaryFocalLoss, self).__init__()
assert 0 < alpha < 1
self.gamma = gamma
self.alpha = alpha
self.logits = logits
self.reduce = reduce
self.loss_weight = loss_weight
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N)
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
if self.logits:
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
else:
bce = F.binary_cross_entropy(pred, target, reduction="none")
pt = torch.exp(-bce)
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * (1 - pt) ** self.gamma * bce
if self.reduce:
focal_loss = torch.mean(focal_loss)
return focal_loss * self.loss_weight
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
):
"""Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(FocalLoss, self).__init__()
assert reduction in (
"mean",
"sum",
), "AssertionError: reduction should be 'mean' or 'sum'"
assert isinstance(
alpha, (float, list)
), "AssertionError: alpha should be of type float"
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
assert isinstance(
loss_weight, float
), "AssertionError: loss_weight should be of type float"
assert isinstance(ignore_index, int), "ignore_index must be of type int"
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
if len(target) == 0:
return 0.0
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes)
alpha = self.alpha
if isinstance(alpha, list):
alpha = pred.new_tensor(alpha)
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
self.gamma
)
loss = (
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
* focal_weight
)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.total()
return self.loss_weight * loss
@LOSSES.register_module()
class DiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
self.exponent = exponent
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
)
total_loss = 0
for i in range(num_classes):
if i != self.ignore_index:
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
den = (
torch.sum(
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
)
+ self.smooth
)
dice_loss = 1 - num / den
total_loss += dice_loss
loss = total_loss / num_classes
return self.loss_weight * loss
|