CrossFlow / libs /model /dimr_t2i.py
QHL067's picture
working
f9567e5
from re import A
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import math
import einops
import torch.utils.checkpoint
from functools import partial
import open_clip
import numpy as np
from PIL import Image
import torch.nn.functional as F
import timm
from timm.models.layers import trunc_normal_, Mlp
from .sigmoid.module import LayerNorm, RMSNorm, AdaRMSNorm, TDRMSNorm, QKNorm, TimeDependentParameter
from .common_layers import Linear, EvenDownInterpolate, ChannelFirst, ChannelLast, Embedding
from .axial_rope import AxialRoPE, make_axial_pos
from .trans_autoencoder import TransEncoder, Adaptor
def check_zip(*args):
args = [list(arg) for arg in args]
length = len(args[0])
for arg in args:
assert len(arg) == length
return zip(*args)
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim_in, dim_out, ratio = 2):
super().__init__()
self.ratio = ratio
self.kernel = Linear(dim_in, dim_out * self.ratio * self.ratio)
def forward(self, x):
x = self.kernel(x)
B, H, W, C = x.shape
x = x.reshape(B, H, W, self.ratio, self.ratio, C // self.ratio // self.ratio)
x = x.transpose(2, 3)
x = x.reshape(B, H * self.ratio, W * self.ratio, C // self.ratio // self.ratio)
return x
class PositionEmbeddings(nn.Module):
def __init__(self, max_height, max_width, dim):
super().__init__()
self.max_height = max_height
self.max_width = max_width
self.position_embeddings = Embedding(self.max_height * self.max_width, dim)
def forward(self, x):
B, H, W, C = x.shape
height_idxes = torch.arange(H, device = x.device)[:, None].repeat(1, W)
width_idxes = torch.arange(W, device = x.device)[None, :].repeat(H, 1)
idxes = height_idxes * self.max_width + width_idxes
x = x + self.position_embeddings(idxes[None])
return x
class TextPositionEmbeddings(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding = Embedding(num_embeddings, embedding_dim)
def forward(self, x):
batch_size, num_embeddings, embedding_dim = x.shape
# positions = torch.arange(height * width, device=x.device).reshape(1, height, width)
positions = torch.arange(num_embeddings, device=x.device).unsqueeze(0).expand(batch_size, num_embeddings)
x = x + self.embedding(positions)
return x
class MLPBlock(nn.Module):
def __init__(self, config):
super().__init__()
if config.norm_type == 'LN':
self.norm_type = 'LN'
self.norm = LayerNorm(config.dim)
elif config.norm_type == 'RMSN':
self.norm_type = 'RMSN'
self.norm = RMSNorm(config.dim)
elif config.norm_type == 'TDRMSN':
self.norm_type = 'TDRMSN'
self.norm = TDRMSNorm(config.dim)
elif config.norm_type == 'ADARMSN':
self.norm_type = 'ADARMSN'
self.norm = AdaRMSNorm(config.dim, config.dim)
self.act = nn.GELU()
self.w0 = Linear(config.dim, config.hidden_dim)
self.w1 = Linear(config.dim, config.hidden_dim)
self.w2 = Linear(config.hidden_dim, config.dim)
def forward(self, x):
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
x = self.norm(x)
elif self.norm_type == 'ADARMSN':
condition = x[:,0]
x = self.norm(x, condition)
x = self.act(self.w0(x)) * self.w1(x)
x = self.w2(x)
return x
class SelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.dim % config.num_attention_heads == 0
self.num_heads = config.num_attention_heads
self.head_dim = config.dim // config.num_attention_heads
if hasattr(config, "self_att_prompt") and config.self_att_prompt:
self.condition_key_value = Linear(config.clip_dim, 2 * config.dim, bias = False)
if config.norm_type == 'LN':
self.norm_type = 'LN'
self.norm = LayerNorm(config.dim)
elif config.norm_type == 'RMSN':
self.norm_type = 'RMSN'
self.norm = RMSNorm(config.dim)
elif config.norm_type == 'TDRMSN':
self.norm_type = 'TDRMSN'
self.norm = TDRMSNorm(config.dim)
elif config.norm_type == 'ADARMSN':
self.norm_type = 'ADARMSN'
self.norm = AdaRMSNorm(config.dim, config.dim)
self.pe_type = config.pe_type
if config.pe_type == 'Axial_RoPE':
self.pos_emb = AxialRoPE(self.head_dim, self.num_heads)
self.qk_norm = QKNorm(self.num_heads)
self.query_key_value = Linear(config.dim, 3 * config.dim, bias = False)
self.dense = Linear(config.dim, config.dim)
def forward(self, x, condition_embeds, condition_masks, pos=None):
B, N, C = x.shape
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
qkv = self.query_key_value(self.norm(x))
elif self.norm_type == 'ADARMSN':
condition = x[:,0]
qkv = self.query_key_value(self.norm(x, condition))
q, k, v = qkv.reshape(B, N, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(3, dim = 1)
if self.pe_type == 'Axial_RoPE':
q = self.pos_emb(self.qk_norm(q), pos)
k = self.pos_emb(self.qk_norm(k), pos)
if condition_embeds is not None:
_, L, D = condition_embeds.shape
kcvc = self.condition_key_value(condition_embeds)
kc, vc = kcvc.reshape(B, L, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(2, dim = 1)
k = torch.cat([k, kc], dim = 2)
v = torch.cat([v, vc], dim = 2)
mask = torch.cat([torch.ones(B, N, dtype = torch.bool, device = condition_masks.device), condition_masks], dim = -1)
mask = mask[:, None, None, :]
else:
mask = None
x = F.scaled_dot_product_attention(q, k, v, attn_mask = mask)
x = self.dense(x.permute(0, 2, 1, 3).reshape(B, N, C))
return x
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.block1 = SelfAttention(config)
self.block2 = MLPBlock(config)
self.dropout = nn.Dropout(config.dropout_prob)
self.gradient_checking = config.gradient_checking
def forward(self, x, condition_embeds, condition_masks, pos):
if self.gradient_checking:
return torch.utils.checkpoint.checkpoint(self._forward, x, condition_embeds, condition_masks, pos)
else:
return self._forward(x, condition_embeds, condition_masks, pos)
def _forward(self, x, condition_embeds, condition_masks, pos):
x = x + self.dropout(self.block1(x, condition_embeds, condition_masks, pos))
x = x + self.dropout(self.block2(x))
return x
class ConvNeXtBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.block1 = nn.Sequential(
ChannelFirst(),
nn.Conv2d(config.dim, config.dim, kernel_size = config.kernel_size, padding = config.kernel_size // 2, stride = 1, groups = config.dim),
ChannelLast()
)
self.block2 = MLPBlock(config)
self.dropout = nn.Dropout(config.dropout_prob)
self.gradient_checking = config.gradient_checking
def forward(self, x, condition_embeds, condition_masks, pos):
if self.gradient_checking:
return torch.utils.checkpoint.checkpoint(self._forward, x)
else:
return self._forward(x)
def _forward(self, x):
x = x + self.dropout(self.block1(x))
x = x + self.dropout(self.block2(x))
return x
class Stage(nn.Module):
def __init__(self, channels, config, lowres_dim = None, lowres_height = None):
super().__init__()
if config.block_type == "TransformerBlock":
self.encoder_cls = TransformerBlock
elif config.block_type == "ConvNeXtBlock":
self.encoder_cls = ConvNeXtBlock
else:
raise Exception()
self.pe_type = config.pe_type
self.input_layer = nn.Sequential(
EvenDownInterpolate(config.image_input_ratio),
nn.Conv2d(channels, config.dim, kernel_size = config.input_feature_ratio, stride = config.input_feature_ratio),
ChannelLast(),
PositionEmbeddings(config.max_height, config.max_width, config.dim)
)
if lowres_dim is not None:
ratio = config.max_height // lowres_height
self.upsample = nn.Sequential(
LayerNorm(lowres_dim),
PixelShuffleUpsample(lowres_dim, config.dim, ratio = ratio),
LayerNorm(config.dim),
)
self.blocks = nn.ModuleList([self.encoder_cls(config) for _ in range(config.num_blocks // 2 * 2 + 1)])
self.skip_denses = nn.ModuleList([Linear(config.dim * 2, config.dim) for _ in range(config.num_blocks // 2)])
self.output_layer = nn.Sequential(
LayerNorm(config.dim),
ChannelFirst(),
nn.Conv2d(config.dim, channels, kernel_size = config.final_kernel_size, padding = config.final_kernel_size // 2),
)
self.tensor_true = torch.nn.Parameter(torch.tensor([-1.0])) if self.encoder_cls is TransformerBlock else None
self.tensor_false = torch.nn.Parameter(torch.tensor([1.0])) if self.encoder_cls is TransformerBlock else None
def forward(self, images, lowres_skips = None, condition_context = None, condition_embeds = None, condition_masks = None, null_indicator=None):
if self.pe_type == 'Axial_RoPE' and self.encoder_cls is TransformerBlock:
x = self.input_layer(images)
_, H, W, _ = x.shape
pos = make_axial_pos(H, W)
else:
x = self.input_layer(images)
pos = None
if lowres_skips is not None:
x = x + self.upsample(lowres_skips)
if self.encoder_cls is TransformerBlock:
B, H, W, C = x.shape
x = x.reshape(B, H * W, C)
if null_indicator is not None:
indicator_tensor = torch.where(null_indicator, self.tensor_true, self.tensor_false)
indicator_tensor = indicator_tensor.view(B, 1, 1).expand(-1, -1, C)
x = torch.cat([indicator_tensor, x], dim = 1)
external_skips = [x]
num_blocks = len(self.blocks)
in_blocks = self.blocks[:(num_blocks // 2)]
mid_block = self.blocks[(num_blocks // 2)]
out_blocks = self.blocks[(num_blocks // 2 + 1):]
skips = []
for block in in_blocks:
x = block(x, condition_embeds, condition_masks, pos=pos)
external_skips.append(x)
skips.append(x)
x = mid_block(x, condition_embeds, condition_masks, pos=pos)
external_skips.append(x)
for dense, block in check_zip(self.skip_denses, out_blocks):
x = dense(torch.cat([x, skips.pop()], dim = -1))
x = block(x, condition_embeds, condition_masks, pos=pos)
external_skips.append(x)
if self.encoder_cls is TransformerBlock:
if null_indicator is not None:
x = x[:, 1:, :]
external_skips = [skip[:, 1:, :] for skip in external_skips]
x = x.reshape(B, H, W, C)
external_skips = [skip.reshape(B, H, W, C) for skip in external_skips]
output = self.output_layer(x)
return output, external_skips
class MRModel(nn.Module):
def __init__(self, config):
super().__init__()
self.channels = config.channels
self.block_grad_to_lowres = config.block_grad_to_lowres
for stage_config in config.stage_configs:
if hasattr(config, "use_t2i"):
stage_config.use_t2i = config.use_t2i
if hasattr(config, "clip_dim"):
stage_config.clip_dim = config.clip_dim
if hasattr(config, "num_clip_token"):
stage_config.num_clip_token = config.num_clip_token
if hasattr(config, "gradient_checking"):
stage_config.gradient_checking = config.gradient_checking
if hasattr(config, "pe_type"):
stage_config.pe_type = config.pe_type
else:
stage_config.pe_type = 'APE'
if hasattr(config, "norm_type"):
stage_config.norm_type = config.norm_type
else:
stage_config.norm_type = 'LN'
#### diffusion model
if hasattr(config, "not_training_diff") and config.not_training_diff:
self.has_diff = False
else:
self.has_diff = True
lowres_dims = [None] + [stage_config.dim * (stage_config.num_blocks // 2 * 2 + 2) for stage_config in config.stage_configs[:-1]]
lowres_heights = [None] + [stage_config.max_height for stage_config in config.stage_configs[:-1]]
self.stages = nn.ModuleList([
Stage(self.channels, stage_config, lowres_dim = lowres_dim, lowres_height=lowres_height)
for stage_config, lowres_dim, lowres_height in check_zip(config.stage_configs, lowres_dims, lowres_heights)]
)
#### Text VE
if hasattr(config.textVAE, "num_down_sample_block"):
down_sample_block = config.textVAE.num_down_sample_block
else:
down_sample_block = 3
self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token,
head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim,
latten_size=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width * 2,
down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False)
#### image encoder to train VE
self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None)
if config.stage_configs[-1].max_width==32:
# for 256px generation
self.open_clip_output = Mlp(in_features=1024,
hidden_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
out_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
norm_layer=nn.LayerNorm,
)
else:
# for 512px generation
self.open_clip_output = Adaptor(input_dim=1024,
tar_dim=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width
)
del self.open_clip.text
del self.open_clip.logit_bias
def _forward(self, images, log_snr, condition_context = None, condition_text_embeds = None, condition_text_masks = None, condition_drop_prob = None, null_indicator=None):
if self.has_diff:
TimeDependentParameter.seed_time(self, log_snr)
assert condition_context is None
assert condition_text_embeds is None
if condition_text_embeds is not None:
condition_embeds = self.text_conditioning(condition_text_embeds)
condition_masks = condition_text_masks
else:
condition_embeds = None
condition_masks = None
outputs = []
lowres_skips = None
for stage in self.stages:
output, lowres_skips = stage(images, lowres_skips = lowres_skips, condition_context = condition_context, condition_embeds = condition_embeds, condition_masks = condition_masks, null_indicator=null_indicator)
outputs.append(output)
lowres_skips = torch.cat(lowres_skips, dim = -1)
if self.block_grad_to_lowres:
lowres_skips = lowres_skips.detach()
return outputs
else:
return [images]
def _reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def _text_encoder(self, condition_context, tar_shape, mask):
output = self.context_encoder(condition_context, mask)
mu, log_var = torch.chunk(output, 2, dim=-1)
z = self._reparameterize(mu, log_var)
return [z, mu, log_var]
def _text_decoder(self, condition_enbedding, tar_shape):
context_token = self.context_decoder(condition_enbedding)
return context_token
def _img_clip(self, image_input):
image_latent = self.open_clip.encode_image(image_input)
image_latent = self.open_clip_output(image_latent)
return image_latent, self.open_clip.logit_scale
def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None):
if text_encoder:
return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask)
elif text_decoder:
return self._text_decoder(condition_enbedding = x, tar_shape=shape) # mask is not needed for decoder
elif image_clip:
return self._img_clip(image_input = x)
else:
assert log_snr.dtype == torch.float32
return self._forward(images = x, log_snr = log_snr, null_indicator=null_indicator)