MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
from dataclasses import dataclass
from typing import NamedTuple, Tuple
import torch as th
import torch.nn.functional as F
from torch import nn
from visualizr.config_base import BaseConfig
from visualizr.model.blocks import (
AttentionBlock,
Downsample,
ResBlockConfig,
TimestepEmbedSequential,
Upsample,
)
from visualizr.model.nn import (
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
@dataclass
class BeatGANsUNetConfig(BaseConfig):
image_size: int = 64
in_channels: int = 3
# base channels will be multiplied
model_channels: int = 64
# output of the unet
# suggest: 3
# you only need 6 if you also model the variance of the noise prediction
# (usually we use an analytical variance hence 3)
out_channels: int = 3
# how many repeating resblocks per resolution
# the decoding side would have "one more" resblock
# default: 2
num_res_blocks: int = 2
# you can also set the number of resblocks specifically for the input blocks
# default: None = above
num_input_res_blocks: int = None
# number of time embed channels and style channels
embed_channels: int = 512
# at what resolutions you want to do self-attention of the feature maps
# attentions generally improve performance
# default: [16]
# beatgans: [32, 16, 8]
attention_resolutions: Tuple[int] = (16,)
# number of time embed channels
time_embed_channels: int = None
# dropout applies to the resblocks (on feature maps)
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4, 8)
input_channel_mult: Tuple[int] = None
conv_resample: bool = True
# always 2 = 2d conv
dims: int = 2
# don't use this, legacy from BeatGANs
num_classes: int = None
use_checkpoint: bool = False
# number of attention heads
num_heads: int = 1
# or specify the number of channels per attention head
num_head_channels: int = -1
# what's this?
num_heads_upsample: int = -1
# use resblock for upscale/downscale blocks (expensive)
# default: True (BeatGANs)
resblock_updown: bool = True
# never tried
use_new_attention_order: bool = False
resnet_two_cond: bool = False
resnet_cond_channels: int = None
# init the decoding conv layers with zero weights, this speeds up training
# default: True (BeattGANs)
resnet_use_zero_module: bool = True
# gradient checkpoint the attention operation
attn_checkpoint: bool = False
def make_model(self):
return BeatGANsUNetModel(self)
class BeatGANsUNetModel(nn.Module):
def __init__(self, conf: BeatGANsUNetConfig):
super().__init__()
self.conf = conf
if conf.num_heads_upsample == -1:
self.num_heads_upsample = conf.num_heads
self.dtype = th.float32
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
if conf.num_classes is not None:
self.label_emb = nn.Embedding(conf.num_classes, conf.embed_channels)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)
)
]
)
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
use_zero_module=conf.resnet_use_zero_module,
# style channels for the resnet block
cond_emb_channels=conf.resnet_cond_channels,
)
self._feature_size = ch
# input_block_chans = [ch]
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.input_channel_mult or conf.channel_mult):
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# input_block_chans.append(ch)
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
# print(input_block_chans)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
down=True,
**kwargs,
).make_model()
if conf.resblock_updown
else Downsample(ch, conf.conv_resample, conf.dims, out_ch)
)
)
ch = out_ch
# input_block_chans.append(ch)
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
# print(input_block_chans)
# ich = input_block_chans.pop()
try:
ich = input_block_chans[level].pop()
except IndexError:
# this happens only when num_res_block > num_enc_res_block
# we will not have enough lateral (skip) connecions for all decoder blocks
ich = 0
# print('pop:', ich)
layers = [
ResBlockConfig(
# only direct channels when gated
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
# lateral channels are described here when gated
has_lateral=True if ich > 0 else False,
lateral_channels=None,
**kwargs,
).make_model()
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=self.num_heads_upsample,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
)
)
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
up=True,
**kwargs,
).make_model()
if conf.resblock_updown
else Upsample(
ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self._feature_size += ch
# print(input_block_chans)
# print('inputs:', self.input_num_blocks)
# print('outputs:', self.output_num_blocks)
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1)
),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
def forward(self, x, t, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (self.conf.num_classes is not None), (
"must specify y if and only if the model is class-conditional"
)
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
if self.conf.num_classes is not None:
raise NotImplementedError()
# assert y.shape == (x.shape[0], )
# emb = emb + self.label_emb(y)
# new code supports input_num_blocks != output_num_blocks
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
# print(i, j, h.shape)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
h = self.middle_block(h, emb=emb)
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop()
# print(i, j, lateral.shape)
except IndexError:
lateral = None
# print(i, j, lateral)
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
return Return(pred=pred)
class Return(NamedTuple):
pred: th.Tensor
@dataclass
class BeatGANsEncoderConfig(BaseConfig):
image_size: int
in_channels: int
model_channels: int
out_hid_channels: int
out_channels: int
num_res_blocks: int
attention_resolutions: Tuple[int]
dropout: float = 0
channel_mult: Tuple[int] = (1, 2, 4, 8)
use_time_condition: bool = True
conv_resample: bool = True
dims: int = 2
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
resblock_updown: bool = False
use_new_attention_order: bool = False
pool: str = "adaptivenonzero"
def make_model(self):
return BeatGANsEncoderModel(self)
class BeatGANsEncoderModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(self, conf: BeatGANsEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
if conf.use_time_condition:
time_embed_dim = conf.model_channels * 4
self.time_embed = nn.Sequential(
linear(conf.model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
else:
time_embed_dim = None
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)
)
]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
down=True,
).make_model()
if (conf.resblock_updown)
else Downsample(
ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
)
self._feature_size += ch
if conf.pool == "adaptivenonzero":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
conv_nd(conf.dims, ch, conf.out_channels, 1),
nn.Flatten(),
)
else:
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
def forward(self, x, t=None, return_2d_feature=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
if self.conf.use_time_condition:
emb = self.time_embed(timestep_embedding(t, self.model_channels))
else:
emb = None
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
else:
h = h.type(x.dtype)
h_2d = h
h = self.out(h)
if return_2d_feature:
return h, h_2d
else:
return h
def forward_flatten(self, x):
"""
transform the last 2d feature into a flatten vector
"""
h = self.out(x)
return h
class SuperResModel(BeatGANsUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, image_size, in_channels, *args, **kwargs):
super().__init__(image_size, in_channels * 2, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = th.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)