File size: 7,779 Bytes
62a25d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from dataclasses import dataclass
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
import torch
import torch.nn as nn
from typing import Any, Dict, Optional, Tuple
from pixcell_transformer_2d import PixCellTransformer2DModel
from diffusers.models.controlnet import zero_module
from diffusers.models.embeddings import PatchEmbed
from diffusers.utils import BaseOutput, is_torch_version
@dataclass
class PixCellControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
class PixCellControlNet(ModelMixin, ConfigMixin):
def __init__(
self,
base_transformer: PixCellTransformer2DModel,
n_blocks: int = None,
):
super().__init__()
self.n_blocks = n_blocks
# Base transformer
self.transformer = base_transformer
# Input patch embedding is frozen
# self.transformer.pos_embed.requires_grad = False
# Condition patch embedding
interpolation_scale = (
self.transformer.config.interpolation_scale
if self.transformer.config.interpolation_scale is not None
else max(self.transformer.config.sample_size // 64, 1)
)
self.cond_pos_embed = zero_module(PatchEmbed(
height=self.transformer.config.sample_size,
width=self.transformer.config.sample_size,
patch_size=self.transformer.config.patch_size,
in_channels=self.transformer.config.in_channels,
embed_dim=self.transformer.inner_dim,
interpolation_scale=interpolation_scale,
))
# Do not use all transformer blocks for controlnet
if self.n_blocks is not None:
self.transformer.transformer_blocks = self.transformer.transformer_blocks[:self.n_blocks]
# ControlNet layers
self.controlnet_blocks = nn.ModuleList([])
for i in range(len(self.transformer.transformer_blocks)):
controlnet_block = nn.Linear(self.transformer.inner_dim, self.transformer.inner_dim)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
if self.n_blocks is not None:
if i+1 == self.n_blocks:
break
def forward(
self,
hidden_states: torch.Tensor,
conditioning: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
conditioning_scale: float = 1.0,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.transformer.config.patch_size,
hidden_states.shape[-1] // self.transformer.config.patch_size,
)
hidden_states = self.transformer.pos_embed(hidden_states)
# Conditioning
hidden_states = hidden_states + self.cond_pos_embed(conditioning)
timestep, embedded_timestep = self.transformer.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.transformer.caption_projection is not None:
# Add positional embeddings to conditions if >1 UNI are given
if self.transformer.y_pos_embed is not None:
encoder_hidden_states = self.transformer.y_pos_embed(encoder_hidden_states)
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 2. Blocks
block_outputs = ()
for block in self.transformer.transformer_blocks:
if torch.is_grad_enabled() and self.transformer.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
block_outputs = block_outputs + (hidden_states,)
# 3. controlnet blocks
controlnet_outputs = ()
for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks):
b_output = controlnet_block(t_output)
controlnet_outputs = controlnet_outputs + (b_output,)
controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs]
if not return_dict:
return (controlnet_outputs,)
return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs)
|