Spaces:
Sleeping
Sleeping
| from itertools import zip_longest | |
| import torch | |
| class MultitaskLoss(torch.nn.Module): | |
| """A generic multitask loss class that takes a tuple of loss functions as input""" | |
| def __init__(self, loss_fns, reduction='sum'): | |
| super().__init__() | |
| self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions | |
| self.loss_fns = loss_fns # store the tuple of loss functions | |
| self.reduction = reduction | |
| def forward(self, preds, target): | |
| if isinstance(preds, torch.Tensor): | |
| preds = (preds,) | |
| if isinstance(target, torch.Tensor): | |
| target = (target,) | |
| # compute the weighted losses for each task by applying the corresponding loss function and weight | |
| # losses = [weight * loss_fn(p, t) | |
| # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] | |
| losses = [] | |
| for loss_fn, p, t in zip_longest(self.loss_fns, preds, target): | |
| if t is not None: | |
| loss = loss_fn(p, t) | |
| else: | |
| loss = loss_fn(p) | |
| losses.append(loss) | |
| reduced_loss = None | |
| # apply reduction if specified | |
| if self.reduction == 'sum': | |
| reduced_loss = sum(losses) | |
| elif self.reduction == 'mean': | |
| reduced_loss = sum(losses) / self.n_tasks | |
| # return the tuple of losses or the reduced value | |
| return reduced_loss | |
| class MultitaskWeightedLoss(MultitaskLoss): | |
| """A multitask loss class that takes a tuple of loss functions and weights as input""" | |
| def __init__(self, loss_fns, weights, reduction='sum'): | |
| super().__init__(loss_fns, reduction) | |
| self.weights = weights # store the tuple of weights | |
| def forward(self, preds, target): | |
| if isinstance(preds, torch.Tensor): | |
| preds = (preds,) | |
| if isinstance(target, torch.Tensor): | |
| target = (target,) | |
| # compute the weighted losses for each task by applying the corresponding loss function and weight | |
| # losses = [weight * loss_fn(p, t) | |
| # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] | |
| losses = [] | |
| for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target): | |
| if t is not None: | |
| loss = weight * loss_fn(p, t) | |
| else: | |
| loss = weight * loss_fn(p) | |
| losses.append(loss) | |
| reduced_loss = None | |
| # apply reduction if specified | |
| if self.reduction == 'sum': | |
| reduced_loss = sum(losses) | |
| elif self.reduction == 'mean': | |
| reduced_loss = sum(losses) / self.n_tasks | |
| # return the tuple of losses or the reduced value | |
| return reduced_loss | |
| class MultitaskUncertaintyLoss(MultitaskLoss): | |
| """ | |
| Modified from https://arxiv.org/abs/1705.07115. | |
| Removed task-specific scale factor for flexibility. | |
| """ | |
| def __init__(self, loss_fns): | |
| # for loss_fn in loss_fns: | |
| # loss_fn.reduction = 'none' | |
| super().__init__(loss_fns, reduction='none') | |
| self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True)) | |
| def forward(self, preds, targets, rescale=True): | |
| losses = super().forward(preds, targets) | |
| stds = torch.exp(self.log_vars / 2) | |
| coeffs = 1 / (stds ** 2) | |
| loss = coeffs * losses + torch.log(stds) | |
| return loss | |
| class MultitaskAutomaticWeightedLoss(MultitaskLoss): | |
| """Automatically weighted multitask loss | |
| Params: | |
| loss_fns: tuple of loss functions | |
| num: int, the number of losses | |
| x: multitask loss | |
| Examples: | |
| loss1 = 1 | |
| loss2 = 2 | |
| awl = AutomaticWeightedLoss(2) | |
| loss_sum = awl(loss1, loss2) | |
| """ | |
| def __init__(self, loss_fns): | |
| super().__init__(loss_fns, reduction='none') | |
| self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True)) | |
| def forward(self, preds, target): | |
| losses = super().forward(preds, target) | |
| loss = sum( | |
| 0.5 / (param ** 2) * loss + torch.log(1 + param ** 2) | |
| for param, loss in zip(self.params, losses) | |
| ) | |
| return loss | |