AlexGraikos commited on
Commit
62a25d2
·
verified ·
1 Parent(s): 86d49c5

Create pixcell_controlnet.py

Browse files
Files changed (1) hide show
  1. pixcell_controlnet.py +176 -0
pixcell_controlnet.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from diffusers.configuration_utils import ConfigMixin
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Any, Dict, Optional, Tuple
7
+ from pixcell_transformer_2d import PixCellTransformer2DModel
8
+
9
+ from diffusers.models.controlnet import zero_module
10
+ from diffusers.models.embeddings import PatchEmbed
11
+ from diffusers.utils import BaseOutput, is_torch_version
12
+
13
+ @dataclass
14
+ class PixCellControlNetOutput(BaseOutput):
15
+ controlnet_block_samples: Tuple[torch.Tensor]
16
+
17
+ class PixCellControlNet(ModelMixin, ConfigMixin):
18
+ def __init__(
19
+ self,
20
+ base_transformer: PixCellTransformer2DModel,
21
+ n_blocks: int = None,
22
+ ):
23
+ super().__init__()
24
+
25
+ self.n_blocks = n_blocks
26
+
27
+ # Base transformer
28
+ self.transformer = base_transformer
29
+
30
+ # Input patch embedding is frozen
31
+ # self.transformer.pos_embed.requires_grad = False
32
+
33
+ # Condition patch embedding
34
+ interpolation_scale = (
35
+ self.transformer.config.interpolation_scale
36
+ if self.transformer.config.interpolation_scale is not None
37
+ else max(self.transformer.config.sample_size // 64, 1)
38
+ )
39
+ self.cond_pos_embed = zero_module(PatchEmbed(
40
+ height=self.transformer.config.sample_size,
41
+ width=self.transformer.config.sample_size,
42
+ patch_size=self.transformer.config.patch_size,
43
+ in_channels=self.transformer.config.in_channels,
44
+ embed_dim=self.transformer.inner_dim,
45
+ interpolation_scale=interpolation_scale,
46
+ ))
47
+
48
+
49
+ # Do not use all transformer blocks for controlnet
50
+ if self.n_blocks is not None:
51
+ self.transformer.transformer_blocks = self.transformer.transformer_blocks[:self.n_blocks]
52
+
53
+ # ControlNet layers
54
+ self.controlnet_blocks = nn.ModuleList([])
55
+ for i in range(len(self.transformer.transformer_blocks)):
56
+ controlnet_block = nn.Linear(self.transformer.inner_dim, self.transformer.inner_dim)
57
+ controlnet_block = zero_module(controlnet_block)
58
+ self.controlnet_blocks.append(controlnet_block)
59
+
60
+ if self.n_blocks is not None:
61
+ if i+1 == self.n_blocks:
62
+ break
63
+
64
+ def forward(
65
+ self,
66
+ hidden_states: torch.Tensor,
67
+ conditioning: torch.Tensor,
68
+ encoder_hidden_states: Optional[torch.Tensor] = None,
69
+ timestep: Optional[torch.LongTensor] = None,
70
+ conditioning_scale: float = 1.0,
71
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
72
+ cross_attention_kwargs: Dict[str, Any] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ encoder_attention_mask: Optional[torch.Tensor] = None,
75
+ return_dict: bool = True,
76
+ ):
77
+ if self.transformer.use_additional_conditions and added_cond_kwargs is None:
78
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
79
+
80
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
81
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
82
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
83
+ # expects mask of shape:
84
+ # [batch, key_tokens]
85
+ # adds singleton query_tokens dimension:
86
+ # [batch, 1, key_tokens]
87
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
88
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
89
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
90
+ if attention_mask is not None and attention_mask.ndim == 2:
91
+ # assume that mask is expressed as:
92
+ # (1 = keep, 0 = discard)
93
+ # convert mask into a bias that can be added to attention scores:
94
+ # (keep = +0, discard = -10000.0)
95
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
96
+ attention_mask = attention_mask.unsqueeze(1)
97
+
98
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
99
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
100
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
101
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
102
+
103
+ # 1. Input
104
+ batch_size = hidden_states.shape[0]
105
+ height, width = (
106
+ hidden_states.shape[-2] // self.transformer.config.patch_size,
107
+ hidden_states.shape[-1] // self.transformer.config.patch_size,
108
+ )
109
+ hidden_states = self.transformer.pos_embed(hidden_states)
110
+
111
+ # Conditioning
112
+ hidden_states = hidden_states + self.cond_pos_embed(conditioning)
113
+
114
+ timestep, embedded_timestep = self.transformer.adaln_single(
115
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
116
+ )
117
+
118
+ if self.transformer.caption_projection is not None:
119
+ # Add positional embeddings to conditions if >1 UNI are given
120
+ if self.transformer.y_pos_embed is not None:
121
+ encoder_hidden_states = self.transformer.y_pos_embed(encoder_hidden_states)
122
+ encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
123
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
124
+
125
+ # 2. Blocks
126
+ block_outputs = ()
127
+
128
+ for block in self.transformer.transformer_blocks:
129
+ if torch.is_grad_enabled() and self.transformer.gradient_checkpointing:
130
+
131
+ def create_custom_forward(module, return_dict=None):
132
+ def custom_forward(*inputs):
133
+ if return_dict is not None:
134
+ return module(*inputs, return_dict=return_dict)
135
+ else:
136
+ return module(*inputs)
137
+
138
+ return custom_forward
139
+
140
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
141
+ hidden_states = torch.utils.checkpoint.checkpoint(
142
+ create_custom_forward(block),
143
+ hidden_states,
144
+ attention_mask,
145
+ encoder_hidden_states,
146
+ encoder_attention_mask,
147
+ timestep,
148
+ cross_attention_kwargs,
149
+ None,
150
+ **ckpt_kwargs,
151
+ )
152
+ else:
153
+ hidden_states = block(
154
+ hidden_states,
155
+ attention_mask=attention_mask,
156
+ encoder_hidden_states=encoder_hidden_states,
157
+ encoder_attention_mask=encoder_attention_mask,
158
+ timestep=timestep,
159
+ cross_attention_kwargs=cross_attention_kwargs,
160
+ class_labels=None,
161
+ )
162
+
163
+ block_outputs = block_outputs + (hidden_states,)
164
+
165
+ # 3. controlnet blocks
166
+ controlnet_outputs = ()
167
+ for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks):
168
+ b_output = controlnet_block(t_output)
169
+ controlnet_outputs = controlnet_outputs + (b_output,)
170
+
171
+ controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs]
172
+
173
+ if not return_dict:
174
+ return (controlnet_outputs,)
175
+
176
+ return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs)