|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | @Author  :   Peike Li | 
					
						
						|  | @Contact :   [email protected] | 
					
						
						|  | @File    :   criterion.py | 
					
						
						|  | @Time    :   8/30/19 8:59 PM | 
					
						
						|  | @Desc    : | 
					
						
						|  | @License :   This source code is licensed under the license found in the | 
					
						
						|  | LICENSE file in the root directory of this source tree. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | from .lovasz_softmax import LovaszSoftmax | 
					
						
						|  | from .kl_loss import KLDivergenceLoss | 
					
						
						|  | from .consistency_loss import ConsistencyLoss | 
					
						
						|  |  | 
					
						
						|  | NUM_CLASSES = 20 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CriterionAll(nn.Module): | 
					
						
						|  | def __init__(self, use_class_weight=False, ignore_index=255, lambda_1=1, lambda_2=1, lambda_3=1, | 
					
						
						|  | num_classes=20): | 
					
						
						|  | super(CriterionAll, self).__init__() | 
					
						
						|  | self.ignore_index = ignore_index | 
					
						
						|  | self.use_class_weight = use_class_weight | 
					
						
						|  | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) | 
					
						
						|  | self.lovasz = LovaszSoftmax(ignore_index=ignore_index) | 
					
						
						|  | self.kldiv = KLDivergenceLoss(ignore_index=ignore_index) | 
					
						
						|  | self.reg = ConsistencyLoss(ignore_index=ignore_index) | 
					
						
						|  | self.lamda_1 = lambda_1 | 
					
						
						|  | self.lamda_2 = lambda_2 | 
					
						
						|  | self.lamda_3 = lambda_3 | 
					
						
						|  | self.num_classes = num_classes | 
					
						
						|  |  | 
					
						
						|  | def parsing_loss(self, preds, target, cycle_n=None): | 
					
						
						|  | """ | 
					
						
						|  | Loss function definition. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | preds: [[parsing result1, parsing result2],[edge result]] | 
					
						
						|  | target: [parsing label, egde label] | 
					
						
						|  | soft_preds: [[parsing result1, parsing result2],[edge result]] | 
					
						
						|  | Returns: | 
					
						
						|  | Calculated Loss. | 
					
						
						|  | """ | 
					
						
						|  | h, w = target[0].size(1), target[0].size(2) | 
					
						
						|  |  | 
					
						
						|  | pos_num = torch.sum(target[1] == 1, dtype=torch.float) | 
					
						
						|  | neg_num = torch.sum(target[1] == 0, dtype=torch.float) | 
					
						
						|  |  | 
					
						
						|  | weight_pos = neg_num / (pos_num + neg_num) | 
					
						
						|  | weight_neg = pos_num / (pos_num + neg_num) | 
					
						
						|  | weights = torch.tensor([weight_neg, weight_pos]) | 
					
						
						|  |  | 
					
						
						|  | loss = 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | preds_parsing = preds[0] | 
					
						
						|  | for pred_parsing in preds_parsing: | 
					
						
						|  | scale_pred = F.interpolate(input=pred_parsing, size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  |  | 
					
						
						|  | loss += 0.5 * self.lamda_1 * self.lovasz(scale_pred, target[0]) | 
					
						
						|  | if target[2] is None: | 
					
						
						|  | loss += 0.5 * self.lamda_1 * self.criterion(scale_pred, target[0]) | 
					
						
						|  | else: | 
					
						
						|  | soft_scale_pred = F.interpolate(input=target[2], size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  | soft_scale_pred = moving_average(soft_scale_pred, to_one_hot(target[0], num_cls=self.num_classes), | 
					
						
						|  | 1.0 / (cycle_n + 1.0)) | 
					
						
						|  | loss += 0.5 * self.lamda_1 * self.kldiv(scale_pred, soft_scale_pred, target[0]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | preds_edge = preds[1] | 
					
						
						|  | for pred_edge in preds_edge: | 
					
						
						|  | scale_pred = F.interpolate(input=pred_edge, size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  | if target[3] is None: | 
					
						
						|  | loss += self.lamda_2 * F.cross_entropy(scale_pred, target[1], | 
					
						
						|  | weights.cuda(), ignore_index=self.ignore_index) | 
					
						
						|  | else: | 
					
						
						|  | soft_scale_edge = F.interpolate(input=target[3], size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  | soft_scale_edge = moving_average(soft_scale_edge, to_one_hot(target[1], num_cls=2), | 
					
						
						|  | 1.0 / (cycle_n + 1.0)) | 
					
						
						|  | loss += self.lamda_2 * self.kldiv(scale_pred, soft_scale_edge, target[0]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | preds_parsing = preds[0] | 
					
						
						|  | preds_edge = preds[1] | 
					
						
						|  | for pred_parsing in preds_parsing: | 
					
						
						|  | scale_pred = F.interpolate(input=pred_parsing, size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  | scale_edge = F.interpolate(input=preds_edge[0], size=(h, w), | 
					
						
						|  | mode='bilinear', align_corners=True) | 
					
						
						|  | loss += self.lamda_3 * self.reg(scale_pred, scale_edge, target[0]) | 
					
						
						|  |  | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def forward(self, preds, target, cycle_n=None): | 
					
						
						|  | loss = self.parsing_loss(preds, target, cycle_n) | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def _generate_weights(self, masks, num_classes): | 
					
						
						|  | """ | 
					
						
						|  | masks: torch.Tensor with shape [B, H, W] | 
					
						
						|  | """ | 
					
						
						|  | masks_label = masks.data.cpu().numpy().astype(np.int64) | 
					
						
						|  | pixel_nums = [] | 
					
						
						|  | tot_pixels = 0 | 
					
						
						|  | for i in range(num_classes): | 
					
						
						|  | pixel_num_of_cls_i = np.sum(masks_label == i).astype(np.float) | 
					
						
						|  | pixel_nums.append(pixel_num_of_cls_i) | 
					
						
						|  | tot_pixels += pixel_num_of_cls_i | 
					
						
						|  | weights = [] | 
					
						
						|  | for i in range(num_classes): | 
					
						
						|  | weights.append( | 
					
						
						|  | (tot_pixels - pixel_nums[i]) / tot_pixels / (num_classes - 1) | 
					
						
						|  | ) | 
					
						
						|  | weights = np.array(weights, dtype=np.float) | 
					
						
						|  |  | 
					
						
						|  | return weights | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def moving_average(target1, target2, alpha=1.0): | 
					
						
						|  | target = 0 | 
					
						
						|  | target += (1.0 - alpha) * target1 | 
					
						
						|  | target += target2 * alpha | 
					
						
						|  | return target | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def to_one_hot(tensor, num_cls, dim=1, ignore_index=255): | 
					
						
						|  | b, h, w = tensor.shape | 
					
						
						|  | tensor[tensor == ignore_index] = 0 | 
					
						
						|  | onehot_tensor = torch.zeros(b, num_cls, h, w).cuda() | 
					
						
						|  | onehot_tensor.scatter_(dim, tensor.unsqueeze(dim), 1) | 
					
						
						|  | return onehot_tensor | 
					
						
						|  |  |