mahmed10 commited on
Commit
c70c475
·
verified ·
1 Parent(s): 1260432

final upload

Browse files
models/diffloss.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ import math
5
+
6
+ from diffusion import create_diffusion
7
+
8
+
9
+ class DiffLoss(nn.Module):
10
+ """Diffusion Loss"""
11
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12
+ super(DiffLoss, self).__init__()
13
+ self.in_channels = target_channels
14
+ self.net = SimpleMLPAdaLN(
15
+ in_channels=target_channels,
16
+ model_channels=width,
17
+ out_channels=target_channels * 2, # for vlb loss
18
+ z_channels=z_channels,
19
+ num_res_blocks=depth,
20
+ grad_checkpointing=grad_checkpointing
21
+ )
22
+
23
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25
+
26
+ def forward(self, target, z, mask=None):
27
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28
+ model_kwargs = dict(c=z)
29
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30
+ loss = loss_dict["loss"]
31
+ if mask is not None:
32
+ loss = (loss * mask).sum() / mask.sum()
33
+ return loss.mean()
34
+
35
+ def sample(self, z, temperature=1.0, cfg=1.0):
36
+ # diffusion loss sampling
37
+ if not cfg == 1.0:
38
+ noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39
+ noise = torch.cat([noise, noise], dim=0)
40
+ model_kwargs = dict(c=z, cfg_scale=cfg)
41
+ sample_fn = self.net.forward_with_cfg
42
+ else:
43
+ noise = torch.randn(z.shape[0], self.in_channels).cuda()
44
+ model_kwargs = dict(c=z)
45
+ sample_fn = self.net.forward
46
+
47
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
48
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49
+ temperature=temperature
50
+ )
51
+
52
+ return sampled_token_latent
53
+
54
+
55
+ def modulate(x, shift, scale):
56
+ return x * (1 + scale) + shift
57
+
58
+
59
+ class TimestepEmbedder(nn.Module):
60
+ """
61
+ Embeds scalar timesteps into vector representations.
62
+ """
63
+ def __init__(self, hidden_size, frequency_embedding_size=256):
64
+ super().__init__()
65
+ self.mlp = nn.Sequential(
66
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size, bias=True),
69
+ )
70
+ self.frequency_embedding_size = frequency_embedding_size
71
+
72
+ @staticmethod
73
+ def timestep_embedding(t, dim, max_period=10000):
74
+ """
75
+ Create sinusoidal timestep embeddings.
76
+ :param t: a 1-D Tensor of N indices, one per batch element.
77
+ These may be fractional.
78
+ :param dim: the dimension of the output.
79
+ :param max_period: controls the minimum frequency of the embeddings.
80
+ :return: an (N, D) Tensor of positional embeddings.
81
+ """
82
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83
+ half = dim // 2
84
+ freqs = torch.exp(
85
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86
+ ).to(device=t.device)
87
+ args = t[:, None].float() * freqs[None]
88
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89
+ if dim % 2:
90
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91
+ return embedding
92
+
93
+ def forward(self, t):
94
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95
+ t_emb = self.mlp(t_freq)
96
+ return t_emb
97
+
98
+
99
+ class ResBlock(nn.Module):
100
+ """
101
+ A residual block that can optionally change the number of channels.
102
+ :param channels: the number of input channels.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ channels
108
+ ):
109
+ super().__init__()
110
+ self.channels = channels
111
+
112
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113
+ self.mlp = nn.Sequential(
114
+ nn.Linear(channels, channels, bias=True),
115
+ nn.SiLU(),
116
+ nn.Linear(channels, channels, bias=True),
117
+ )
118
+
119
+ self.adaLN_modulation = nn.Sequential(
120
+ nn.SiLU(),
121
+ nn.Linear(channels, 3 * channels, bias=True)
122
+ )
123
+
124
+ def forward(self, x, y):
125
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127
+ h = self.mlp(h)
128
+ return x + gate_mlp * h
129
+
130
+
131
+ class FinalLayer(nn.Module):
132
+ """
133
+ The final layer adopted from DiT.
134
+ """
135
+ def __init__(self, model_channels, out_channels):
136
+ super().__init__()
137
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
139
+ self.adaLN_modulation = nn.Sequential(
140
+ nn.SiLU(),
141
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
142
+ )
143
+
144
+ def forward(self, x, c):
145
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146
+ x = modulate(self.norm_final(x), shift, scale)
147
+ x = self.linear(x)
148
+ return x
149
+
150
+
151
+ class SimpleMLPAdaLN(nn.Module):
152
+ """
153
+ The MLP for Diffusion Loss.
154
+ :param in_channels: channels in the input Tensor.
155
+ :param model_channels: base channel count for the model.
156
+ :param out_channels: channels in the output Tensor.
157
+ :param z_channels: channels in the condition.
158
+ :param num_res_blocks: number of residual blocks per downsample.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ model_channels,
165
+ out_channels,
166
+ z_channels,
167
+ num_res_blocks,
168
+ grad_checkpointing=False
169
+ ):
170
+ super().__init__()
171
+
172
+ self.in_channels = in_channels
173
+ self.model_channels = model_channels
174
+ self.out_channels = out_channels
175
+ self.num_res_blocks = num_res_blocks
176
+ self.grad_checkpointing = grad_checkpointing
177
+
178
+ self.time_embed = TimestepEmbedder(model_channels)
179
+ self.cond_embed = nn.Linear(z_channels, model_channels)
180
+
181
+ self.input_proj = nn.Linear(in_channels, model_channels)
182
+
183
+ res_blocks = []
184
+ for i in range(num_res_blocks):
185
+ res_blocks.append(ResBlock(
186
+ model_channels,
187
+ ))
188
+
189
+ self.res_blocks = nn.ModuleList(res_blocks)
190
+ self.final_layer = FinalLayer(model_channels, out_channels)
191
+
192
+ self.initialize_weights()
193
+
194
+ def initialize_weights(self):
195
+ def _basic_init(module):
196
+ if isinstance(module, nn.Linear):
197
+ torch.nn.init.xavier_uniform_(module.weight)
198
+ if module.bias is not None:
199
+ nn.init.constant_(module.bias, 0)
200
+ self.apply(_basic_init)
201
+
202
+ # Initialize timestep embedding MLP
203
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205
+
206
+ # Zero-out adaLN modulation layers
207
+ for block in self.res_blocks:
208
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210
+
211
+ # Zero-out output layers
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214
+ nn.init.constant_(self.final_layer.linear.weight, 0)
215
+ nn.init.constant_(self.final_layer.linear.bias, 0)
216
+
217
+ def forward(self, x, t, c):
218
+ """
219
+ Apply the model to an input batch.
220
+ :param x: an [N x C] Tensor of inputs.
221
+ :param t: a 1-D batch of timesteps.
222
+ :param c: conditioning from AR transformer.
223
+ :return: an [N x C] Tensor of outputs.
224
+ """
225
+ x = self.input_proj(x)
226
+ t = self.time_embed(t)
227
+ c = self.cond_embed(c)
228
+
229
+ y = t + c
230
+
231
+ if self.grad_checkpointing and not torch.jit.is_scripting():
232
+ for block in self.res_blocks:
233
+ x = checkpoint(block, x, y)
234
+ else:
235
+ for block in self.res_blocks:
236
+ x = block(x, y)
237
+
238
+ return self.final_layer(x, y)
239
+
240
+ def forward_with_cfg(self, x, t, c, cfg_scale):
241
+ half = x[: len(x) // 2]
242
+ combined = torch.cat([half, half], dim=0)
243
+ model_out = self.forward(combined, t, c)
244
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
245
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
246
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
247
+ eps = torch.cat([half_eps, half_eps], dim=0)
248
+ return torch.cat([eps, rest], dim=1)
models/mar.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import scipy.stats as stats
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from timm.models.vision_transformer import Block
12
+
13
+ from models.diffloss import DiffLoss
14
+
15
+
16
+ def mask_by_order(mask_len, order, bsz, seq_len):
17
+ masking = torch.zeros(bsz, seq_len).cuda()
18
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
19
+ return masking
20
+
21
+
22
+ class MAR(nn.Module):
23
+ """ Masked Autoencoder with VisionTransformer backbone
24
+ """
25
+ def __init__(self, img_size=256, vae_stride=16, patch_size=1,
26
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
27
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
28
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
29
+ vae_embed_dim=16,
30
+ mask_ratio_min=0.7,
31
+ label_drop_prob=0.1,
32
+ attn_dropout=0.1,
33
+ proj_dropout=0.1,
34
+ buffer_size=64,
35
+ diffloss_d=3,
36
+ diffloss_w=1024,
37
+ num_sampling_steps='100',
38
+ diffusion_batch_mul=4,
39
+ grad_checkpointing=False,
40
+ ):
41
+ super().__init__()
42
+
43
+ # --------------------------------------------------------------------------
44
+ # VAE and patchify specifics
45
+ self.vae_embed_dim = vae_embed_dim
46
+
47
+ self.img_size = img_size
48
+ self.vae_stride = vae_stride
49
+ self.patch_size = patch_size
50
+ self.seq_h = self.seq_w = img_size // vae_stride // patch_size
51
+ self.seq_len = self.seq_h * self.seq_w
52
+ self.token_embed_dim = vae_embed_dim * patch_size**2
53
+ self.grad_checkpointing = grad_checkpointing
54
+
55
+ # --------------------------------------------------------------------------
56
+ # image drop
57
+ self.label_drop_prob = label_drop_prob
58
+ # Fake class embedding for CFG's unconditional generation
59
+ # self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
60
+
61
+ # --------------------------------------------------------------------------
62
+ # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
63
+ self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
64
+
65
+ # --------------------------------------------------------------------------
66
+ # MAR encoder specifics
67
+ self.z_proj1 = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
68
+ self.z_proj2 = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
69
+ self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
70
+ self.buffer_size = buffer_size
71
+ self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, 2 * self.seq_len, encoder_embed_dim))
72
+
73
+ self.encoder_blocks = nn.ModuleList([
74
+ Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
75
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
76
+ self.encoder_norm = norm_layer(encoder_embed_dim)
77
+
78
+ # --------------------------------------------------------------------------
79
+ # MAR decoder specifics
80
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
81
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
82
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, 2 * self.seq_len, decoder_embed_dim))
83
+
84
+ self.decoder_blocks = nn.ModuleList([
85
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
86
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
87
+
88
+ self.decoder_norm = norm_layer(decoder_embed_dim)
89
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, 2*self.seq_len, decoder_embed_dim))
90
+
91
+ self.initialize_weights()
92
+
93
+ # --------------------------------------------------------------------------
94
+ # Diffusion Loss
95
+ self.diffloss = DiffLoss(
96
+ target_channels=self.token_embed_dim,
97
+ z_channels=decoder_embed_dim,
98
+ width=diffloss_w,
99
+ depth=diffloss_d,
100
+ num_sampling_steps=num_sampling_steps,
101
+ grad_checkpointing=grad_checkpointing
102
+ )
103
+ self.diffusion_batch_mul = diffusion_batch_mul
104
+
105
+ def initialize_weights(self):
106
+ # parameters
107
+ # torch.nn.init.normal_(self.class_emb.weight, std=.02)
108
+ # torch.nn.init.normal_(self.fake_latent, std=.02)
109
+ torch.nn.init.normal_(self.mask_token, std=.02)
110
+ torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
111
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
112
+ torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
113
+
114
+ # initialize nn.Linear and nn.LayerNorm
115
+ self.apply(self._init_weights)
116
+
117
+ def _init_weights(self, m):
118
+ if isinstance(m, nn.Linear):
119
+ # we use xavier_uniform following official JAX ViT:
120
+ torch.nn.init.xavier_uniform_(m.weight)
121
+ if isinstance(m, nn.Linear) and m.bias is not None:
122
+ nn.init.constant_(m.bias, 0)
123
+ elif isinstance(m, nn.LayerNorm):
124
+ if m.bias is not None:
125
+ nn.init.constant_(m.bias, 0)
126
+ if m.weight is not None:
127
+ nn.init.constant_(m.weight, 1.0)
128
+
129
+ def patchify(self, x):
130
+ bsz, c, h, w = x.shape
131
+ p = self.patch_size
132
+ h_, w_ = h // p, w // p
133
+
134
+ x = x.reshape(bsz, c, h_, p, w_, p)
135
+ x = torch.einsum('nchpwq->nhwcpq', x)
136
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
137
+ return x # [n, l, d]
138
+
139
+ def unpatchify(self, x):
140
+ bsz = x.shape[0]
141
+ p = self.patch_size
142
+ c = self.vae_embed_dim
143
+ h_, w_ = self.seq_h, self.seq_w
144
+
145
+ x = x.reshape(bsz, h_, w_, c, p, p)
146
+ x = torch.einsum('nhwcpq->nchpwq', x)
147
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
148
+ return x # [n, c, h, w]
149
+
150
+ def sample_orders(self, bsz):
151
+ # generate a batch of random generation orders
152
+ orders = []
153
+ for _ in range(bsz):
154
+ order = np.array(list(range(self.seq_len)))
155
+ np.random.shuffle(order)
156
+ orders.append(order)
157
+ orders = torch.Tensor(np.array(orders)).cuda().long()
158
+ return orders
159
+
160
+ def random_masking(self, x, orders):
161
+ # generate token mask
162
+ bsz, seq_len, embed_dim = x.shape
163
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
164
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
165
+ mask = torch.zeros(bsz, seq_len, device=x.device)
166
+ mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
167
+ src=torch.ones(bsz, seq_len, device=x.device))
168
+ return mask
169
+
170
+ def forward_mae_encoder(self, x, mask, y):
171
+ x = self.z_proj1(x)
172
+ y = self.z_proj2(y)
173
+ bsz, seq_len, embed_dim = y.shape
174
+
175
+ # concat buffer
176
+ x = torch.cat([x, y], dim=1)
177
+ mask_with_buffer = mask #torch.cat([torch.zeros(y.size(0), self.seq_len, device=y.device), mask], dim=1)
178
+
179
+ # # random drop class embedding during training
180
+ # if self.training:
181
+ # drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
182
+ # drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
183
+ # class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
184
+
185
+ # x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
186
+
187
+ # encoder position embedding
188
+ x = x + self.encoder_pos_embed_learned
189
+ x = self.z_proj_ln(x)
190
+
191
+ # dropping
192
+ x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
193
+
194
+ # apply Transformer blocks
195
+ if self.grad_checkpointing and not torch.jit.is_scripting():
196
+ for block in self.encoder_blocks:
197
+ x = checkpoint(block, x)
198
+ else:
199
+ for block in self.encoder_blocks:
200
+ x = block(x)
201
+ x = self.encoder_norm(x)
202
+
203
+ return x
204
+
205
+ def forward_mae_decoder(self, x, mask):
206
+
207
+ x = self.decoder_embed(x)
208
+ mask_with_buffer = mask#cleartorch.cat([torch.zeros(x.size(0), self.seq_len, device=x.device), mask], dim=1)
209
+
210
+ # pad mask tokens
211
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
212
+ x_after_pad = mask_tokens.clone()
213
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
214
+
215
+ # decoder position embedding
216
+ x = x_after_pad + self.decoder_pos_embed_learned
217
+
218
+ # apply Transformer blocks
219
+ if self.grad_checkpointing and not torch.jit.is_scripting():
220
+ for block in self.decoder_blocks:
221
+ x = checkpoint(block, x)
222
+ else:
223
+ for block in self.decoder_blocks:
224
+ x = block(x)
225
+ x = self.decoder_norm(x)
226
+
227
+ # x = x [:, self.seq_len:]
228
+ x = x + self.diffusion_pos_embed_learned
229
+ return x
230
+
231
+ def forward_loss(self, z, target, mask):
232
+ bsz, seq_len, _ = target.shape
233
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
234
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
235
+ mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
236
+ loss = self.diffloss(z=z, target=target, mask=mask)
237
+ return loss
238
+
239
+ def forward(self, imgs, labels):
240
+
241
+ # class embed
242
+ # class_embedding = self.class_emb(labels)
243
+
244
+ # patchify and mask (drop) tokens
245
+ x = self.patchify(imgs)
246
+ y = self.patchify(labels)
247
+ gt_latents = torch.cat([x, y], dim=1).clone().detach()
248
+ orders = self.sample_orders(bsz=y.size(0))
249
+ mask = self.random_masking(x, orders)
250
+ mask = torch.cat([torch.zeros(y.size(0), self.seq_len).cuda(), mask], dim=1)
251
+ # mask = torch.cat([torch.zeros(y.size(0), self.seq_len), torch.ones(y.size(0), self.seq_len)], dim=1)
252
+
253
+ # mae encoder
254
+ x = self.forward_mae_encoder(x, mask, y)
255
+
256
+ # mae decoder
257
+ z = self.forward_mae_decoder(x, mask)
258
+
259
+ # diffloss
260
+ loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
261
+
262
+ return loss
263
+
264
+ def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
265
+
266
+ # init and sample generation orders
267
+ mask = torch.ones(bsz, self.seq_len).cuda()
268
+ tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
269
+ orders = self.sample_orders(bsz)
270
+
271
+ indices = list(range(num_iter))
272
+ if progress:
273
+ indices = tqdm(indices)
274
+ # generate latents
275
+ for step in indices:
276
+ cur_tokens = tokens.clone()
277
+
278
+ # class embedding and CFG
279
+ if labels is not None:
280
+ class_embedding = self.class_emb(labels)
281
+ else:
282
+ class_embedding = self.fake_latent.repeat(bsz, 1)
283
+ if not cfg == 1.0:
284
+ tokens = torch.cat([tokens, tokens], dim=0)
285
+ class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
286
+ mask = torch.cat([mask, mask], dim=0)
287
+
288
+ # mae encoder
289
+ x = self.forward_mae_encoder(tokens, mask, class_embedding)
290
+
291
+ # mae decoder
292
+ z = self.forward_mae_decoder(x, mask)
293
+
294
+ # mask ratio for the next round, following MaskGIT and MAGE.
295
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
296
+ mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
297
+
298
+ # masks out at least one for the next iteration
299
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
300
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
301
+
302
+ # get masking for next iteration and locations to be predicted in this iteration
303
+ mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
304
+ if step >= num_iter - 1:
305
+ mask_to_pred = mask[:bsz].bool()
306
+ else:
307
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
308
+ mask = mask_next
309
+ if not cfg == 1.0:
310
+ mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
311
+
312
+ # sample token latents for this step
313
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
314
+ # cfg schedule follow Muse
315
+ if cfg_schedule == "linear":
316
+ cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
317
+ elif cfg_schedule == "constant":
318
+ cfg_iter = cfg
319
+ else:
320
+ raise NotImplementedError
321
+ sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
322
+ if not cfg == 1.0:
323
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
324
+ mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
325
+
326
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
327
+ tokens = cur_tokens.clone()
328
+
329
+ # unpatchify
330
+ tokens = self.unpatchify(tokens)
331
+ return tokens
332
+
333
+
334
+ def mar_base(**kwargs):
335
+ model = MAR(
336
+ encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
337
+ decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
338
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
339
+ return model
340
+
341
+
342
+ def mar_large(**kwargs):
343
+ model = MAR(
344
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
345
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
346
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
347
+ return model
348
+
349
+
350
+ def mar_huge(**kwargs):
351
+ model = MAR(
352
+ encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
353
+ decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
354
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
355
+ return model
models/vae.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm.util import instantiate_from_config
10
+
11
+
12
+ class AutoencoderKL(nn.Module):
13
+ def __init__(self,
14
+ ddconfig,
15
+ embed_dim,
16
+ ckpt_path=None,
17
+ ignore_keys=[],
18
+ image_key="image",
19
+ colorize_nlabels=None,
20
+ monitor=None,
21
+ ):
22
+ super().__init__()
23
+ self.image_key = image_key
24
+ self.encoder = Encoder(**ddconfig)
25
+ self.decoder = Decoder(**ddconfig)
26
+ assert ddconfig["double_z"]
27
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
28
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
29
+ self.embed_dim = embed_dim
30
+ if colorize_nlabels is not None:
31
+ assert type(colorize_nlabels)==int
32
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
33
+ if monitor is not None:
34
+ self.monitor = monitor
35
+ if ckpt_path is not None:
36
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37
+
38
+ def init_from_ckpt(self, path, ignore_keys=list()):
39
+ sd = torch.load(path, map_location="cpu")["state_dict"]
40
+ keys = list(sd.keys())
41
+ for k in keys:
42
+ for ik in ignore_keys:
43
+ if k.startswith(ik):
44
+ print("Deleting key {} from state_dict.".format(k))
45
+ del sd[k]
46
+ self.load_state_dict(sd, strict=False)
47
+ print(f"Restored from {path}")
48
+
49
+ def encode(self, x):
50
+ h = self.encoder(x)
51
+ moments = self.quant_conv(h)
52
+ posterior = DiagonalGaussianDistribution(moments)
53
+ return posterior
54
+
55
+ def decode(self, z):
56
+ z = self.post_quant_conv(z)
57
+ dec = self.decoder(z)
58
+ return dec
59
+
60
+ def forward(self, input, sample_posterior=True):
61
+ posterior = self.encode(input)
62
+ if sample_posterior:
63
+ z = posterior.sample()
64
+ else:
65
+ z = posterior.mode()
66
+ dec = self.decode(z)
67
+ return dec, posterior
taming/modules/autoencoder/lpips/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
util/crop.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def center_crop_arr(pil_image, image_size):
6
+ """
7
+ Center cropping implementation from ADM.
8
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
9
+ """
10
+ while min(*pil_image.size) >= 2 * image_size:
11
+ pil_image = pil_image.resize(
12
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
13
+ )
14
+
15
+ scale = image_size / min(*pil_image.size)
16
+ pil_image = pil_image.resize(
17
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
18
+ )
19
+
20
+ arr = np.array(pil_image)
21
+ crop_y = (arr.shape[0] - image_size) // 2
22
+ crop_x = (arr.shape[1] - image_size) // 2
23
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
util/download.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import requests
4
+
5
+
6
+ def download_pretrained_vae(overwrite=False):
7
+ download_path = "pretrained_models/vae/kl16.ckpt"
8
+ if not os.path.exists(download_path) or overwrite:
9
+ headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
10
+ os.makedirs("pretrained_models/vae", exist_ok=True)
11
+ r = requests.get("https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0", stream=True, headers=headers)
12
+ print("Downloading KL-16 VAE...")
13
+ with open(download_path, 'wb') as f:
14
+ for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=254):
15
+ if chunk:
16
+ f.write(chunk)
17
+
18
+
19
+ def download_pretrained_marb(overwrite=False):
20
+ download_path = "pretrained_models/mar/mar_base/checkpoint-last.pth"
21
+ if not os.path.exists(download_path) or overwrite:
22
+ headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
23
+ os.makedirs("pretrained_models/mar/mar_base", exist_ok=True)
24
+ r = requests.get("https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0", stream=True, headers=headers)
25
+ print("Downloading MAR-B...")
26
+ with open(download_path, 'wb') as f:
27
+ for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1587):
28
+ if chunk:
29
+ f.write(chunk)
30
+
31
+
32
+ def download_pretrained_marl(overwrite=False):
33
+ download_path = "pretrained_models/mar/mar_large/checkpoint-last.pth"
34
+ if not os.path.exists(download_path) or overwrite:
35
+ headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
36
+ os.makedirs("pretrained_models/mar/mar_large", exist_ok=True)
37
+ r = requests.get("https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0", stream=True, headers=headers)
38
+ print("Downloading MAR-L...")
39
+ with open(download_path, 'wb') as f:
40
+ for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=3650):
41
+ if chunk:
42
+ f.write(chunk)
43
+
44
+
45
+ def download_pretrained_marh(overwrite=False):
46
+ download_path = "pretrained_models/mar/mar_huge/checkpoint-last.pth"
47
+ if not os.path.exists(download_path) or overwrite:
48
+ headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
49
+ os.makedirs("pretrained_models/mar/mar_huge", exist_ok=True)
50
+ r = requests.get("https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0", stream=True, headers=headers)
51
+ print("Downloading MAR-H...")
52
+ with open(download_path, 'wb') as f:
53
+ for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=7191):
54
+ if chunk:
55
+ f.write(chunk)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ download_pretrained_vae()
60
+ download_pretrained_marb()
61
+ download_pretrained_marl()
62
+ download_pretrained_marh()
util/loader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision.datasets as datasets
6
+
7
+
8
+ class ImageFolderWithFilename(datasets.ImageFolder):
9
+ def __getitem__(self, index: int):
10
+ """
11
+ Args:
12
+ index (int): Index
13
+
14
+ Returns:
15
+ tuple: (sample, target, filename).
16
+ """
17
+ path, target = self.samples[index]
18
+ sample = self.loader(path)
19
+ if self.transform is not None:
20
+ sample = self.transform(sample)
21
+ if self.target_transform is not None:
22
+ target = self.target_transform(target)
23
+
24
+ filename = path.split(os.path.sep)[-2:]
25
+ filename = os.path.join(*filename)
26
+ return sample, target, filename
27
+
28
+
29
+ class CachedFolder(datasets.DatasetFolder):
30
+ def __init__(
31
+ self,
32
+ root: str,
33
+ ):
34
+ super().__init__(
35
+ root,
36
+ loader=None,
37
+ extensions=(".npz",),
38
+ )
39
+
40
+ def __getitem__(self, index: int):
41
+ """
42
+ Args:
43
+ index (int): Index
44
+
45
+ Returns:
46
+ tuple: (moments, target).
47
+ """
48
+ path, target = self.samples[index]
49
+
50
+ data = np.load(path)
51
+ if torch.rand(1) < 0.5: # randomly hflip
52
+ moments = data['moments']
53
+ else:
54
+ moments = data['moments_flip']
55
+
56
+ return moments, target
util/lr_sched.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ def adjust_learning_rate(optimizer, epoch, args):
5
+ """Decay the learning rate with half-cycle cosine after warmup"""
6
+ if epoch < args.warmup_epochs:
7
+ lr = args.lr * epoch / args.warmup_epochs
8
+ else:
9
+ if args.lr_schedule == "constant":
10
+ lr = args.lr
11
+ elif args.lr_schedule == "cosine":
12
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
13
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
14
+ else:
15
+ raise NotImplementedError
16
+ for param_group in optimizer.param_groups:
17
+ if "lr_scale" in param_group:
18
+ param_group["lr"] = lr * param_group["lr_scale"]
19
+ else:
20
+ param_group["lr"] = lr
21
+ return lr
util/misc.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import datetime
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
11
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
12
+
13
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
14
+ from torch._six import inf
15
+ else:
16
+ from torch import inf
17
+ import copy
18
+
19
+
20
+ class SmoothedValue(object):
21
+ """Track a series of values and provide access to smoothed values over a
22
+ window or the global series average.
23
+ """
24
+
25
+ def __init__(self, window_size=20, fmt=None):
26
+ if fmt is None:
27
+ fmt = "{median:.4f} ({global_avg:.4f})"
28
+ self.deque = deque(maxlen=window_size)
29
+ self.total = 0.0
30
+ self.count = 0
31
+ self.fmt = fmt
32
+
33
+ def update(self, value, n=1):
34
+ self.deque.append(value)
35
+ self.count += n
36
+ self.total += value * n
37
+
38
+ def synchronize_between_processes(self):
39
+ """
40
+ Warning: does not synchronize the deque!
41
+ """
42
+ if not is_dist_avail_and_initialized():
43
+ return
44
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
45
+ dist.barrier()
46
+ dist.all_reduce(t)
47
+ t = t.tolist()
48
+ self.count = int(t[0])
49
+ self.total = t[1]
50
+
51
+ @property
52
+ def median(self):
53
+ d = torch.tensor(list(self.deque))
54
+ return d.median().item()
55
+
56
+ @property
57
+ def avg(self):
58
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
59
+ return d.mean().item()
60
+
61
+ @property
62
+ def global_avg(self):
63
+ return self.total / self.count
64
+
65
+ @property
66
+ def max(self):
67
+ return max(self.deque)
68
+
69
+ @property
70
+ def value(self):
71
+ return self.deque[-1]
72
+
73
+ def __str__(self):
74
+ return self.fmt.format(
75
+ median=self.median,
76
+ avg=self.avg,
77
+ global_avg=self.global_avg,
78
+ max=self.max,
79
+ value=self.value)
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if v is None:
90
+ continue
91
+ if isinstance(v, torch.Tensor):
92
+ v = v.item()
93
+ assert isinstance(v, (float, int))
94
+ self.meters[k].update(v)
95
+
96
+ def __getattr__(self, attr):
97
+ if attr in self.meters:
98
+ return self.meters[attr]
99
+ if attr in self.__dict__:
100
+ return self.__dict__[attr]
101
+ raise AttributeError("'{}' object has no attribute '{}'".format(
102
+ type(self).__name__, attr))
103
+
104
+ def __str__(self):
105
+ loss_str = []
106
+ for name, meter in self.meters.items():
107
+ loss_str.append(
108
+ "{}: {}".format(name, str(meter))
109
+ )
110
+ return self.delimiter.join(loss_str)
111
+
112
+ def synchronize_between_processes(self):
113
+ for meter in self.meters.values():
114
+ meter.synchronize_between_processes()
115
+
116
+ def add_meter(self, name, meter):
117
+ self.meters[name] = meter
118
+
119
+ def log_every(self, iterable, print_freq, header=None):
120
+ i = 0
121
+ if not header:
122
+ header = ''
123
+ start_time = time.time()
124
+ end = time.time()
125
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
126
+ data_time = SmoothedValue(fmt='{avg:.4f}')
127
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
128
+ log_msg = [
129
+ header,
130
+ '[{0' + space_fmt + '}/{1}]',
131
+ 'eta: {eta}',
132
+ '{meters}',
133
+ 'time: {time}',
134
+ 'data: {data}'
135
+ ]
136
+ if torch.cuda.is_available():
137
+ log_msg.append('max mem: {memory:.0f}')
138
+ log_msg = self.delimiter.join(log_msg)
139
+ MB = 1024.0 * 1024.0
140
+ for obj in iterable:
141
+ data_time.update(time.time() - end)
142
+ yield obj
143
+ iter_time.update(time.time() - end)
144
+ if i % print_freq == 0 or i == len(iterable) - 1:
145
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
146
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
147
+ if torch.cuda.is_available():
148
+ print(log_msg.format(
149
+ i, len(iterable), eta=eta_string,
150
+ meters=str(self),
151
+ time=str(iter_time), data=str(data_time),
152
+ memory=torch.cuda.max_memory_allocated() / MB))
153
+ else:
154
+ print(log_msg.format(
155
+ i, len(iterable), eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time), data=str(data_time)))
158
+ i += 1
159
+ end = time.time()
160
+ total_time = time.time() - start_time
161
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162
+ print('{} Total time: {} ({:.4f} s / it)'.format(
163
+ header, total_time_str, total_time / len(iterable)))
164
+
165
+
166
+ def setup_for_distributed(is_master):
167
+ """
168
+ This function disables printing when not in master process
169
+ """
170
+ builtin_print = builtins.print
171
+
172
+ def print(*args, **kwargs):
173
+ force = kwargs.pop('force', False)
174
+ force = force or (get_world_size() > 8)
175
+ if is_master or force:
176
+ now = datetime.datetime.now().time()
177
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
178
+ builtin_print(*args, **kwargs)
179
+
180
+ builtins.print = print
181
+
182
+
183
+ def is_dist_avail_and_initialized():
184
+ if not dist.is_available():
185
+ return False
186
+ if not dist.is_initialized():
187
+ return False
188
+ return True
189
+
190
+
191
+ def get_world_size():
192
+ if not is_dist_avail_and_initialized():
193
+ return 1
194
+ return dist.get_world_size()
195
+
196
+
197
+ def get_rank():
198
+ if not is_dist_avail_and_initialized():
199
+ return 0
200
+ return dist.get_rank()
201
+
202
+
203
+ def is_main_process():
204
+ return get_rank() == 0
205
+
206
+
207
+ def save_on_master(*args, **kwargs):
208
+ if is_main_process():
209
+ torch.save(*args, **kwargs)
210
+
211
+
212
+ def init_distributed_mode(args):
213
+ if args.dist_on_itp:
214
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
215
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
216
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
217
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
218
+ os.environ['LOCAL_RANK'] = str(args.gpu)
219
+ os.environ['RANK'] = str(args.rank)
220
+ os.environ['WORLD_SIZE'] = str(args.world_size)
221
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
222
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
223
+ args.rank = int(os.environ["RANK"])
224
+ args.world_size = int(os.environ['WORLD_SIZE'])
225
+ args.gpu = int(os.environ['LOCAL_RANK'])
226
+ elif 'SLURM_PROCID' in os.environ:
227
+ args.rank = int(os.environ['SLURM_PROCID'])
228
+ args.gpu = args.rank % torch.cuda.device_count()
229
+ else:
230
+ print('Not using distributed mode')
231
+ setup_for_distributed(is_master=True) # hack
232
+ args.distributed = False
233
+ return
234
+
235
+ args.distributed = True
236
+
237
+ torch.cuda.set_device(args.gpu)
238
+ args.dist_backend = 'nccl'
239
+ print('| distributed init (rank {}): {}, gpu {}'.format(
240
+ args.rank, args.dist_url, args.gpu), flush=True)
241
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
242
+ world_size=args.world_size, rank=args.rank)
243
+ torch.distributed.barrier()
244
+ setup_for_distributed(args.rank == 0)
245
+
246
+
247
+ class NativeScalerWithGradNormCount:
248
+ state_dict_key = "amp_scaler"
249
+
250
+ def __init__(self):
251
+ self._scaler = torch.cuda.amp.GradScaler()
252
+
253
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
254
+ self._scaler.scale(loss).backward(create_graph=create_graph)
255
+ if update_grad:
256
+ if clip_grad is not None:
257
+ assert parameters is not None
258
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
259
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
260
+ else:
261
+ self._scaler.unscale_(optimizer)
262
+ norm = get_grad_norm_(parameters)
263
+ self._scaler.step(optimizer)
264
+ self._scaler.update()
265
+ else:
266
+ norm = None
267
+ return norm
268
+
269
+ def state_dict(self):
270
+ return self._scaler.state_dict()
271
+
272
+ def load_state_dict(self, state_dict):
273
+ self._scaler.load_state_dict(state_dict)
274
+
275
+
276
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
277
+ if isinstance(parameters, torch.Tensor):
278
+ parameters = [parameters]
279
+ parameters = [p for p in parameters if p.grad is not None]
280
+ norm_type = float(norm_type)
281
+ if len(parameters) == 0:
282
+ return torch.tensor(0.)
283
+ device = parameters[0].grad.device
284
+ if norm_type == inf:
285
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
286
+ else:
287
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
288
+ return total_norm
289
+
290
+
291
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
292
+ decay = []
293
+ no_decay = []
294
+ for name, param in model.named_parameters():
295
+ if not param.requires_grad:
296
+ continue # frozen weights
297
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
298
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
299
+ else:
300
+ decay.append(param)
301
+ return [
302
+ {'params': no_decay, 'weight_decay': 0.},
303
+ {'params': decay, 'weight_decay': weight_decay}]
304
+
305
+
306
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
307
+ if epoch_name is None:
308
+ epoch_name = str(epoch)
309
+ output_dir = Path(args.output_dir)
310
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
311
+
312
+ # ema
313
+ if ema_params is not None:
314
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
315
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
316
+ assert name in ema_state_dict
317
+ ema_state_dict[name] = ema_params[i]
318
+ else:
319
+ ema_state_dict = None
320
+
321
+ to_save = {
322
+ 'model': model_without_ddp.state_dict(),
323
+ 'model_ema': ema_state_dict,
324
+ 'optimizer': optimizer.state_dict(),
325
+ 'epoch': epoch,
326
+ 'scaler': loss_scaler.state_dict(),
327
+ 'args': args,
328
+ }
329
+ save_on_master(to_save, checkpoint_path)
330
+
331
+
332
+ def all_reduce_mean(x):
333
+ world_size = get_world_size()
334
+ if world_size > 1:
335
+ x_reduce = torch.tensor(x).cuda()
336
+ dist.all_reduce(x_reduce)
337
+ x_reduce /= world_size
338
+ return x_reduce.item()
339
+ else:
340
+ return x