Spaces:
Running
Running
File size: 6,695 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
from functools import partial
import torch
from torch import nn
from torch.nn import Module, ModuleList
from diffusion_model.network.attention import LinearAttention, Attention
from diffusion_model.network.timestep_embedding import SinusoidalEmbedding
from diffusion_model.network.blocks import ResnetBlock
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(t, length = 1):
if isinstance(t, tuple):
return t
return ((t,) * length)
def divisible_by(numer, denom):
return (numer % denom) == 0
# small helper modules
class DownSample(nn.Module):
def __init__(self, dim: int, dim_out: int):
"""
Downsamples the spatial dimensions by a factor of 2 using a strided convolution.
Args:
dim: Input channel dimension.
"""
super().__init__()
self.downsample = nn.Conv2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.tensor) -> torch.tensor:
"""
Forward pass.
Args:
x: Input tensor of shape [B, C, H, W].
Returns:
Downsampled tensor of shape [B, C, H/2, W/2].
"""
return self.downsample(x)
class UpSample(nn.Module):
def __init__(self, dim: int, dim_out: int):
"""
Upsamples the spatial dimensions by a factor of 2 using a transposed convolution.
Args:
dim: Input channel dimension.
"""
super().__init__()
self.upsample = nn.ConvTranspose2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape [B, C, H, W].
Returns:
Upsampled tensor of shape [B, C, 2*H, 2*W].
"""
return self.upsample(x)
# model
class Unet(Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
cond_dim = None,
dim_mults = (1, 2, 4, 8),
channels = 3,
dropout = 0.,
attn_dim_head = 32,
attn_heads = 4,
full_attn = None, # defaults to full attention only for inner most layer
):
super().__init__()
# determine dimensions
self.channels = channels
input_channels = channels
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
dims = [*map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time embeddings
time_dim = dim * 4
sinu_pos_emb = SinusoidalEmbedding(dim)
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# attention
if not full_attn:
full_attn = (*((False,) * (len(dim_mults) - 1)), True)
num_stages = len(dim_mults)
full_attn = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
assert len(full_attn) == len(dim_mults)
# prepare blocks
FullAttention = Attention
resnet_block = partial(ResnetBlock,
t_emb_dim = time_dim, y_emb_dim = cond_dim, dropout = dropout)
# layers
self.downs = ModuleList([])
self.ups = ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
is_last = ind >= (num_resolutions - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.downs.append(ModuleList([
resnet_block(dim_in, dim_in),
resnet_block(dim_in, dim_in),
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
DownSample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
self.mid_block1 = resnet_block(mid_dim, mid_dim)
self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
self.mid_block2 = resnet_block(mid_dim, mid_dim)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
is_last = ind == (len(in_out) - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.ups.append(ModuleList([
resnet_block(dim_out + dim_in, dim_out),
resnet_block(dim_out + dim_in, dim_out),
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
UpSample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
default_out_dim = channels
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = resnet_block(init_dim * 2, init_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)
@property
def downsample_factor(self):
return 2 ** (len(self.downs) - 1)
def forward(self, x, t, y = None):
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(t)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t, y)
h.append(x)
x = block2(x, t, y)
x = attn(x) + x
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t, y)
x = self.mid_attn(x) + x
x = self.mid_block2(x, t, y)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t, y)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t, y)
x = attn(x) + x
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t, y)
return self.final_conv(x) |