Spaces:
Configuration error
Configuration error
import math | |
import torch | |
import numpy as np | |
from torch import nn | |
from typing import Callable, Iterable, Union, Optional | |
from einops import rearrange, repeat | |
from comfy import model_management | |
from .kl import ( | |
Encoder, Decoder, Upsample, Normalize, | |
AttnBlock, ResnetBlock, #MemoryEfficientAttnBlock, | |
DiagonalGaussianDistribution, nonlinearity, make_attn | |
) | |
class AutoencoderKL(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.embed_dim = config["embed_dim"] | |
self.encoder = Encoder(**config) | |
self.decoder = VideoDecoder(**config) | |
assert config["double_z"] | |
# these aren't used here for some reason | |
# self.quant_conv = torch.nn.Conv2d(2*config["z_channels"], 2*self.embed_dim, 1) | |
# self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, config["z_channels"], 1) | |
def encode(self, x): | |
## batched | |
# n_samples = x.shape[0] | |
# n_rounds = math.ceil(x.shape[0] / n_samples) | |
# all_out = [] | |
# for n in range(n_rounds): | |
# h = self.encoder( | |
# x[n * n_samples : (n + 1) * n_samples] | |
# ) | |
# moments = h # self.quant_conv(h) | |
# posterior = DiagonalGaussianDistribution(moments) | |
# all_out.append(posterior.sample()) | |
# z = torch.cat(all_out, dim=0) | |
# return z | |
## default | |
h = self.encoder(x) | |
moments = h # self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior.sample() | |
def decode(self, z): | |
## batched - seems the same as default? | |
# n_samples = z.shape[0] | |
# n_rounds = math.ceil(z.shape[0] / n_samples) | |
# all_out = [] | |
# for n in range(n_rounds): | |
# dec = self.decoder( | |
# z[n * n_samples : (n + 1) * n_samples], | |
# timesteps=len(z[n * n_samples : (n + 1) * n_samples]), | |
# ) | |
# all_out.append(dec) | |
# out = torch.cat(all_out, dim=0) | |
## default | |
out = self.decoder( | |
z, timesteps=len(z) | |
) | |
return out | |
def forward(self, input, sample_posterior=True): | |
posterior = self.encode(input) | |
if sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
dec = self.decode(z) | |
return dec, posterior | |
class VideoDecoder(nn.Module): | |
available_time_modes = ["all", "conv-only", "attn-only"] | |
def __init__( | |
self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, | |
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, | |
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, | |
attn_type="vanilla", | |
video_kernel_size: Union[int, list] = 3, alpha: float = 0.0, merge_strategy: str = "learned", time_mode: str = "conv-only", | |
**ignorekwargs | |
): | |
super().__init__() | |
if use_linear_attn: attn_type = "linear" | |
self.ch = ch | |
self.temb_ch = 0 | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.in_channels = in_channels | |
self.give_pre_end = give_pre_end | |
self.tanh_out = tanh_out | |
self.video_kernel_size = video_kernel_size | |
self.alpha = alpha | |
self.merge_strategy = merge_strategy | |
self.time_mode = time_mode | |
assert ( | |
self.time_mode in self.available_time_modes | |
), f"time_mode parameter has to be in {self.available_time_modes}" | |
# compute in_ch_mult, block_in and curr_res at lowest res | |
in_ch_mult = (1,)+tuple(ch_mult) | |
block_in = ch*ch_mult[self.num_resolutions-1] | |
curr_res = resolution // 2**(self.num_resolutions-1) | |
self.z_shape = (1,z_channels,curr_res,curr_res) | |
print("Working with z of shape {} = {} dimensions.".format( | |
self.z_shape, np.prod(self.z_shape))) | |
# z to block_in | |
self.conv_in = torch.nn.Conv2d( | |
z_channels, | |
block_in, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = VideoResBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
video_kernel_size=self.video_kernel_size, | |
alpha=self.alpha, | |
merge_strategy=self.merge_strategy, | |
) | |
self.mid.attn_1 = make_attn( | |
block_in, | |
attn_type=attn_type, | |
) | |
self.mid.block_2 = VideoResBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
video_kernel_size=self.video_kernel_size, | |
alpha=self.alpha, | |
merge_strategy=self.merge_strategy, | |
) | |
# upsampling | |
self.up = nn.ModuleList() | |
for i_level in reversed(range(self.num_resolutions)): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_out = ch*ch_mult[i_level] | |
for i_block in range(self.num_res_blocks+1): | |
block.append(VideoResBlock( | |
in_channels=block_in, | |
out_channels=block_out, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
video_kernel_size=self.video_kernel_size, | |
alpha=self.alpha, | |
merge_strategy=self.merge_strategy, | |
)) | |
block_in = block_out | |
if curr_res in attn_resolutions: | |
attn.append(make_attn( | |
block_in, | |
attn_type=attn_type, | |
)) | |
up = nn.Module() | |
up.block = block | |
up.attn = attn | |
if i_level != 0: | |
up.upsample = Upsample(block_in, resamp_with_conv) | |
curr_res = curr_res * 2 | |
self.up.insert(0, up) # prepend to get consistent order | |
# end | |
self.norm_out = Normalize(block_in) | |
self.conv_out = AE3DConv( | |
in_channels = block_in, | |
out_channels = out_ch, | |
video_kernel_size=self.video_kernel_size, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
def get_last_layer(self, skip_time_mix=False, **kwargs): | |
if self.time_mode == "attn-only": | |
raise NotImplementedError("TODO") | |
else: | |
return ( | |
self.conv_out.time_mix_conv.weight | |
if not skip_time_mix | |
else self.conv_out.weight | |
) | |
def forward(self, z, **kwargs): | |
#assert z.shape[1:] == self.z_shape[1:] | |
self.last_z_shape = z.shape | |
# timestep embedding | |
temb = None | |
# z to block_in | |
h = self.conv_in(z) | |
# middle | |
h = self.mid.block_1(h, temb, **kwargs) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h, temb, **kwargs) | |
# upsampling | |
for i_level in reversed(range(self.num_resolutions)): | |
for i_block in range(self.num_res_blocks+1): | |
h = self.up[i_level].block[i_block](h, temb, **kwargs) | |
if len(self.up[i_level].attn) > 0: | |
h = self.up[i_level].attn[i_block](h) | |
if i_level != 0: | |
h = self.up[i_level].upsample(h) | |
# end | |
if self.give_pre_end: | |
return h | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h, **kwargs) | |
if self.tanh_out: | |
h = torch.tanh(h) | |
return h | |
class ResBlock(nn.Module): | |
""" | |
A residual block that can optionally change the number of channels. | |
:param channels: the number of input channels. | |
:param emb_channels: the number of timestep embedding channels. | |
:param dropout: the rate of dropout. | |
:param out_channels: if specified, the number of out channels. | |
:param use_conv: if True and out_channels is specified, use a spatial | |
convolution instead of a smaller 1x1 convolution to change the | |
channels in the skip connection. | |
:param dims: determines if the signal is 1D, 2D, or 3D. | |
:param use_checkpoint: if True, use gradient checkpointing on this module. | |
:param up: if True, use this block for upsampling. | |
:param down: if True, use this block for downsampling. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
emb_channels: int, | |
dropout: float, | |
out_channels: Optional[int] = None, | |
use_conv: bool = False, | |
use_scale_shift_norm: bool = False, | |
dims: int = 2, | |
use_checkpoint: bool = False, | |
up: bool = False, | |
down: bool = False, | |
kernel_size: int = 3, | |
exchange_temb_dims: bool = False, | |
skip_t_emb: bool = False, | |
): | |
super().__init__() | |
self.channels = channels | |
self.emb_channels = emb_channels | |
self.dropout = dropout | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_checkpoint = use_checkpoint | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.exchange_temb_dims = exchange_temb_dims | |
if isinstance(kernel_size, Iterable): | |
padding = [k // 2 for k in kernel_size] | |
else: | |
padding = kernel_size // 2 | |
self.in_layers = nn.Sequential( | |
normalization(channels), | |
nn.SiLU(), | |
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), | |
) | |
self.updown = up or down | |
if up: | |
self.h_upd = Upsample(channels, False, dims) | |
self.x_upd = Upsample(channels, False, dims) | |
elif down: | |
self.h_upd = Downsample(channels, False, dims) | |
self.x_upd = Downsample(channels, False, dims) | |
else: | |
self.h_upd = self.x_upd = nn.Identity() | |
self.skip_t_emb = skip_t_emb | |
self.emb_out_channels = ( | |
2 * self.out_channels if use_scale_shift_norm else self.out_channels | |
) | |
if self.skip_t_emb: | |
print(f"Skipping timestep embedding in {self.__class__.__name__}") | |
assert not self.use_scale_shift_norm | |
self.emb_layers = None | |
self.exchange_temb_dims = False | |
else: | |
self.emb_layers = nn.Sequential( | |
nn.SiLU(), | |
linear( | |
emb_channels, | |
self.emb_out_channels, | |
), | |
) | |
self.out_layers = nn.Sequential( | |
normalization(self.out_channels), | |
nn.SiLU(), | |
nn.Dropout(p=dropout), | |
zero_module( | |
conv_nd( | |
dims, | |
self.out_channels, | |
self.out_channels, | |
kernel_size, | |
padding=padding, | |
) | |
), | |
) | |
if self.out_channels == channels: | |
self.skip_connection = nn.Identity() | |
elif use_conv: | |
self.skip_connection = conv_nd( | |
dims, channels, self.out_channels, kernel_size, padding=padding | |
) | |
else: | |
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
""" | |
Apply the block to a Tensor, conditioned on a timestep embedding. | |
:param x: an [N x C x ...] Tensor of features. | |
:param emb: an [N x emb_channels] Tensor of timestep embeddings. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
if self.use_checkpoint: | |
return checkpoint(self._forward, x, emb) | |
else: | |
return self._forward(x, emb) | |
def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
if self.updown: | |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | |
h = in_rest(x) | |
h = self.h_upd(h) | |
x = self.x_upd(x) | |
h = in_conv(h) | |
else: | |
h = self.in_layers(x) | |
if self.skip_t_emb: | |
emb_out = torch.zeros_like(h) | |
else: | |
emb_out = self.emb_layers(emb).type(h.dtype) | |
while len(emb_out.shape) < len(h.shape): | |
emb_out = emb_out[..., None] | |
if self.use_scale_shift_norm: | |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] | |
scale, shift = torch.chunk(emb_out, 2, dim=1) | |
h = out_norm(h) * (1 + scale) + shift | |
h = out_rest(h) | |
else: | |
if self.exchange_temb_dims: | |
emb_out = rearrange(emb_out, "b t c ... -> b c t ...") | |
h = h + emb_out | |
h = self.out_layers(h) | |
return self.skip_connection(x) + h | |
class VideoResBlock(ResnetBlock): | |
def __init__( | |
self, | |
out_channels, | |
*args, | |
dropout=0.0, | |
video_kernel_size=3, | |
alpha=0.0, | |
merge_strategy="learned", | |
**kwargs, | |
): | |
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) | |
if video_kernel_size is None: | |
video_kernel_size = [3, 1, 1] | |
self.time_stack = ResBlock( | |
channels=out_channels, | |
emb_channels=0, | |
dropout=dropout, | |
dims=3, | |
use_scale_shift_norm=False, | |
use_conv=False, | |
up=False, | |
down=False, | |
kernel_size=video_kernel_size, | |
use_checkpoint=False, | |
skip_t_emb=True, | |
) | |
self.merge_strategy = merge_strategy | |
if self.merge_strategy == "fixed": | |
self.register_buffer("mix_factor", torch.Tensor([alpha])) | |
elif self.merge_strategy == "learned": | |
self.register_parameter( | |
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) | |
) | |
else: | |
raise ValueError(f"unknown merge strategy {self.merge_strategy}") | |
def get_alpha(self, bs): | |
if self.merge_strategy == "fixed": | |
return self.mix_factor | |
elif self.merge_strategy == "learned": | |
return torch.sigmoid(self.mix_factor) | |
else: | |
raise NotImplementedError() | |
def forward(self, x, temb, skip_video=False, timesteps=None): | |
if timesteps is None: | |
timesteps = self.timesteps | |
b, c, h, w = x.shape | |
x = super().forward(x, temb) | |
if not skip_video: | |
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) | |
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) | |
x = self.time_stack(x, temb) | |
alpha = self.get_alpha(bs=b // timesteps) | |
x = alpha * x + (1.0 - alpha) * x_mix | |
x = rearrange(x, "b c t h w -> (b t) c h w") | |
return x | |
class AE3DConv(torch.nn.Conv2d): | |
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): | |
super().__init__(in_channels, out_channels, *args, **kwargs) | |
if isinstance(video_kernel_size, Iterable): | |
padding = [int(k // 2) for k in video_kernel_size] | |
else: | |
padding = int(video_kernel_size // 2) | |
self.time_mix_conv = torch.nn.Conv3d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=video_kernel_size, | |
padding=padding, | |
) | |
def forward(self, input, timesteps, skip_video=False): | |
x = super().forward(input) | |
if skip_video: | |
return x | |
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) | |
x = self.time_mix_conv(x) | |
return rearrange(x, "b c t h w -> (b t) c h w") | |
def normalization(channels): | |
""" | |
Make a standard normalization layer. | |
:param channels: number of input channels. | |
:return: an nn.Module for normalization. | |
""" | |
return GroupNorm32(32, channels) | |
class SiLU(nn.Module): | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
class GroupNorm32(nn.GroupNorm): | |
def forward(self, x): | |
return super().forward(x.float()).type(x.dtype) | |
def conv_nd(dims, *args, **kwargs): | |
""" | |
Create a 1D, 2D, or 3D convolution module. | |
""" | |
if dims == 1: | |
return nn.Conv1d(*args, **kwargs) | |
elif dims == 2: | |
return nn.Conv2d(*args, **kwargs) | |
elif dims == 3: | |
return nn.Conv3d(*args, **kwargs) | |
raise ValueError(f"unsupported dimensions: {dims}") | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |