Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
import torch.nn as nn | |
class WeightedMSELoss(nn.Module): | |
r"""Compute Weighted MSE loss""" | |
def __init__(self, reduction='mean'): | |
super(WeightedMSELoss, self).__init__() | |
self.reduction = reduction | |
def forward(self, input, target, weight): | |
r"""Return weighted MSE Loss. | |
Args: | |
input (tensor): | |
target (tensor): | |
weight (tensor): | |
Returns: | |
(tensor): Loss value. | |
""" | |
if self.reduction == 'mean': | |
loss = torch.mean(weight * (input - target) ** 2) | |
else: | |
loss = torch.sum(weight * (input - target) ** 2) | |
return loss | |