File size: 11,287 Bytes
ea88892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 |
import torch
import math
import numpy as np
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import DynamicCache
from src.builder import BUILDER
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
def build_mlp(hidden_size, projector_dim, z_dim):
return nn.Sequential(
nn.Linear(hidden_size, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, z_dim),
)
def mask_by_order(mask_len, order, bsz, seq_len):
masking = torch.zeros(bsz, seq_len, device=order.device)
masking = torch.scatter(
masking,
dim=-1,
index=order[:, : mask_len.long()],
src=torch.ones(bsz, seq_len, device=order.device),
).bool()
return masking
class SkyworkUnipic(nn.Module):
def __init__(self, vae, vae_scale, llm, mar, tokenizer, prompt_template, siglip2):
super().__init__()
# VAE
self.vae = BUILDER.build(vae)
self.vae.requires_grad_(False)
self.vae_scale = vae_scale
# LLM
self.llm = BUILDER.build(llm)
self.tokenizer = BUILDER.build(tokenizer)
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
self.prompt_template = prompt_template
# MAR
self.mar = BUILDER.build(mar)
# projection layers
self.proj_in = build_mlp(
hidden_size=self.mar.encoder_embed_dim,
projector_dim=self.llm.config.hidden_size,
z_dim=self.llm.config.hidden_size,
)
self.proj_out = build_mlp(
hidden_size=self.llm.config.hidden_size,
projector_dim=self.llm.config.hidden_size,
z_dim=self.mar.encoder_embed_dim,
)
# siglip
self.siglip2 = BUILDER.build(siglip2)
self.siglip2_proj = build_mlp(
hidden_size=1152,
projector_dim=self.llm.config.hidden_size,
z_dim=self.llm.config.hidden_size,
)
@property
def llm_model(self):
return self.llm.model
@property
def device(self):
return self.llm.device
@property
def dtype(self):
return self.llm.dtype
@property
def gen_seq_len(self):
return self.mar.seq_len
@property
def token_embed_dim(self):
return self.vae.embed_dim * (self.mar.patch_size**2)
@torch.no_grad()
def encode(self, x):
posterior = self.vae.encode(x)
z = posterior.mode().mul_(self.vae_scale)
z = rearrange(
z,
"b c (m p) (n q) -> b m n (c p q)",
p=self.mar.patch_size,
q=self.mar.patch_size,
)
return z
@torch.no_grad()
def decode(self, z):
z /= self.vae_scale
z = rearrange(
z,
"b m n (c p q) -> b c (m p) (n q)",
p=self.mar.patch_size,
q=self.mar.patch_size,
)
x = self.vae.decode(z)
return x
def prepare_forward_input(
self,
x,
inputs_embeds=None,
input_ids=None,
attention_mask=None,
past_key_values=None,
):
b, l, _ = x.shape
attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones(b, l)], dim=1
)
position_ids = torch.cumsum(attention_mask, dim=1) - 1
position_ids[position_ids < 0] = 0
# import pdb; pdb.set_trace()
# prepare context
if past_key_values is not None:
inputs_embeds = x
position_ids = position_ids[:, -l:]
else:
if inputs_embeds is None:
input_ids = input_ids.to(self.device)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
return dict(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
def extract_visual_feature(self, x, mask=None, detach=False):
b, m, n, _ = x.shape
x = x.view(b, m * n, -1)
# x: b mn c
if mask is None:
mask = torch.zeros_like(x[..., 0])
null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
z_enc = self.proj_in(x_enc)
# Move buffers to the end of the image sequence
z_enc = torch.cat(
[z_enc[:, self.mar.buffer_size :], z_enc[:, : self.mar.buffer_size]], dim=1
)
if detach:
x_enc = x_enc.detach()
z_enc = z_enc.detach()
return x_enc, z_enc
def forward_mae_encoder(self, x, mask, detach=False, **context):
b, m, n, _ = x.shape
x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
inputs = self.prepare_forward_input(x=z_enc, **context)
output = self.llm_model(**inputs, return_dict=True)
z_llm = output.last_hidden_state[:, -z_enc.shape[1] :]
# move buffers back to the start of the image sequence
z_llm = torch.cat(
[z_llm[:, -self.mar.buffer_size :], z_llm[:, : -self.mar.buffer_size]],
dim=1,
)
# residual learning
x_enc = x_enc + self.proj_out(z_llm)
return x_enc
@staticmethod
def curtail_cache(past_key_values, cur_len):
for past_key_values_ in past_key_values:
keys, values = past_key_values_
keys.data = keys.data[:, :, :cur_len]
values.data = values.data[:, :, :cur_len]
@torch.no_grad()
def prepare_text_conditions(self, prompt, cfg_prompt="Generate an image."):
all_prompts = [
self.prompt_template["INSTRUCTION"].format(input=prompt),
self.prompt_template["INSTRUCTION"].format(input=cfg_prompt),
]
input_ids = [
self.tokenizer.encode(p, add_special_tokens=True, return_tensors="pt")[0]
for p in all_prompts
]
valid_lens = [len(input_ids_) for input_ids_ in input_ids]
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id
)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, : valid_lens[i]] = True
return dict(
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
)
@torch.no_grad()
def sample(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
num_iter=64,
cfg=1.0,
cfg_schedule="constant",
temperature=1.0,
progress=False,
mask=None,
past_key_values=None,
image_shape=None,
x_con=None,
**kwargs,
):
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
bsz = attention_mask.shape[0]
if cfg != 1.0:
assert bsz % 2 == 0
if image_shape is None:
m = n = int(self.gen_seq_len**0.5)
else:
m, n = image_shape
if mask is None:
mask = torch.ones(bsz, m * n, device=self.device, dtype=self.dtype)
else:
mask = mask.view(bsz, m * n)
tokens = torch.zeros(
bsz, m * n, self.token_embed_dim, device=self.device, dtype=self.dtype
)
orders = self.mar.sample_orders(bsz, seq_len=m * n)
if cfg != 1.0:
orders[bsz // 2 :] = orders[: bsz // 2]
indices = list(range(num_iter))
if progress:
indices = tqdm(indices)
# past key values can be prepared outside (usually in multi-turn editing)
if past_key_values is None:
output = self.llm_model(
inputs_embeds=inputs_embeds,
attention_mask=None,
position_ids=None,
past_key_values=DynamicCache.from_legacy_cache(),
return_dict=True,
use_cache=True,
)
past_key_values = output.past_key_values
# generate latents
for step in indices:
cur_tokens = tokens.clone()
x_enc = self.forward_mae_encoder(
tokens.view(bsz, m, n, -1),
mask.to(self.dtype),
past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
# import pdb; pdb.set_trace()
self.curtail_cache(past_key_values, inputs_embeds.shape[1])
# import pdb; pdb.set_trace()
z = self.mar.forward_mae_decoder(
x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con
)
# mask ratio for the next round, following MaskGIT and MAGE.
mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter)
mask_len = torch.Tensor([np.floor(m * n * mask_ratio)]).to(self.device)
# masks out at least one for the next iteration
mask_len = torch.maximum(
torch.Tensor([1]).to(self.device),
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len),
)
# get masking for next iteration and locations to be predicted in this iteration
mask_next = mask_by_order(mask_len[0], orders, bsz, m * n).to(self.device)
if cfg != 1.0:
mask_next[bsz // 2 :] = mask_next[: bsz // 2]
if step >= num_iter - 1:
mask_to_pred = mask[:bsz].bool()
else:
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
mask = mask_next
# if not cfg == 1.0:
# mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
# sample token latents for this step
z = z[mask_to_pred.nonzero(as_tuple=True)]
# cfg schedule follow Muse
if cfg_schedule == "linear":
cfg_iter = 1 + (cfg - 1) * (m * n - mask_len[0]) / (m * n)
elif cfg_schedule == "constant":
cfg_iter = cfg
else:
raise NotImplementedError
sampled_token_latent = self.mar.diffloss.sample(
z, temperature, cfg_iter
).to(self.dtype)
# if not cfg == 1.0:
# sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
# mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
if cfg != 1.0:
cur_tokens[bsz // 2 :] = cur_tokens[: bsz // 2]
tokens = cur_tokens.clone()
pred = self.decode(tokens.view(bsz, m, n, -1))
if cfg != 1.0:
pred = pred[: bsz // 2]
return pred
|