AlexGraikos commited on
Commit
8ed3158
·
verified ·
1 Parent(s): ea72b80

Delete pixcell_controlnet.py

Browse files
Files changed (1) hide show
  1. pixcell_controlnet.py +0 -176
pixcell_controlnet.py DELETED
@@ -1,176 +0,0 @@
1
- from dataclasses import dataclass
2
- from diffusers.configuration_utils import ConfigMixin
3
- from diffusers.models.modeling_utils import ModelMixin
4
- import torch
5
- import torch.nn as nn
6
- from typing import Any, Dict, Optional, Tuple
7
- from pixcell_transformer_2d import PixCellTransformer2DModel
8
-
9
- from diffusers.models.controlnet import zero_module
10
- from diffusers.models.embeddings import PatchEmbed
11
- from diffusers.utils import BaseOutput, is_torch_version
12
-
13
- @dataclass
14
- class PixCellControlNetOutput(BaseOutput):
15
- controlnet_block_samples: Tuple[torch.Tensor]
16
-
17
- class PixCellControlNet(ModelMixin, ConfigMixin):
18
- def __init__(
19
- self,
20
- base_transformer: PixCellTransformer2DModel,
21
- n_blocks: int = None,
22
- ):
23
- super().__init__()
24
-
25
- self.n_blocks = n_blocks
26
-
27
- # Base transformer
28
- self.transformer = base_transformer
29
-
30
- # Input patch embedding is frozen
31
- # self.transformer.pos_embed.requires_grad = False
32
-
33
- # Condition patch embedding
34
- interpolation_scale = (
35
- self.transformer.config.interpolation_scale
36
- if self.transformer.config.interpolation_scale is not None
37
- else max(self.transformer.config.sample_size // 64, 1)
38
- )
39
- self.cond_pos_embed = zero_module(PatchEmbed(
40
- height=self.transformer.config.sample_size,
41
- width=self.transformer.config.sample_size,
42
- patch_size=self.transformer.config.patch_size,
43
- in_channels=self.transformer.config.in_channels,
44
- embed_dim=self.transformer.inner_dim,
45
- interpolation_scale=interpolation_scale,
46
- ))
47
-
48
-
49
- # Do not use all transformer blocks for controlnet
50
- if self.n_blocks is not None:
51
- self.transformer.transformer_blocks = self.transformer.transformer_blocks[:self.n_blocks]
52
-
53
- # ControlNet layers
54
- self.controlnet_blocks = nn.ModuleList([])
55
- for i in range(len(self.transformer.transformer_blocks)):
56
- controlnet_block = nn.Linear(self.transformer.inner_dim, self.transformer.inner_dim)
57
- controlnet_block = zero_module(controlnet_block)
58
- self.controlnet_blocks.append(controlnet_block)
59
-
60
- if self.n_blocks is not None:
61
- if i+1 == self.n_blocks:
62
- break
63
-
64
- def forward(
65
- self,
66
- hidden_states: torch.Tensor,
67
- conditioning: torch.Tensor,
68
- encoder_hidden_states: Optional[torch.Tensor] = None,
69
- timestep: Optional[torch.LongTensor] = None,
70
- conditioning_scale: float = 1.0,
71
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
72
- cross_attention_kwargs: Dict[str, Any] = None,
73
- attention_mask: Optional[torch.Tensor] = None,
74
- encoder_attention_mask: Optional[torch.Tensor] = None,
75
- return_dict: bool = True,
76
- ):
77
- if self.transformer.use_additional_conditions and added_cond_kwargs is None:
78
- raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
79
-
80
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
81
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
82
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
83
- # expects mask of shape:
84
- # [batch, key_tokens]
85
- # adds singleton query_tokens dimension:
86
- # [batch, 1, key_tokens]
87
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
88
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
89
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
90
- if attention_mask is not None and attention_mask.ndim == 2:
91
- # assume that mask is expressed as:
92
- # (1 = keep, 0 = discard)
93
- # convert mask into a bias that can be added to attention scores:
94
- # (keep = +0, discard = -10000.0)
95
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
96
- attention_mask = attention_mask.unsqueeze(1)
97
-
98
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
99
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
100
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
101
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
102
-
103
- # 1. Input
104
- batch_size = hidden_states.shape[0]
105
- height, width = (
106
- hidden_states.shape[-2] // self.transformer.config.patch_size,
107
- hidden_states.shape[-1] // self.transformer.config.patch_size,
108
- )
109
- hidden_states = self.transformer.pos_embed(hidden_states)
110
-
111
- # Conditioning
112
- hidden_states = hidden_states + self.cond_pos_embed(conditioning)
113
-
114
- timestep, embedded_timestep = self.transformer.adaln_single(
115
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
116
- )
117
-
118
- if self.transformer.caption_projection is not None:
119
- # Add positional embeddings to conditions if >1 UNI are given
120
- if self.transformer.y_pos_embed is not None:
121
- encoder_hidden_states = self.transformer.y_pos_embed(encoder_hidden_states)
122
- encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
123
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
124
-
125
- # 2. Blocks
126
- block_outputs = ()
127
-
128
- for block in self.transformer.transformer_blocks:
129
- if torch.is_grad_enabled() and self.transformer.gradient_checkpointing:
130
-
131
- def create_custom_forward(module, return_dict=None):
132
- def custom_forward(*inputs):
133
- if return_dict is not None:
134
- return module(*inputs, return_dict=return_dict)
135
- else:
136
- return module(*inputs)
137
-
138
- return custom_forward
139
-
140
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
141
- hidden_states = torch.utils.checkpoint.checkpoint(
142
- create_custom_forward(block),
143
- hidden_states,
144
- attention_mask,
145
- encoder_hidden_states,
146
- encoder_attention_mask,
147
- timestep,
148
- cross_attention_kwargs,
149
- None,
150
- **ckpt_kwargs,
151
- )
152
- else:
153
- hidden_states = block(
154
- hidden_states,
155
- attention_mask=attention_mask,
156
- encoder_hidden_states=encoder_hidden_states,
157
- encoder_attention_mask=encoder_attention_mask,
158
- timestep=timestep,
159
- cross_attention_kwargs=cross_attention_kwargs,
160
- class_labels=None,
161
- )
162
-
163
- block_outputs = block_outputs + (hidden_states,)
164
-
165
- # 3. controlnet blocks
166
- controlnet_outputs = ()
167
- for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks):
168
- b_output = controlnet_block(t_output)
169
- controlnet_outputs = controlnet_outputs + (b_output,)
170
-
171
- controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs]
172
-
173
- if not return_dict:
174
- return (controlnet_outputs,)
175
-
176
- return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs)