File size: 1,130 Bytes
e3af1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import torch
from typing import Callable

def attribute_preserving_loss(
    generated: torch.Tensor,
    original: torch.Tensor,
    attr_predictor: Callable[[torch.Tensor], torch.Tensor],
    y_target: torch.Tensor,
    lambda_pred: float = 1.0,
    lambda_recon: float = 1.0
) -> torch.Tensor:
    """
    Custom loss enforcing attribute fidelity and identity preservation.
    L_attr(G(z + alpha d)) = lambda_pred * ||f_attr(G(.)) - y_target||^2 + lambda_recon * ||G(z + alpha d) - G(z)||^2
    :param generated: Generated image tensor (B, ...)
    :param original: Original image tensor (B, ...)
    :param attr_predictor: Function mapping image tensor to attribute prediction
    :param y_target: Target attribute value tensor (B, ...)
    :param lambda_pred: Weight for attribute prediction loss
    :param lambda_recon: Weight for reconstruction loss
    :return: Scalar loss tensor
    """
    pred_loss = torch.nn.functional.mse_loss(attr_predictor(generated), y_target)
    recon_loss = torch.nn.functional.mse_loss(generated, original)
    return lambda_pred * pred_loss + lambda_recon * recon_loss