|
|
|
|
|
import math |
|
from types import SimpleNamespace |
|
from typing import Any, Optional |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from einops import rearrange |
|
from .utils import setup_logging |
|
|
|
setup_logging() |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
from . import sdxl_original_unet |
|
from .sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl |
|
|
|
|
|
class ControlNetConditioningEmbedding(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
dims = [16, 32, 96, 256] |
|
|
|
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) |
|
self.blocks = nn.ModuleList([]) |
|
|
|
for i in range(len(dims) - 1): |
|
channel_in = dims[i] |
|
channel_out = dims[i + 1] |
|
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
|
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
|
|
|
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) |
|
nn.init.zeros_(self.conv_out.weight) |
|
nn.init.zeros_(self.conv_out.bias) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = F.silu(x) |
|
for block in self.blocks: |
|
x = block(x) |
|
x = F.silu(x) |
|
x = self.conv_out(x) |
|
return x |
|
|
|
|
|
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): |
|
def __init__(self, multiplier: Optional[float] = None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.multiplier = multiplier |
|
|
|
|
|
self.output_blocks = nn.ModuleList([]) |
|
del self.out |
|
|
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding() |
|
|
|
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] |
|
self.controlnet_down_blocks = nn.ModuleList([]) |
|
for dim in dims: |
|
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) |
|
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) |
|
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) |
|
|
|
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) |
|
nn.init.zeros_(self.controlnet_mid_block.weight) |
|
nn.init.zeros_(self.controlnet_mid_block.bias) |
|
|
|
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): |
|
unet_sd = unet.state_dict() |
|
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} |
|
sd = super().state_dict() |
|
sd.update(unet_sd) |
|
info = super().load_state_dict(sd, strict=True, assign=True) |
|
return info |
|
|
|
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: |
|
|
|
unet_sd = {} |
|
for k in list(state_dict.keys()): |
|
if not k.startswith("controlnet_"): |
|
unet_sd[k] = state_dict.pop(k) |
|
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) |
|
state_dict.update(unet_sd) |
|
super().load_state_dict(state_dict, strict=strict, assign=assign) |
|
|
|
def state_dict(self, destination=None, prefix="", keep_vars=False): |
|
|
|
state_dict = super().state_dict(destination, prefix, keep_vars) |
|
control_net_sd = {} |
|
for k in list(state_dict.keys()): |
|
if k.startswith("controlnet_"): |
|
control_net_sd[k] = state_dict.pop(k) |
|
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) |
|
state_dict.update(control_net_sd) |
|
return state_dict |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
timesteps: Optional[torch.Tensor] = None, |
|
context: Optional[torch.Tensor] = None, |
|
y: Optional[torch.Tensor] = None, |
|
cond_image: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
|
|
timesteps = timesteps.expand(x.shape[0]) |
|
|
|
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) |
|
t_emb = t_emb.to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" |
|
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" |
|
emb = emb + self.label_emb(y) |
|
|
|
def call_module(module, h, emb, context): |
|
x = h |
|
for layer in module: |
|
if isinstance(layer, sdxl_original_unet.ResnetBlock2D): |
|
x = layer(x, emb) |
|
elif isinstance(layer, sdxl_original_unet.Transformer2DModel): |
|
x = layer(x, context) |
|
else: |
|
x = layer(x) |
|
return x |
|
|
|
h = x |
|
multiplier = self.multiplier if self.multiplier is not None else 1.0 |
|
hs = [] |
|
for i, module in enumerate(self.input_blocks): |
|
h = call_module(module, h, emb, context) |
|
if i == 0: |
|
h = self.controlnet_cond_embedding(cond_image) + h |
|
hs.append(self.controlnet_down_blocks[i](h) * multiplier) |
|
|
|
h = call_module(self.middle_block, h, emb, context) |
|
h = self.controlnet_mid_block(h) * multiplier |
|
|
|
return hs, h |
|
|
|
|
|
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): |
|
""" |
|
This class is for training purpose only. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): |
|
|
|
timesteps = timesteps.expand(x.shape[0]) |
|
|
|
hs = [] |
|
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) |
|
t_emb = t_emb.to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" |
|
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" |
|
emb = emb + self.label_emb(y) |
|
|
|
def call_module(module, h, emb, context): |
|
x = h |
|
for layer in module: |
|
if isinstance(layer, sdxl_original_unet.ResnetBlock2D): |
|
x = layer(x, emb) |
|
elif isinstance(layer, sdxl_original_unet.Transformer2DModel): |
|
x = layer(x, context) |
|
else: |
|
x = layer(x) |
|
return x |
|
|
|
h = x |
|
for module in self.input_blocks: |
|
h = call_module(module, h, emb, context) |
|
hs.append(h) |
|
|
|
h = call_module(self.middle_block, h, emb, context) |
|
h = h + mid_add |
|
|
|
for module in self.output_blocks: |
|
resi = hs.pop() + input_resi_add.pop() |
|
h = torch.cat([h, resi], dim=1) |
|
h = call_module(module, h, emb, context) |
|
|
|
h = h.type(x.dtype) |
|
h = call_module(self.out, h, emb, context) |
|
|
|
return h |
|
|
|
|
|
if __name__ == "__main__": |
|
import time |
|
|
|
logger.info("create unet") |
|
unet = SdxlControlledUNet() |
|
unet.to("cuda", torch.bfloat16) |
|
unet.set_use_sdpa(True) |
|
unet.set_gradient_checkpointing(True) |
|
unet.train() |
|
|
|
logger.info("create control_net") |
|
control_net = SdxlControlNet() |
|
control_net.to("cuda") |
|
control_net.set_use_sdpa(True) |
|
control_net.set_gradient_checkpointing(True) |
|
control_net.train() |
|
|
|
logger.info("Initialize control_net from unet") |
|
control_net.init_from_unet(unet) |
|
|
|
unet.requires_grad_(False) |
|
control_net.requires_grad_(True) |
|
|
|
|
|
logger.info("preparing optimizer") |
|
|
|
|
|
|
|
import bitsandbytes |
|
|
|
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=True) |
|
|
|
logger.info("start training") |
|
steps = 10 |
|
batch_size = 1 |
|
|
|
for step in range(steps): |
|
logger.info(f"step {step}") |
|
if step == 1: |
|
time_start = time.perf_counter() |
|
|
|
x = torch.randn(batch_size, 4, 128, 128).cuda() |
|
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") |
|
txt = torch.randn(batch_size, 77, 2048).cuda() |
|
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() |
|
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() |
|
|
|
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): |
|
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) |
|
output = unet(x, t, txt, vector, input_resi_add, mid_add) |
|
target = torch.randn_like(output) |
|
loss = torch.nn.functional.mse_loss(output, target) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
time_end = time.perf_counter() |
|
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") |
|
|
|
logger.info("finish training") |
|
sd = control_net.state_dict() |
|
|
|
from safetensors.torch import save_file |
|
|
|
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") |
|
|