import torch l1_loss = torch.nn.L1Loss() l2_loss = torch.nn.MSELoss() def hinge_loss(X, positive=True): if positive: return torch.relu(1-X) else: return torch.relu(X+1) def compute_generator_losses(G, Y, Xt, Xt_attr, Di, embed, ZY, eye_heatmaps, loss_adv_accumulated, diff_person, same_person, args): # adversarial loss L_adv = 0. for di in Di: L_adv += hinge_loss(di[0], True).mean(dim=[1, 2, 3]) L_adv = torch.sum(L_adv * diff_person) / (diff_person.sum() + 1e-4) # id loss L_id =(1 - torch.cosine_similarity(embed, ZY, dim=1)).mean() # attr loss if args.optim_level == "O2" or args.optim_level == "O3": Y_attr = G.get_attr(Y.type(torch.half)) else: Y_attr = G.get_attr(Y) L_attr = 0 for i in range(len(Xt_attr)): L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_attr[i], 2).reshape(args.batch_size, -1), dim=1).mean() L_attr /= 2.0 # reconstruction loss L_rec = torch.sum(0.5 * torch.mean(torch.pow(Y - Xt, 2).reshape(args.batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6) # l2 eyes loss if args.eye_detector_loss: Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right = eye_heatmaps L_l2_eyes = l2_loss(Xt_heatmap_left, Y_heatmap_left) + l2_loss(Xt_heatmap_right, Y_heatmap_right) else: L_l2_eyes = 0 # final loss of generator lossG = args.weight_adv*L_adv + args.weight_attr*L_attr + args.weight_id*L_id + args.weight_rec*L_rec + args.weight_eyes*L_l2_eyes loss_adv_accumulated = loss_adv_accumulated*0.98 + L_adv.item()*0.02 return lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes def compute_discriminator_loss(D, Y, Xs, diff_person): # fake part fake_D = D(Y.detach()) loss_fake = 0 for di in fake_D: loss_fake += torch.sum(hinge_loss(di[0], False).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4) # ground truth part true_D = D(Xs) loss_true = 0 for di in true_D: loss_true += torch.sum(hinge_loss(di[0], True).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4) lossD = 0.5*(loss_true.mean() + loss_fake.mean()) return lossD