Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from packaging import version | |
from einops import repeat, rearrange | |
from diffusers.utils import _get_model_file | |
from diffusers.models.modeling_utils import load_state_dict | |
from ...modules.diffusionmodules.augment_pipeline import AugmentPipe | |
from ...modules.encoders.modules import ConcatTimestepEmbedderND | |
from ...util import append_dims | |
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" | |
class IdentityWrapper(nn.Module): | |
def __init__(self, diffusion_model, compile_model: bool = False): | |
super().__init__() | |
compile = ( | |
torch.compile | |
if (version.parse(torch.__version__) >= version.parse("2.0.0")) | |
and compile_model | |
else lambda x: x | |
) | |
self.diffusion_model = compile(diffusion_model) | |
def forward(self, *args, **kwargs): | |
return self.diffusion_model(*args, **kwargs) | |
class OpenAIWrapper(IdentityWrapper): | |
def __init__( | |
self, | |
diffusion_model, | |
compile_model: bool = False, | |
ada_aug_percent=0.0, | |
fix_image_leak=False, | |
add_embeddings=False, | |
im_size=[64, 64], | |
n_channels=4, | |
): | |
super().__init__(diffusion_model, compile_model) | |
self.fix_image_leak = fix_image_leak | |
if fix_image_leak: | |
self.beta_m = 15 | |
self.a = 5 | |
self.noise_encoder = ConcatTimestepEmbedderND(256) | |
self.augment_pipe = None | |
if ada_aug_percent > 0.0: | |
augment_kwargs = dict( | |
xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1 | |
) | |
self.augment_pipe = AugmentPipe(ada_aug_percent, **augment_kwargs) | |
self.add_embeddings = add_embeddings | |
if add_embeddings: | |
self.learned_mask = nn.Parameter( | |
torch.zeros(n_channels, im_size[0], im_size[1]) | |
) | |
def get_noised_input( | |
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor | |
) -> torch.Tensor: | |
noised_input = input + noise * sigmas_bc | |
return noised_input | |
def forward( | |
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
) -> torch.Tensor: | |
cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) | |
if len(cond_cat.shape) and cond_cat.shape[0]: | |
T = x.shape[0] // cond_cat.shape[0] | |
if self.fix_image_leak: | |
noise_aug_strength = get_sigma_s( | |
rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700, | |
self.a, | |
self.beta_m, | |
) | |
noise_aug = append_dims(noise_aug_strength, 4).to(x.device) | |
noise = torch.randn_like(noise_aug) | |
cond_cat = self.get_noised_input(noise_aug, noise, cond_cat) | |
noise_emb = self.noise_encoder(noise_aug_strength).to(x.device) | |
c["vector"] = ( | |
noise_emb | |
if "vector" not in c | |
else torch.cat([c["vector"], noise_emb], dim=1) | |
) | |
if ( | |
len(cond_cat.shape) | |
and cond_cat.shape[0] | |
and x.shape[0] != cond_cat.shape[0] | |
): | |
cond_cat = repeat(cond_cat, "b c h w -> b c t h w", t=T) | |
cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w") | |
x = torch.cat((x, cond_cat), dim=1) | |
if self.add_embeddings: | |
learned_mask = repeat( | |
self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0] | |
) | |
x = torch.cat((x, learned_mask), dim=1) | |
if self.augment_pipe is not None: | |
x, labels = self.augment_pipe(x) | |
else: | |
labels = torch.zeros(x.shape[0], 9, device=x.device) | |
return self.diffusion_model( | |
x, | |
timesteps=t, | |
context=c.get("crossattn", None), | |
reference_context=c.get("reference", None), | |
y=c.get("vector", None), | |
audio_emb=c.get("audio_emb", None), | |
landmarks=c.get("landmarks", None), | |
aug_labels=labels, | |
**kwargs, | |
) | |
class DubbingWrapper(IdentityWrapper): | |
def __init__(self, diffusion_model, compile_model: bool = False, mask_input=False): | |
super().__init__(diffusion_model, compile_model) | |
self.mask_input = mask_input | |
def forward( | |
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
) -> torch.Tensor: | |
cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) | |
if len(cond_cat.shape): | |
T = x.shape[0] // cond_cat.shape[0] | |
if cond_cat.shape[1] == 4: | |
cond_cat = repeat(cond_cat, "b c h w -> b (t c) h w", t=T) | |
cond_cat = rearrange(cond_cat, "b (t c) h w -> (b t) c h w", t=T) | |
x = torch.cat((x, cond_cat), dim=1) | |
out = self.diffusion_model( | |
x, | |
timesteps=t, | |
context=c.get("crossattn", None), | |
y=c.get("vector", None), | |
audio_emb=c.get("audio_emb", None), | |
skip_spatial_attention_at=c.get("skip_spatial_attention_at", None), | |
skip_temporal_attention_at=c.get("skip_temporal_attention_at", None), | |
**kwargs, | |
) | |
return out | |
class StabilityWrapper(IdentityWrapper): | |
def __init__( | |
self, | |
diffusion_model, | |
compile_model: bool = False, | |
use_ipadapter: bool = False, | |
ipadapter_model: str = "ip-adapter_sd15.bin", | |
adapter_scale: float = 1.0, | |
n_adapters: int = 1, | |
skip_text_emb: bool = False, | |
# pass_image_emb_to_hidden_states: bool = False, | |
): | |
super().__init__(diffusion_model, compile_model) | |
self.use_ipadapter = use_ipadapter | |
# self.pass_image_emb_to_hidden_states = pass_image_emb_to_hidden_states | |
if use_ipadapter: | |
model_file = _get_model_file( | |
"h94/IP-Adapter", | |
weights_name=ipadapter_model, # ip-adapter_sd15.bin | |
# cache_dir="/vol/paramonos2/projects/antoni/.cache", | |
subfolder="models", | |
) | |
state_dict = load_state_dict(model_file) | |
state_dict = [load_state_dict(model_file)] * n_adapters | |
print(f"Loading IP-Adapter weights from {model_file}") | |
diffusion_model.set_ip_adapter_scale(adapter_scale) | |
def forward( | |
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
) -> torch.Tensor: | |
added_cond_kwargs = None | |
if self.use_ipadapter: | |
added_cond_kwargs = {"image_embeds": c.get("image_embeds", None)} | |
landmarks = c.get("landmarks", None) | |
if landmarks is not None: | |
added_cond_kwargs["image_embeds"] = [ | |
added_cond_kwargs["image_embeds"], | |
landmarks, | |
] | |
cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) | |
if len(cond_cat.shape) and cond_cat.shape[0]: | |
cond_cat = repeat( | |
cond_cat, "b c h w -> b c t h w", t=x.shape[0] // cond_cat.shape[0] | |
) | |
cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w") | |
x = torch.cat((x, cond_cat), dim=1) | |
return self.diffusion_model( | |
x, | |
t, | |
encoder_hidden_states=c.get("crossattn", None), | |
added_cond_kwargs=added_cond_kwargs, | |
audio_emb=c.get("audio_emb", None), | |
**kwargs, | |
)[0] | |
def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000): | |
y_samples = torch.randn(sample_num) * s + m | |
x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) | |
return x_samples | |
def mu_t(t, a=5, mu_max=1): | |
t = t.to("cpu") | |
return 2 * mu_max * t**a - mu_max | |
def get_sigma_s(t, a, beta_m): | |
mu = mu_t(t, a=a) | |
sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m) | |
return sigma_s | |
class InterpolationWrapper(IdentityWrapper): | |
def __init__( | |
self, | |
diffusion_model, | |
compile_model: bool = False, | |
im_size=[512, 512], | |
n_channels=4, | |
starting_mask_method="zeros", | |
add_mask=True, | |
fix_image_leak=False, | |
): | |
super().__init__(diffusion_model, compile_model) | |
im_size = [ | |
x // 8 for x in im_size | |
] # 8 is the default downscaling factor in the vae model | |
if starting_mask_method == "zeros": | |
self.learned_mask = nn.Parameter( | |
torch.zeros(n_channels, im_size[0], im_size[1]) | |
) | |
elif starting_mask_method == "ones": | |
self.learned_mask = nn.Parameter( | |
torch.ones(n_channels, im_size[0], im_size[1]) | |
) | |
elif starting_mask_method == "random": | |
self.learned_mask = nn.Parameter( | |
torch.randn(n_channels, im_size[0], im_size[1]) | |
) | |
elif starting_mask_method == "none": | |
self.learned_mask = None | |
elif starting_mask_method == "fixed_ones": | |
self.learned_mask = torch.ones(n_channels, im_size[0], im_size[1]) | |
elif starting_mask_method == "fixed_zeros": | |
self.learned_mask = torch.zeros(n_channels, im_size[0], im_size[1]) | |
else: | |
raise NotImplementedError( | |
f"Unknown stating_mask_method: {starting_mask_method}" | |
) | |
self.add_mask = add_mask | |
self.fix_image_leak = fix_image_leak | |
if fix_image_leak: | |
self.beta_m = 15 | |
self.a = 5 | |
self.noise_encoder = ConcatTimestepEmbedderND(256) | |
def get_noised_input( | |
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor | |
) -> torch.Tensor: | |
noised_input = input + noise * sigmas_bc | |
return noised_input | |
def forward( | |
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
) -> torch.Tensor: | |
cond_cat = c.get("concat", torch.Tensor([]).type_as(x)) | |
T = x.shape[0] // cond_cat.shape[0] | |
if self.fix_image_leak: | |
noise_aug_strength = get_sigma_s( | |
rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700, | |
self.a, | |
self.beta_m, | |
) | |
noise_aug = append_dims(noise_aug_strength, 4).to(x.device) | |
noise = torch.randn_like(noise_aug) | |
cond_cat = self.get_noised_input(noise_aug, noise, cond_cat) | |
noise_emb = self.noise_encoder(noise_aug_strength).to(x.device) | |
# cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1) | |
c["vector"] = noise_emb | |
cond_cat = rearrange(cond_cat, "b (t c) h w -> b c t h w", t=2) | |
start, end = cond_cat.chunk(2, dim=2) | |
if self.learned_mask is None: | |
learned_mask = torch.stack( | |
[start.squeeze(2)] * (T // 2 - 1) + [end.squeeze(2)] * (T // 2 - 1), | |
dim=2, | |
) | |
else: | |
learned_mask = repeat( | |
self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0] | |
) | |
ones_mask = torch.ones_like(learned_mask)[:, 0].unsqueeze(1) | |
zeros_mask = torch.zeros_like(learned_mask)[:, 0].unsqueeze(1) | |
if self.learned_mask is None: | |
cond_seq = torch.cat([start] + [learned_mask] + [end], dim=2) | |
else: | |
cond_seq = torch.stack( | |
[start.squeeze(2)] + [learned_mask] * (T - 2) + [end.squeeze(2)], dim=2 | |
) | |
cond_seq = rearrange(cond_seq, "b c t h w -> (b t) c h w") | |
x = torch.cat((x, cond_seq), dim=1) | |
if self.add_mask: | |
mask_seq = torch.stack( | |
[ones_mask] + [zeros_mask] * (T - 2) + [ones_mask], dim=2 | |
) | |
mask_seq = rearrange(mask_seq, "b c t h w -> (b t) c h w") | |
x = torch.cat((x, mask_seq), dim=1) | |
return self.diffusion_model( | |
x, | |
timesteps=t, | |
context=c.get("crossattn", None), | |
y=c.get("vector", None), | |
audio_emb=c.get("audio_emb", None), | |
**kwargs, | |
) | |