|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
import numpy as np |
|
|
|
from basicsr.models.losses.loss_util import weighted_loss |
|
|
|
_reduction_modes = ['none', 'mean', 'sum'] |
|
|
|
|
|
@weighted_loss |
|
def l1_loss(pred, target): |
|
return F.l1_loss(pred, target, reduction='none') |
|
|
|
|
|
@weighted_loss |
|
def mse_loss(pred, target): |
|
return F.mse_loss(pred, target, reduction='none') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class L1Loss(nn.Module): |
|
"""L1 (mean absolute error, MAE) loss. |
|
|
|
Args: |
|
loss_weight (float): Loss weight for L1 loss. Default: 1.0. |
|
reduction (str): Specifies the reduction to apply to the output. |
|
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. |
|
""" |
|
|
|
def __init__(self, loss_weight=1.0, reduction='mean'): |
|
super(L1Loss, self).__init__() |
|
if reduction not in ['none', 'mean', 'sum']: |
|
raise ValueError(f'Unsupported reduction mode: {reduction}. ' |
|
f'Supported ones are: {_reduction_modes}') |
|
|
|
self.loss_weight = loss_weight |
|
self.reduction = reduction |
|
|
|
def forward(self, pred, target, weight=None, **kwargs): |
|
""" |
|
Args: |
|
pred (Tensor): of shape (N, C, H, W). Predicted tensor. |
|
target (Tensor): of shape (N, C, H, W). Ground truth tensor. |
|
weight (Tensor, optional): of shape (N, C, H, W). Element-wise |
|
weights. Default: None. |
|
""" |
|
return self.loss_weight * l1_loss( |
|
pred, target, weight, reduction=self.reduction) |
|
|
|
class MSELoss(nn.Module): |
|
"""MSE (L2) loss. |
|
|
|
Args: |
|
loss_weight (float): Loss weight for MSE loss. Default: 1.0. |
|
reduction (str): Specifies the reduction to apply to the output. |
|
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. |
|
""" |
|
|
|
def __init__(self, loss_weight=1.0, reduction='mean'): |
|
super(MSELoss, self).__init__() |
|
if reduction not in ['none', 'mean', 'sum']: |
|
raise ValueError(f'Unsupported reduction mode: {reduction}. ' |
|
f'Supported ones are: {_reduction_modes}') |
|
|
|
self.loss_weight = loss_weight |
|
self.reduction = reduction |
|
|
|
def forward(self, pred, target, weight=None, **kwargs): |
|
""" |
|
Args: |
|
pred (Tensor): of shape (N, C, H, W). Predicted tensor. |
|
target (Tensor): of shape (N, C, H, W). Ground truth tensor. |
|
weight (Tensor, optional): of shape (N, C, H, W). Element-wise |
|
weights. Default: None. |
|
""" |
|
return self.loss_weight * mse_loss( |
|
pred, target, weight, reduction=self.reduction) |
|
|
|
class PSNRLoss(nn.Module): |
|
|
|
def __init__(self, loss_weight=1.0, reduction='mean', toY=False): |
|
super(PSNRLoss, self).__init__() |
|
assert reduction == 'mean' |
|
self.loss_weight = loss_weight |
|
self.scale = 10 / np.log(10) |
|
self.toY = toY |
|
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) |
|
self.first = True |
|
|
|
def forward(self, pred, target): |
|
assert len(pred.size()) == 4 |
|
if self.toY: |
|
if self.first: |
|
self.coef = self.coef.to(pred.device) |
|
self.first = False |
|
|
|
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. |
|
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. |
|
|
|
pred, target = pred / 255., target / 255. |
|
pass |
|
assert len(pred.size()) == 4 |
|
|
|
return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() |
|
|
|
class CharbonnierLoss(nn.Module): |
|
"""Charbonnier Loss (L1)""" |
|
|
|
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): |
|
super(CharbonnierLoss, self).__init__() |
|
self.eps = eps |
|
|
|
def forward(self, x, y): |
|
diff = x - y |
|
|
|
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) |
|
return loss |
|
|