File size: 2,636 Bytes
			
			| 50e6701 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import torch
from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
    def __init__(
        self,
        num_blocks = None,
        control_latent_channels = None,
        dtype = None,
        device = None,
        operations = None,
        **kwargs,
    ):
        super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
        # controlnet_blocks
        self.controlnet_blocks = torch.nn.ModuleList([])
        for _ in range(len(self.joint_blocks)):
            self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
        if control_latent_channels is None:
            control_latent_channels = self.in_channels
        self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
            None,
            self.patch_size,
            control_latent_channels,
            self.hidden_size,
            bias=True,
            strict_img_size=False,
            dtype=dtype,
            device=device,
            operations=operations
        )
    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
        hint = None,
    ) -> torch.Tensor:
        #weird sd3 controlnet specific stuff
        y = torch.zeros_like(y)
        if self.context_processor is not None:
            context = self.context_processor(context)
        hw = x.shape[-2:]
        x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
        x += self.pos_embed_input(hint)
        c = self.t_embedder(timesteps, dtype=x.dtype)
        if y is not None and self.y_embedder is not None:
            y = self.y_embedder(y)
            c = c + y
        if context is not None:
            context = self.context_embedder(context)
        output = []
        blocks = len(self.joint_blocks)
        for i in range(blocks):
            context, x = self.joint_blocks[i](
                context,
                x,
                c=c,
                use_checkpoint=self.use_checkpoint,
            )
            out = self.controlnet_blocks[i](x)
            count = self.depth // blocks
            if i == blocks - 1:
                count -= 1
            for j in range(count):
                output.append(out)
        return {"output": output}
 | 
 
			
