import torch import torch.nn as nn from packaging import version from einops import repeat, rearrange from diffusers.utils import _get_model_file from diffusers.models.modeling_utils import load_state_dict from ...modules.diffusionmodules.augment_pipeline import AugmentPipe from ...modules.encoders.modules import ConcatTimestepEmbedderND from ...util import append_dims OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" class IdentityWrapper(nn.Module): def __init__(self, diffusion_model, compile_model: bool = False): super().__init__() compile = ( torch.compile if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model else lambda x: x ) self.diffusion_model = compile(diffusion_model) def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs) class OpenAIWrapper(IdentityWrapper): def __init__( self, diffusion_model, compile_model: bool = False, ada_aug_percent=0.0, fix_image_leak=False, add_embeddings=False, im_size=[64, 64], n_channels=4, ): super().__init__(diffusion_model, compile_model) self.fix_image_leak = fix_image_leak if fix_image_leak: self.beta_m = 15 self.a = 5 self.noise_encoder = ConcatTimestepEmbedderND(256) self.augment_pipe = None if ada_aug_percent > 0.0: augment_kwargs = dict( xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1 ) self.augment_pipe = AugmentPipe(ada_aug_percent, **augment_kwargs) self.add_embeddings = add_embeddings if add_embeddings: self.learned_mask = nn.Parameter( torch.zeros(n_channels, im_size[0], im_size[1]) ) def get_noised_input( self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: noised_input = input + noise * sigmas_bc return noised_input def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) if len(cond_cat.shape) and cond_cat.shape[0]: T = x.shape[0] // cond_cat.shape[0] if self.fix_image_leak: noise_aug_strength = get_sigma_s( rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700, self.a, self.beta_m, ) noise_aug = append_dims(noise_aug_strength, 4).to(x.device) noise = torch.randn_like(noise_aug) cond_cat = self.get_noised_input(noise_aug, noise, cond_cat) noise_emb = self.noise_encoder(noise_aug_strength).to(x.device) c["vector"] = ( noise_emb if "vector" not in c else torch.cat([c["vector"], noise_emb], dim=1) ) if ( len(cond_cat.shape) and cond_cat.shape[0] and x.shape[0] != cond_cat.shape[0] ): cond_cat = repeat(cond_cat, "b c h w -> b c t h w", t=T) cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w") x = torch.cat((x, cond_cat), dim=1) if self.add_embeddings: learned_mask = repeat( self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0] ) x = torch.cat((x, learned_mask), dim=1) if self.augment_pipe is not None: x, labels = self.augment_pipe(x) else: labels = torch.zeros(x.shape[0], 9, device=x.device) return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), reference_context=c.get("reference", None), y=c.get("vector", None), audio_emb=c.get("audio_emb", None), landmarks=c.get("landmarks", None), aug_labels=labels, **kwargs, ) class DubbingWrapper(IdentityWrapper): def __init__(self, diffusion_model, compile_model: bool = False, mask_input=False): super().__init__(diffusion_model, compile_model) self.mask_input = mask_input def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) if len(cond_cat.shape): T = x.shape[0] // cond_cat.shape[0] if cond_cat.shape[1] == 4: cond_cat = repeat(cond_cat, "b c h w -> b (t c) h w", t=T) cond_cat = rearrange(cond_cat, "b (t c) h w -> (b t) c h w", t=T) x = torch.cat((x, cond_cat), dim=1) out = self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), audio_emb=c.get("audio_emb", None), skip_spatial_attention_at=c.get("skip_spatial_attention_at", None), skip_temporal_attention_at=c.get("skip_temporal_attention_at", None), **kwargs, ) return out class StabilityWrapper(IdentityWrapper): def __init__( self, diffusion_model, compile_model: bool = False, use_ipadapter: bool = False, ipadapter_model: str = "ip-adapter_sd15.bin", adapter_scale: float = 1.0, n_adapters: int = 1, skip_text_emb: bool = False, # pass_image_emb_to_hidden_states: bool = False, ): super().__init__(diffusion_model, compile_model) self.use_ipadapter = use_ipadapter # self.pass_image_emb_to_hidden_states = pass_image_emb_to_hidden_states if use_ipadapter: model_file = _get_model_file( "h94/IP-Adapter", weights_name=ipadapter_model, # ip-adapter_sd15.bin # cache_dir="/vol/paramonos2/projects/antoni/.cache", subfolder="models", ) state_dict = load_state_dict(model_file) state_dict = [load_state_dict(model_file)] * n_adapters print(f"Loading IP-Adapter weights from {model_file}") diffusion_model.set_ip_adapter_scale(adapter_scale) def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: added_cond_kwargs = None if self.use_ipadapter: added_cond_kwargs = {"image_embeds": c.get("image_embeds", None)} landmarks = c.get("landmarks", None) if landmarks is not None: added_cond_kwargs["image_embeds"] = [ added_cond_kwargs["image_embeds"], landmarks, ] cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) if len(cond_cat.shape) and cond_cat.shape[0]: cond_cat = repeat( cond_cat, "b c h w -> b c t h w", t=x.shape[0] // cond_cat.shape[0] ) cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w") x = torch.cat((x, cond_cat), dim=1) return self.diffusion_model( x, t, encoder_hidden_states=c.get("crossattn", None), added_cond_kwargs=added_cond_kwargs, audio_emb=c.get("audio_emb", None), **kwargs, )[0] def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000): y_samples = torch.randn(sample_num) * s + m x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) return x_samples def mu_t(t, a=5, mu_max=1): t = t.to("cpu") return 2 * mu_max * t**a - mu_max def get_sigma_s(t, a, beta_m): mu = mu_t(t, a=a) sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m) return sigma_s class InterpolationWrapper(IdentityWrapper): def __init__( self, diffusion_model, compile_model: bool = False, im_size=[512, 512], n_channels=4, starting_mask_method="zeros", add_mask=True, fix_image_leak=False, ): super().__init__(diffusion_model, compile_model) im_size = [ x // 8 for x in im_size ] # 8 is the default downscaling factor in the vae model if starting_mask_method == "zeros": self.learned_mask = nn.Parameter( torch.zeros(n_channels, im_size[0], im_size[1]) ) elif starting_mask_method == "ones": self.learned_mask = nn.Parameter( torch.ones(n_channels, im_size[0], im_size[1]) ) elif starting_mask_method == "random": self.learned_mask = nn.Parameter( torch.randn(n_channels, im_size[0], im_size[1]) ) elif starting_mask_method == "none": self.learned_mask = None elif starting_mask_method == "fixed_ones": self.learned_mask = torch.ones(n_channels, im_size[0], im_size[1]) elif starting_mask_method == "fixed_zeros": self.learned_mask = torch.zeros(n_channels, im_size[0], im_size[1]) else: raise NotImplementedError( f"Unknown stating_mask_method: {starting_mask_method}" ) self.add_mask = add_mask self.fix_image_leak = fix_image_leak if fix_image_leak: self.beta_m = 15 self.a = 5 self.noise_encoder = ConcatTimestepEmbedderND(256) def get_noised_input( self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: noised_input = input + noise * sigmas_bc return noised_input def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) T = x.shape[0] // cond_cat.shape[0] if self.fix_image_leak: noise_aug_strength = get_sigma_s( rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700, self.a, self.beta_m, ) noise_aug = append_dims(noise_aug_strength, 4).to(x.device) noise = torch.randn_like(noise_aug) cond_cat = self.get_noised_input(noise_aug, noise, cond_cat) noise_emb = self.noise_encoder(noise_aug_strength).to(x.device) # cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1) c["vector"] = noise_emb cond_cat = rearrange(cond_cat, "b (t c) h w -> b c t h w", t=2) start, end = cond_cat.chunk(2, dim=2) if self.learned_mask is None: learned_mask = torch.stack( [start.squeeze(2)] * (T // 2 - 1) + [end.squeeze(2)] * (T // 2 - 1), dim=2, ) else: learned_mask = repeat( self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0] ) ones_mask = torch.ones_like(learned_mask)[:, 0].unsqueeze(1) zeros_mask = torch.zeros_like(learned_mask)[:, 0].unsqueeze(1) if self.learned_mask is None: cond_seq = torch.cat([start] + [learned_mask] + [end], dim=2) else: cond_seq = torch.stack( [start.squeeze(2)] + [learned_mask] * (T - 2) + [end.squeeze(2)], dim=2 ) cond_seq = rearrange(cond_seq, "b c t h w -> (b t) c h w") x = torch.cat((x, cond_seq), dim=1) if self.add_mask: mask_seq = torch.stack( [ones_mask] + [zeros_mask] * (T - 2) + [ones_mask], dim=2 ) mask_seq = rearrange(mask_seq, "b c t h w -> (b t) c h w") x = torch.cat((x, mask_seq), dim=1) return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), audio_emb=c.get("audio_emb", None), **kwargs, )