from dataclasses import dataclass from typing import NamedTuple, Tuple import torch from torch import Tensor, nn from visualizr.model import BeatGANsUNetConfig, BeatGANsUNetModel from visualizr.model.blocks import ResBlock from visualizr.model.latentnet import MLPSkipNetConfig from visualizr.model.nn import linear, timestep_embedding from visualizr.model.unet import BeatGANsEncoderConfig @dataclass class BeatGANsAutoencConfig(BeatGANsUNetConfig): # number of style channels enc_out_channels: int = 512 enc_attn_resolutions: Tuple[int] = None enc_pool: str = "depthconv" enc_num_res_block: int = 2 enc_channel_mult: Tuple[int] = None enc_grad_checkpoint: bool = False latent_net_conf: MLPSkipNetConfig = None def make_model(self): return BeatGANsAutoencModel(self) class BeatGANsAutoencModel(BeatGANsUNetModel): def __init__(self, conf: BeatGANsAutoencConfig): super().__init__(conf) self.conf = conf # having only time, cond self.time_embed = TimeStyleSeperateEmbed( time_channels=conf.model_channels, time_out_channels=conf.embed_channels, ) self.encoder = BeatGANsEncoderConfig( image_size=conf.image_size, in_channels=conf.in_channels, model_channels=conf.model_channels, out_hid_channels=conf.enc_out_channels, out_channels=conf.enc_out_channels, num_res_blocks=conf.enc_num_res_block, attention_resolutions=( conf.enc_attn_resolutions or conf.attention_resolutions ), dropout=conf.dropout, channel_mult=conf.enc_channel_mult or conf.channel_mult, use_time_condition=False, conv_resample=conf.conv_resample, dims=conf.dims, use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, resblock_updown=conf.resblock_updown, use_new_attention_order=conf.use_new_attention_order, pool=conf.enc_pool, ).make_model() if conf.latent_net_conf is not None: self.latent_net = conf.latent_net_conf.make_model() def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ assert self.conf.is_stochastic std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def sample_z(self, n: int, device): assert self.conf.is_stochastic return torch.randn(n, self.conf.enc_out_channels, device=device) def noise_to_cond(self, noise: Tensor): raise NotImplementedError() # assert self.conf.noise_net_conf is not None # return self.noise_net.forward(noise) def encode(self, x): cond = self.encoder.forward(x) return {"cond": cond} @property def stylespace_sizes(self): modules = ( list(self.input_blocks.modules()) + list(self.middle_block.modules()) + list(self.output_blocks.modules()) ) sizes = [] for module in modules: if isinstance(module, ResBlock): linear = module.cond_emb_layers[-1] sizes.append(linear.weight.shape[0]) return sizes def encode_stylespace(self, x, return_vector: bool = True): """ encode to style space """ modules = ( list(self.input_blocks.modules()) + list(self.middle_block.modules()) + list(self.output_blocks.modules()) ) # (n, c) cond = self.encoder.forward(x) S = [] for module in modules: if isinstance(module, ResBlock): # (n, c') s = module.cond_emb_layers.forward(cond) S.append(s) if return_vector: # (n, sum_c) return torch.cat(S, dim=1) else: return S def forward( self, x, t, y=None, x_start=None, cond=None, style=None, noise=None, t_cond=None, **kwargs, ): """ Apply the model to an input batch. Args: x_start: the original image to encode cond: output of the encoder noise: random noise (to predict the cond) """ if t_cond is None: t_cond = t if noise is not None: # if the noise is given, we predict the cond from noise cond = self.noise_to_cond(noise) if cond is None: if x is not None: assert len(x) == len(x_start), f"{len(x)} != {len(x_start)}" tmp = self.encode(x_start) cond = tmp["cond"] if t is not None: _t_emb = timestep_embedding(t, self.conf.model_channels) _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) else: # this happens when training only autoenc _t_emb = None _t_cond_emb = None if self.conf.resnet_two_cond: res = self.time_embed.forward( time_emb=_t_emb, cond=cond, time_cond_emb=_t_cond_emb, ) else: raise NotImplementedError() if self.conf.resnet_two_cond: # two cond: first = time emb, second = cond_emb emb = res.time_emb cond_emb = res.emb else: # one cond = combined of both time and cond emb = res.emb cond_emb = None # override the style if given style = style or res.style assert (y is not None) == (self.conf.num_classes is not None), ( "must specify y if and only if the model is class-conditional" ) if self.conf.num_classes is not None: raise NotImplementedError() # assert y.shape == (x.shape[0], ) # emb = emb + self.label_emb(y) # where in the model to supply time conditions enc_time_emb = emb mid_time_emb = emb dec_time_emb = emb # where in the model to supply style conditions enc_cond_emb = cond_emb mid_cond_emb = cond_emb dec_cond_emb = cond_emb # hs = [] hs = [[] for _ in range(len(self.conf.channel_mult))] if x is not None: h = x.type(self.dtype) # input blocks k = 0 for i in range(len(self.input_num_blocks)): for j in range(self.input_num_blocks[i]): h = self.input_blocks[k](h, emb=enc_time_emb, cond=enc_cond_emb) # print(i, j, h.shape) hs[i].append(h) k += 1 assert k == len(self.input_blocks) # middle blocks h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) else: # no lateral connections # happens when training only the autonecoder h = None hs = [[] for _ in range(len(self.conf.channel_mult))] # output blocks k = 0 for i in range(len(self.output_num_blocks)): for j in range(self.output_num_blocks[i]): # take the lateral connection from the same layer (in reserve) # until there is no more, use None try: lateral = hs[-i - 1].pop() # print(i, j, lateral.shape) except IndexError: lateral = None # print(i, j, lateral) h = self.output_blocks[k]( h, emb=dec_time_emb, cond=dec_cond_emb, lateral=lateral ) k += 1 pred = self.out(h) return AutoencReturn(pred=pred, cond=cond) class AutoencReturn(NamedTuple): pred: Tensor cond: Tensor = None class EmbedReturn(NamedTuple): # style and time emb: Tensor = None # time only time_emb: Tensor = None # style only (but could depend on time) style: Tensor = None class TimeStyleSeperateEmbed(nn.Module): # embed only style def __init__(self, time_channels, time_out_channels): super().__init__() self.time_embed = nn.Sequential( linear(time_channels, time_out_channels), nn.SiLU(), linear(time_out_channels, time_out_channels), ) self.style = nn.Identity() def forward(self, time_emb=None, cond=None, **kwargs): if time_emb is None: # happens with autoenc training mode time_emb = None else: time_emb = self.time_embed(time_emb) style = self.style(cond) return EmbedReturn(emb=style, time_emb=time_emb, style=style)