Spaces:
Running
Running
| from typing import Dict, List, Optional, Tuple, Union | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| import lpips | |
| from facenet_pytorch import InceptionResnetV1 | |
| from ...modules.autoencoding.lpips.loss.lpips import LPIPS | |
| from ...modules.encoders.modules import GeneralConditioner, ConcatTimestepEmbedderND | |
| from ...util import append_dims, instantiate_from_config, default | |
| from ...modules.autoencoding.temporal_ae import VideoDecoder | |
| from ...data.data_utils import extract_face | |
| def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000): | |
| y_samples = torch.randn(sample_num) * s + m | |
| x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) | |
| return x_samples | |
| def mu_t(t, a=5, mu_max=1): | |
| t = t.to("cpu") | |
| return 2 * mu_max * t**a - mu_max | |
| def get_sigma_s(t, a, beta_m): | |
| mu = mu_t(t, a=a) | |
| sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m) | |
| return sigma_s | |
| class StandardDiffusionLoss(nn.Module): | |
| def __init__( | |
| self, | |
| sigma_sampler_config: dict, | |
| loss_weighting_config: dict, | |
| loss_type: str = "l2", | |
| offset_noise_level: float = 0.0, | |
| batch2model_keys: Optional[Union[str, List[str]]] = None, | |
| lambda_lower: float = 1.0, | |
| lambda_upper: float = 1.0, | |
| fix_image_leak: bool = False, | |
| add_lpips: bool = False, | |
| weight_pixel: float = 0.0, | |
| n_frames_pixel: Optional[int] = 1, | |
| what_pixel_losses: Optional[List[str]] = [], | |
| disable_first_stage_autocast: bool = True, | |
| ): | |
| super().__init__() | |
| assert loss_type in ["l2", "l1", "lpips"] | |
| self.sigma_sampler = instantiate_from_config(sigma_sampler_config) | |
| self.loss_weighting = instantiate_from_config(loss_weighting_config) | |
| self.loss_type = loss_type | |
| self.offset_noise_level = offset_noise_level | |
| self.lambda_lower = lambda_lower | |
| self.lambda_upper = lambda_upper | |
| self.add_lpips = add_lpips | |
| self.weight_pixel = weight_pixel | |
| self.n_frames_pixel = n_frames_pixel | |
| self.what_pixel_losses = what_pixel_losses | |
| self.en_and_decode_n_samples_a_time = 1 | |
| self.disable_first_stage_autocast = disable_first_stage_autocast | |
| if loss_type == "lpips": | |
| self.lpips = LPIPS().eval() | |
| if add_lpips or "lpips" in what_pixel_losses: | |
| self.lpips = lpips.LPIPS(net="vgg").eval() | |
| if "id" in what_pixel_losses or "id_mse" in what_pixel_losses: | |
| self.id_model = InceptionResnetV1(pretrained="vggface2").eval().cuda() | |
| for param in self.id_model.parameters(): | |
| param.requires_grad = False | |
| if not batch2model_keys: | |
| batch2model_keys = [] | |
| if isinstance(batch2model_keys, str): | |
| batch2model_keys = [batch2model_keys] | |
| self.batch2model_keys = set(batch2model_keys) | |
| self.fix_image_leak = fix_image_leak | |
| if fix_image_leak: | |
| self.beta_m = 15 | |
| self.a = 5 | |
| self.noise_encoder = ConcatTimestepEmbedderND(256) | |
| def get_noised_input( | |
| self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor | |
| ) -> torch.Tensor: | |
| noised_input = input + noise * sigmas_bc | |
| return noised_input | |
| def decode_first_stage(self, z, first_stage_model): | |
| if len(z.shape) == 5: | |
| z = rearrange(z, "b c t h w -> (b t) c h w") | |
| z = 1.0 / 0.18215 * z | |
| n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) | |
| n_rounds = math.ceil(z.shape[0] / n_samples) | |
| all_out = [] | |
| with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): | |
| for n in range(n_rounds): | |
| if isinstance(first_stage_model.decoder, VideoDecoder): | |
| kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} | |
| else: | |
| kwargs = {} | |
| out = first_stage_model.decode( | |
| z[n * n_samples : (n + 1) * n_samples], **kwargs | |
| ) | |
| all_out.append(out) | |
| out = torch.cat(all_out, dim=0) | |
| # out = rearrange(out, "b c h w -> b h w c") | |
| torch.cuda.empty_cache() | |
| return out.clip(-1, 1) | |
| def forward( | |
| self, | |
| network: nn.Module, | |
| denoiser: nn.Module, | |
| conditioner: GeneralConditioner, | |
| input: torch.Tensor, | |
| batch: Dict, | |
| first_stage_model: nn.Module = None, | |
| ) -> torch.Tensor: | |
| cond = conditioner(batch) | |
| return self._forward(network, denoiser, cond, input, batch, first_stage_model) | |
| def _forward( | |
| self, | |
| network: nn.Module, | |
| denoiser: nn.Module, | |
| cond: Dict, | |
| input: torch.Tensor, | |
| batch: Dict, | |
| first_stage_model: nn.Module = None, | |
| ) -> Tuple[torch.Tensor, Dict]: | |
| additional_model_inputs = { | |
| key: batch[key] for key in self.batch2model_keys.intersection(batch) | |
| } | |
| sigmas = self.sigma_sampler(input.shape[0]).to(input) | |
| noise = torch.randn_like(input) | |
| if self.offset_noise_level > 0.0: | |
| offset_shape = ( | |
| (input.shape[0], 1, input.shape[2]) | |
| if self.n_frames is not None | |
| else (input.shape[0], input.shape[1]) | |
| ) | |
| noise = noise + self.offset_noise_level * append_dims( | |
| torch.randn(offset_shape, device=input.device), | |
| input.ndim, | |
| ) | |
| sigmas_bc = append_dims(sigmas, input.ndim) | |
| noised_input = self.get_noised_input(sigmas_bc, noise, input) | |
| if self.fix_image_leak: | |
| noise_aug_strength = get_sigma_s(sigmas / 700, self.a, self.beta_m) | |
| noise_aug = append_dims(noise_aug_strength, 4).to(input.device) | |
| noise = torch.randn_like(noise_aug) | |
| cond["concat"] = self.get_noised_input(noise_aug, noise, cond["concat"]) | |
| noise_emb = self.noise_encoder(noise_aug_strength).to(input.device) | |
| # cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1) | |
| cond["vector"] = noise_emb | |
| # print(cond["concat"].shape, cond["vector"].shape, noise.shape, noise_aug.shape, noise_emb.shape) | |
| model_output = denoiser( | |
| network, noised_input, sigmas, cond, **additional_model_inputs | |
| ) | |
| mask = cond.get("masks", None) | |
| w = append_dims(self.loss_weighting(sigmas), input.ndim) | |
| return self.get_loss( | |
| model_output, | |
| input, | |
| w, | |
| sigmas, | |
| mask, | |
| first_stage_model, | |
| batch.get("original_frames", None), | |
| batch.get("landmarks", None), | |
| ) | |
| def get_loss( | |
| self, | |
| model_output, | |
| target, | |
| w, | |
| sigmas, | |
| mask=None, | |
| first_stage_model=None, | |
| original_frames=None, | |
| landmarks=None, | |
| ): | |
| scaling_w = w[:, 0, 0, 0] | |
| T = 1 | |
| if target.ndim == 5: | |
| target = rearrange(target, "b c t h w -> (b t) c h w") | |
| B = w.shape[0] | |
| T = target.shape[0] // B | |
| if w.shape[2] != T: | |
| w = repeat(w, "b () () () () -> (b t) () () ()", t=T) | |
| else: | |
| w = rearrange(w, "b c t h w -> (b t) c h w") | |
| or_w = w.clone() | |
| if self.lambda_lower != 1.0: | |
| weight_lower = torch.ones_like(model_output, device=w.device) | |
| weight_lower[:, :, model_output.shape[2] // 2 :] *= self.lambda_lower | |
| w = weight_lower * w | |
| if self.lambda_upper != 1.0: | |
| weight_upper = torch.ones_like(model_output, device=w.device) | |
| weight_upper[:, :, : model_output.shape[2] // 2] *= self.lambda_upper | |
| w = weight_upper * w | |
| loss_dict = {} | |
| if self.loss_type == "l2": | |
| loss = torch.mean( | |
| (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 | |
| ) | |
| elif self.loss_type == "l1": | |
| loss = torch.mean( | |
| (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 | |
| ) | |
| elif self.loss_type == "lpips": | |
| loss = self.lpips(model_output, target).reshape(-1) | |
| else: | |
| raise NotImplementedError(f"Unknown loss type {self.loss_type}") | |
| loss_dict[self.loss_type] = loss.clone() | |
| loss_dict["loss"] = loss | |
| if self.add_lpips: | |
| loss_dict["lpips"] = w[:, 0, 0, 0] * self.lpips( | |
| (model_output[:, :3] * 0.18215).clip(-1, 1), | |
| (target[:, :3] * 0.18215).clip(-1, 1), | |
| ).reshape(-1) | |
| loss_dict["loss"] += loss_dict["lpips"].mean() | |
| if self.weight_pixel > 0.0: | |
| assert original_frames is not None | |
| # Randomly select n_frames_pixel frames | |
| selected_frames = torch.randperm(T)[: self.n_frames_pixel] | |
| selected_model_output = rearrange( | |
| model_output, "(b t) ... -> b t ...", t=T | |
| )[:, selected_frames] | |
| selected_model_output = rearrange( | |
| selected_model_output, "b t ... -> (b t) ..." | |
| ) | |
| selected_original_frames = original_frames[:, :, selected_frames] | |
| selected_original_frames = rearrange( | |
| selected_original_frames, "b c t ... -> (b t) c ..." | |
| ) | |
| selected_w = rearrange(or_w, "(b t) ... -> b t ...", t=T)[ | |
| :, selected_frames | |
| ] | |
| selected_w = rearrange(selected_w, "b t ... -> (b t) ...") | |
| if selected_w.shape[-1] != selected_original_frames.shape[-1]: | |
| # Interpolate the weights to match the number of frames | |
| selected_w = torch.nn.functional.interpolate( | |
| selected_w, size=selected_original_frames.shape[-1], mode="nearest" | |
| ) | |
| decoded_frames = self.decode_first_stage( | |
| selected_model_output, first_stage_model | |
| ) | |
| # print(decoded_frames.shape, selected_original_frames.shape, selected_w.shape) | |
| for loss_name in self.what_pixel_losses: | |
| if loss_name == "l2": | |
| # print(selected_w.shape, decoded_frames.shape, selected_original_frames.shape) | |
| loss_pixel = torch.mean( | |
| ( | |
| selected_w | |
| * (decoded_frames - selected_original_frames) ** 2 | |
| ).reshape(selected_original_frames.shape[0], -1), | |
| 1, | |
| ) | |
| loss_dict["pixel_l2"] = self.weight_pixel * loss_pixel.mean() | |
| loss += self.weight_pixel * loss_pixel.mean() | |
| elif loss_name == "lpips": | |
| loss_pixel = ( | |
| self.lpips(decoded_frames, selected_original_frames).reshape(-1) | |
| * scaling_w | |
| ) | |
| loss_dict["pixel_lpips"] = loss_pixel.mean() | |
| loss += self.weight_pixel * loss_pixel.mean() | |
| elif loss_name == "l1": | |
| loss_pixel = torch.mean( | |
| ( | |
| selected_w | |
| * (decoded_frames - selected_original_frames).abs() | |
| ).reshape(selected_original_frames.shape[0], -1), | |
| 1, | |
| ) | |
| loss_dict["pixel_l1"] = self.weight_pixel * loss_pixel.mean() | |
| loss += self.weight_pixel * loss_pixel.mean() | |
| elif loss_name == "id": | |
| landmarks = landmarks[:, selected_frames] | |
| cat_id_input = ( | |
| ( | |
| torch.cat([decoded_frames, selected_original_frames], dim=0) | |
| + 1 | |
| ) | |
| / 2 | |
| ) * 255 | |
| cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0) | |
| cat_id_landmarks = ( | |
| rearrange(cat_id_landmarks, "b t ... -> (b t) ...") | |
| .cpu() | |
| .numpy() | |
| ) | |
| try: | |
| cropped_decoded_frames = extract_face( | |
| rearrange(cat_id_input, "b c h w -> b h w c"), | |
| cat_id_landmarks, | |
| margin=30, | |
| postprocess=True, | |
| ) | |
| # Save first frame to debug | |
| n = cat_id_input.shape[0] // 2 | |
| id_embeddings = self.id_model( | |
| rearrange(cropped_decoded_frames, "b h w c -> b c h w") | |
| ) | |
| pred_embeddings, target_embeddings = ( | |
| id_embeddings[:n], | |
| id_embeddings[n:], | |
| ) | |
| # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized) | |
| id_w = scaling_w | |
| loss_pixel = ( | |
| id_w | |
| * ( | |
| 1 | |
| - torch.nn.functional.cosine_similarity( | |
| pred_embeddings, target_embeddings | |
| ) | |
| ) | |
| ).mean() | |
| loss_dict["pixel_id"] = self.weight_pixel * loss_pixel | |
| loss += self.weight_pixel * loss_pixel | |
| except RuntimeError as e: | |
| if "adaptive_avg_pool2d()" in str(e): | |
| print( | |
| "Warning: Invalid face crop dimensions, skipping ID loss for this batch" | |
| ) | |
| loss_dict["pixel_id"] = torch.tensor( | |
| 0.0, device=cat_id_input.device | |
| ) | |
| continue | |
| else: | |
| raise # Re-raise other RuntimeErrors | |
| elif loss_name == "id_mse": | |
| landmarks = landmarks[:, selected_frames] | |
| cat_id_input = ( | |
| ( | |
| torch.cat([decoded_frames, selected_original_frames], dim=0) | |
| + 1 | |
| ) | |
| / 2 | |
| ) * 255 | |
| cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0) | |
| cat_id_landmarks = ( | |
| rearrange(cat_id_landmarks, "b t ... -> (b t) ...") | |
| .cpu() | |
| .numpy() | |
| ) | |
| cropped_decoded_frames = extract_face( | |
| rearrange(cat_id_input, "b c h w -> b h w c"), | |
| cat_id_landmarks, | |
| margin=30, | |
| postprocess=True, | |
| ) | |
| # Save first frame to debug | |
| n = cat_id_input.shape[0] // 2 | |
| id_embeddings = self.id_model( | |
| rearrange(cropped_decoded_frames, "b h w c -> b c h w") | |
| ) | |
| pred_embeddings, target_embeddings = ( | |
| id_embeddings[:n], | |
| id_embeddings[n:], | |
| ) | |
| # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized) | |
| id_w = append_dims( | |
| self.loss_weighting(sigmas), pred_embeddings.ndim | |
| ) | |
| loss_pixel = ( | |
| id_w * ((pred_embeddings - target_embeddings) ** 2) | |
| ).mean() | |
| loss_dict["pixel_id_mse"] = self.weight_pixel * loss_pixel | |
| loss += self.weight_pixel * loss_pixel | |
| else: | |
| raise NotImplementedError(f"Unknown pixel loss type {loss_name}") | |
| return loss_dict | |