from typing import Dict, List, Optional, Tuple, Union import math import torch import torch.nn as nn from einops import rearrange, repeat import lpips from facenet_pytorch import InceptionResnetV1 from ...modules.autoencoding.lpips.loss.lpips import LPIPS from ...modules.encoders.modules import GeneralConditioner, ConcatTimestepEmbedderND from ...util import append_dims, instantiate_from_config, default from ...modules.autoencoding.temporal_ae import VideoDecoder from ...data.data_utils import extract_face def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000): y_samples = torch.randn(sample_num) * s + m x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) return x_samples def mu_t(t, a=5, mu_max=1): t = t.to("cpu") return 2 * mu_max * t**a - mu_max def get_sigma_s(t, a, beta_m): mu = mu_t(t, a=a) sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m) return sigma_s class StandardDiffusionLoss(nn.Module): def __init__( self, sigma_sampler_config: dict, loss_weighting_config: dict, loss_type: str = "l2", offset_noise_level: float = 0.0, batch2model_keys: Optional[Union[str, List[str]]] = None, lambda_lower: float = 1.0, lambda_upper: float = 1.0, fix_image_leak: bool = False, add_lpips: bool = False, weight_pixel: float = 0.0, n_frames_pixel: Optional[int] = 1, what_pixel_losses: Optional[List[str]] = [], disable_first_stage_autocast: bool = True, ): super().__init__() assert loss_type in ["l2", "l1", "lpips"] self.sigma_sampler = instantiate_from_config(sigma_sampler_config) self.loss_weighting = instantiate_from_config(loss_weighting_config) self.loss_type = loss_type self.offset_noise_level = offset_noise_level self.lambda_lower = lambda_lower self.lambda_upper = lambda_upper self.add_lpips = add_lpips self.weight_pixel = weight_pixel self.n_frames_pixel = n_frames_pixel self.what_pixel_losses = what_pixel_losses self.en_and_decode_n_samples_a_time = 1 self.disable_first_stage_autocast = disable_first_stage_autocast if loss_type == "lpips": self.lpips = LPIPS().eval() if add_lpips or "lpips" in what_pixel_losses: self.lpips = lpips.LPIPS(net="vgg").eval() if "id" in what_pixel_losses or "id_mse" in what_pixel_losses: self.id_model = InceptionResnetV1(pretrained="vggface2").eval().cuda() for param in self.id_model.parameters(): param.requires_grad = False if not batch2model_keys: batch2model_keys = [] if isinstance(batch2model_keys, str): batch2model_keys = [batch2model_keys] self.batch2model_keys = set(batch2model_keys) self.fix_image_leak = fix_image_leak if fix_image_leak: self.beta_m = 15 self.a = 5 self.noise_encoder = ConcatTimestepEmbedderND(256) def get_noised_input( self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: noised_input = input + noise * sigmas_bc return noised_input def decode_first_stage(self, z, first_stage_model): if len(z.shape) == 5: z = rearrange(z, "b c t h w -> (b t) c h w") z = 1.0 / 0.18215 * z n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} else: kwargs = {} out = first_stage_model.decode( z[n * n_samples : (n + 1) * n_samples], **kwargs ) all_out.append(out) out = torch.cat(all_out, dim=0) # out = rearrange(out, "b c h w -> b h w c") torch.cuda.empty_cache() return out.clip(-1, 1) def forward( self, network: nn.Module, denoiser: nn.Module, conditioner: GeneralConditioner, input: torch.Tensor, batch: Dict, first_stage_model: nn.Module = None, ) -> torch.Tensor: cond = conditioner(batch) return self._forward(network, denoiser, cond, input, batch, first_stage_model) def _forward( self, network: nn.Module, denoiser: nn.Module, cond: Dict, input: torch.Tensor, batch: Dict, first_stage_model: nn.Module = None, ) -> Tuple[torch.Tensor, Dict]: additional_model_inputs = { key: batch[key] for key in self.batch2model_keys.intersection(batch) } sigmas = self.sigma_sampler(input.shape[0]).to(input) noise = torch.randn_like(input) if self.offset_noise_level > 0.0: offset_shape = ( (input.shape[0], 1, input.shape[2]) if self.n_frames is not None else (input.shape[0], input.shape[1]) ) noise = noise + self.offset_noise_level * append_dims( torch.randn(offset_shape, device=input.device), input.ndim, ) sigmas_bc = append_dims(sigmas, input.ndim) noised_input = self.get_noised_input(sigmas_bc, noise, input) if self.fix_image_leak: noise_aug_strength = get_sigma_s(sigmas / 700, self.a, self.beta_m) noise_aug = append_dims(noise_aug_strength, 4).to(input.device) noise = torch.randn_like(noise_aug) cond["concat"] = self.get_noised_input(noise_aug, noise, cond["concat"]) noise_emb = self.noise_encoder(noise_aug_strength).to(input.device) # cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1) cond["vector"] = noise_emb # print(cond["concat"].shape, cond["vector"].shape, noise.shape, noise_aug.shape, noise_emb.shape) model_output = denoiser( network, noised_input, sigmas, cond, **additional_model_inputs ) mask = cond.get("masks", None) w = append_dims(self.loss_weighting(sigmas), input.ndim) return self.get_loss( model_output, input, w, sigmas, mask, first_stage_model, batch.get("original_frames", None), batch.get("landmarks", None), ) def get_loss( self, model_output, target, w, sigmas, mask=None, first_stage_model=None, original_frames=None, landmarks=None, ): scaling_w = w[:, 0, 0, 0] T = 1 if target.ndim == 5: target = rearrange(target, "b c t h w -> (b t) c h w") B = w.shape[0] T = target.shape[0] // B if w.shape[2] != T: w = repeat(w, "b () () () () -> (b t) () () ()", t=T) else: w = rearrange(w, "b c t h w -> (b t) c h w") or_w = w.clone() if self.lambda_lower != 1.0: weight_lower = torch.ones_like(model_output, device=w.device) weight_lower[:, :, model_output.shape[2] // 2 :] *= self.lambda_lower w = weight_lower * w if self.lambda_upper != 1.0: weight_upper = torch.ones_like(model_output, device=w.device) weight_upper[:, :, : model_output.shape[2] // 2] *= self.lambda_upper w = weight_upper * w loss_dict = {} if self.loss_type == "l2": loss = torch.mean( (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 ) elif self.loss_type == "l1": loss = torch.mean( (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 ) elif self.loss_type == "lpips": loss = self.lpips(model_output, target).reshape(-1) else: raise NotImplementedError(f"Unknown loss type {self.loss_type}") loss_dict[self.loss_type] = loss.clone() loss_dict["loss"] = loss if self.add_lpips: loss_dict["lpips"] = w[:, 0, 0, 0] * self.lpips( (model_output[:, :3] * 0.18215).clip(-1, 1), (target[:, :3] * 0.18215).clip(-1, 1), ).reshape(-1) loss_dict["loss"] += loss_dict["lpips"].mean() if self.weight_pixel > 0.0: assert original_frames is not None # Randomly select n_frames_pixel frames selected_frames = torch.randperm(T)[: self.n_frames_pixel] selected_model_output = rearrange( model_output, "(b t) ... -> b t ...", t=T )[:, selected_frames] selected_model_output = rearrange( selected_model_output, "b t ... -> (b t) ..." ) selected_original_frames = original_frames[:, :, selected_frames] selected_original_frames = rearrange( selected_original_frames, "b c t ... -> (b t) c ..." ) selected_w = rearrange(or_w, "(b t) ... -> b t ...", t=T)[ :, selected_frames ] selected_w = rearrange(selected_w, "b t ... -> (b t) ...") if selected_w.shape[-1] != selected_original_frames.shape[-1]: # Interpolate the weights to match the number of frames selected_w = torch.nn.functional.interpolate( selected_w, size=selected_original_frames.shape[-1], mode="nearest" ) decoded_frames = self.decode_first_stage( selected_model_output, first_stage_model ) # print(decoded_frames.shape, selected_original_frames.shape, selected_w.shape) for loss_name in self.what_pixel_losses: if loss_name == "l2": # print(selected_w.shape, decoded_frames.shape, selected_original_frames.shape) loss_pixel = torch.mean( ( selected_w * (decoded_frames - selected_original_frames) ** 2 ).reshape(selected_original_frames.shape[0], -1), 1, ) loss_dict["pixel_l2"] = self.weight_pixel * loss_pixel.mean() loss += self.weight_pixel * loss_pixel.mean() elif loss_name == "lpips": loss_pixel = ( self.lpips(decoded_frames, selected_original_frames).reshape(-1) * scaling_w ) loss_dict["pixel_lpips"] = loss_pixel.mean() loss += self.weight_pixel * loss_pixel.mean() elif loss_name == "l1": loss_pixel = torch.mean( ( selected_w * (decoded_frames - selected_original_frames).abs() ).reshape(selected_original_frames.shape[0], -1), 1, ) loss_dict["pixel_l1"] = self.weight_pixel * loss_pixel.mean() loss += self.weight_pixel * loss_pixel.mean() elif loss_name == "id": landmarks = landmarks[:, selected_frames] cat_id_input = ( ( torch.cat([decoded_frames, selected_original_frames], dim=0) + 1 ) / 2 ) * 255 cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0) cat_id_landmarks = ( rearrange(cat_id_landmarks, "b t ... -> (b t) ...") .cpu() .numpy() ) try: cropped_decoded_frames = extract_face( rearrange(cat_id_input, "b c h w -> b h w c"), cat_id_landmarks, margin=30, postprocess=True, ) # Save first frame to debug n = cat_id_input.shape[0] // 2 id_embeddings = self.id_model( rearrange(cropped_decoded_frames, "b h w c -> b c h w") ) pred_embeddings, target_embeddings = ( id_embeddings[:n], id_embeddings[n:], ) # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized) id_w = scaling_w loss_pixel = ( id_w * ( 1 - torch.nn.functional.cosine_similarity( pred_embeddings, target_embeddings ) ) ).mean() loss_dict["pixel_id"] = self.weight_pixel * loss_pixel loss += self.weight_pixel * loss_pixel except RuntimeError as e: if "adaptive_avg_pool2d()" in str(e): print( "Warning: Invalid face crop dimensions, skipping ID loss for this batch" ) loss_dict["pixel_id"] = torch.tensor( 0.0, device=cat_id_input.device ) continue else: raise # Re-raise other RuntimeErrors elif loss_name == "id_mse": landmarks = landmarks[:, selected_frames] cat_id_input = ( ( torch.cat([decoded_frames, selected_original_frames], dim=0) + 1 ) / 2 ) * 255 cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0) cat_id_landmarks = ( rearrange(cat_id_landmarks, "b t ... -> (b t) ...") .cpu() .numpy() ) cropped_decoded_frames = extract_face( rearrange(cat_id_input, "b c h w -> b h w c"), cat_id_landmarks, margin=30, postprocess=True, ) # Save first frame to debug n = cat_id_input.shape[0] // 2 id_embeddings = self.id_model( rearrange(cropped_decoded_frames, "b h w c -> b c h w") ) pred_embeddings, target_embeddings = ( id_embeddings[:n], id_embeddings[n:], ) # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized) id_w = append_dims( self.loss_weighting(sigmas), pred_embeddings.ndim ) loss_pixel = ( id_w * ((pred_embeddings - target_embeddings) ** 2) ).mean() loss_dict["pixel_id_mse"] = self.weight_pixel * loss_pixel loss += self.weight_pixel * loss_pixel else: raise NotImplementedError(f"Unknown pixel loss type {loss_name}") return loss_dict