File size: 875 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn


class WeightedMSELoss(nn.Module):
    r"""Compute Weighted MSE loss"""
    def __init__(self, reduction='mean'):
        super(WeightedMSELoss, self).__init__()
        self.reduction = reduction

    def forward(self, input, target, weight):
        r"""Return weighted MSE Loss.
        Args:
           input (tensor):
           target (tensor):
           weight (tensor):
        Returns:
           (tensor): Loss value.
        """
        if self.reduction == 'mean':
            loss = torch.mean(weight * (input - target) ** 2)
        else:
            loss = torch.sum(weight * (input - target) ** 2)
        return loss