|
import torch |
|
import torch.nn.functional as F |
|
from segmentation_models_pytorch.utils import base |
|
from segmentation_models_pytorch.base.modules import Activation |
|
|
|
class FocalLossFunction(base.Loss): |
|
def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs): |
|
super().__init__(**kwargs) |
|
self.activation = Activation(activation) |
|
self.alpha = alpha |
|
self.gamma = gamma |
|
self.reduction = reduction |
|
|
|
def forward(self, inputs, targets): |
|
if inputs.shape[1] == 1: |
|
inputs = torch.cat((inputs, 1 - inputs), dim=1) |
|
targets = torch.cat((targets, 1 - targets), dim=1) |
|
|
|
targets = torch.argmax(targets, dim=1) |
|
cross_entropy = F.cross_entropy(inputs, targets, reduction='none') |
|
probability = torch.exp(-cross_entropy) |
|
alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where( |
|
targets == 1, 1-self.alpha, self.alpha) |
|
|
|
focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy |
|
|
|
if self.reduction == 'mean': |
|
return focal_weight.mean() |
|
elif self.reduction == 'sum': |
|
return focal_weight.sum() |
|
return focal_weight |
|
|
|
class TverskyLossFunction(base.Loss): |
|
def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None, |
|
reduction='mean', **kwargs): |
|
super().__init__(**kwargs) |
|
self.activation = Activation(activation) |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.ignore_channels = ignore_channels |
|
self.reduction = reduction |
|
|
|
def forward(self, inputs, targets): |
|
if self.ignore_channels is not None: |
|
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) |
|
mask[self.ignore_channels] = False |
|
inputs = inputs[:, mask, ...] |
|
|
|
num_classes = inputs.shape[1] |
|
inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1 |
|
else F.softmax(inputs, dim=1)) |
|
|
|
if num_classes == 1: |
|
inputs_softmax = inputs_softmax.squeeze(1) |
|
targets = targets.squeeze(1) |
|
|
|
tversky_loss = 0 |
|
for class_idx in range(num_classes): |
|
if num_classes == 1: |
|
flat_inputs = inputs_softmax.reshape(-1) |
|
flat_targets = targets.reshape(-1) |
|
else: |
|
flat_inputs = inputs_softmax[:, class_idx].reshape(-1) |
|
flat_targets = targets[:, class_idx].reshape(-1) |
|
|
|
intersection = (flat_inputs * flat_targets).sum() |
|
fps = ((1 - flat_targets) * flat_inputs).sum() |
|
fns = (flat_targets * (1 - flat_inputs)).sum() |
|
|
|
tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10 |
|
tversky_loss += 1 - intersection / tversky_index |
|
|
|
if self.reduction == 'mean': |
|
return tversky_loss / (1 if num_classes == 1 else num_classes) |
|
elif self.reduction == 'sum': |
|
return tversky_loss |
|
return tversky_loss / inputs.shape[0] |
|
|
|
class EnhancedCrossEntropy(base.Loss): |
|
def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs): |
|
super().__init__(**kwargs) |
|
self.activation = Activation(activation) |
|
self.ignore_channels = ignore_channels |
|
self.reduction = reduction |
|
|
|
def forward(self, inputs, targets): |
|
inputs = self.activation(inputs) |
|
|
|
if self.ignore_channels is not None: |
|
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) |
|
mask[self.ignore_channels] = False |
|
inputs = inputs[:, mask, ...] |
|
|
|
if targets.dim() == 4: |
|
targets = torch.argmax(targets, dim=1) |
|
|
|
return F.cross_entropy(inputs, targets, reduction=self.reduction) |