from dataclasses import dataclass from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin import torch import torch.nn as nn from typing import Any, Dict, Optional, Tuple from pixcell_transformer_2d import PixCellTransformer2DModel from diffusers.models.controlnet import zero_module from diffusers.models.embeddings import PatchEmbed from diffusers.utils import BaseOutput, is_torch_version @dataclass class PixCellControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] class PixCellControlNet(ModelMixin, ConfigMixin): def __init__( self, base_transformer: PixCellTransformer2DModel, n_blocks: int = None, ): super().__init__() self.n_blocks = n_blocks # Base transformer self.transformer = base_transformer # Input patch embedding is frozen # self.transformer.pos_embed.requires_grad = False # Condition patch embedding interpolation_scale = ( self.transformer.config.interpolation_scale if self.transformer.config.interpolation_scale is not None else max(self.transformer.config.sample_size // 64, 1) ) self.cond_pos_embed = zero_module(PatchEmbed( height=self.transformer.config.sample_size, width=self.transformer.config.sample_size, patch_size=self.transformer.config.patch_size, in_channels=self.transformer.config.in_channels, embed_dim=self.transformer.inner_dim, interpolation_scale=interpolation_scale, )) # Do not use all transformer blocks for controlnet if self.n_blocks is not None: self.transformer.transformer_blocks = self.transformer.transformer_blocks[:self.n_blocks] # ControlNet layers self.controlnet_blocks = nn.ModuleList([]) for i in range(len(self.transformer.transformer_blocks)): controlnet_block = nn.Linear(self.transformer.inner_dim, self.transformer.inner_dim) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) if self.n_blocks is not None: if i+1 == self.n_blocks: break def forward( self, hidden_states: torch.Tensor, conditioning: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, conditioning_scale: float = 1.0, added_cond_kwargs: Dict[str, torch.Tensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): if self.transformer.use_additional_conditions and added_cond_kwargs is None: raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input batch_size = hidden_states.shape[0] height, width = ( hidden_states.shape[-2] // self.transformer.config.patch_size, hidden_states.shape[-1] // self.transformer.config.patch_size, ) hidden_states = self.transformer.pos_embed(hidden_states) # Conditioning hidden_states = hidden_states + self.cond_pos_embed(conditioning) timestep, embedded_timestep = self.transformer.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) if self.transformer.caption_projection is not None: # Add positional embeddings to conditions if >1 UNI are given if self.transformer.y_pos_embed is not None: encoder_hidden_states = self.transformer.y_pos_embed(encoder_hidden_states) encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # 2. Blocks block_outputs = () for block in self.transformer.transformer_blocks: if torch.is_grad_enabled() and self.transformer.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, None, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=None, ) block_outputs = block_outputs + (hidden_states,) # 3. controlnet blocks controlnet_outputs = () for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks): b_output = controlnet_block(t_output) controlnet_outputs = controlnet_outputs + (b_output,) controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs] if not return_dict: return (controlnet_outputs,) return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs)