|
from re import A |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
import math |
|
import einops |
|
import torch.utils.checkpoint |
|
from functools import partial |
|
import open_clip |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch.nn.functional as F |
|
import timm |
|
from timm.models.layers import trunc_normal_, Mlp |
|
from .sigmoid.module import LayerNorm, RMSNorm, AdaRMSNorm, TDRMSNorm, QKNorm, TimeDependentParameter |
|
from .common_layers import Linear, EvenDownInterpolate, ChannelFirst, ChannelLast, Embedding |
|
from .axial_rope import AxialRoPE, make_axial_pos |
|
from .trans_autoencoder import TransEncoder, Adaptor |
|
|
|
def check_zip(*args): |
|
args = [list(arg) for arg in args] |
|
length = len(args[0]) |
|
for arg in args: |
|
assert len(arg) == length |
|
return zip(*args) |
|
|
|
class PixelShuffleUpsample(nn.Module): |
|
def __init__(self, dim_in, dim_out, ratio = 2): |
|
super().__init__() |
|
self.ratio = ratio |
|
self.kernel = Linear(dim_in, dim_out * self.ratio * self.ratio) |
|
|
|
def forward(self, x): |
|
x = self.kernel(x) |
|
B, H, W, C = x.shape |
|
x = x.reshape(B, H, W, self.ratio, self.ratio, C // self.ratio // self.ratio) |
|
x = x.transpose(2, 3) |
|
x = x.reshape(B, H * self.ratio, W * self.ratio, C // self.ratio // self.ratio) |
|
return x |
|
|
|
class PositionEmbeddings(nn.Module): |
|
def __init__(self, max_height, max_width, dim): |
|
super().__init__() |
|
self.max_height = max_height |
|
self.max_width = max_width |
|
self.position_embeddings = Embedding(self.max_height * self.max_width, dim) |
|
|
|
def forward(self, x): |
|
B, H, W, C = x.shape |
|
height_idxes = torch.arange(H, device = x.device)[:, None].repeat(1, W) |
|
width_idxes = torch.arange(W, device = x.device)[None, :].repeat(H, 1) |
|
idxes = height_idxes * self.max_width + width_idxes |
|
x = x + self.position_embeddings(idxes[None]) |
|
return x |
|
|
|
class TextPositionEmbeddings(nn.Module): |
|
def __init__(self, num_embeddings, embedding_dim): |
|
super().__init__() |
|
self.embedding = Embedding(num_embeddings, embedding_dim) |
|
|
|
def forward(self, x): |
|
batch_size, num_embeddings, embedding_dim = x.shape |
|
|
|
positions = torch.arange(num_embeddings, device=x.device).unsqueeze(0).expand(batch_size, num_embeddings) |
|
x = x + self.embedding(positions) |
|
return x |
|
|
|
|
|
class MLPBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
if config.norm_type == 'LN': |
|
self.norm_type = 'LN' |
|
self.norm = LayerNorm(config.dim) |
|
elif config.norm_type == 'RMSN': |
|
self.norm_type = 'RMSN' |
|
self.norm = RMSNorm(config.dim) |
|
elif config.norm_type == 'TDRMSN': |
|
self.norm_type = 'TDRMSN' |
|
self.norm = TDRMSNorm(config.dim) |
|
elif config.norm_type == 'ADARMSN': |
|
self.norm_type = 'ADARMSN' |
|
self.norm = AdaRMSNorm(config.dim, config.dim) |
|
self.act = nn.GELU() |
|
self.w0 = Linear(config.dim, config.hidden_dim) |
|
self.w1 = Linear(config.dim, config.hidden_dim) |
|
self.w2 = Linear(config.hidden_dim, config.dim) |
|
|
|
def forward(self, x): |
|
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN': |
|
x = self.norm(x) |
|
elif self.norm_type == 'ADARMSN': |
|
condition = x[:,0] |
|
x = self.norm(x, condition) |
|
x = self.act(self.w0(x)) * self.w1(x) |
|
x = self.w2(x) |
|
return x |
|
|
|
class SelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.dim % config.num_attention_heads == 0 |
|
|
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = config.dim // config.num_attention_heads |
|
|
|
if hasattr(config, "self_att_prompt") and config.self_att_prompt: |
|
self.condition_key_value = Linear(config.clip_dim, 2 * config.dim, bias = False) |
|
|
|
if config.norm_type == 'LN': |
|
self.norm_type = 'LN' |
|
self.norm = LayerNorm(config.dim) |
|
elif config.norm_type == 'RMSN': |
|
self.norm_type = 'RMSN' |
|
self.norm = RMSNorm(config.dim) |
|
elif config.norm_type == 'TDRMSN': |
|
self.norm_type = 'TDRMSN' |
|
self.norm = TDRMSNorm(config.dim) |
|
elif config.norm_type == 'ADARMSN': |
|
self.norm_type = 'ADARMSN' |
|
self.norm = AdaRMSNorm(config.dim, config.dim) |
|
|
|
self.pe_type = config.pe_type |
|
if config.pe_type == 'Axial_RoPE': |
|
self.pos_emb = AxialRoPE(self.head_dim, self.num_heads) |
|
self.qk_norm = QKNorm(self.num_heads) |
|
|
|
self.query_key_value = Linear(config.dim, 3 * config.dim, bias = False) |
|
self.dense = Linear(config.dim, config.dim) |
|
|
|
def forward(self, x, condition_embeds, condition_masks, pos=None): |
|
B, N, C = x.shape |
|
|
|
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN': |
|
qkv = self.query_key_value(self.norm(x)) |
|
elif self.norm_type == 'ADARMSN': |
|
condition = x[:,0] |
|
qkv = self.query_key_value(self.norm(x, condition)) |
|
q, k, v = qkv.reshape(B, N, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(3, dim = 1) |
|
|
|
if self.pe_type == 'Axial_RoPE': |
|
q = self.pos_emb(self.qk_norm(q), pos) |
|
k = self.pos_emb(self.qk_norm(k), pos) |
|
|
|
if condition_embeds is not None: |
|
_, L, D = condition_embeds.shape |
|
kcvc = self.condition_key_value(condition_embeds) |
|
kc, vc = kcvc.reshape(B, L, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(2, dim = 1) |
|
k = torch.cat([k, kc], dim = 2) |
|
v = torch.cat([v, vc], dim = 2) |
|
mask = torch.cat([torch.ones(B, N, dtype = torch.bool, device = condition_masks.device), condition_masks], dim = -1) |
|
mask = mask[:, None, None, :] |
|
else: |
|
mask = None |
|
|
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask = mask) |
|
x = self.dense(x.permute(0, 2, 1, 3).reshape(B, N, C)) |
|
|
|
return x |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.block1 = SelfAttention(config) |
|
self.block2 = MLPBlock(config) |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
self.gradient_checking = config.gradient_checking |
|
|
|
def forward(self, x, condition_embeds, condition_masks, pos): |
|
if self.gradient_checking: |
|
return torch.utils.checkpoint.checkpoint(self._forward, x, condition_embeds, condition_masks, pos) |
|
else: |
|
return self._forward(x, condition_embeds, condition_masks, pos) |
|
|
|
def _forward(self, x, condition_embeds, condition_masks, pos): |
|
x = x + self.dropout(self.block1(x, condition_embeds, condition_masks, pos)) |
|
x = x + self.dropout(self.block2(x)) |
|
return x |
|
|
|
class ConvNeXtBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.block1 = nn.Sequential( |
|
ChannelFirst(), |
|
nn.Conv2d(config.dim, config.dim, kernel_size = config.kernel_size, padding = config.kernel_size // 2, stride = 1, groups = config.dim), |
|
ChannelLast() |
|
) |
|
self.block2 = MLPBlock(config) |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
self.gradient_checking = config.gradient_checking |
|
|
|
def forward(self, x, condition_embeds, condition_masks, pos): |
|
if self.gradient_checking: |
|
return torch.utils.checkpoint.checkpoint(self._forward, x) |
|
else: |
|
return self._forward(x) |
|
|
|
def _forward(self, x): |
|
x = x + self.dropout(self.block1(x)) |
|
x = x + self.dropout(self.block2(x)) |
|
return x |
|
|
|
|
|
class Stage(nn.Module): |
|
def __init__(self, channels, config, lowres_dim = None, lowres_height = None): |
|
super().__init__() |
|
if config.block_type == "TransformerBlock": |
|
self.encoder_cls = TransformerBlock |
|
elif config.block_type == "ConvNeXtBlock": |
|
self.encoder_cls = ConvNeXtBlock |
|
else: |
|
raise Exception() |
|
|
|
self.pe_type = config.pe_type |
|
|
|
self.input_layer = nn.Sequential( |
|
EvenDownInterpolate(config.image_input_ratio), |
|
nn.Conv2d(channels, config.dim, kernel_size = config.input_feature_ratio, stride = config.input_feature_ratio), |
|
ChannelLast(), |
|
PositionEmbeddings(config.max_height, config.max_width, config.dim) |
|
) |
|
|
|
|
|
if lowres_dim is not None: |
|
ratio = config.max_height // lowres_height |
|
self.upsample = nn.Sequential( |
|
LayerNorm(lowres_dim), |
|
PixelShuffleUpsample(lowres_dim, config.dim, ratio = ratio), |
|
LayerNorm(config.dim), |
|
) |
|
|
|
self.blocks = nn.ModuleList([self.encoder_cls(config) for _ in range(config.num_blocks // 2 * 2 + 1)]) |
|
self.skip_denses = nn.ModuleList([Linear(config.dim * 2, config.dim) for _ in range(config.num_blocks // 2)]) |
|
|
|
self.output_layer = nn.Sequential( |
|
LayerNorm(config.dim), |
|
ChannelFirst(), |
|
nn.Conv2d(config.dim, channels, kernel_size = config.final_kernel_size, padding = config.final_kernel_size // 2), |
|
) |
|
|
|
self.tensor_true = torch.nn.Parameter(torch.tensor([-1.0])) if self.encoder_cls is TransformerBlock else None |
|
self.tensor_false = torch.nn.Parameter(torch.tensor([1.0])) if self.encoder_cls is TransformerBlock else None |
|
|
|
|
|
|
|
|
|
def forward(self, images, lowres_skips = None, condition_context = None, condition_embeds = None, condition_masks = None, null_indicator=None): |
|
if self.pe_type == 'Axial_RoPE' and self.encoder_cls is TransformerBlock: |
|
x = self.input_layer(images) |
|
_, H, W, _ = x.shape |
|
pos = make_axial_pos(H, W) |
|
else: |
|
x = self.input_layer(images) |
|
pos = None |
|
|
|
if lowres_skips is not None: |
|
x = x + self.upsample(lowres_skips) |
|
|
|
if self.encoder_cls is TransformerBlock: |
|
B, H, W, C = x.shape |
|
x = x.reshape(B, H * W, C) |
|
|
|
if null_indicator is not None: |
|
indicator_tensor = torch.where(null_indicator, self.tensor_true, self.tensor_false) |
|
indicator_tensor = indicator_tensor.view(B, 1, 1).expand(-1, -1, C) |
|
|
|
x = torch.cat([indicator_tensor, x], dim = 1) |
|
|
|
external_skips = [x] |
|
|
|
num_blocks = len(self.blocks) |
|
in_blocks = self.blocks[:(num_blocks // 2)] |
|
mid_block = self.blocks[(num_blocks // 2)] |
|
out_blocks = self.blocks[(num_blocks // 2 + 1):] |
|
|
|
skips = [] |
|
for block in in_blocks: |
|
x = block(x, condition_embeds, condition_masks, pos=pos) |
|
external_skips.append(x) |
|
skips.append(x) |
|
|
|
x = mid_block(x, condition_embeds, condition_masks, pos=pos) |
|
external_skips.append(x) |
|
|
|
for dense, block in check_zip(self.skip_denses, out_blocks): |
|
x = dense(torch.cat([x, skips.pop()], dim = -1)) |
|
x = block(x, condition_embeds, condition_masks, pos=pos) |
|
external_skips.append(x) |
|
|
|
if self.encoder_cls is TransformerBlock: |
|
|
|
if null_indicator is not None: |
|
x = x[:, 1:, :] |
|
external_skips = [skip[:, 1:, :] for skip in external_skips] |
|
|
|
x = x.reshape(B, H, W, C) |
|
external_skips = [skip.reshape(B, H, W, C) for skip in external_skips] |
|
|
|
output = self.output_layer(x) |
|
|
|
return output, external_skips |
|
|
|
|
|
class MRModel(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.channels = config.channels |
|
self.block_grad_to_lowres = config.block_grad_to_lowres |
|
|
|
for stage_config in config.stage_configs: |
|
if hasattr(config, "use_t2i"): |
|
stage_config.use_t2i = config.use_t2i |
|
if hasattr(config, "clip_dim"): |
|
stage_config.clip_dim = config.clip_dim |
|
if hasattr(config, "num_clip_token"): |
|
stage_config.num_clip_token = config.num_clip_token |
|
if hasattr(config, "gradient_checking"): |
|
stage_config.gradient_checking = config.gradient_checking |
|
if hasattr(config, "pe_type"): |
|
stage_config.pe_type = config.pe_type |
|
else: |
|
stage_config.pe_type = 'APE' |
|
if hasattr(config, "norm_type"): |
|
stage_config.norm_type = config.norm_type |
|
else: |
|
stage_config.norm_type = 'LN' |
|
|
|
|
|
|
|
if hasattr(config, "not_training_diff") and config.not_training_diff: |
|
self.has_diff = False |
|
else: |
|
self.has_diff = True |
|
|
|
lowres_dims = [None] + [stage_config.dim * (stage_config.num_blocks // 2 * 2 + 2) for stage_config in config.stage_configs[:-1]] |
|
lowres_heights = [None] + [stage_config.max_height for stage_config in config.stage_configs[:-1]] |
|
self.stages = nn.ModuleList([ |
|
Stage(self.channels, stage_config, lowres_dim = lowres_dim, lowres_height=lowres_height) |
|
for stage_config, lowres_dim, lowres_height in check_zip(config.stage_configs, lowres_dims, lowres_heights)] |
|
) |
|
|
|
|
|
|
|
if hasattr(config.textVAE, "num_down_sample_block"): |
|
down_sample_block = config.textVAE.num_down_sample_block |
|
else: |
|
down_sample_block = 3 |
|
|
|
self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token, |
|
head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim, |
|
latten_size=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width * 2, |
|
down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False) |
|
|
|
|
|
|
|
|
|
self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None) |
|
if config.stage_configs[-1].max_width==32: |
|
|
|
self.open_clip_output = Mlp(in_features=1024, |
|
hidden_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width, |
|
out_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width, |
|
norm_layer=nn.LayerNorm, |
|
) |
|
else: |
|
|
|
self.open_clip_output = Adaptor(input_dim=1024, |
|
tar_dim=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width |
|
) |
|
del self.open_clip.text |
|
del self.open_clip.logit_bias |
|
|
|
|
|
def _forward(self, images, log_snr, condition_context = None, condition_text_embeds = None, condition_text_masks = None, condition_drop_prob = None, null_indicator=None): |
|
if self.has_diff: |
|
TimeDependentParameter.seed_time(self, log_snr) |
|
|
|
assert condition_context is None |
|
assert condition_text_embeds is None |
|
|
|
if condition_text_embeds is not None: |
|
condition_embeds = self.text_conditioning(condition_text_embeds) |
|
condition_masks = condition_text_masks |
|
else: |
|
condition_embeds = None |
|
condition_masks = None |
|
|
|
outputs = [] |
|
lowres_skips = None |
|
for stage in self.stages: |
|
output, lowres_skips = stage(images, lowres_skips = lowres_skips, condition_context = condition_context, condition_embeds = condition_embeds, condition_masks = condition_masks, null_indicator=null_indicator) |
|
outputs.append(output) |
|
lowres_skips = torch.cat(lowres_skips, dim = -1) |
|
if self.block_grad_to_lowres: |
|
lowres_skips = lowres_skips.detach() |
|
|
|
return outputs |
|
|
|
else: |
|
return [images] |
|
|
|
|
|
def _reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return eps * std + mu |
|
|
|
def _text_encoder(self, condition_context, tar_shape, mask): |
|
|
|
output = self.context_encoder(condition_context, mask) |
|
mu, log_var = torch.chunk(output, 2, dim=-1) |
|
|
|
z = self._reparameterize(mu, log_var) |
|
|
|
return [z, mu, log_var] |
|
|
|
def _text_decoder(self, condition_enbedding, tar_shape): |
|
|
|
context_token = self.context_decoder(condition_enbedding) |
|
|
|
return context_token |
|
|
|
def _img_clip(self, image_input): |
|
|
|
image_latent = self.open_clip.encode_image(image_input) |
|
image_latent = self.open_clip_output(image_latent) |
|
|
|
return image_latent, self.open_clip.logit_scale |
|
|
|
|
|
|
|
def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None): |
|
if text_encoder: |
|
return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask) |
|
elif text_decoder: |
|
return self._text_decoder(condition_enbedding = x, tar_shape=shape) |
|
elif image_clip: |
|
return self._img_clip(image_input = x) |
|
else: |
|
assert log_snr.dtype == torch.float32 |
|
return self._forward(images = x, log_snr = log_snr, null_indicator=null_indicator) |
|
|