File size: 1,130 Bytes
e3af1ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import numpy as np
import torch
from typing import Callable
def attribute_preserving_loss(
generated: torch.Tensor,
original: torch.Tensor,
attr_predictor: Callable[[torch.Tensor], torch.Tensor],
y_target: torch.Tensor,
lambda_pred: float = 1.0,
lambda_recon: float = 1.0
) -> torch.Tensor:
"""
Custom loss enforcing attribute fidelity and identity preservation.
L_attr(G(z + alpha d)) = lambda_pred * ||f_attr(G(.)) - y_target||^2 + lambda_recon * ||G(z + alpha d) - G(z)||^2
:param generated: Generated image tensor (B, ...)
:param original: Original image tensor (B, ...)
:param attr_predictor: Function mapping image tensor to attribute prediction
:param y_target: Target attribute value tensor (B, ...)
:param lambda_pred: Weight for attribute prediction loss
:param lambda_recon: Weight for reconstruction loss
:return: Scalar loss tensor
"""
pred_loss = torch.nn.functional.mse_loss(attr_predictor(generated), y_target)
recon_loss = torch.nn.functional.mse_loss(generated, original)
return lambda_pred * pred_loss + lambda_recon * recon_loss |