Transcendental-Programmer
Refactor core logic: move and modularize all latent space, sampling, and utility code into faceforge_core/
e3af1ef
import numpy as np | |
import torch | |
from typing import Callable | |
def attribute_preserving_loss( | |
generated: torch.Tensor, | |
original: torch.Tensor, | |
attr_predictor: Callable[[torch.Tensor], torch.Tensor], | |
y_target: torch.Tensor, | |
lambda_pred: float = 1.0, | |
lambda_recon: float = 1.0 | |
) -> torch.Tensor: | |
""" | |
Custom loss enforcing attribute fidelity and identity preservation. | |
L_attr(G(z + alpha d)) = lambda_pred * ||f_attr(G(.)) - y_target||^2 + lambda_recon * ||G(z + alpha d) - G(z)||^2 | |
:param generated: Generated image tensor (B, ...) | |
:param original: Original image tensor (B, ...) | |
:param attr_predictor: Function mapping image tensor to attribute prediction | |
:param y_target: Target attribute value tensor (B, ...) | |
:param lambda_pred: Weight for attribute prediction loss | |
:param lambda_recon: Weight for reconstruction loss | |
:return: Scalar loss tensor | |
""" | |
pred_loss = torch.nn.functional.mse_loss(attr_predictor(generated), y_target) | |
recon_loss = torch.nn.functional.mse_loss(generated, original) | |
return lambda_pred * pred_loss + lambda_recon * recon_loss |