keysync-demo / sgm /models /diffusion.py
Antoni Bigata
requirements
4fd1a69
import os
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import re
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange
from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
class DiffusionEngine(pl.LightningModule):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str, Dict, ListConfig, OmegaConf] = None,
ckpt_path: Union[None, str] = None,
remove_keys_from_weights: Union[None, List, Tuple] = None,
pattern_to_remove: Union[None, str] = None,
remove_keys_from_unet_weights: Union[None, List, Tuple] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "jpg",
log_keys: Union[List, None] = None,
no_log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
only_train_ipadapter: Optional[bool] = False,
to_unfreeze: Optional[List[str]] = [],
to_freeze: Optional[List[str]] = [],
separate_unet_ckpt: Optional[str] = None,
use_thunder: Optional[bool] = False,
is_dubbing: Optional[bool] = False,
bad_model_path: Optional[str] = None,
bad_model_config: Optional[Dict] = None,
):
super().__init__()
# self.automatic_optimization = False
self.log_keys = log_keys
self.no_log_keys = no_log_keys
self.input_key = input_key
self.is_dubbing = is_dubbing
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
self.model = self.initialize_network(
network_config, network_wrapper, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.is_guided = True
if (
self.sampler
and "IdentityGuider" in sampler_config["params"]["guider_config"]["target"]
):
self.is_guided = False
if self.sampler is not None:
config_guider = sampler_config["params"]["guider_config"]
sampler_config["params"]["guider_config"] = None
self.sampler_no_guidance = instantiate_from_config(sampler_config)
sampler_config["params"]["guider_config"] = config_guider
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(
ckpt_path,
remove_keys_from_weights=remove_keys_from_weights,
pattern_to_remove=pattern_to_remove,
)
if separate_unet_ckpt is not None:
sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"]
if remove_keys_from_unet_weights is not None:
for k in list(sd.keys()):
for remove_key in remove_keys_from_unet_weights:
if remove_key in k:
del sd[k]
self.model.diffusion_model.load_state_dict(sd, strict=False)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
print(
"Using",
self.en_and_decode_n_samples_a_time,
"samples at a time for encoding and decoding",
)
if to_freeze:
for name, p in self.model.diffusion_model.named_parameters():
for layer in to_freeze:
if layer[0] == "!":
if layer[1:] not in name:
# print("Freezing", name)
p.requires_grad = False
else:
if layer in name:
# print("Freezing", name)
p.requires_grad = False
# if "time_" in name:
# print("Freezing", name)
# p.requires_grad = False
if only_train_ipadapter:
# Freeze the model
for p in self.model.parameters():
p.requires_grad = False
# Unfreeze the adapter projection layer
for p in self.model.diffusion_model.encoder_hid_proj.parameters():
p.requires_grad = True
# Unfreeze the cross-attention layer
for att_layer in self.model.diffusion_model.attn_processors.values():
if isinstance(att_layer, IPAdapterAttnProcessor2_0):
for p in att_layer.parameters():
p.requires_grad = True
# for name, p in self.named_parameters():
# if p.requires_grad:
# print(name)
if to_unfreeze:
for name in to_unfreeze:
for p in getattr(self.model.diffusion_model, name).parameters():
p.requires_grad = True
if use_thunder:
import thunder
self.model.diffusion_model = thunder.jit(self.model.diffusion_model)
if "Karras" in denoiser_config.target:
assert bad_model_path is not None, (
"bad_model_path must be provided for KarrasGuidanceDenoiser"
)
karras_config = default(bad_model_config, network_config)
bad_model = self.initialize_network(
karras_config, network_wrapper, compile_model=compile_model
)
state_dict = self.load_bad_model_weights(bad_model_path)
bad_model.load_state_dict(state_dict)
self.denoiser.set_bad_network(bad_model)
def load_bad_model_weights(self, path: str) -> None:
print(f"Restoring bad model from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=False)
new_dict = {}
for k, v in state_dict["module"].items():
if "learned_mask" in k:
new_dict[k.replace("_forward_module.", "").replace("model.", "")] = v
if "diffusion_model" in k:
new_dict["diffusion_model" + k.split("diffusion_model")[1]] = v
return new_dict
def initialize_network(self, network_config, network_wrapper, compile_model=False):
model = instantiate_from_config(network_config)
if isinstance(network_wrapper, str) or network_wrapper is None:
model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
else:
target = network_wrapper["target"]
params = network_wrapper.get("params", dict())
model = get_obj_from_str(target)(
model, compile_model=compile_model, **params
)
return model
def init_from_ckpt(
self,
path: str,
remove_keys_from_weights: Optional[Union[List, Tuple]] = None,
pattern_to_remove: str = None,
) -> None:
print(f"Restoring from {path}")
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
elif path.endswith("pt"):
sd = torch.load(path, map_location="cpu", weights_only=False)["module"]
# Remove leading _forward_module from keys
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
elif path.endswith("bin"):
sd = torch.load(path, map_location="cpu", weights_only=False)
# Remove leading _forward_module from keys
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
print(f"Loaded state dict from {path} with {len(sd)} keys")
# if remove_keys_from_weights is not None:
# for k in list(sd.keys()):
# for remove_key in remove_keys_from_weights:
# if remove_key in k:
# del sd[k]
if pattern_to_remove is not None or remove_keys_from_weights is not None:
sd = self.remove_mismatched_keys(
sd, pattern_to_remove, remove_keys_from_weights
)
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def remove_mismatched_keys(self, state_dict, pattern=None, additional_keys=None):
"""Remove keys from the state dictionary based on a pattern and a list of additional specific keys."""
# Find keys that match the pattern
if pattern is not None:
mismatched_keys = [key for key in state_dict if re.search(pattern, key)]
else:
mismatched_keys = []
print(f"Removing {len(mismatched_keys)} keys based on pattern {pattern}")
print(mismatched_keys)
# Add specific keys to be removed
if additional_keys:
mismatched_keys.extend(
[key for key in additional_keys if key in state_dict]
)
# Remove all identified keys
for key in mismatched_keys:
if key in state_dict:
del state_dict[key]
return state_dict
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
if self.input_key == "latents":
# Remove encoder to save memory
self.first_stage_model.encoder = None
torch.cuda.empty_cache()
def get_input(self, batch):
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]
@torch.no_grad()
def decode_first_stage(self, z):
is_video = False
if len(z.shape) == 5:
is_video = True
T = z.shape[2]
z = rearrange(z, "b c t h w -> (b t) c h w")
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
if is_video:
out = rearrange(out, "(b t) c h w -> b c t h w", t=T)
torch.cuda.empty_cache()
return out
@torch.no_grad()
def encode_first_stage(self, x):
is_video = False
if len(x.shape) == 5:
is_video = True
T = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
if is_video:
z = rearrange(z, "(b t) c h w -> b c t h w", t=T)
return z
def forward(self, x, batch):
loss_dict = self.loss_fn(
self.model,
self.denoiser,
self.conditioner,
x,
batch,
self.first_stage_model,
)
# loss_mean = loss.mean()
for k in loss_dict:
loss_dict[k] = loss_dict[k].mean()
# loss_dict = {"loss": loss_mean}
return loss_dict["loss"], loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.input_key != "latents":
x = self.encode_first_stage(x)
batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
# debugging_message = "Training step"
# print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
# debugging_message = "Training step - log"
# print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
# # to prevent other processes from moving forward until all processes are in sync
# self.trainer.strategy.barrier()
return loss
# def validation_step(self, batch, batch_idx):
# # loss, loss_dict = self.shared_step(batch)
# # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)
# self.log(
# "global_step",
# self.global_step,
# prog_bar=True,
# logger=True,
# on_step=True,
# on_epoch=False,
# )
# return 0
# def on_train_epoch_start(self, *args, **kwargs):
# print(f"RANK - {self.trainer.global_rank}: on_train_epoch_start")
def on_train_start(self, *args, **kwargs):
# os.environ["CUDA_VISIBLE_DEVICES"] = str(self.trainer.global_rank)
# torch.cuda.set_device(self.trainer.global_rank)
# torch.cuda.empty_cache()
if self.sampler is None or self.loss_fn is None:
raise ValueError("Sampler and loss function need to be set for training.")
# def on_before_batch_transfer(self, batch, dataloader_idx):
# print(f"RANK - {self.trainer.global_rank}: on_before_batch_transfer - {dataloader_idx}")
# return batch
# def on_after_batch_transfer(self, batch, dataloader_idx):
# print(f"RANK - {self.trainer.global_rank}: on_after_batch_transfer - {dataloader_idx}")
# return batch
def on_train_batch_end(self, *args, **kwargs):
# print(f"RANK - {self.trainer.global_rank}: on_train_batch_end")
if self.use_ema:
self.model_ema(self.model)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for embedder in self.conditioner.embedders:
if embedder.is_trainable:
params = params + list(embedder.parameters())
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
@torch.no_grad()
def sample_no_guider(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
samples = self.sampler_no_guidance(denoiser, randn, cond, uc=uc)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[-2:]
log = dict()
for embedder in self.conditioner.embedders:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
if embedder.input_key in self.no_log_keys:
continue
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
"x".join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
elif x.dim() == 4: # already an image
xc = x
elif x.dim() == 5:
xc = torch.cat([x[:, :, i] for i in range(x.shape[2])], dim=-1)
else:
print(x.shape, embedder.input_key)
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
if self.input_key != "latents":
log["inputs"] = x
z = self.encode_first_stage(x)
else:
z = x
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
if sample:
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
with self.ema_scope("Plotting"):
samples = self.sample_no_guider(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples_no_guidance"] = samples
return log
@torch.no_grad()
def log_videos(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
# conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
# if ucg_keys:
# assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
# "Each defined ucg key for sampling must be in the provided conditioner input keys,"
# f"but we have {ucg_keys} vs. {conditioner_input_keys}"
# )
# else:
# ucg_keys = conditioner_input_keys
log = dict()
batch_uc = {}
x = self.get_input(batch)
num_frames = x.shape[2] # assuming bcthw format
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=ucg_keys
if ucg_keys is not None
else [
"cond_frames",
"cond_frames_without_noise",
],
)
# for k in ["crossattn", "concat"]:
# uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
# uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
# c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
# c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
if self.input_key != "latents":
log["inputs"] = x
z = self.encode_first_stage(x)
else:
z = x
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N))
if c.get("masks", None) is not None:
# Create a mask reconstruction
masks = 1 - c["masks"]
t = masks.shape[2]
masks = rearrange(masks, "b c t h w -> (b t) c h w")
target_size = (
log["reconstructions"].shape[-2],
log["reconstructions"].shape[-1],
)
masks = torch.nn.functional.interpolate(
masks, size=target_size, mode="nearest"
)
masks = rearrange(masks, "(b t) c h w -> b c t h w", t=t)
log["mask_reconstructions"] = log["reconstructions"] * masks
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
elif isinstance(c[k], list):
for i in range(len(c[k])):
c[k][i], uc[k][i] = map(
lambda y: y[k][i][:N].to(self.device), (c, uc)
)
if sample:
n = 2 if self.is_guided else 1
# if num_frames == 1:
# sampling_kwargs["image_only_indicator"] = torch.ones(n, num_frames).to(self.device)
# else:
sampling_kwargs["image_only_indicator"] = torch.zeros(n, num_frames).to(
self.device
)
sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
if self.is_dubbing:
samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
:, :, :, : samples.shape[-2] // 2
]
log["samples"] = samples
# Without guidance
# if num_frames == 1:
# sampling_kwargs["image_only_indicator"] = torch.ones(1, num_frames).to(self.device)
# else:
sampling_kwargs["image_only_indicator"] = torch.zeros(1, num_frames).to(
self.device
)
sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
with self.ema_scope("Plotting"):
samples = self.sample_no_guider(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
if self.is_dubbing:
samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
:, :, :, : samples.shape[-2] // 2
]
log["samples_no_guidance"] = samples
torch.cuda.empty_cache()
return log