MMD-Coder commited on
Commit
66a6dae
·
verified ·
1 Parent(s): 51df25e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +37 -35
  3. .gitignore +3 -0
  4. MagicQuill/.DS_Store +0 -0
  5. MagicQuill/__pycache__/brushnet_nodes.cpython-311.pyc +0 -0
  6. MagicQuill/__pycache__/comfyui_utils.cpython-311.pyc +0 -0
  7. MagicQuill/__pycache__/folder_paths.cpython-311.pyc +0 -0
  8. MagicQuill/__pycache__/latent_preview.cpython-311.pyc +0 -0
  9. MagicQuill/__pycache__/magic_utils.cpython-311.pyc +0 -0
  10. MagicQuill/__pycache__/model_patch.cpython-311.pyc +0 -0
  11. MagicQuill/__pycache__/pidi.cpython-311.pyc +0 -0
  12. MagicQuill/__pycache__/scribble_color_edit.cpython-311.pyc +0 -0
  13. MagicQuill/brushnet/__pycache__/brushnet.cpython-311.pyc +0 -0
  14. MagicQuill/brushnet/__pycache__/brushnet_ca.cpython-311.pyc +0 -0
  15. MagicQuill/brushnet/__pycache__/powerpaint_utils.cpython-311.pyc +0 -0
  16. MagicQuill/brushnet/__pycache__/unet_2d_blocks.cpython-311.pyc +3 -0
  17. MagicQuill/brushnet/__pycache__/unet_2d_condition.cpython-311.pyc +0 -0
  18. MagicQuill/brushnet/brushnet.json +58 -0
  19. MagicQuill/brushnet/brushnet.py +949 -0
  20. MagicQuill/brushnet/brushnet_ca.py +983 -0
  21. MagicQuill/brushnet/brushnet_xl.json +63 -0
  22. MagicQuill/brushnet/powerpaint.json +57 -0
  23. MagicQuill/brushnet/powerpaint_utils.py +496 -0
  24. MagicQuill/brushnet/unet_2d_blocks.py +0 -0
  25. MagicQuill/brushnet/unet_2d_condition.py +1355 -0
  26. MagicQuill/brushnet_nodes.py +1094 -0
  27. MagicQuill/comfy/.DS_Store +0 -0
  28. MagicQuill/comfy/__pycache__/checkpoint_pickle.cpython-311.pyc +0 -0
  29. MagicQuill/comfy/__pycache__/cli_args.cpython-311.pyc +0 -0
  30. MagicQuill/comfy/__pycache__/clip_model.cpython-311.pyc +0 -0
  31. MagicQuill/comfy/__pycache__/clip_vision.cpython-311.pyc +0 -0
  32. MagicQuill/comfy/__pycache__/conds.cpython-311.pyc +0 -0
  33. MagicQuill/comfy/__pycache__/controlnet.cpython-311.pyc +0 -0
  34. MagicQuill/comfy/__pycache__/diffusers_convert.cpython-311.pyc +0 -0
  35. MagicQuill/comfy/__pycache__/diffusers_load.cpython-311.pyc +0 -0
  36. MagicQuill/comfy/__pycache__/gligen.cpython-311.pyc +0 -0
  37. MagicQuill/comfy/__pycache__/latent_formats.cpython-311.pyc +0 -0
  38. MagicQuill/comfy/__pycache__/lora.cpython-311.pyc +0 -0
  39. MagicQuill/comfy/__pycache__/model_base.cpython-311.pyc +0 -0
  40. MagicQuill/comfy/__pycache__/model_detection.cpython-311.pyc +0 -0
  41. MagicQuill/comfy/__pycache__/model_management.cpython-311.pyc +0 -0
  42. MagicQuill/comfy/__pycache__/model_patcher.cpython-311.pyc +0 -0
  43. MagicQuill/comfy/__pycache__/model_sampling.cpython-311.pyc +0 -0
  44. MagicQuill/comfy/__pycache__/ops.cpython-311.pyc +0 -0
  45. MagicQuill/comfy/__pycache__/options.cpython-311.pyc +0 -0
  46. MagicQuill/comfy/__pycache__/sa_t5.cpython-311.pyc +0 -0
  47. MagicQuill/comfy/__pycache__/sample.cpython-311.pyc +0 -0
  48. MagicQuill/comfy/__pycache__/sampler_helpers.cpython-311.pyc +0 -0
  49. MagicQuill/comfy/__pycache__/samplers.cpython-311.pyc +0 -0
  50. MagicQuill/comfy/__pycache__/sd.cpython-311.pyc +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs text
2
+ *.arrow filter=lfs diff=lfs merge=lfs text
3
+ *.bin filter=lfs diff=lfs merge=lfs text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs text
6
+ *.ftz filter=lfs diff=lfs merge=lfs text
7
+ *.gz filter=lfs diff=lfs merge=lfs text
8
+ *.h5 filter=lfs diff=lfs merge=lfs text
9
+ *.joblib filter=lfs diff=lfs merge=lfs text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs text
12
+ *.model filter=lfs diff=lfs merge=lfs text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs text
14
+ *.npy filter=lfs diff=lfs merge=lfs text
15
+ *.npz filter=lfs diff=lfs merge=lfs text
16
+ *.onnx filter=lfs diff=lfs merge=lfs text
17
+ *.ot filter=lfs diff=lfs merge=lfs text
18
+ *.parquet filter=lfs diff=lfs merge=lfs text
19
+ *.pb filter=lfs diff=lfs merge=lfs text
20
+ *.pickle filter=lfs diff=lfs merge=lfs text
21
+ *.pkl filter=lfs diff=lfs merge=lfs text
22
+ *.pt filter=lfs diff=lfs merge=lfs text
23
+ *.pth filter=lfs diff=lfs merge=lfs text
24
+ *.rar filter=lfs diff=lfs merge=lfs text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs text
28
+ *.tar filter=lfs diff=lfs merge=lfs text
29
+ *.tflite filter=lfs diff=lfs merge=lfs text
30
+ *.tgz filter=lfs diff=lfs merge=lfs text
31
+ *.wasm filter=lfs diff=lfs merge=lfs text
32
+ *.xz filter=lfs diff=lfs merge=lfs text
33
+ *.zip filter=lfs diff=lfs merge=lfs text
34
+ *.zst filter=lfs diff=lfs merge=lfs text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs text
36
+ MagicQuill/brushnet/__pycache__/unet_2d_blocks.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ gradio_magicquill-0.0.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ upload.py
2
+ download.py
3
+ models/
MagicQuill/.DS_Store ADDED
Binary file (6.15 kB). View file
 
MagicQuill/__pycache__/brushnet_nodes.cpython-311.pyc ADDED
Binary file (55.2 kB). View file
 
MagicQuill/__pycache__/comfyui_utils.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
MagicQuill/__pycache__/folder_paths.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
MagicQuill/__pycache__/latent_preview.cpython-311.pyc ADDED
Binary file (6.83 kB). View file
 
MagicQuill/__pycache__/magic_utils.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
MagicQuill/__pycache__/model_patch.cpython-311.pyc ADDED
Binary file (6.9 kB). View file
 
MagicQuill/__pycache__/pidi.cpython-311.pyc ADDED
Binary file (31.8 kB). View file
 
MagicQuill/__pycache__/scribble_color_edit.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
MagicQuill/brushnet/__pycache__/brushnet.cpython-311.pyc ADDED
Binary file (47.7 kB). View file
 
MagicQuill/brushnet/__pycache__/brushnet_ca.cpython-311.pyc ADDED
Binary file (48.3 kB). View file
 
MagicQuill/brushnet/__pycache__/powerpaint_utils.cpython-311.pyc ADDED
Binary file (26.5 kB). View file
 
MagicQuill/brushnet/__pycache__/unet_2d_blocks.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b79c2a8ec9c715f51689a6e0aa557ed2de8fd797159ad7615a58c465255cca5
3
+ size 122668
MagicQuill/brushnet/__pycache__/unet_2d_condition.cpython-311.pyc ADDED
Binary file (65.9 kB). View file
 
MagicQuill/brushnet/brushnet.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "brushnet_conditioning_channel_order": "rgb",
17
+ "class_embed_type": null,
18
+ "conditioning_channels": 5,
19
+ "conditioning_embedding_out_channels": [
20
+ 16,
21
+ 32,
22
+ 96,
23
+ 256
24
+ ],
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "DownBlock2D",
28
+ "DownBlock2D",
29
+ "DownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "MidBlock2D",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "up_block_types": [
51
+ "UpBlock2D",
52
+ "UpBlock2D",
53
+ "UpBlock2D",
54
+ "UpBlock2D"
55
+ ],
56
+ "upcast_attention": false,
57
+ "use_linear_projection": false
58
+ }
MagicQuill/brushnet/brushnet.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.models.attention_processor import (
11
+ ADDED_KV_ATTENTION_PROCESSORS,
12
+ CROSS_ATTENTION_PROCESSORS,
13
+ AttentionProcessor,
14
+ AttnAddedKVProcessor,
15
+ AttnProcessor,
16
+ )
17
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+
20
+ from .unet_2d_blocks import (
21
+ CrossAttnDownBlock2D,
22
+ DownBlock2D,
23
+ UNetMidBlock2D,
24
+ UNetMidBlock2DCrossAttn,
25
+ get_down_block,
26
+ get_mid_block,
27
+ get_up_block,
28
+ MidBlock2D
29
+ )
30
+
31
+ from .unet_2d_condition import UNet2DConditionModel
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @dataclass
38
+ class BrushNetOutput(BaseOutput):
39
+ """
40
+ The output of [`BrushNetModel`].
41
+
42
+ Args:
43
+ up_block_res_samples (`tuple[torch.Tensor]`):
44
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
45
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
46
+ used to condition the original UNet's upsampling activations.
47
+ down_block_res_samples (`tuple[torch.Tensor]`):
48
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
49
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
50
+ used to condition the original UNet's downsampling activations.
51
+ mid_down_block_re_sample (`torch.Tensor`):
52
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
53
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
54
+ Output can be used to condition the original UNet's middle block activation.
55
+ """
56
+
57
+ up_block_res_samples: Tuple[torch.Tensor]
58
+ down_block_res_samples: Tuple[torch.Tensor]
59
+ mid_block_res_sample: torch.Tensor
60
+
61
+
62
+ class BrushNetModel(ModelMixin, ConfigMixin):
63
+ """
64
+ A BrushNet model.
65
+
66
+ Args:
67
+ in_channels (`int`, defaults to 4):
68
+ The number of channels in the input sample.
69
+ flip_sin_to_cos (`bool`, defaults to `True`):
70
+ Whether to flip the sin to cos in the time embedding.
71
+ freq_shift (`int`, defaults to 0):
72
+ The frequency shift to apply to the time embedding.
73
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
74
+ The tuple of downsample blocks to use.
75
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
76
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
77
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
78
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
79
+ The tuple of upsample blocks to use.
80
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
81
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
82
+ The tuple of output channels for each block.
83
+ layers_per_block (`int`, defaults to 2):
84
+ The number of layers per block.
85
+ downsample_padding (`int`, defaults to 1):
86
+ The padding to use for the downsampling convolution.
87
+ mid_block_scale_factor (`float`, defaults to 1):
88
+ The scale factor to use for the mid block.
89
+ act_fn (`str`, defaults to "silu"):
90
+ The activation function to use.
91
+ norm_num_groups (`int`, *optional*, defaults to 32):
92
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
93
+ in post-processing.
94
+ norm_eps (`float`, defaults to 1e-5):
95
+ The epsilon to use for the normalization.
96
+ cross_attention_dim (`int`, defaults to 1280):
97
+ The dimension of the cross attention features.
98
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
99
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
100
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
101
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
102
+ encoder_hid_dim (`int`, *optional*, defaults to None):
103
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
104
+ dimension to `cross_attention_dim`.
105
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
106
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
107
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
108
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
109
+ The dimension of the attention heads.
110
+ use_linear_projection (`bool`, defaults to `False`):
111
+ class_embed_type (`str`, *optional*, defaults to `None`):
112
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
113
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
114
+ addition_embed_type (`str`, *optional*, defaults to `None`):
115
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
116
+ "text". "text" will use the `TextTimeEmbedding` layer.
117
+ num_class_embeds (`int`, *optional*, defaults to 0):
118
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
119
+ class conditioning with `class_embed_type` equal to `None`.
120
+ upcast_attention (`bool`, defaults to `False`):
121
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
122
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
123
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
124
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
125
+ `class_embed_type="projection"`.
126
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
127
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
128
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
129
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
130
+ global_pool_conditions (`bool`, defaults to `False`):
131
+ TODO(Patrick) - unused parameter.
132
+ addition_embed_type_num_heads (`int`, defaults to 64):
133
+ The number of heads to use for the `TextTimeEmbedding` layer.
134
+ """
135
+
136
+ _supports_gradient_checkpointing = True
137
+
138
+ @register_to_config
139
+ def __init__(
140
+ self,
141
+ in_channels: int = 4,
142
+ conditioning_channels: int = 5,
143
+ flip_sin_to_cos: bool = True,
144
+ freq_shift: int = 0,
145
+ down_block_types: Tuple[str, ...] = (
146
+ "DownBlock2D",
147
+ "DownBlock2D",
148
+ "DownBlock2D",
149
+ "DownBlock2D",
150
+ ),
151
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
152
+ up_block_types: Tuple[str, ...] = (
153
+ "UpBlock2D",
154
+ "UpBlock2D",
155
+ "UpBlock2D",
156
+ "UpBlock2D",
157
+ ),
158
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
159
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
160
+ layers_per_block: int = 2,
161
+ downsample_padding: int = 1,
162
+ mid_block_scale_factor: float = 1,
163
+ act_fn: str = "silu",
164
+ norm_num_groups: Optional[int] = 32,
165
+ norm_eps: float = 1e-5,
166
+ cross_attention_dim: int = 1280,
167
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
168
+ encoder_hid_dim: Optional[int] = None,
169
+ encoder_hid_dim_type: Optional[str] = None,
170
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
171
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
172
+ use_linear_projection: bool = False,
173
+ class_embed_type: Optional[str] = None,
174
+ addition_embed_type: Optional[str] = None,
175
+ addition_time_embed_dim: Optional[int] = None,
176
+ num_class_embeds: Optional[int] = None,
177
+ upcast_attention: bool = False,
178
+ resnet_time_scale_shift: str = "default",
179
+ projection_class_embeddings_input_dim: Optional[int] = None,
180
+ brushnet_conditioning_channel_order: str = "rgb",
181
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
182
+ global_pool_conditions: bool = False,
183
+ addition_embed_type_num_heads: int = 64,
184
+ ):
185
+ super().__init__()
186
+
187
+ # If `num_attention_heads` is not defined (which is the case for most models)
188
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
189
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
190
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
191
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
192
+ # which is why we correct for the naming here.
193
+ num_attention_heads = num_attention_heads or attention_head_dim
194
+
195
+ # Check inputs
196
+ if len(down_block_types) != len(up_block_types):
197
+ raise ValueError(
198
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
199
+ )
200
+
201
+ if len(block_out_channels) != len(down_block_types):
202
+ raise ValueError(
203
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
204
+ )
205
+
206
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
207
+ raise ValueError(
208
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
209
+ )
210
+
211
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
212
+ raise ValueError(
213
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
214
+ )
215
+
216
+ if isinstance(transformer_layers_per_block, int):
217
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
218
+
219
+ # input
220
+ conv_in_kernel = 3
221
+ conv_in_padding = (conv_in_kernel - 1) // 2
222
+ self.conv_in_condition = nn.Conv2d(
223
+ in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
224
+ )
225
+
226
+ # time
227
+ time_embed_dim = block_out_channels[0] * 4
228
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
+ timestep_input_dim = block_out_channels[0]
230
+ self.time_embedding = TimestepEmbedding(
231
+ timestep_input_dim,
232
+ time_embed_dim,
233
+ act_fn=act_fn,
234
+ )
235
+
236
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
237
+ encoder_hid_dim_type = "text_proj"
238
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
239
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
240
+
241
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
242
+ raise ValueError(
243
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
244
+ )
245
+
246
+ if encoder_hid_dim_type == "text_proj":
247
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
248
+ elif encoder_hid_dim_type == "text_image_proj":
249
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
250
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
251
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
252
+ self.encoder_hid_proj = TextImageProjection(
253
+ text_embed_dim=encoder_hid_dim,
254
+ image_embed_dim=cross_attention_dim,
255
+ cross_attention_dim=cross_attention_dim,
256
+ )
257
+
258
+ elif encoder_hid_dim_type is not None:
259
+ raise ValueError(
260
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
261
+ )
262
+ else:
263
+ self.encoder_hid_proj = None
264
+
265
+ # class embedding
266
+ if class_embed_type is None and num_class_embeds is not None:
267
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
268
+ elif class_embed_type == "timestep":
269
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
270
+ elif class_embed_type == "identity":
271
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
272
+ elif class_embed_type == "projection":
273
+ if projection_class_embeddings_input_dim is None:
274
+ raise ValueError(
275
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
276
+ )
277
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
278
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
279
+ # 2. it projects from an arbitrary input dimension.
280
+ #
281
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
282
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
283
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
284
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
285
+ else:
286
+ self.class_embedding = None
287
+
288
+ if addition_embed_type == "text":
289
+ if encoder_hid_dim is not None:
290
+ text_time_embedding_from_dim = encoder_hid_dim
291
+ else:
292
+ text_time_embedding_from_dim = cross_attention_dim
293
+
294
+ self.add_embedding = TextTimeEmbedding(
295
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
296
+ )
297
+ elif addition_embed_type == "text_image":
298
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
299
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
300
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
301
+ self.add_embedding = TextImageTimeEmbedding(
302
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
303
+ )
304
+ elif addition_embed_type == "text_time":
305
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
306
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
307
+
308
+ elif addition_embed_type is not None:
309
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
310
+
311
+ self.down_blocks = nn.ModuleList([])
312
+ self.brushnet_down_blocks = nn.ModuleList([])
313
+
314
+ if isinstance(only_cross_attention, bool):
315
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
316
+
317
+ if isinstance(attention_head_dim, int):
318
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
319
+
320
+ if isinstance(num_attention_heads, int):
321
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
322
+
323
+ # down
324
+ output_channel = block_out_channels[0]
325
+
326
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
327
+ brushnet_block = zero_module(brushnet_block)
328
+ self.brushnet_down_blocks.append(brushnet_block)
329
+
330
+ for i, down_block_type in enumerate(down_block_types):
331
+ input_channel = output_channel
332
+ output_channel = block_out_channels[i]
333
+ is_final_block = i == len(block_out_channels) - 1
334
+
335
+ down_block = get_down_block(
336
+ down_block_type,
337
+ num_layers=layers_per_block,
338
+ transformer_layers_per_block=transformer_layers_per_block[i],
339
+ in_channels=input_channel,
340
+ out_channels=output_channel,
341
+ temb_channels=time_embed_dim,
342
+ add_downsample=not is_final_block,
343
+ resnet_eps=norm_eps,
344
+ resnet_act_fn=act_fn,
345
+ resnet_groups=norm_num_groups,
346
+ cross_attention_dim=cross_attention_dim,
347
+ num_attention_heads=num_attention_heads[i],
348
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
349
+ downsample_padding=downsample_padding,
350
+ use_linear_projection=use_linear_projection,
351
+ only_cross_attention=only_cross_attention[i],
352
+ upcast_attention=upcast_attention,
353
+ resnet_time_scale_shift=resnet_time_scale_shift,
354
+ )
355
+ self.down_blocks.append(down_block)
356
+
357
+ for _ in range(layers_per_block):
358
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
359
+ brushnet_block = zero_module(brushnet_block)
360
+ self.brushnet_down_blocks.append(brushnet_block)
361
+
362
+ if not is_final_block:
363
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
364
+ brushnet_block = zero_module(brushnet_block)
365
+ self.brushnet_down_blocks.append(brushnet_block)
366
+
367
+ # mid
368
+ mid_block_channel = block_out_channels[-1]
369
+
370
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
371
+ brushnet_block = zero_module(brushnet_block)
372
+ self.brushnet_mid_block = brushnet_block
373
+
374
+ self.mid_block = get_mid_block(
375
+ mid_block_type,
376
+ transformer_layers_per_block=transformer_layers_per_block[-1],
377
+ in_channels=mid_block_channel,
378
+ temb_channels=time_embed_dim,
379
+ resnet_eps=norm_eps,
380
+ resnet_act_fn=act_fn,
381
+ output_scale_factor=mid_block_scale_factor,
382
+ resnet_time_scale_shift=resnet_time_scale_shift,
383
+ cross_attention_dim=cross_attention_dim,
384
+ num_attention_heads=num_attention_heads[-1],
385
+ resnet_groups=norm_num_groups,
386
+ use_linear_projection=use_linear_projection,
387
+ upcast_attention=upcast_attention,
388
+ )
389
+
390
+ # count how many layers upsample the images
391
+ self.num_upsamplers = 0
392
+
393
+ # up
394
+ reversed_block_out_channels = list(reversed(block_out_channels))
395
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
396
+ reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
397
+ only_cross_attention = list(reversed(only_cross_attention))
398
+
399
+ output_channel = reversed_block_out_channels[0]
400
+
401
+ self.up_blocks = nn.ModuleList([])
402
+ self.brushnet_up_blocks = nn.ModuleList([])
403
+
404
+ for i, up_block_type in enumerate(up_block_types):
405
+ is_final_block = i == len(block_out_channels) - 1
406
+
407
+ prev_output_channel = output_channel
408
+ output_channel = reversed_block_out_channels[i]
409
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
410
+
411
+ # add upsample block for all BUT final layer
412
+ if not is_final_block:
413
+ add_upsample = True
414
+ self.num_upsamplers += 1
415
+ else:
416
+ add_upsample = False
417
+
418
+ up_block = get_up_block(
419
+ up_block_type,
420
+ num_layers=layers_per_block+1,
421
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
422
+ in_channels=input_channel,
423
+ out_channels=output_channel,
424
+ prev_output_channel=prev_output_channel,
425
+ temb_channels=time_embed_dim,
426
+ add_upsample=add_upsample,
427
+ resnet_eps=norm_eps,
428
+ resnet_act_fn=act_fn,
429
+ resolution_idx=i,
430
+ resnet_groups=norm_num_groups,
431
+ cross_attention_dim=cross_attention_dim,
432
+ num_attention_heads=reversed_num_attention_heads[i],
433
+ use_linear_projection=use_linear_projection,
434
+ only_cross_attention=only_cross_attention[i],
435
+ upcast_attention=upcast_attention,
436
+ resnet_time_scale_shift=resnet_time_scale_shift,
437
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
438
+ )
439
+ self.up_blocks.append(up_block)
440
+ prev_output_channel = output_channel
441
+
442
+ for _ in range(layers_per_block+1):
443
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
444
+ brushnet_block = zero_module(brushnet_block)
445
+ self.brushnet_up_blocks.append(brushnet_block)
446
+
447
+ if not is_final_block:
448
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
449
+ brushnet_block = zero_module(brushnet_block)
450
+ self.brushnet_up_blocks.append(brushnet_block)
451
+
452
+
453
+ @classmethod
454
+ def from_unet(
455
+ cls,
456
+ unet: UNet2DConditionModel,
457
+ brushnet_conditioning_channel_order: str = "rgb",
458
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
459
+ load_weights_from_unet: bool = True,
460
+ conditioning_channels: int = 5,
461
+ ):
462
+ r"""
463
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
464
+
465
+ Parameters:
466
+ unet (`UNet2DConditionModel`):
467
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
468
+ where applicable.
469
+ """
470
+ transformer_layers_per_block = (
471
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
472
+ )
473
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
474
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
475
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
476
+ addition_time_embed_dim = (
477
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
478
+ )
479
+
480
+ brushnet = cls(
481
+ in_channels=unet.config.in_channels,
482
+ conditioning_channels=conditioning_channels,
483
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
484
+ freq_shift=unet.config.freq_shift,
485
+ down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
486
+ mid_block_type='MidBlock2D',
487
+ up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
488
+ only_cross_attention=unet.config.only_cross_attention,
489
+ block_out_channels=unet.config.block_out_channels,
490
+ layers_per_block=unet.config.layers_per_block,
491
+ downsample_padding=unet.config.downsample_padding,
492
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
493
+ act_fn=unet.config.act_fn,
494
+ norm_num_groups=unet.config.norm_num_groups,
495
+ norm_eps=unet.config.norm_eps,
496
+ cross_attention_dim=unet.config.cross_attention_dim,
497
+ transformer_layers_per_block=transformer_layers_per_block,
498
+ encoder_hid_dim=encoder_hid_dim,
499
+ encoder_hid_dim_type=encoder_hid_dim_type,
500
+ attention_head_dim=unet.config.attention_head_dim,
501
+ num_attention_heads=unet.config.num_attention_heads,
502
+ use_linear_projection=unet.config.use_linear_projection,
503
+ class_embed_type=unet.config.class_embed_type,
504
+ addition_embed_type=addition_embed_type,
505
+ addition_time_embed_dim=addition_time_embed_dim,
506
+ num_class_embeds=unet.config.num_class_embeds,
507
+ upcast_attention=unet.config.upcast_attention,
508
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
511
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
512
+ )
513
+
514
+ if load_weights_from_unet:
515
+ conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
516
+ conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
517
+ conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
518
+ brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
519
+ brushnet.conv_in_condition.bias=unet.conv_in.bias
520
+
521
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
522
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
523
+
524
+ if brushnet.class_embedding:
525
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
526
+
527
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
528
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
529
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
530
+
531
+ return brushnet
532
+
533
+ @property
534
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
535
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
536
+ r"""
537
+ Returns:
538
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
539
+ indexed by its weight name.
540
+ """
541
+ # set recursively
542
+ processors = {}
543
+
544
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
545
+ if hasattr(module, "get_processor"):
546
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
547
+
548
+ for sub_name, child in module.named_children():
549
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
550
+
551
+ return processors
552
+
553
+ for name, module in self.named_children():
554
+ fn_recursive_add_processors(name, module, processors)
555
+
556
+ return processors
557
+
558
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
559
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
560
+ r"""
561
+ Sets the attention processor to use to compute attention.
562
+
563
+ Parameters:
564
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
565
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
566
+ for **all** `Attention` layers.
567
+
568
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
569
+ processor. This is strongly recommended when setting trainable attention processors.
570
+
571
+ """
572
+ count = len(self.attn_processors.keys())
573
+
574
+ if isinstance(processor, dict) and len(processor) != count:
575
+ raise ValueError(
576
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
577
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
578
+ )
579
+
580
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
581
+ if hasattr(module, "set_processor"):
582
+ if not isinstance(processor, dict):
583
+ module.set_processor(processor)
584
+ else:
585
+ module.set_processor(processor.pop(f"{name}.processor"))
586
+
587
+ for sub_name, child in module.named_children():
588
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
589
+
590
+ for name, module in self.named_children():
591
+ fn_recursive_attn_processor(name, module, processor)
592
+
593
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
594
+ def set_default_attn_processor(self):
595
+ """
596
+ Disables custom attention processors and sets the default attention implementation.
597
+ """
598
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
599
+ processor = AttnAddedKVProcessor()
600
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
601
+ processor = AttnProcessor()
602
+ else:
603
+ raise ValueError(
604
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
605
+ )
606
+
607
+ self.set_attn_processor(processor)
608
+
609
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
610
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
611
+ r"""
612
+ Enable sliced attention computation.
613
+
614
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
615
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
616
+
617
+ Args:
618
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
619
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
620
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
621
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
622
+ must be a multiple of `slice_size`.
623
+ """
624
+ sliceable_head_dims = []
625
+
626
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
627
+ if hasattr(module, "set_attention_slice"):
628
+ sliceable_head_dims.append(module.sliceable_head_dim)
629
+
630
+ for child in module.children():
631
+ fn_recursive_retrieve_sliceable_dims(child)
632
+
633
+ # retrieve number of attention layers
634
+ for module in self.children():
635
+ fn_recursive_retrieve_sliceable_dims(module)
636
+
637
+ num_sliceable_layers = len(sliceable_head_dims)
638
+
639
+ if slice_size == "auto":
640
+ # half the attention head size is usually a good trade-off between
641
+ # speed and memory
642
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
643
+ elif slice_size == "max":
644
+ # make smallest slice possible
645
+ slice_size = num_sliceable_layers * [1]
646
+
647
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
648
+
649
+ if len(slice_size) != len(sliceable_head_dims):
650
+ raise ValueError(
651
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
652
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
653
+ )
654
+
655
+ for i in range(len(slice_size)):
656
+ size = slice_size[i]
657
+ dim = sliceable_head_dims[i]
658
+ if size is not None and size > dim:
659
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
660
+
661
+ # Recursively walk through all the children.
662
+ # Any children which exposes the set_attention_slice method
663
+ # gets the message
664
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
665
+ if hasattr(module, "set_attention_slice"):
666
+ module.set_attention_slice(slice_size.pop())
667
+
668
+ for child in module.children():
669
+ fn_recursive_set_attention_slice(child, slice_size)
670
+
671
+ reversed_slice_size = list(reversed(slice_size))
672
+ for module in self.children():
673
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
674
+
675
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
676
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
677
+ module.gradient_checkpointing = value
678
+
679
+ def forward(
680
+ self,
681
+ sample: torch.FloatTensor,
682
+ encoder_hidden_states: torch.Tensor,
683
+ brushnet_cond: torch.FloatTensor,
684
+ timestep = None,
685
+ time_emb = None,
686
+ conditioning_scale: float = 1.0,
687
+ class_labels: Optional[torch.Tensor] = None,
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
691
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
692
+ guess_mode: bool = False,
693
+ return_dict: bool = True,
694
+ debug = False,
695
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
696
+ """
697
+ The [`BrushNetModel`] forward method.
698
+
699
+ Args:
700
+ sample (`torch.FloatTensor`):
701
+ The noisy input tensor.
702
+ timestep (`Union[torch.Tensor, float, int]`):
703
+ The number of timesteps to denoise an input.
704
+ encoder_hidden_states (`torch.Tensor`):
705
+ The encoder hidden states.
706
+ brushnet_cond (`torch.FloatTensor`):
707
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
708
+ conditioning_scale (`float`, defaults to `1.0`):
709
+ The scale factor for BrushNet outputs.
710
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
711
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
712
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
713
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
714
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
715
+ embeddings.
716
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
717
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
718
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
719
+ negative values to the attention scores corresponding to "discard" tokens.
720
+ added_cond_kwargs (`dict`):
721
+ Additional conditions for the Stable Diffusion XL UNet.
722
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
723
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
724
+ guess_mode (`bool`, defaults to `False`):
725
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
726
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
727
+ return_dict (`bool`, defaults to `True`):
728
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
729
+
730
+ Returns:
731
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
732
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
733
+ returned where the first element is the sample tensor.
734
+ """
735
+
736
+ # check channel order
737
+ channel_order = self.config.brushnet_conditioning_channel_order
738
+
739
+ if channel_order == "rgb":
740
+ # in rgb order by default
741
+ ...
742
+ elif channel_order == "bgr":
743
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
744
+ else:
745
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
746
+
747
+ # prepare attention_mask
748
+ if attention_mask is not None:
749
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
750
+ attention_mask = attention_mask.unsqueeze(1)
751
+
752
+ if timestep is None and time_emb is None:
753
+ raise ValueError(f"`timestep` and `emb` are both None")
754
+
755
+ #print("BN: sample.device", sample.device)
756
+ #print("BN: TE.device", self.time_embedding.linear_1.weight.device)
757
+
758
+ if timestep is not None:
759
+ # 1. time
760
+ timesteps = timestep
761
+ if not torch.is_tensor(timesteps):
762
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
763
+ # This would be a good case for the `match` statement (Python 3.10+)
764
+ is_mps = sample.device.type == "mps"
765
+ if isinstance(timestep, float):
766
+ dtype = torch.float32 if is_mps else torch.float64
767
+ else:
768
+ dtype = torch.int32 if is_mps else torch.int64
769
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
770
+ elif len(timesteps.shape) == 0:
771
+ timesteps = timesteps[None].to(sample.device)
772
+
773
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
774
+ timesteps = timesteps.expand(sample.shape[0])
775
+
776
+ t_emb = self.time_proj(timesteps)
777
+
778
+ # timesteps does not contain any weights and will always return f32 tensors
779
+ # but time_embedding might actually be running in fp16. so we need to cast here.
780
+ # there might be better ways to encapsulate this.
781
+ t_emb = t_emb.to(dtype=sample.dtype)
782
+
783
+ #print("t_emb.device =",t_emb.device)
784
+
785
+ emb = self.time_embedding(t_emb, timestep_cond)
786
+ aug_emb = None
787
+
788
+ #print('emb.shape', emb.shape)
789
+
790
+ if self.class_embedding is not None:
791
+ if class_labels is None:
792
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
793
+
794
+ if self.config.class_embed_type == "timestep":
795
+ class_labels = self.time_proj(class_labels)
796
+
797
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
798
+ emb = emb + class_emb
799
+
800
+ if self.config.addition_embed_type is not None:
801
+ if self.config.addition_embed_type == "text":
802
+ aug_emb = self.add_embedding(encoder_hidden_states)
803
+
804
+ elif self.config.addition_embed_type == "text_time":
805
+ if "text_embeds" not in added_cond_kwargs:
806
+ raise ValueError(
807
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
808
+ )
809
+ text_embeds = added_cond_kwargs.get("text_embeds")
810
+ if "time_ids" not in added_cond_kwargs:
811
+ raise ValueError(
812
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
813
+ )
814
+ time_ids = added_cond_kwargs.get("time_ids")
815
+ time_embeds = self.add_time_proj(time_ids.flatten())
816
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
817
+
818
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
819
+ add_embeds = add_embeds.to(emb.dtype)
820
+ aug_emb = self.add_embedding(add_embeds)
821
+
822
+ #print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
823
+
824
+ emb = emb + aug_emb if aug_emb is not None else emb
825
+ else:
826
+ emb = time_emb
827
+
828
+ # 2. pre-process
829
+
830
+ brushnet_cond=torch.concat([sample,brushnet_cond],1)
831
+ sample = self.conv_in_condition(brushnet_cond)
832
+
833
+ # 3. down
834
+ down_block_res_samples = (sample,)
835
+ for downsample_block in self.down_blocks:
836
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
837
+ sample, res_samples = downsample_block(
838
+ hidden_states=sample,
839
+ temb=emb,
840
+ encoder_hidden_states=encoder_hidden_states,
841
+ attention_mask=attention_mask,
842
+ cross_attention_kwargs=cross_attention_kwargs,
843
+ )
844
+ else:
845
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
846
+
847
+ down_block_res_samples += res_samples
848
+
849
+ # 4. PaintingNet down blocks
850
+ brushnet_down_block_res_samples = ()
851
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
852
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
853
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
854
+
855
+
856
+ # 5. mid
857
+ if self.mid_block is not None:
858
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
859
+ sample = self.mid_block(
860
+ sample,
861
+ emb,
862
+ encoder_hidden_states=encoder_hidden_states,
863
+ attention_mask=attention_mask,
864
+ cross_attention_kwargs=cross_attention_kwargs,
865
+ )
866
+ else:
867
+ sample = self.mid_block(sample, emb)
868
+
869
+ # 6. BrushNet mid blocks
870
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
871
+
872
+ # 7. up
873
+ up_block_res_samples = ()
874
+ for i, upsample_block in enumerate(self.up_blocks):
875
+ is_final_block = i == len(self.up_blocks) - 1
876
+
877
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
878
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
879
+
880
+ # if we have not reached the final block and need to forward the
881
+ # upsample size, we do it here
882
+ if not is_final_block:
883
+ upsample_size = down_block_res_samples[-1].shape[2:]
884
+
885
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
886
+ sample, up_res_samples = upsample_block(
887
+ hidden_states=sample,
888
+ temb=emb,
889
+ res_hidden_states_tuple=res_samples,
890
+ encoder_hidden_states=encoder_hidden_states,
891
+ cross_attention_kwargs=cross_attention_kwargs,
892
+ upsample_size=upsample_size,
893
+ attention_mask=attention_mask,
894
+ return_res_samples=True
895
+ )
896
+ else:
897
+ sample, up_res_samples = upsample_block(
898
+ hidden_states=sample,
899
+ temb=emb,
900
+ res_hidden_states_tuple=res_samples,
901
+ upsample_size=upsample_size,
902
+ return_res_samples=True
903
+ )
904
+
905
+ up_block_res_samples += up_res_samples
906
+
907
+ # 8. BrushNet up blocks
908
+ brushnet_up_block_res_samples = ()
909
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
910
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
911
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
912
+
913
+ # 6. scaling
914
+ if guess_mode and not self.config.global_pool_conditions:
915
+ scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
916
+ scales = scales * conditioning_scale
917
+
918
+ brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
919
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
920
+ brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
921
+ else:
922
+ brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
923
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
924
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
925
+
926
+
927
+ if self.config.global_pool_conditions:
928
+ brushnet_down_block_res_samples = [
929
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
930
+ ]
931
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
932
+ brushnet_up_block_res_samples = [
933
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
934
+ ]
935
+
936
+ if not return_dict:
937
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
938
+
939
+ return BrushNetOutput(
940
+ down_block_res_samples=brushnet_down_block_res_samples,
941
+ mid_block_res_sample=brushnet_mid_block_res_sample,
942
+ up_block_res_samples=brushnet_up_block_res_samples
943
+ )
944
+
945
+
946
+ def zero_module(module):
947
+ for p in module.parameters():
948
+ nn.init.zeros_(p)
949
+ return module
MagicQuill/brushnet/brushnet_ca.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.models.attention_processor import (
10
+ ADDED_KV_ATTENTION_PROCESSORS,
11
+ CROSS_ATTENTION_PROCESSORS,
12
+ AttentionProcessor,
13
+ AttnAddedKVProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+
19
+ from .unet_2d_blocks import (
20
+ CrossAttnDownBlock2D,
21
+ DownBlock2D,
22
+ UNetMidBlock2D,
23
+ UNetMidBlock2DCrossAttn,
24
+ get_down_block,
25
+ get_mid_block,
26
+ get_up_block,
27
+ MidBlock2D
28
+ )
29
+
30
+ from .unet_2d_condition import UNet2DConditionModel
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class BrushNetOutput(BaseOutput):
38
+ """
39
+ The output of [`BrushNetModel`].
40
+
41
+ Args:
42
+ up_block_res_samples (`tuple[torch.Tensor]`):
43
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
44
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
45
+ used to condition the original UNet's upsampling activations.
46
+ down_block_res_samples (`tuple[torch.Tensor]`):
47
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
48
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
49
+ used to condition the original UNet's downsampling activations.
50
+ mid_down_block_re_sample (`torch.Tensor`):
51
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
52
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
53
+ Output can be used to condition the original UNet's middle block activation.
54
+ """
55
+
56
+ up_block_res_samples: Tuple[torch.Tensor]
57
+ down_block_res_samples: Tuple[torch.Tensor]
58
+ mid_block_res_sample: torch.Tensor
59
+
60
+
61
+ class BrushNetModel(ModelMixin, ConfigMixin):
62
+ """
63
+ A BrushNet model.
64
+
65
+ Args:
66
+ in_channels (`int`, defaults to 4):
67
+ The number of channels in the input sample.
68
+ flip_sin_to_cos (`bool`, defaults to `True`):
69
+ Whether to flip the sin to cos in the time embedding.
70
+ freq_shift (`int`, defaults to 0):
71
+ The frequency shift to apply to the time embedding.
72
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
73
+ The tuple of downsample blocks to use.
74
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
75
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
76
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
77
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
78
+ The tuple of upsample blocks to use.
79
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
80
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
81
+ The tuple of output channels for each block.
82
+ layers_per_block (`int`, defaults to 2):
83
+ The number of layers per block.
84
+ downsample_padding (`int`, defaults to 1):
85
+ The padding to use for the downsampling convolution.
86
+ mid_block_scale_factor (`float`, defaults to 1):
87
+ The scale factor to use for the mid block.
88
+ act_fn (`str`, defaults to "silu"):
89
+ The activation function to use.
90
+ norm_num_groups (`int`, *optional*, defaults to 32):
91
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
92
+ in post-processing.
93
+ norm_eps (`float`, defaults to 1e-5):
94
+ The epsilon to use for the normalization.
95
+ cross_attention_dim (`int`, defaults to 1280):
96
+ The dimension of the cross attention features.
97
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
98
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
99
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
100
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
101
+ encoder_hid_dim (`int`, *optional*, defaults to None):
102
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
103
+ dimension to `cross_attention_dim`.
104
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
105
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
106
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
107
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
108
+ The dimension of the attention heads.
109
+ use_linear_projection (`bool`, defaults to `False`):
110
+ class_embed_type (`str`, *optional*, defaults to `None`):
111
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
112
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
113
+ addition_embed_type (`str`, *optional*, defaults to `None`):
114
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
115
+ "text". "text" will use the `TextTimeEmbedding` layer.
116
+ num_class_embeds (`int`, *optional*, defaults to 0):
117
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
118
+ class conditioning with `class_embed_type` equal to `None`.
119
+ upcast_attention (`bool`, defaults to `False`):
120
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
121
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
122
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
123
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
124
+ `class_embed_type="projection"`.
125
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
126
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
127
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
128
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
129
+ global_pool_conditions (`bool`, defaults to `False`):
130
+ TODO(Patrick) - unused parameter.
131
+ addition_embed_type_num_heads (`int`, defaults to 64):
132
+ The number of heads to use for the `TextTimeEmbedding` layer.
133
+ """
134
+
135
+ _supports_gradient_checkpointing = True
136
+
137
+ @register_to_config
138
+ def __init__(
139
+ self,
140
+ in_channels: int = 4,
141
+ conditioning_channels: int = 5,
142
+ flip_sin_to_cos: bool = True,
143
+ freq_shift: int = 0,
144
+ down_block_types: Tuple[str, ...] = (
145
+ "CrossAttnDownBlock2D",
146
+ "CrossAttnDownBlock2D",
147
+ "CrossAttnDownBlock2D",
148
+ "DownBlock2D",
149
+ ),
150
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
151
+ up_block_types: Tuple[str, ...] = (
152
+ "UpBlock2D",
153
+ "CrossAttnUpBlock2D",
154
+ "CrossAttnUpBlock2D",
155
+ "CrossAttnUpBlock2D",
156
+ ),
157
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
158
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
159
+ layers_per_block: int = 2,
160
+ downsample_padding: int = 1,
161
+ mid_block_scale_factor: float = 1,
162
+ act_fn: str = "silu",
163
+ norm_num_groups: Optional[int] = 32,
164
+ norm_eps: float = 1e-5,
165
+ cross_attention_dim: int = 1280,
166
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
167
+ encoder_hid_dim: Optional[int] = None,
168
+ encoder_hid_dim_type: Optional[str] = None,
169
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
170
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
171
+ use_linear_projection: bool = False,
172
+ class_embed_type: Optional[str] = None,
173
+ addition_embed_type: Optional[str] = None,
174
+ addition_time_embed_dim: Optional[int] = None,
175
+ num_class_embeds: Optional[int] = None,
176
+ upcast_attention: bool = False,
177
+ resnet_time_scale_shift: str = "default",
178
+ projection_class_embeddings_input_dim: Optional[int] = None,
179
+ brushnet_conditioning_channel_order: str = "rgb",
180
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
181
+ global_pool_conditions: bool = False,
182
+ addition_embed_type_num_heads: int = 64,
183
+ ):
184
+ super().__init__()
185
+
186
+ # If `num_attention_heads` is not defined (which is the case for most models)
187
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
188
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
189
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
190
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
191
+ # which is why we correct for the naming here.
192
+ num_attention_heads = num_attention_heads or attention_head_dim
193
+
194
+ # Check inputs
195
+ if len(down_block_types) != len(up_block_types):
196
+ raise ValueError(
197
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
198
+ )
199
+
200
+ if len(block_out_channels) != len(down_block_types):
201
+ raise ValueError(
202
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
203
+ )
204
+
205
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
206
+ raise ValueError(
207
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
208
+ )
209
+
210
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
211
+ raise ValueError(
212
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
213
+ )
214
+
215
+ if isinstance(transformer_layers_per_block, int):
216
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
217
+
218
+ # input
219
+ conv_in_kernel = 3
220
+ conv_in_padding = (conv_in_kernel - 1) // 2
221
+ self.conv_in_condition = nn.Conv2d(
222
+ in_channels + conditioning_channels,
223
+ block_out_channels[0],
224
+ kernel_size=conv_in_kernel,
225
+ padding=conv_in_padding,
226
+ )
227
+
228
+ # time
229
+ time_embed_dim = block_out_channels[0] * 4
230
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
231
+ timestep_input_dim = block_out_channels[0]
232
+ self.time_embedding = TimestepEmbedding(
233
+ timestep_input_dim,
234
+ time_embed_dim,
235
+ act_fn=act_fn,
236
+ )
237
+
238
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
239
+ encoder_hid_dim_type = "text_proj"
240
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
241
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
242
+
243
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
244
+ raise ValueError(
245
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
246
+ )
247
+
248
+ if encoder_hid_dim_type == "text_proj":
249
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
250
+ elif encoder_hid_dim_type == "text_image_proj":
251
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
252
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
253
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
254
+ self.encoder_hid_proj = TextImageProjection(
255
+ text_embed_dim=encoder_hid_dim,
256
+ image_embed_dim=cross_attention_dim,
257
+ cross_attention_dim=cross_attention_dim,
258
+ )
259
+
260
+ elif encoder_hid_dim_type is not None:
261
+ raise ValueError(
262
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
263
+ )
264
+ else:
265
+ self.encoder_hid_proj = None
266
+
267
+ # class embedding
268
+ if class_embed_type is None and num_class_embeds is not None:
269
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
270
+ elif class_embed_type == "timestep":
271
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
272
+ elif class_embed_type == "identity":
273
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
274
+ elif class_embed_type == "projection":
275
+ if projection_class_embeddings_input_dim is None:
276
+ raise ValueError(
277
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
278
+ )
279
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
280
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
281
+ # 2. it projects from an arbitrary input dimension.
282
+ #
283
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
284
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
285
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
286
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
287
+ else:
288
+ self.class_embedding = None
289
+
290
+ if addition_embed_type == "text":
291
+ if encoder_hid_dim is not None:
292
+ text_time_embedding_from_dim = encoder_hid_dim
293
+ else:
294
+ text_time_embedding_from_dim = cross_attention_dim
295
+
296
+ self.add_embedding = TextTimeEmbedding(
297
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
298
+ )
299
+ elif addition_embed_type == "text_image":
300
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
301
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
302
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
303
+ self.add_embedding = TextImageTimeEmbedding(
304
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
305
+ )
306
+ elif addition_embed_type == "text_time":
307
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
308
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
309
+
310
+ elif addition_embed_type is not None:
311
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
312
+
313
+ self.down_blocks = nn.ModuleList([])
314
+ self.brushnet_down_blocks = nn.ModuleList([])
315
+
316
+ if isinstance(only_cross_attention, bool):
317
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
318
+
319
+ if isinstance(attention_head_dim, int):
320
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
321
+
322
+ if isinstance(num_attention_heads, int):
323
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
324
+
325
+ # down
326
+ output_channel = block_out_channels[0]
327
+
328
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
329
+ brushnet_block = zero_module(brushnet_block)
330
+ self.brushnet_down_blocks.append(brushnet_block)
331
+
332
+ for i, down_block_type in enumerate(down_block_types):
333
+ input_channel = output_channel
334
+ output_channel = block_out_channels[i]
335
+ is_final_block = i == len(block_out_channels) - 1
336
+
337
+ down_block = get_down_block(
338
+ down_block_type,
339
+ num_layers=layers_per_block,
340
+ transformer_layers_per_block=transformer_layers_per_block[i],
341
+ in_channels=input_channel,
342
+ out_channels=output_channel,
343
+ temb_channels=time_embed_dim,
344
+ add_downsample=not is_final_block,
345
+ resnet_eps=norm_eps,
346
+ resnet_act_fn=act_fn,
347
+ resnet_groups=norm_num_groups,
348
+ cross_attention_dim=cross_attention_dim,
349
+ num_attention_heads=num_attention_heads[i],
350
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
351
+ downsample_padding=downsample_padding,
352
+ use_linear_projection=use_linear_projection,
353
+ only_cross_attention=only_cross_attention[i],
354
+ upcast_attention=upcast_attention,
355
+ resnet_time_scale_shift=resnet_time_scale_shift,
356
+ )
357
+ self.down_blocks.append(down_block)
358
+
359
+ for _ in range(layers_per_block):
360
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
361
+ brushnet_block = zero_module(brushnet_block)
362
+ self.brushnet_down_blocks.append(brushnet_block)
363
+
364
+ if not is_final_block:
365
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
366
+ brushnet_block = zero_module(brushnet_block)
367
+ self.brushnet_down_blocks.append(brushnet_block)
368
+
369
+ # mid
370
+ mid_block_channel = block_out_channels[-1]
371
+
372
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
373
+ brushnet_block = zero_module(brushnet_block)
374
+ self.brushnet_mid_block = brushnet_block
375
+
376
+ self.mid_block = get_mid_block(
377
+ mid_block_type,
378
+ transformer_layers_per_block=transformer_layers_per_block[-1],
379
+ in_channels=mid_block_channel,
380
+ temb_channels=time_embed_dim,
381
+ resnet_eps=norm_eps,
382
+ resnet_act_fn=act_fn,
383
+ output_scale_factor=mid_block_scale_factor,
384
+ resnet_time_scale_shift=resnet_time_scale_shift,
385
+ cross_attention_dim=cross_attention_dim,
386
+ num_attention_heads=num_attention_heads[-1],
387
+ resnet_groups=norm_num_groups,
388
+ use_linear_projection=use_linear_projection,
389
+ upcast_attention=upcast_attention,
390
+ )
391
+
392
+ # count how many layers upsample the images
393
+ self.num_upsamplers = 0
394
+
395
+ # up
396
+ reversed_block_out_channels = list(reversed(block_out_channels))
397
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
398
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
399
+ only_cross_attention = list(reversed(only_cross_attention))
400
+
401
+ output_channel = reversed_block_out_channels[0]
402
+
403
+ self.up_blocks = nn.ModuleList([])
404
+ self.brushnet_up_blocks = nn.ModuleList([])
405
+
406
+ for i, up_block_type in enumerate(up_block_types):
407
+ is_final_block = i == len(block_out_channels) - 1
408
+
409
+ prev_output_channel = output_channel
410
+ output_channel = reversed_block_out_channels[i]
411
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
412
+
413
+ # add upsample block for all BUT final layer
414
+ if not is_final_block:
415
+ add_upsample = True
416
+ self.num_upsamplers += 1
417
+ else:
418
+ add_upsample = False
419
+
420
+ up_block = get_up_block(
421
+ up_block_type,
422
+ num_layers=layers_per_block + 1,
423
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
424
+ in_channels=input_channel,
425
+ out_channels=output_channel,
426
+ prev_output_channel=prev_output_channel,
427
+ temb_channels=time_embed_dim,
428
+ add_upsample=add_upsample,
429
+ resnet_eps=norm_eps,
430
+ resnet_act_fn=act_fn,
431
+ resolution_idx=i,
432
+ resnet_groups=norm_num_groups,
433
+ cross_attention_dim=cross_attention_dim,
434
+ num_attention_heads=reversed_num_attention_heads[i],
435
+ use_linear_projection=use_linear_projection,
436
+ only_cross_attention=only_cross_attention[i],
437
+ upcast_attention=upcast_attention,
438
+ resnet_time_scale_shift=resnet_time_scale_shift,
439
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
440
+ )
441
+ self.up_blocks.append(up_block)
442
+ prev_output_channel = output_channel
443
+
444
+ for _ in range(layers_per_block + 1):
445
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
446
+ brushnet_block = zero_module(brushnet_block)
447
+ self.brushnet_up_blocks.append(brushnet_block)
448
+
449
+ if not is_final_block:
450
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
451
+ brushnet_block = zero_module(brushnet_block)
452
+ self.brushnet_up_blocks.append(brushnet_block)
453
+
454
+ @classmethod
455
+ def from_unet(
456
+ cls,
457
+ unet: UNet2DConditionModel,
458
+ brushnet_conditioning_channel_order: str = "rgb",
459
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
460
+ load_weights_from_unet: bool = True,
461
+ conditioning_channels: int = 5,
462
+ ):
463
+ r"""
464
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
465
+
466
+ Parameters:
467
+ unet (`UNet2DConditionModel`):
468
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
469
+ where applicable.
470
+ """
471
+ transformer_layers_per_block = (
472
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
473
+ )
474
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
475
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
476
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
477
+ addition_time_embed_dim = (
478
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
479
+ )
480
+
481
+ brushnet = cls(
482
+ in_channels=unet.config.in_channels,
483
+ conditioning_channels=conditioning_channels,
484
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
485
+ freq_shift=unet.config.freq_shift,
486
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
487
+ down_block_types=[
488
+ "CrossAttnDownBlock2D",
489
+ "CrossAttnDownBlock2D",
490
+ "CrossAttnDownBlock2D",
491
+ "DownBlock2D",
492
+ ],
493
+ # mid_block_type='MidBlock2D',
494
+ mid_block_type="UNetMidBlock2DCrossAttn",
495
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
496
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
497
+ only_cross_attention=unet.config.only_cross_attention,
498
+ block_out_channels=unet.config.block_out_channels,
499
+ layers_per_block=unet.config.layers_per_block,
500
+ downsample_padding=unet.config.downsample_padding,
501
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
502
+ act_fn=unet.config.act_fn,
503
+ norm_num_groups=unet.config.norm_num_groups,
504
+ norm_eps=unet.config.norm_eps,
505
+ cross_attention_dim=unet.config.cross_attention_dim,
506
+ transformer_layers_per_block=transformer_layers_per_block,
507
+ encoder_hid_dim=encoder_hid_dim,
508
+ encoder_hid_dim_type=encoder_hid_dim_type,
509
+ attention_head_dim=unet.config.attention_head_dim,
510
+ num_attention_heads=unet.config.num_attention_heads,
511
+ use_linear_projection=unet.config.use_linear_projection,
512
+ class_embed_type=unet.config.class_embed_type,
513
+ addition_embed_type=addition_embed_type,
514
+ addition_time_embed_dim=addition_time_embed_dim,
515
+ num_class_embeds=unet.config.num_class_embeds,
516
+ upcast_attention=unet.config.upcast_attention,
517
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
518
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
519
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
520
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
521
+ )
522
+
523
+ if load_weights_from_unet:
524
+ conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
525
+ conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
526
+ conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
527
+ brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
528
+ brushnet.conv_in_condition.bias = unet.conv_in.bias
529
+
530
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
531
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
532
+
533
+ if brushnet.class_embedding:
534
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
535
+
536
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
537
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
538
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
539
+
540
+ return brushnet.to(unet.dtype)
541
+
542
+ @property
543
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
544
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
545
+ r"""
546
+ Returns:
547
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
548
+ indexed by its weight name.
549
+ """
550
+ # set recursively
551
+ processors = {}
552
+
553
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
554
+ if hasattr(module, "get_processor"):
555
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
556
+
557
+ for sub_name, child in module.named_children():
558
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
559
+
560
+ return processors
561
+
562
+ for name, module in self.named_children():
563
+ fn_recursive_add_processors(name, module, processors)
564
+
565
+ return processors
566
+
567
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
568
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
569
+ r"""
570
+ Sets the attention processor to use to compute attention.
571
+
572
+ Parameters:
573
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
574
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
575
+ for **all** `Attention` layers.
576
+
577
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
578
+ processor. This is strongly recommended when setting trainable attention processors.
579
+
580
+ """
581
+ count = len(self.attn_processors.keys())
582
+
583
+ if isinstance(processor, dict) and len(processor) != count:
584
+ raise ValueError(
585
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
586
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
587
+ )
588
+
589
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
590
+ if hasattr(module, "set_processor"):
591
+ if not isinstance(processor, dict):
592
+ module.set_processor(processor)
593
+ else:
594
+ module.set_processor(processor.pop(f"{name}.processor"))
595
+
596
+ for sub_name, child in module.named_children():
597
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
598
+
599
+ for name, module in self.named_children():
600
+ fn_recursive_attn_processor(name, module, processor)
601
+
602
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
603
+ def set_default_attn_processor(self):
604
+ """
605
+ Disables custom attention processors and sets the default attention implementation.
606
+ """
607
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
608
+ processor = AttnAddedKVProcessor()
609
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
+ processor = AttnProcessor()
611
+ else:
612
+ raise ValueError(
613
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
614
+ )
615
+
616
+ self.set_attn_processor(processor)
617
+
618
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
619
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
620
+ r"""
621
+ Enable sliced attention computation.
622
+
623
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
624
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
625
+
626
+ Args:
627
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
628
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
629
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
630
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
631
+ must be a multiple of `slice_size`.
632
+ """
633
+ sliceable_head_dims = []
634
+
635
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
636
+ if hasattr(module, "set_attention_slice"):
637
+ sliceable_head_dims.append(module.sliceable_head_dim)
638
+
639
+ for child in module.children():
640
+ fn_recursive_retrieve_sliceable_dims(child)
641
+
642
+ # retrieve number of attention layers
643
+ for module in self.children():
644
+ fn_recursive_retrieve_sliceable_dims(module)
645
+
646
+ num_sliceable_layers = len(sliceable_head_dims)
647
+
648
+ if slice_size == "auto":
649
+ # half the attention head size is usually a good trade-off between
650
+ # speed and memory
651
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
652
+ elif slice_size == "max":
653
+ # make smallest slice possible
654
+ slice_size = num_sliceable_layers * [1]
655
+
656
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
657
+
658
+ if len(slice_size) != len(sliceable_head_dims):
659
+ raise ValueError(
660
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
661
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
662
+ )
663
+
664
+ for i in range(len(slice_size)):
665
+ size = slice_size[i]
666
+ dim = sliceable_head_dims[i]
667
+ if size is not None and size > dim:
668
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
669
+
670
+ # Recursively walk through all the children.
671
+ # Any children which exposes the set_attention_slice method
672
+ # gets the message
673
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
674
+ if hasattr(module, "set_attention_slice"):
675
+ module.set_attention_slice(slice_size.pop())
676
+
677
+ for child in module.children():
678
+ fn_recursive_set_attention_slice(child, slice_size)
679
+
680
+ reversed_slice_size = list(reversed(slice_size))
681
+ for module in self.children():
682
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
683
+
684
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
685
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
686
+ module.gradient_checkpointing = value
687
+
688
+ def forward(
689
+ self,
690
+ sample: torch.FloatTensor,
691
+ timestep: Union[torch.Tensor, float, int],
692
+ encoder_hidden_states: torch.Tensor,
693
+ brushnet_cond: torch.FloatTensor,
694
+ conditioning_scale: float = 1.0,
695
+ class_labels: Optional[torch.Tensor] = None,
696
+ timestep_cond: Optional[torch.Tensor] = None,
697
+ attention_mask: Optional[torch.Tensor] = None,
698
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
699
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ guess_mode: bool = False,
701
+ return_dict: bool = True,
702
+ debug=False,
703
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
704
+ """
705
+ The [`BrushNetModel`] forward method.
706
+
707
+ Args:
708
+ sample (`torch.FloatTensor`):
709
+ The noisy input tensor.
710
+ timestep (`Union[torch.Tensor, float, int]`):
711
+ The number of timesteps to denoise an input.
712
+ encoder_hidden_states (`torch.Tensor`):
713
+ The encoder hidden states.
714
+ brushnet_cond (`torch.FloatTensor`):
715
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
716
+ conditioning_scale (`float`, defaults to `1.0`):
717
+ The scale factor for BrushNet outputs.
718
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
719
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
720
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
721
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
722
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
723
+ embeddings.
724
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
725
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
726
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
727
+ negative values to the attention scores corresponding to "discard" tokens.
728
+ added_cond_kwargs (`dict`):
729
+ Additional conditions for the Stable Diffusion XL UNet.
730
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
731
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
732
+ guess_mode (`bool`, defaults to `False`):
733
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
734
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
735
+ return_dict (`bool`, defaults to `True`):
736
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
737
+
738
+ Returns:
739
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
740
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
741
+ returned where the first element is the sample tensor.
742
+ """
743
+ # check channel order
744
+ channel_order = self.config.brushnet_conditioning_channel_order
745
+
746
+ if channel_order == "rgb":
747
+ # in rgb order by default
748
+ ...
749
+ elif channel_order == "bgr":
750
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
751
+ else:
752
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
753
+
754
+ if debug: print('BrushNet CA: attn mask')
755
+
756
+ # prepare attention_mask
757
+ if attention_mask is not None:
758
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
759
+ attention_mask = attention_mask.unsqueeze(1)
760
+
761
+ if debug: print('BrushNet CA: time')
762
+
763
+ # 1. time
764
+ timesteps = timestep
765
+ if not torch.is_tensor(timesteps):
766
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
767
+ # This would be a good case for the `match` statement (Python 3.10+)
768
+ is_mps = sample.device.type == "mps"
769
+ if isinstance(timestep, float):
770
+ dtype = torch.float32 if is_mps else torch.float64
771
+ else:
772
+ dtype = torch.int32 if is_mps else torch.int64
773
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
774
+ elif len(timesteps.shape) == 0:
775
+ timesteps = timesteps[None].to(sample.device)
776
+
777
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
778
+ timesteps = timesteps.expand(sample.shape[0])
779
+
780
+ t_emb = self.time_proj(timesteps)
781
+
782
+ # timesteps does not contain any weights and will always return f32 tensors
783
+ # but time_embedding might actually be running in fp16. so we need to cast here.
784
+ # there might be better ways to encapsulate this.
785
+ t_emb = t_emb.to(dtype=sample.dtype)
786
+
787
+ emb = self.time_embedding(t_emb, timestep_cond)
788
+ aug_emb = None
789
+
790
+ if self.class_embedding is not None:
791
+ if class_labels is None:
792
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
793
+
794
+ if self.config.class_embed_type == "timestep":
795
+ class_labels = self.time_proj(class_labels)
796
+
797
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
798
+ emb = emb + class_emb
799
+
800
+ if self.config.addition_embed_type is not None:
801
+ if self.config.addition_embed_type == "text":
802
+ aug_emb = self.add_embedding(encoder_hidden_states)
803
+
804
+ elif self.config.addition_embed_type == "text_time":
805
+ if "text_embeds" not in added_cond_kwargs:
806
+ raise ValueError(
807
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
808
+ )
809
+ text_embeds = added_cond_kwargs.get("text_embeds")
810
+ if "time_ids" not in added_cond_kwargs:
811
+ raise ValueError(
812
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
813
+ )
814
+ time_ids = added_cond_kwargs.get("time_ids")
815
+ time_embeds = self.add_time_proj(time_ids.flatten())
816
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
817
+
818
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
819
+ add_embeds = add_embeds.to(emb.dtype)
820
+ aug_emb = self.add_embedding(add_embeds)
821
+
822
+ emb = emb + aug_emb if aug_emb is not None else emb
823
+
824
+ if debug: print('BrushNet CA: pre-process')
825
+
826
+
827
+ # 2. pre-process
828
+ brushnet_cond = torch.concat([sample, brushnet_cond], 1)
829
+ sample = self.conv_in_condition(brushnet_cond)
830
+
831
+ if debug: print('BrushNet CA: down')
832
+
833
+ # 3. down
834
+ down_block_res_samples = (sample,)
835
+ for downsample_block in self.down_blocks:
836
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
837
+ if debug: print('BrushNet CA (down block with XA): ', type(downsample_block))
838
+ sample, res_samples = downsample_block(
839
+ hidden_states=sample,
840
+ temb=emb,
841
+ encoder_hidden_states=encoder_hidden_states,
842
+ attention_mask=attention_mask,
843
+ cross_attention_kwargs=cross_attention_kwargs,
844
+ debug=debug,
845
+ )
846
+ else:
847
+ if debug: print('BrushNet CA (down block): ', type(downsample_block))
848
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, debug=debug)
849
+
850
+ down_block_res_samples += res_samples
851
+
852
+ if debug: print('BrushNet CA: PP down')
853
+
854
+ # 4. PaintingNet down blocks
855
+ brushnet_down_block_res_samples = ()
856
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
857
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
858
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
859
+
860
+ if debug: print('BrushNet CA: PP mid')
861
+
862
+ # 5. mid
863
+ if self.mid_block is not None:
864
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
865
+ sample = self.mid_block(
866
+ sample,
867
+ emb,
868
+ encoder_hidden_states=encoder_hidden_states,
869
+ attention_mask=attention_mask,
870
+ cross_attention_kwargs=cross_attention_kwargs,
871
+ )
872
+ else:
873
+ sample = self.mid_block(sample, emb)
874
+
875
+ if debug: print('BrushNet CA: mid')
876
+
877
+ # 6. BrushNet mid blocks
878
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
879
+
880
+ if debug: print('BrushNet CA: PP up')
881
+
882
+ # 7. up
883
+ up_block_res_samples = ()
884
+ for i, upsample_block in enumerate(self.up_blocks):
885
+ is_final_block = i == len(self.up_blocks) - 1
886
+
887
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
888
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
889
+
890
+ # if we have not reached the final block and need to forward the
891
+ # upsample size, we do it here
892
+ if not is_final_block:
893
+ upsample_size = down_block_res_samples[-1].shape[2:]
894
+
895
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
896
+ sample, up_res_samples = upsample_block(
897
+ hidden_states=sample,
898
+ temb=emb,
899
+ res_hidden_states_tuple=res_samples,
900
+ encoder_hidden_states=encoder_hidden_states,
901
+ cross_attention_kwargs=cross_attention_kwargs,
902
+ upsample_size=upsample_size,
903
+ attention_mask=attention_mask,
904
+ return_res_samples=True,
905
+ )
906
+ else:
907
+ sample, up_res_samples = upsample_block(
908
+ hidden_states=sample,
909
+ temb=emb,
910
+ res_hidden_states_tuple=res_samples,
911
+ upsample_size=upsample_size,
912
+ return_res_samples=True,
913
+ )
914
+
915
+ up_block_res_samples += up_res_samples
916
+
917
+ if debug: print('BrushNet CA: up')
918
+
919
+ # 8. BrushNet up blocks
920
+ brushnet_up_block_res_samples = ()
921
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
922
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
923
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
924
+
925
+ if debug: print('BrushNet CA: scaling')
926
+
927
+ # 6. scaling
928
+ if guess_mode and not self.config.global_pool_conditions:
929
+ scales = torch.logspace(
930
+ -1,
931
+ 0,
932
+ len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
933
+ device=sample.device,
934
+ ) # 0.1 to 1.0
935
+ scales = scales * conditioning_scale
936
+
937
+ brushnet_down_block_res_samples = [
938
+ sample * scale
939
+ for sample, scale in zip(
940
+ brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
941
+ )
942
+ ]
943
+ brushnet_mid_block_res_sample = (
944
+ brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
945
+ )
946
+ brushnet_up_block_res_samples = [
947
+ sample * scale
948
+ for sample, scale in zip(
949
+ brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
950
+ )
951
+ ]
952
+ else:
953
+ brushnet_down_block_res_samples = [
954
+ sample * conditioning_scale for sample in brushnet_down_block_res_samples
955
+ ]
956
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
957
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
958
+
959
+ if self.config.global_pool_conditions:
960
+ brushnet_down_block_res_samples = [
961
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
962
+ ]
963
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
964
+ brushnet_up_block_res_samples = [
965
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
966
+ ]
967
+
968
+ if debug: print('BrushNet CA: finish')
969
+
970
+ if not return_dict:
971
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
972
+
973
+ return BrushNetOutput(
974
+ down_block_res_samples=brushnet_down_block_res_samples,
975
+ mid_block_res_sample=brushnet_mid_block_res_sample,
976
+ up_block_res_samples=brushnet_up_block_res_samples,
977
+ )
978
+
979
+
980
+ def zero_module(module):
981
+ for p in module.parameters():
982
+ nn.init.zeros_(p)
983
+ return module
MagicQuill/brushnet/brushnet_xl.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280
18
+ ],
19
+ "brushnet_conditioning_channel_order": "rgb",
20
+ "class_embed_type": null,
21
+ "conditioning_channels": 5,
22
+ "conditioning_embedding_out_channels": [
23
+ 16,
24
+ 32,
25
+ 96,
26
+ 256
27
+ ],
28
+ "cross_attention_dim": 2048,
29
+ "down_block_types": [
30
+ "DownBlock2D",
31
+ "DownBlock2D",
32
+ "DownBlock2D"
33
+ ],
34
+ "downsample_padding": 1,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "global_pool_conditions": false,
40
+ "in_channels": 4,
41
+ "layers_per_block": 2,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "MidBlock2D",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": null,
47
+ "num_class_embeds": null,
48
+ "only_cross_attention": false,
49
+ "projection_class_embeddings_input_dim": 2816,
50
+ "resnet_time_scale_shift": "default",
51
+ "transformer_layers_per_block": [
52
+ 1,
53
+ 2,
54
+ 10
55
+ ],
56
+ "up_block_types": [
57
+ "UpBlock2D",
58
+ "UpBlock2D",
59
+ "UpBlock2D"
60
+ ],
61
+ "upcast_attention": null,
62
+ "use_linear_projection": true
63
+ }
MagicQuill/brushnet/powerpaint.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "block_out_channels": [
10
+ 320,
11
+ 640,
12
+ 1280,
13
+ 1280
14
+ ],
15
+ "brushnet_conditioning_channel_order": "rgb",
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 5,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "cross_attention_dim": 768,
25
+ "down_block_types": [
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "DownBlock2D"
30
+ ],
31
+ "downsample_padding": 1,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "global_pool_conditions": false,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "up_block_types": [
50
+ "UpBlock2D",
51
+ "CrossAttnUpBlock2D",
52
+ "CrossAttnUpBlock2D",
53
+ "CrossAttnUpBlock2D"
54
+ ],
55
+ "upcast_attention": false,
56
+ "use_linear_projection": false
57
+ }
MagicQuill/brushnet/powerpaint_utils.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import CLIPTokenizer
7
+ from typing import Any, List, Optional, Union
8
+
9
+ class TokenizerWrapper:
10
+ """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
11
+ currently. This wrapper is modified from https://github.com/huggingface/dif
12
+ fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
13
+ py#L358 # noqa.
14
+
15
+ Args:
16
+ from_pretrained (Union[str, os.PathLike], optional): The *model id*
17
+ of a pretrained model or a path to a *directory* containing
18
+ model weights and config. Defaults to None.
19
+ from_config (Union[str, os.PathLike], optional): The *model id*
20
+ of a pretrained model or a path to a *directory* containing
21
+ model weights and config. Defaults to None.
22
+
23
+ *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
24
+ will be passed to `from_pretrained` function. Otherwise, *args
25
+ and **kwargs will be used to initialize the model by
26
+ `self._module_cls(*args, **kwargs)`.
27
+ """
28
+
29
+ def __init__(self, tokenizer: CLIPTokenizer):
30
+ self.wrapped = tokenizer
31
+ self.token_map = {}
32
+
33
+ def __getattr__(self, name: str) -> Any:
34
+ if name in self.__dict__:
35
+ return getattr(self, name)
36
+ #if name == "wrapped":
37
+ # return getattr(self, 'wrapped')#super().__getattr__("wrapped")
38
+
39
+ try:
40
+ return getattr(self.wrapped, name)
41
+ except AttributeError:
42
+ raise AttributeError(
43
+ "'name' cannot be found in both "
44
+ f"'{self.__class__.__name__}' and "
45
+ f"'{self.__class__.__name__}.tokenizer'."
46
+ )
47
+
48
+ def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
49
+ """Attempt to add tokens to the tokenizer.
50
+
51
+ Args:
52
+ tokens (Union[str, List[str]]): The tokens to be added.
53
+ """
54
+ num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
55
+ assert num_added_tokens != 0, (
56
+ f"The tokenizer already contains the token {tokens}. Please pass "
57
+ "a different `placeholder_token` that is not already in the "
58
+ "tokenizer."
59
+ )
60
+
61
+ def get_token_info(self, token: str) -> dict:
62
+ """Get the information of a token, including its start and end index in
63
+ the current tokenizer.
64
+
65
+ Args:
66
+ token (str): The token to be queried.
67
+
68
+ Returns:
69
+ dict: The information of the token, including its start and end
70
+ index in current tokenizer.
71
+ """
72
+ token_ids = self.__call__(token).input_ids
73
+ start, end = token_ids[1], token_ids[-2] + 1
74
+ return {"name": token, "start": start, "end": end}
75
+
76
+ def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
77
+ """Add placeholder tokens to the tokenizer.
78
+
79
+ Args:
80
+ placeholder_token (str): The placeholder token to be added.
81
+ num_vec_per_token (int, optional): The number of vectors of
82
+ the added placeholder token.
83
+ *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
84
+ """
85
+ output = []
86
+ if num_vec_per_token == 1:
87
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
88
+ output.append(placeholder_token)
89
+ else:
90
+ output = []
91
+ for i in range(num_vec_per_token):
92
+ ith_token = placeholder_token + f"_{i}"
93
+ self.try_adding_tokens(ith_token, *args, **kwargs)
94
+ output.append(ith_token)
95
+
96
+ for token in self.token_map:
97
+ if token in placeholder_token:
98
+ raise ValueError(
99
+ f"The tokenizer already has placeholder token {token} "
100
+ f"that can get confused with {placeholder_token} "
101
+ "keep placeholder tokens independent"
102
+ )
103
+ self.token_map[placeholder_token] = output
104
+
105
+ def replace_placeholder_tokens_in_text(
106
+ self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
107
+ ) -> Union[str, List[str]]:
108
+ """Replace the keywords in text with placeholder tokens. This function
109
+ will be called in `self.__call__` and `self.encode`.
110
+
111
+ Args:
112
+ text (Union[str, List[str]]): The text to be processed.
113
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
114
+ Defaults to False.
115
+ prop_tokens_to_load (float, optional): The proportion of tokens to
116
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
117
+
118
+ Returns:
119
+ Union[str, List[str]]: The processed text.
120
+ """
121
+ if isinstance(text, list):
122
+ output = []
123
+ for i in range(len(text)):
124
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
125
+ return output
126
+
127
+ for placeholder_token in self.token_map:
128
+ if placeholder_token in text:
129
+ tokens = self.token_map[placeholder_token]
130
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
131
+ if vector_shuffle:
132
+ tokens = copy.copy(tokens)
133
+ random.shuffle(tokens)
134
+ text = text.replace(placeholder_token, " ".join(tokens))
135
+ return text
136
+
137
+ def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
138
+ """Replace the placeholder tokens in text with the original keywords.
139
+ This function will be called in `self.decode`.
140
+
141
+ Args:
142
+ text (Union[str, List[str]]): The text to be processed.
143
+
144
+ Returns:
145
+ Union[str, List[str]]: The processed text.
146
+ """
147
+ if isinstance(text, list):
148
+ output = []
149
+ for i in range(len(text)):
150
+ output.append(self.replace_text_with_placeholder_tokens(text[i]))
151
+ return output
152
+
153
+ for placeholder_token, tokens in self.token_map.items():
154
+ merged_tokens = " ".join(tokens)
155
+ if merged_tokens in text:
156
+ text = text.replace(merged_tokens, placeholder_token)
157
+ return text
158
+
159
+ def __call__(
160
+ self,
161
+ text: Union[str, List[str]],
162
+ *args,
163
+ vector_shuffle: bool = False,
164
+ prop_tokens_to_load: float = 1.0,
165
+ **kwargs,
166
+ ):
167
+ """The call function of the wrapper.
168
+
169
+ Args:
170
+ text (Union[str, List[str]]): The text to be tokenized.
171
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
172
+ Defaults to False.
173
+ prop_tokens_to_load (float, optional): The proportion of tokens to
174
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
175
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
176
+ """
177
+ replaced_text = self.replace_placeholder_tokens_in_text(
178
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
179
+ )
180
+
181
+ return self.wrapped.__call__(replaced_text, *args, **kwargs)
182
+
183
+ def encode(self, text: Union[str, List[str]], *args, **kwargs):
184
+ """Encode the passed text to token index.
185
+
186
+ Args:
187
+ text (Union[str, List[str]]): The text to be encode.
188
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
189
+ """
190
+ replaced_text = self.replace_placeholder_tokens_in_text(text)
191
+ return self.wrapped(replaced_text, *args, **kwargs)
192
+
193
+ def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
194
+ """Decode the token index to text.
195
+
196
+ Args:
197
+ token_ids: The token index to be decoded.
198
+ return_raw: Whether keep the placeholder token in the text.
199
+ Defaults to False.
200
+ *args, **kwargs: The arguments for `self.wrapped.decode`.
201
+
202
+ Returns:
203
+ Union[str, List[str]]: The decoded text.
204
+ """
205
+ text = self.wrapped.decode(token_ids, *args, **kwargs)
206
+ if return_raw:
207
+ return text
208
+ replaced_text = self.replace_text_with_placeholder_tokens(text)
209
+ return replaced_text
210
+
211
+ def __repr__(self):
212
+ """The representation of the wrapper."""
213
+ s = super().__repr__()
214
+ prefix = f"Wrapped Module Class: {self._module_cls}\n"
215
+ prefix += f"Wrapped Module Name: {self._module_name}\n"
216
+ if self._from_pretrained:
217
+ prefix += f"From Pretrained: {self._from_pretrained}\n"
218
+ s = prefix + s
219
+ return s
220
+
221
+
222
+ class EmbeddingLayerWithFixes(nn.Module):
223
+ """The revised embedding layer to support external embeddings. This design
224
+ of this class is inspired by https://github.com/AUTOMATIC1111/stable-
225
+ diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
226
+ jack.py#L224 # noqa.
227
+
228
+ Args:
229
+ wrapped (nn.Emebdding): The embedding layer to be wrapped.
230
+ external_embeddings (Union[dict, List[dict]], optional): The external
231
+ embeddings added to this layer. Defaults to None.
232
+ """
233
+
234
+ def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
235
+ super().__init__()
236
+ self.wrapped = wrapped
237
+ self.num_embeddings = wrapped.weight.shape[0]
238
+
239
+ self.external_embeddings = []
240
+ if external_embeddings:
241
+ self.add_embeddings(external_embeddings)
242
+
243
+ self.trainable_embeddings = nn.ParameterDict()
244
+
245
+ @property
246
+ def weight(self):
247
+ """Get the weight of wrapped embedding layer."""
248
+ return self.wrapped.weight
249
+
250
+ def check_duplicate_names(self, embeddings: List[dict]):
251
+ """Check whether duplicate names exist in list of 'external
252
+ embeddings'.
253
+
254
+ Args:
255
+ embeddings (List[dict]): A list of embedding to be check.
256
+ """
257
+ names = [emb["name"] for emb in embeddings]
258
+ assert len(names) == len(set(names)), (
259
+ "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
260
+ )
261
+
262
+ def check_ids_overlap(self, embeddings):
263
+ """Check whether overlap exist in token ids of 'external_embeddings'.
264
+
265
+ Args:
266
+ embeddings (List[dict]): A list of embedding to be check.
267
+ """
268
+ ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
269
+ ids_range.sort() # sort by 'start'
270
+ # check if 'end' has overlapping
271
+ for idx in range(len(ids_range) - 1):
272
+ name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
273
+ assert ids_range[idx][1] <= ids_range[idx + 1][0], (
274
+ f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
275
+ )
276
+
277
+ def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
278
+ """Add external embeddings to this layer.
279
+
280
+ Use case:
281
+
282
+ >>> 1. Add token to tokenizer and get the token id.
283
+ >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
284
+ >>> # 'how much' in kiswahili
285
+ >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
286
+ >>>
287
+ >>> 2. Add external embeddings to the model.
288
+ >>> new_embedding = {
289
+ >>> 'name': 'ngapi', # 'how much' in kiswahili
290
+ >>> 'embedding': torch.ones(1, 15) * 4,
291
+ >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
292
+ >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
293
+ >>> 'trainable': False # if True, will registry as a parameter
294
+ >>> }
295
+ >>> embedding_layer = nn.Embedding(10, 15)
296
+ >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
297
+ >>> embedding_layer_wrapper.add_embeddings(new_embedding)
298
+ >>>
299
+ >>> 3. Forward tokenizer and embedding layer!
300
+ >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
301
+ >>> input_ids = tokenizer(
302
+ >>> input_text, padding='max_length', truncation=True,
303
+ >>> return_tensors='pt')['input_ids']
304
+ >>> out_feat = embedding_layer_wrapper(input_ids)
305
+ >>>
306
+ >>> 4. Let's validate the result!
307
+ >>> assert (out_feat[0, 3: 7] == 2.3).all()
308
+ >>> assert (out_feat[2, 5: 9] == 2.3).all()
309
+
310
+ Args:
311
+ embeddings (Union[dict, list[dict]]): The external embeddings to
312
+ be added. Each dict must contain the following 4 fields: 'name'
313
+ (the name of this embedding), 'embedding' (the embedding
314
+ tensor), 'start' (the start token id of this embedding), 'end'
315
+ (the end token id of this embedding). For example:
316
+ `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
317
+ """
318
+ if isinstance(embeddings, dict):
319
+ embeddings = [embeddings]
320
+
321
+ self.external_embeddings += embeddings
322
+ self.check_duplicate_names(self.external_embeddings)
323
+ self.check_ids_overlap(self.external_embeddings)
324
+
325
+ # set for trainable
326
+ added_trainable_emb_info = []
327
+ for embedding in embeddings:
328
+ trainable = embedding.get("trainable", False)
329
+ if trainable:
330
+ name = embedding["name"]
331
+ embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
332
+ self.trainable_embeddings[name] = embedding["embedding"]
333
+ added_trainable_emb_info.append(name)
334
+
335
+ added_emb_info = [emb["name"] for emb in embeddings]
336
+ added_emb_info = ", ".join(added_emb_info)
337
+ print(f"Successfully add external embeddings: {added_emb_info}.", "current")
338
+
339
+ if added_trainable_emb_info:
340
+ added_trainable_emb_info = ", ".join(added_trainable_emb_info)
341
+ print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
342
+
343
+ def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
344
+ """Replace external input ids to 0.
345
+
346
+ Args:
347
+ input_ids (torch.Tensor): The input ids to be replaced.
348
+
349
+ Returns:
350
+ torch.Tensor: The replaced input ids.
351
+ """
352
+ input_ids_fwd = input_ids.clone()
353
+ input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
354
+ return input_ids_fwd
355
+
356
+ def replace_embeddings(
357
+ self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
358
+ ) -> torch.Tensor:
359
+ """Replace external embedding to the embedding layer. Noted that, in
360
+ this function we use `torch.cat` to avoid inplace modification.
361
+
362
+ Args:
363
+ input_ids (torch.Tensor): The original token ids. Shape like
364
+ [LENGTH, ].
365
+ embedding (torch.Tensor): The embedding of token ids after
366
+ `replace_input_ids` function.
367
+ external_embedding (dict): The external embedding to be replaced.
368
+
369
+ Returns:
370
+ torch.Tensor: The replaced embedding.
371
+ """
372
+ new_embedding = []
373
+
374
+ name = external_embedding["name"]
375
+ start = external_embedding["start"]
376
+ end = external_embedding["end"]
377
+ target_ids_to_replace = [i for i in range(start, end)]
378
+ ext_emb = external_embedding["embedding"]
379
+
380
+ # do not need to replace
381
+ if not (input_ids == start).any():
382
+ return embedding
383
+
384
+ # start replace
385
+ s_idx, e_idx = 0, 0
386
+ while e_idx < len(input_ids):
387
+ if input_ids[e_idx] == start:
388
+ if e_idx != 0:
389
+ # add embedding do not need to replace
390
+ new_embedding.append(embedding[s_idx:e_idx])
391
+
392
+ # check if the next embedding need to replace is valid
393
+ actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
394
+ assert actually_ids_to_replace == target_ids_to_replace, (
395
+ f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
396
+ f"Expect '{target_ids_to_replace}' for embedding "
397
+ f"'{name}' but found '{actually_ids_to_replace}'."
398
+ )
399
+
400
+ new_embedding.append(ext_emb)
401
+
402
+ s_idx = e_idx + end - start
403
+ e_idx = s_idx + 1
404
+ else:
405
+ e_idx += 1
406
+
407
+ if e_idx == len(input_ids):
408
+ new_embedding.append(embedding[s_idx:e_idx])
409
+
410
+ return torch.cat(new_embedding, dim=0)
411
+
412
+ def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
413
+ """The forward function.
414
+
415
+ Args:
416
+ input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
417
+ [LENGTH, ].
418
+ external_embeddings (Optional[List[dict]]): The external
419
+ embeddings. If not passed, only `self.external_embeddings`
420
+ will be used. Defaults to None.
421
+
422
+ input_ids: shape like [bz, LENGTH] or [LENGTH].
423
+ """
424
+ assert input_ids.ndim in [1, 2]
425
+ if input_ids.ndim == 1:
426
+ input_ids = input_ids.unsqueeze(0)
427
+
428
+ if external_embeddings is None and not self.external_embeddings:
429
+ return self.wrapped(input_ids)
430
+
431
+ input_ids_fwd = self.replace_input_ids(input_ids)
432
+ inputs_embeds = self.wrapped(input_ids_fwd)
433
+
434
+ vecs = []
435
+
436
+ if external_embeddings is None:
437
+ external_embeddings = []
438
+ elif isinstance(external_embeddings, dict):
439
+ external_embeddings = [external_embeddings]
440
+ embeddings = self.external_embeddings + external_embeddings
441
+
442
+ for input_id, embedding in zip(input_ids, inputs_embeds):
443
+ new_embedding = embedding
444
+ for external_embedding in embeddings:
445
+ new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
446
+ vecs.append(new_embedding)
447
+
448
+ return torch.stack(vecs)
449
+
450
+
451
+
452
+ def add_tokens(
453
+ tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
454
+ ):
455
+ """Add token for training.
456
+
457
+ # TODO: support add tokens as dict, then we can load pretrained tokens.
458
+ """
459
+ if initialize_tokens is not None:
460
+ assert len(initialize_tokens) == len(
461
+ placeholder_tokens
462
+ ), "placeholder_token should be the same length as initialize_token"
463
+ for ii in range(len(placeholder_tokens)):
464
+ tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
465
+
466
+ # text_encoder.set_embedding_layer()
467
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
468
+ text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
469
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
470
+
471
+ assert embedding_layer is not None, (
472
+ "Do not support get embedding layer for current text encoder. " "Please check your configuration."
473
+ )
474
+ initialize_embedding = []
475
+ if initialize_tokens is not None:
476
+ for ii in range(len(placeholder_tokens)):
477
+ init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
478
+ temp_embedding = embedding_layer.weight[init_id]
479
+ initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
480
+ else:
481
+ for ii in range(len(placeholder_tokens)):
482
+ init_id = tokenizer("a").input_ids[1]
483
+ temp_embedding = embedding_layer.weight[init_id]
484
+ len_emb = temp_embedding.shape[0]
485
+ init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
486
+ initialize_embedding.append(init_weight)
487
+
488
+ # initialize_embedding = torch.cat(initialize_embedding,dim=0)
489
+
490
+ token_info_all = []
491
+ for ii in range(len(placeholder_tokens)):
492
+ token_info = tokenizer.get_token_info(placeholder_tokens[ii])
493
+ token_info["embedding"] = initialize_embedding[ii]
494
+ token_info["trainable"] = True
495
+ token_info_all.append(token_info)
496
+ embedding_layer.add_embeddings(token_info_all)
MagicQuill/brushnet/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
MagicQuill/brushnet/unet_2d_condition.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ Attention,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from diffusers.models.embeddings import (
34
+ GaussianFourierProjection,
35
+ GLIGENTextBoundingboxProjection,
36
+ ImageHintTimeEmbedding,
37
+ ImageProjection,
38
+ ImageTimeEmbedding,
39
+ TextImageProjection,
40
+ TextImageTimeEmbedding,
41
+ TextTimeEmbedding,
42
+ TimestepEmbedding,
43
+ Timesteps,
44
+ )
45
+ from diffusers.models.modeling_utils import ModelMixin
46
+ from .unet_2d_blocks import (
47
+ get_down_block,
48
+ get_mid_block,
49
+ get_up_block,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ @dataclass
57
+ class UNet2DConditionOutput(BaseOutput):
58
+ """
59
+ The output of [`UNet2DConditionModel`].
60
+
61
+ Args:
62
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
+ """
65
+
66
+ sample: torch.FloatTensor = None
67
+
68
+
69
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
70
+ r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
+
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
+
77
+ Parameters:
78
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
+ Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
84
+ Whether to flip the sin to cos in the time embedding.
85
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
+ The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
+ The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
+ The tuple of output channels for each block.
98
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
103
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
104
+ If `None`, normalization and activation layers is skipped in post-processing.
105
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107
+ The dimension of the cross attention features.
108
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
113
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
114
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
115
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
116
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
+ encoder_hid_dim (`int`, *optional*, defaults to None):
118
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
119
+ dimension to `cross_attention_dim`.
120
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
121
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
122
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
123
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
+ num_attention_heads (`int`, *optional*):
125
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
126
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
127
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
128
+ class_embed_type (`str`, *optional*, defaults to `None`):
129
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
130
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
131
+ addition_embed_type (`str`, *optional*, defaults to `None`):
132
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
133
+ "text". "text" will use the `TextTimeEmbedding` layer.
134
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
135
+ Dimension for the timestep embeddings.
136
+ num_class_embeds (`int`, *optional*, defaults to `None`):
137
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
138
+ class conditioning with `class_embed_type` equal to `None`.
139
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
140
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
141
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
142
+ An optional override for the dimension of the projected time embedding.
143
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
144
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
145
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
146
+ timestep_post_act (`str`, *optional*, defaults to `None`):
147
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
148
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
149
+ The dimension of `cond_proj` layer in the timestep embedding.
150
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
151
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
152
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
153
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
154
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
155
+ embeddings with the class embeddings.
156
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
157
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
158
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
159
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
160
+ otherwise.
161
+ """
162
+
163
+ _supports_gradient_checkpointing = True
164
+
165
+ @register_to_config
166
+ def __init__(
167
+ self,
168
+ sample_size: Optional[int] = None,
169
+ in_channels: int = 4,
170
+ out_channels: int = 4,
171
+ center_input_sample: bool = False,
172
+ flip_sin_to_cos: bool = True,
173
+ freq_shift: int = 0,
174
+ down_block_types: Tuple[str] = (
175
+ "CrossAttnDownBlock2D",
176
+ "CrossAttnDownBlock2D",
177
+ "CrossAttnDownBlock2D",
178
+ "DownBlock2D",
179
+ ),
180
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
181
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: float = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads: int = 64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ self._check_config(
241
+ down_block_types=down_block_types,
242
+ up_block_types=up_block_types,
243
+ only_cross_attention=only_cross_attention,
244
+ block_out_channels=block_out_channels,
245
+ layers_per_block=layers_per_block,
246
+ cross_attention_dim=cross_attention_dim,
247
+ transformer_layers_per_block=transformer_layers_per_block,
248
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
249
+ attention_head_dim=attention_head_dim,
250
+ num_attention_heads=num_attention_heads,
251
+ )
252
+
253
+ # input
254
+ conv_in_padding = (conv_in_kernel - 1) // 2
255
+ self.conv_in = nn.Conv2d(
256
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
+ )
258
+
259
+ # time
260
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
261
+ time_embedding_type,
262
+ block_out_channels=block_out_channels,
263
+ flip_sin_to_cos=flip_sin_to_cos,
264
+ freq_shift=freq_shift,
265
+ time_embedding_dim=time_embedding_dim,
266
+ )
267
+
268
+ self.time_embedding = TimestepEmbedding(
269
+ timestep_input_dim,
270
+ time_embed_dim,
271
+ act_fn=act_fn,
272
+ post_act_fn=timestep_post_act,
273
+ cond_proj_dim=time_cond_proj_dim,
274
+ )
275
+
276
+ self._set_encoder_hid_proj(
277
+ encoder_hid_dim_type,
278
+ cross_attention_dim=cross_attention_dim,
279
+ encoder_hid_dim=encoder_hid_dim,
280
+ )
281
+
282
+ # class embedding
283
+ self._set_class_embedding(
284
+ class_embed_type,
285
+ act_fn=act_fn,
286
+ num_class_embeds=num_class_embeds,
287
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
288
+ time_embed_dim=time_embed_dim,
289
+ timestep_input_dim=timestep_input_dim,
290
+ )
291
+
292
+ self._set_add_embedding(
293
+ addition_embed_type,
294
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
295
+ addition_time_embed_dim=addition_time_embed_dim,
296
+ cross_attention_dim=cross_attention_dim,
297
+ encoder_hid_dim=encoder_hid_dim,
298
+ flip_sin_to_cos=flip_sin_to_cos,
299
+ freq_shift=freq_shift,
300
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
+ time_embed_dim=time_embed_dim,
302
+ )
303
+
304
+ if time_embedding_act_fn is None:
305
+ self.time_embed_act = None
306
+ else:
307
+ self.time_embed_act = get_activation(time_embedding_act_fn)
308
+
309
+ self.down_blocks = nn.ModuleList([])
310
+ self.up_blocks = nn.ModuleList([])
311
+
312
+ if isinstance(only_cross_attention, bool):
313
+ if mid_block_only_cross_attention is None:
314
+ mid_block_only_cross_attention = only_cross_attention
315
+
316
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
317
+
318
+ if mid_block_only_cross_attention is None:
319
+ mid_block_only_cross_attention = False
320
+
321
+ if isinstance(num_attention_heads, int):
322
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
323
+
324
+ if isinstance(attention_head_dim, int):
325
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
326
+
327
+ if isinstance(cross_attention_dim, int):
328
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
329
+
330
+ if isinstance(layers_per_block, int):
331
+ layers_per_block = [layers_per_block] * len(down_block_types)
332
+
333
+ if isinstance(transformer_layers_per_block, int):
334
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
335
+
336
+ if class_embeddings_concat:
337
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
338
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
339
+ # regular time embeddings
340
+ blocks_time_embed_dim = time_embed_dim * 2
341
+ else:
342
+ blocks_time_embed_dim = time_embed_dim
343
+
344
+ # down
345
+ output_channel = block_out_channels[0]
346
+ for i, down_block_type in enumerate(down_block_types):
347
+ input_channel = output_channel
348
+ output_channel = block_out_channels[i]
349
+ is_final_block = i == len(block_out_channels) - 1
350
+
351
+ down_block = get_down_block(
352
+ down_block_type,
353
+ num_layers=layers_per_block[i],
354
+ transformer_layers_per_block=transformer_layers_per_block[i],
355
+ in_channels=input_channel,
356
+ out_channels=output_channel,
357
+ temb_channels=blocks_time_embed_dim,
358
+ add_downsample=not is_final_block,
359
+ resnet_eps=norm_eps,
360
+ resnet_act_fn=act_fn,
361
+ resnet_groups=norm_num_groups,
362
+ cross_attention_dim=cross_attention_dim[i],
363
+ num_attention_heads=num_attention_heads[i],
364
+ downsample_padding=downsample_padding,
365
+ dual_cross_attention=dual_cross_attention,
366
+ use_linear_projection=use_linear_projection,
367
+ only_cross_attention=only_cross_attention[i],
368
+ upcast_attention=upcast_attention,
369
+ resnet_time_scale_shift=resnet_time_scale_shift,
370
+ attention_type=attention_type,
371
+ resnet_skip_time_act=resnet_skip_time_act,
372
+ resnet_out_scale_factor=resnet_out_scale_factor,
373
+ cross_attention_norm=cross_attention_norm,
374
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
375
+ dropout=dropout,
376
+ )
377
+ self.down_blocks.append(down_block)
378
+
379
+ # mid
380
+ self.mid_block = get_mid_block(
381
+ mid_block_type,
382
+ temb_channels=blocks_time_embed_dim,
383
+ in_channels=block_out_channels[-1],
384
+ resnet_eps=norm_eps,
385
+ resnet_act_fn=act_fn,
386
+ resnet_groups=norm_num_groups,
387
+ output_scale_factor=mid_block_scale_factor,
388
+ transformer_layers_per_block=transformer_layers_per_block[-1],
389
+ num_attention_heads=num_attention_heads[-1],
390
+ cross_attention_dim=cross_attention_dim[-1],
391
+ dual_cross_attention=dual_cross_attention,
392
+ use_linear_projection=use_linear_projection,
393
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
394
+ upcast_attention=upcast_attention,
395
+ resnet_time_scale_shift=resnet_time_scale_shift,
396
+ attention_type=attention_type,
397
+ resnet_skip_time_act=resnet_skip_time_act,
398
+ cross_attention_norm=cross_attention_norm,
399
+ attention_head_dim=attention_head_dim[-1],
400
+ dropout=dropout,
401
+ )
402
+
403
+ # count how many layers upsample the images
404
+ self.num_upsamplers = 0
405
+
406
+ # up
407
+ reversed_block_out_channels = list(reversed(block_out_channels))
408
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
409
+ reversed_layers_per_block = list(reversed(layers_per_block))
410
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
411
+ reversed_transformer_layers_per_block = (
412
+ list(reversed(transformer_layers_per_block))
413
+ if reverse_transformer_layers_per_block is None
414
+ else reverse_transformer_layers_per_block
415
+ )
416
+ only_cross_attention = list(reversed(only_cross_attention))
417
+
418
+ output_channel = reversed_block_out_channels[0]
419
+ for i, up_block_type in enumerate(up_block_types):
420
+ is_final_block = i == len(block_out_channels) - 1
421
+
422
+ prev_output_channel = output_channel
423
+ output_channel = reversed_block_out_channels[i]
424
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
425
+
426
+ # add upsample block for all BUT final layer
427
+ if not is_final_block:
428
+ add_upsample = True
429
+ self.num_upsamplers += 1
430
+ else:
431
+ add_upsample = False
432
+
433
+ up_block = get_up_block(
434
+ up_block_type,
435
+ num_layers=reversed_layers_per_block[i] + 1,
436
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
437
+ in_channels=input_channel,
438
+ out_channels=output_channel,
439
+ prev_output_channel=prev_output_channel,
440
+ temb_channels=blocks_time_embed_dim,
441
+ add_upsample=add_upsample,
442
+ resnet_eps=norm_eps,
443
+ resnet_act_fn=act_fn,
444
+ resolution_idx=i,
445
+ resnet_groups=norm_num_groups,
446
+ cross_attention_dim=reversed_cross_attention_dim[i],
447
+ num_attention_heads=reversed_num_attention_heads[i],
448
+ dual_cross_attention=dual_cross_attention,
449
+ use_linear_projection=use_linear_projection,
450
+ only_cross_attention=only_cross_attention[i],
451
+ upcast_attention=upcast_attention,
452
+ resnet_time_scale_shift=resnet_time_scale_shift,
453
+ attention_type=attention_type,
454
+ resnet_skip_time_act=resnet_skip_time_act,
455
+ resnet_out_scale_factor=resnet_out_scale_factor,
456
+ cross_attention_norm=cross_attention_norm,
457
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
+ dropout=dropout,
459
+ )
460
+ self.up_blocks.append(up_block)
461
+ prev_output_channel = output_channel
462
+
463
+ # out
464
+ if norm_num_groups is not None:
465
+ self.conv_norm_out = nn.GroupNorm(
466
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
467
+ )
468
+
469
+ self.conv_act = get_activation(act_fn)
470
+
471
+ else:
472
+ self.conv_norm_out = None
473
+ self.conv_act = None
474
+
475
+ conv_out_padding = (conv_out_kernel - 1) // 2
476
+ self.conv_out = nn.Conv2d(
477
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
478
+ )
479
+
480
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
481
+
482
+ def _check_config(
483
+ self,
484
+ down_block_types: Tuple[str],
485
+ up_block_types: Tuple[str],
486
+ only_cross_attention: Union[bool, Tuple[bool]],
487
+ block_out_channels: Tuple[int],
488
+ layers_per_block: Union[int, Tuple[int]],
489
+ cross_attention_dim: Union[int, Tuple[int]],
490
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
491
+ reverse_transformer_layers_per_block: bool,
492
+ attention_head_dim: int,
493
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
494
+ ):
495
+ if len(down_block_types) != len(up_block_types):
496
+ raise ValueError(
497
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
498
+ )
499
+
500
+ if len(block_out_channels) != len(down_block_types):
501
+ raise ValueError(
502
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
503
+ )
504
+
505
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
506
+ raise ValueError(
507
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
508
+ )
509
+
510
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
511
+ raise ValueError(
512
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
513
+ )
514
+
515
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
516
+ raise ValueError(
517
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
518
+ )
519
+
520
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
521
+ raise ValueError(
522
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
523
+ )
524
+
525
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
526
+ raise ValueError(
527
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
528
+ )
529
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
530
+ for layer_number_per_block in transformer_layers_per_block:
531
+ if isinstance(layer_number_per_block, list):
532
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
533
+
534
+ def _set_time_proj(
535
+ self,
536
+ time_embedding_type: str,
537
+ block_out_channels: int,
538
+ flip_sin_to_cos: bool,
539
+ freq_shift: float,
540
+ time_embedding_dim: int,
541
+ ) -> Tuple[int, int]:
542
+ if time_embedding_type == "fourier":
543
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
544
+ if time_embed_dim % 2 != 0:
545
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
546
+ self.time_proj = GaussianFourierProjection(
547
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
548
+ )
549
+ timestep_input_dim = time_embed_dim
550
+ elif time_embedding_type == "positional":
551
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
552
+
553
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
554
+ timestep_input_dim = block_out_channels[0]
555
+ else:
556
+ raise ValueError(
557
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
558
+ )
559
+
560
+ return time_embed_dim, timestep_input_dim
561
+
562
+ def _set_encoder_hid_proj(
563
+ self,
564
+ encoder_hid_dim_type: Optional[str],
565
+ cross_attention_dim: Union[int, Tuple[int]],
566
+ encoder_hid_dim: Optional[int],
567
+ ):
568
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
569
+ encoder_hid_dim_type = "text_proj"
570
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
571
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
572
+
573
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
574
+ raise ValueError(
575
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
576
+ )
577
+
578
+ if encoder_hid_dim_type == "text_proj":
579
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
580
+ elif encoder_hid_dim_type == "text_image_proj":
581
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
582
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
583
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
584
+ self.encoder_hid_proj = TextImageProjection(
585
+ text_embed_dim=encoder_hid_dim,
586
+ image_embed_dim=cross_attention_dim,
587
+ cross_attention_dim=cross_attention_dim,
588
+ )
589
+ elif encoder_hid_dim_type == "image_proj":
590
+ # Kandinsky 2.2
591
+ self.encoder_hid_proj = ImageProjection(
592
+ image_embed_dim=encoder_hid_dim,
593
+ cross_attention_dim=cross_attention_dim,
594
+ )
595
+ elif encoder_hid_dim_type is not None:
596
+ raise ValueError(
597
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
598
+ )
599
+ else:
600
+ self.encoder_hid_proj = None
601
+
602
+ def _set_class_embedding(
603
+ self,
604
+ class_embed_type: Optional[str],
605
+ act_fn: str,
606
+ num_class_embeds: Optional[int],
607
+ projection_class_embeddings_input_dim: Optional[int],
608
+ time_embed_dim: int,
609
+ timestep_input_dim: int,
610
+ ):
611
+ if class_embed_type is None and num_class_embeds is not None:
612
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
613
+ elif class_embed_type == "timestep":
614
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
615
+ elif class_embed_type == "identity":
616
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
617
+ elif class_embed_type == "projection":
618
+ if projection_class_embeddings_input_dim is None:
619
+ raise ValueError(
620
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
621
+ )
622
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
623
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
624
+ # 2. it projects from an arbitrary input dimension.
625
+ #
626
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
627
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
628
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
629
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
630
+ elif class_embed_type == "simple_projection":
631
+ if projection_class_embeddings_input_dim is None:
632
+ raise ValueError(
633
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
634
+ )
635
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
636
+ else:
637
+ self.class_embedding = None
638
+
639
+ def _set_add_embedding(
640
+ self,
641
+ addition_embed_type: str,
642
+ addition_embed_type_num_heads: int,
643
+ addition_time_embed_dim: Optional[int],
644
+ flip_sin_to_cos: bool,
645
+ freq_shift: float,
646
+ cross_attention_dim: Optional[int],
647
+ encoder_hid_dim: Optional[int],
648
+ projection_class_embeddings_input_dim: Optional[int],
649
+ time_embed_dim: int,
650
+ ):
651
+ if addition_embed_type == "text":
652
+ if encoder_hid_dim is not None:
653
+ text_time_embedding_from_dim = encoder_hid_dim
654
+ else:
655
+ text_time_embedding_from_dim = cross_attention_dim
656
+
657
+ self.add_embedding = TextTimeEmbedding(
658
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
659
+ )
660
+ elif addition_embed_type == "text_image":
661
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
662
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
663
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
664
+ self.add_embedding = TextImageTimeEmbedding(
665
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
666
+ )
667
+ elif addition_embed_type == "text_time":
668
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
669
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
670
+ elif addition_embed_type == "image":
671
+ # Kandinsky 2.2
672
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
673
+ elif addition_embed_type == "image_hint":
674
+ # Kandinsky 2.2 ControlNet
675
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
676
+ elif addition_embed_type is not None:
677
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
678
+
679
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
680
+ if attention_type in ["gated", "gated-text-image"]:
681
+ positive_len = 768
682
+ if isinstance(cross_attention_dim, int):
683
+ positive_len = cross_attention_dim
684
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
685
+ positive_len = cross_attention_dim[0]
686
+
687
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
688
+ self.position_net = GLIGENTextBoundingboxProjection(
689
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
690
+ )
691
+
692
+ @property
693
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
694
+ r"""
695
+ Returns:
696
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
697
+ indexed by its weight name.
698
+ """
699
+ # set recursively
700
+ processors = {}
701
+
702
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
703
+ if hasattr(module, "get_processor"):
704
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
705
+
706
+ for sub_name, child in module.named_children():
707
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
708
+
709
+ return processors
710
+
711
+ for name, module in self.named_children():
712
+ fn_recursive_add_processors(name, module, processors)
713
+
714
+ return processors
715
+
716
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
717
+ r"""
718
+ Sets the attention processor to use to compute attention.
719
+
720
+ Parameters:
721
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
722
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
723
+ for **all** `Attention` layers.
724
+
725
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
726
+ processor. This is strongly recommended when setting trainable attention processors.
727
+
728
+ """
729
+ count = len(self.attn_processors.keys())
730
+
731
+ if isinstance(processor, dict) and len(processor) != count:
732
+ raise ValueError(
733
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
734
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
735
+ )
736
+
737
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
738
+ if hasattr(module, "set_processor"):
739
+ if not isinstance(processor, dict):
740
+ module.set_processor(processor)
741
+ else:
742
+ module.set_processor(processor.pop(f"{name}.processor"))
743
+
744
+ for sub_name, child in module.named_children():
745
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
746
+
747
+ for name, module in self.named_children():
748
+ fn_recursive_attn_processor(name, module, processor)
749
+
750
+ def set_default_attn_processor(self):
751
+ """
752
+ Disables custom attention processors and sets the default attention implementation.
753
+ """
754
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
755
+ processor = AttnAddedKVProcessor()
756
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
757
+ processor = AttnProcessor()
758
+ else:
759
+ raise ValueError(
760
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
761
+ )
762
+
763
+ self.set_attn_processor(processor)
764
+
765
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
766
+ r"""
767
+ Enable sliced attention computation.
768
+
769
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
770
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
771
+
772
+ Args:
773
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
774
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
775
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
776
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
777
+ must be a multiple of `slice_size`.
778
+ """
779
+ sliceable_head_dims = []
780
+
781
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
782
+ if hasattr(module, "set_attention_slice"):
783
+ sliceable_head_dims.append(module.sliceable_head_dim)
784
+
785
+ for child in module.children():
786
+ fn_recursive_retrieve_sliceable_dims(child)
787
+
788
+ # retrieve number of attention layers
789
+ for module in self.children():
790
+ fn_recursive_retrieve_sliceable_dims(module)
791
+
792
+ num_sliceable_layers = len(sliceable_head_dims)
793
+
794
+ if slice_size == "auto":
795
+ # half the attention head size is usually a good trade-off between
796
+ # speed and memory
797
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
798
+ elif slice_size == "max":
799
+ # make smallest slice possible
800
+ slice_size = num_sliceable_layers * [1]
801
+
802
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
820
+ if hasattr(module, "set_attention_slice"):
821
+ module.set_attention_slice(slice_size.pop())
822
+
823
+ for child in module.children():
824
+ fn_recursive_set_attention_slice(child, slice_size)
825
+
826
+ reversed_slice_size = list(reversed(slice_size))
827
+ for module in self.children():
828
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
829
+
830
+ def _set_gradient_checkpointing(self, module, value=False):
831
+ if hasattr(module, "gradient_checkpointing"):
832
+ module.gradient_checkpointing = value
833
+
834
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
835
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836
+
837
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
838
+
839
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
840
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
841
+
842
+ Args:
843
+ s1 (`float`):
844
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
845
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
846
+ s2 (`float`):
847
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
850
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
851
+ """
852
+ for i, upsample_block in enumerate(self.up_blocks):
853
+ setattr(upsample_block, "s1", s1)
854
+ setattr(upsample_block, "s2", s2)
855
+ setattr(upsample_block, "b1", b1)
856
+ setattr(upsample_block, "b2", b2)
857
+
858
+ def disable_freeu(self):
859
+ """Disables the FreeU mechanism."""
860
+ freeu_keys = {"s1", "s2", "b1", "b2"}
861
+ for i, upsample_block in enumerate(self.up_blocks):
862
+ for k in freeu_keys:
863
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
864
+ setattr(upsample_block, k, None)
865
+
866
+ def fuse_qkv_projections(self):
867
+ """
868
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
869
+ are fused. For cross-attention modules, key and value projection matrices are fused.
870
+
871
+ <Tip warning={true}>
872
+
873
+ This API is 🧪 experimental.
874
+
875
+ </Tip>
876
+ """
877
+ self.original_attn_processors = None
878
+
879
+ for _, attn_processor in self.attn_processors.items():
880
+ if "Added" in str(attn_processor.__class__.__name__):
881
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
882
+
883
+ self.original_attn_processors = self.attn_processors
884
+
885
+ for module in self.modules():
886
+ if isinstance(module, Attention):
887
+ module.fuse_projections(fuse=True)
888
+
889
+ def unfuse_qkv_projections(self):
890
+ """Disables the fused QKV projection if enabled.
891
+
892
+ <Tip warning={true}>
893
+
894
+ This API is 🧪 experimental.
895
+
896
+ </Tip>
897
+
898
+ """
899
+ if self.original_attn_processors is not None:
900
+ self.set_attn_processor(self.original_attn_processors)
901
+
902
+ def unload_lora(self):
903
+ """Unloads LoRA weights."""
904
+ deprecate(
905
+ "unload_lora",
906
+ "0.28.0",
907
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
908
+ )
909
+ for module in self.modules():
910
+ if hasattr(module, "set_lora_layer"):
911
+ module.set_lora_layer(None)
912
+
913
+ def get_time_embed(
914
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
915
+ ) -> Optional[torch.Tensor]:
916
+ timesteps = timestep
917
+ if not torch.is_tensor(timesteps):
918
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
+ # This would be a good case for the `match` statement (Python 3.10+)
920
+ is_mps = sample.device.type == "mps"
921
+ if isinstance(timestep, float):
922
+ dtype = torch.float32 if is_mps else torch.float64
923
+ else:
924
+ dtype = torch.int32 if is_mps else torch.int64
925
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
+ elif len(timesteps.shape) == 0:
927
+ timesteps = timesteps[None].to(sample.device)
928
+
929
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
+ timesteps = timesteps.expand(sample.shape[0])
931
+
932
+ t_emb = self.time_proj(timesteps)
933
+ # `Timesteps` does not contain any weights and will always return f32 tensors
934
+ # but time_embedding might actually be running in fp16. so we need to cast here.
935
+ # there might be better ways to encapsulate this.
936
+ t_emb = t_emb.to(dtype=sample.dtype)
937
+ return t_emb
938
+
939
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
940
+ class_emb = None
941
+ if self.class_embedding is not None:
942
+ if class_labels is None:
943
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
944
+
945
+ if self.config.class_embed_type == "timestep":
946
+ class_labels = self.time_proj(class_labels)
947
+
948
+ # `Timesteps` does not contain any weights and will always return f32 tensors
949
+ # there might be better ways to encapsulate this.
950
+ class_labels = class_labels.to(dtype=sample.dtype)
951
+
952
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
953
+ return class_emb
954
+
955
+ def get_aug_embed(
956
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
957
+ ) -> Optional[torch.Tensor]:
958
+ aug_emb = None
959
+ if self.config.addition_embed_type == "text":
960
+ aug_emb = self.add_embedding(encoder_hidden_states)
961
+ elif self.config.addition_embed_type == "text_image":
962
+ # Kandinsky 2.1 - style
963
+ if "image_embeds" not in added_cond_kwargs:
964
+ raise ValueError(
965
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
966
+ )
967
+
968
+ image_embs = added_cond_kwargs.get("image_embeds")
969
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
970
+ aug_emb = self.add_embedding(text_embs, image_embs)
971
+ elif self.config.addition_embed_type == "text_time":
972
+ # SDXL - style
973
+ if "text_embeds" not in added_cond_kwargs:
974
+ raise ValueError(
975
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
976
+ )
977
+ text_embeds = added_cond_kwargs.get("text_embeds")
978
+ if "time_ids" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
981
+ )
982
+ time_ids = added_cond_kwargs.get("time_ids")
983
+ time_embeds = self.add_time_proj(time_ids.flatten())
984
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
985
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
986
+ add_embeds = add_embeds.to(emb.dtype)
987
+ aug_emb = self.add_embedding(add_embeds)
988
+ elif self.config.addition_embed_type == "image":
989
+ # Kandinsky 2.2 - style
990
+ if "image_embeds" not in added_cond_kwargs:
991
+ raise ValueError(
992
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
993
+ )
994
+ image_embs = added_cond_kwargs.get("image_embeds")
995
+ aug_emb = self.add_embedding(image_embs)
996
+ elif self.config.addition_embed_type == "image_hint":
997
+ # Kandinsky 2.2 - style
998
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
999
+ raise ValueError(
1000
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1001
+ )
1002
+ image_embs = added_cond_kwargs.get("image_embeds")
1003
+ hint = added_cond_kwargs.get("hint")
1004
+ aug_emb = self.add_embedding(image_embs, hint)
1005
+ return aug_emb
1006
+
1007
+ def process_encoder_hidden_states(
1008
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1009
+ ) -> torch.Tensor:
1010
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1011
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1012
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1013
+ # Kandinsky 2.1 - style
1014
+ if "image_embeds" not in added_cond_kwargs:
1015
+ raise ValueError(
1016
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1017
+ )
1018
+
1019
+ image_embeds = added_cond_kwargs.get("image_embeds")
1020
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1021
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1022
+ # Kandinsky 2.2 - style
1023
+ if "image_embeds" not in added_cond_kwargs:
1024
+ raise ValueError(
1025
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1026
+ )
1027
+ image_embeds = added_cond_kwargs.get("image_embeds")
1028
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1029
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1030
+ if "image_embeds" not in added_cond_kwargs:
1031
+ raise ValueError(
1032
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ image_embeds = self.encoder_hid_proj(image_embeds)
1036
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037
+ return encoder_hidden_states
1038
+
1039
+ def forward(
1040
+ self,
1041
+ sample: torch.FloatTensor,
1042
+ timestep: Union[torch.Tensor, float, int],
1043
+ encoder_hidden_states: torch.Tensor,
1044
+ class_labels: Optional[torch.Tensor] = None,
1045
+ timestep_cond: Optional[torch.Tensor] = None,
1046
+ attention_mask: Optional[torch.Tensor] = None,
1047
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1048
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1049
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1050
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1051
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1053
+ return_dict: bool = True,
1054
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1055
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
1056
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1057
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1058
+ r"""
1059
+ The [`UNet2DConditionModel`] forward method.
1060
+
1061
+ Args:
1062
+ sample (`torch.FloatTensor`):
1063
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1064
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1065
+ encoder_hidden_states (`torch.FloatTensor`):
1066
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1067
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1068
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1069
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1070
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1071
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1072
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1073
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1074
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1075
+ negative values to the attention scores corresponding to "discard" tokens.
1076
+ cross_attention_kwargs (`dict`, *optional*):
1077
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1078
+ `self.processor` in
1079
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1080
+ added_cond_kwargs: (`dict`, *optional*):
1081
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1082
+ are passed along to the UNet blocks.
1083
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1084
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1085
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1086
+ A tensor that if specified is added to the residual of the middle unet block.
1087
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1088
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1089
+ encoder_attention_mask (`torch.Tensor`):
1090
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1091
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1092
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1093
+ return_dict (`bool`, *optional*, defaults to `True`):
1094
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1095
+ tuple.
1096
+
1097
+ Returns:
1098
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1099
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1100
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1101
+ """
1102
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1103
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1104
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1105
+ # on the fly if necessary.
1106
+ default_overall_up_factor = 2**self.num_upsamplers
1107
+
1108
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1109
+ forward_upsample_size = False
1110
+ upsample_size = None
1111
+
1112
+ for dim in sample.shape[-2:]:
1113
+ if dim % default_overall_up_factor != 0:
1114
+ # Forward upsample size to force interpolation output size.
1115
+ forward_upsample_size = True
1116
+ break
1117
+
1118
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1119
+ # expects mask of shape:
1120
+ # [batch, key_tokens]
1121
+ # adds singleton query_tokens dimension:
1122
+ # [batch, 1, key_tokens]
1123
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1124
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1125
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1126
+ if attention_mask is not None:
1127
+ # assume that mask is expressed as:
1128
+ # (1 = keep, 0 = discard)
1129
+ # convert mask into a bias that can be added to attention scores:
1130
+ # (keep = +0, discard = -10000.0)
1131
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1132
+ attention_mask = attention_mask.unsqueeze(1)
1133
+
1134
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1135
+ if encoder_attention_mask is not None:
1136
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1137
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1138
+
1139
+ # 0. center input if necessary
1140
+ if self.config.center_input_sample:
1141
+ sample = 2 * sample - 1.0
1142
+
1143
+ # 1. time
1144
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1145
+ emb = self.time_embedding(t_emb, timestep_cond)
1146
+ aug_emb = None
1147
+
1148
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1149
+ if class_emb is not None:
1150
+ if self.config.class_embeddings_concat:
1151
+ emb = torch.cat([emb, class_emb], dim=-1)
1152
+ else:
1153
+ emb = emb + class_emb
1154
+
1155
+ aug_emb = self.get_aug_embed(
1156
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1157
+ )
1158
+ if self.config.addition_embed_type == "image_hint":
1159
+ aug_emb, hint = aug_emb
1160
+ sample = torch.cat([sample, hint], dim=1)
1161
+
1162
+ emb = emb + aug_emb if aug_emb is not None else emb
1163
+
1164
+ if self.time_embed_act is not None:
1165
+ emb = self.time_embed_act(emb)
1166
+
1167
+ encoder_hidden_states = self.process_encoder_hidden_states(
1168
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1169
+ )
1170
+
1171
+ # 2. pre-process
1172
+ sample = self.conv_in(sample)
1173
+
1174
+ # 2.5 GLIGEN position net
1175
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1176
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1177
+ gligen_args = cross_attention_kwargs.pop("gligen")
1178
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1179
+
1180
+ # 3. down
1181
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1182
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1183
+ if cross_attention_kwargs is not None:
1184
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1185
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1186
+ else:
1187
+ lora_scale = 1.0
1188
+
1189
+ if USE_PEFT_BACKEND:
1190
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1191
+ scale_lora_layers(self, lora_scale)
1192
+
1193
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1194
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1195
+ is_adapter = down_intrablock_additional_residuals is not None
1196
+ # maintain backward compatibility for legacy usage, where
1197
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1198
+ # but can only use one or the other
1199
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
1200
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1201
+ deprecate(
1202
+ "T2I should not use down_block_additional_residuals",
1203
+ "1.3.0",
1204
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1205
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1206
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1207
+ standard_warn=False,
1208
+ )
1209
+ down_intrablock_additional_residuals = down_block_additional_residuals
1210
+ is_adapter = True
1211
+
1212
+ down_block_res_samples = (sample,)
1213
+
1214
+ if is_brushnet:
1215
+ sample = sample + down_block_add_samples.pop(0)
1216
+
1217
+ for downsample_block in self.down_blocks:
1218
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1219
+ # For t2i-adapter CrossAttnDownBlock2D
1220
+ additional_residuals = {}
1221
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1222
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1223
+
1224
+ i = len(down_block_add_samples)
1225
+
1226
+ if is_brushnet and len(down_block_add_samples)>0:
1227
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1228
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1229
+
1230
+ sample, res_samples = downsample_block(
1231
+ hidden_states=sample,
1232
+ temb=emb,
1233
+ encoder_hidden_states=encoder_hidden_states,
1234
+ attention_mask=attention_mask,
1235
+ cross_attention_kwargs=cross_attention_kwargs,
1236
+ encoder_attention_mask=encoder_attention_mask,
1237
+ **additional_residuals,
1238
+ )
1239
+ else:
1240
+ additional_residuals = {}
1241
+
1242
+ i = len(down_block_add_samples)
1243
+
1244
+ if is_brushnet and len(down_block_add_samples)>0:
1245
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1246
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1247
+
1248
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
1249
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1250
+ sample += down_intrablock_additional_residuals.pop(0)
1251
+
1252
+ down_block_res_samples += res_samples
1253
+
1254
+ if is_controlnet:
1255
+ new_down_block_res_samples = ()
1256
+
1257
+ for down_block_res_sample, down_block_additional_residual in zip(
1258
+ down_block_res_samples, down_block_additional_residuals
1259
+ ):
1260
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1261
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1262
+
1263
+ down_block_res_samples = new_down_block_res_samples
1264
+
1265
+ # 4. mid
1266
+ if self.mid_block is not None:
1267
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1268
+ sample = self.mid_block(
1269
+ sample,
1270
+ emb,
1271
+ encoder_hidden_states=encoder_hidden_states,
1272
+ attention_mask=attention_mask,
1273
+ cross_attention_kwargs=cross_attention_kwargs,
1274
+ encoder_attention_mask=encoder_attention_mask,
1275
+ )
1276
+ else:
1277
+ sample = self.mid_block(sample, emb)
1278
+
1279
+ # To support T2I-Adapter-XL
1280
+ if (
1281
+ is_adapter
1282
+ and len(down_intrablock_additional_residuals) > 0
1283
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1284
+ ):
1285
+ sample += down_intrablock_additional_residuals.pop(0)
1286
+
1287
+ if is_controlnet:
1288
+ sample = sample + mid_block_additional_residual
1289
+
1290
+ if is_brushnet:
1291
+ sample = sample + mid_block_add_sample
1292
+
1293
+ # 5. up
1294
+ for i, upsample_block in enumerate(self.up_blocks):
1295
+ is_final_block = i == len(self.up_blocks) - 1
1296
+
1297
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1298
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1299
+
1300
+ # if we have not reached the final block and need to forward the
1301
+ # upsample size, we do it here
1302
+ if not is_final_block and forward_upsample_size:
1303
+ upsample_size = down_block_res_samples[-1].shape[2:]
1304
+
1305
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1306
+ additional_residuals = {}
1307
+
1308
+ i = len(up_block_add_samples)
1309
+
1310
+ if is_brushnet and len(up_block_add_samples)>0:
1311
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1312
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1313
+
1314
+ sample = upsample_block(
1315
+ hidden_states=sample,
1316
+ temb=emb,
1317
+ res_hidden_states_tuple=res_samples,
1318
+ encoder_hidden_states=encoder_hidden_states,
1319
+ cross_attention_kwargs=cross_attention_kwargs,
1320
+ upsample_size=upsample_size,
1321
+ attention_mask=attention_mask,
1322
+ encoder_attention_mask=encoder_attention_mask,
1323
+ **additional_residuals,
1324
+ )
1325
+ else:
1326
+ additional_residuals = {}
1327
+
1328
+ i = len(up_block_add_samples)
1329
+
1330
+ if is_brushnet and len(up_block_add_samples)>0:
1331
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1332
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1333
+
1334
+ sample = upsample_block(
1335
+ hidden_states=sample,
1336
+ temb=emb,
1337
+ res_hidden_states_tuple=res_samples,
1338
+ upsample_size=upsample_size,
1339
+ **additional_residuals,
1340
+ )
1341
+
1342
+ # 6. post-process
1343
+ if self.conv_norm_out:
1344
+ sample = self.conv_norm_out(sample)
1345
+ sample = self.conv_act(sample)
1346
+ sample = self.conv_out(sample)
1347
+
1348
+ if USE_PEFT_BACKEND:
1349
+ # remove `lora_scale` from each PEFT layer
1350
+ unscale_lora_layers(self, lora_scale)
1351
+
1352
+ if not return_dict:
1353
+ return (sample,)
1354
+
1355
+ return UNet2DConditionOutput(sample=sample)
MagicQuill/brushnet_nodes.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import torch.nn.functional as F
8
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
+ import sys
10
+
11
+ import comfy.sd
12
+ import comfy.utils
13
+ import comfy.model_management
14
+ import comfy.sd1_clip
15
+ import comfy.ldm.models.autoencoder
16
+ import comfy.supported_models
17
+
18
+ import folder_paths
19
+
20
+ from .model_patch import add_model_patch_option, patch_model_function_wrapper
21
+ from .brushnet.brushnet import BrushNetModel
22
+ from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
23
+ from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
24
+
25
+ current_directory = os.path.dirname(os.path.abspath(__file__))
26
+ brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
27
+ brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
28
+ powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
29
+
30
+ sd15_scaling_factor = 0.18215
31
+ sdxl_scaling_factor = 0.13025
32
+
33
+ print(sys.path)
34
+ ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
35
+ comfy.ldm.models.autoencoder.AutoencoderKL
36
+ ]
37
+
38
+
39
+ class BrushNetLoader:
40
+ @classmethod
41
+ def INPUT_TYPES(self):
42
+ self.inpaint_files = get_files_with_extension('inpaint')
43
+ return {"required":
44
+ {
45
+ "brushnet": ([file for file in self.inpaint_files], ),
46
+ "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
47
+ },
48
+ }
49
+
50
+ CATEGORY = "inpaint"
51
+ RETURN_TYPES = ("BRMODEL",)
52
+ RETURN_NAMES = ("brushnet",)
53
+
54
+ FUNCTION = "brushnet_loading"
55
+
56
+ def brushnet_loading(self, brushnet, dtype):
57
+ brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
58
+ print('BrushNet model file:', brushnet_file)
59
+ is_SDXL = False
60
+ is_PP = False
61
+ sd = comfy.utils.load_torch_file(brushnet_file)
62
+ brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
63
+ del sd
64
+ if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
65
+ is_SDXL = False
66
+ if keys == 322:
67
+ is_PP = False
68
+ print('BrushNet model type: SD1.5')
69
+ else:
70
+ is_PP = True
71
+ print('PowerPaint model type: SD1.5')
72
+ elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
73
+ print('BrushNet model type: Loading SDXL')
74
+ is_SDXL = True
75
+ is_PP = False
76
+ else:
77
+ raise Exception("Unknown BrushNet model")
78
+
79
+ with init_empty_weights():
80
+ if is_SDXL:
81
+ brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
82
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
83
+ elif is_PP:
84
+ brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
85
+ brushnet_model = PowerPaintModel.from_config(brushnet_config)
86
+ else:
87
+ brushnet_config = BrushNetModel.load_config(brushnet_config_file)
88
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
89
+
90
+ if is_PP:
91
+ print("PowerPaint model file:", brushnet_file)
92
+ else:
93
+ print("BrushNet model file:", brushnet_file)
94
+
95
+ if dtype == 'float16':
96
+ torch_dtype = torch.float16
97
+ elif dtype == 'bfloat16':
98
+ torch_dtype = torch.bfloat16
99
+ elif dtype == 'float32':
100
+ torch_dtype = torch.float32
101
+ else:
102
+ torch_dtype = torch.float64
103
+
104
+ brushnet_model = load_checkpoint_and_dispatch(
105
+ brushnet_model,
106
+ brushnet_file,
107
+ device_map="sequential",
108
+ max_memory=None,
109
+ offload_folder=None,
110
+ offload_state_dict=False,
111
+ dtype=torch_dtype,
112
+ force_hooks=False,
113
+ )
114
+
115
+ if is_PP:
116
+ print("PowerPaint model is loaded")
117
+ elif is_SDXL:
118
+ print("BrushNet SDXL model is loaded")
119
+ else:
120
+ print("BrushNet SD1.5 model is loaded")
121
+
122
+ return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
123
+
124
+
125
+ class PowerPaintCLIPLoader:
126
+
127
+ @classmethod
128
+ def INPUT_TYPES(self):
129
+ self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
130
+ self.clip_files = get_files_with_extension('clip')
131
+ return {"required":
132
+ {
133
+ "base": ([file for file in self.clip_files], ),
134
+ "powerpaint": ([file for file in self.inpaint_files], ),
135
+ },
136
+ }
137
+
138
+ CATEGORY = "inpaint"
139
+ RETURN_TYPES = ("CLIP",)
140
+ RETURN_NAMES = ("clip",)
141
+
142
+ FUNCTION = "ppclip_loading"
143
+
144
+ def ppclip_loading(self, base, powerpaint):
145
+ base_CLIP_file = os.path.join(self.clip_files[base], base)
146
+ pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
147
+
148
+ pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
149
+
150
+ print('PowerPaint base CLIP file: ', base_CLIP_file)
151
+
152
+ pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
153
+ pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
154
+
155
+ add_tokens(
156
+ tokenizer = pp_tokenizer,
157
+ text_encoder = pp_text_encoder,
158
+ placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
159
+ initialize_tokens = ["a", "a", "a"],
160
+ num_vectors_per_token = 10,
161
+ )
162
+
163
+ pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
164
+
165
+ print('PowerPaint CLIP file: ', pp_CLIP_file)
166
+
167
+ pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
168
+ pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
169
+
170
+ return (pp_clip,)
171
+
172
+
173
+ class PowerPaint:
174
+
175
+ @classmethod
176
+ def INPUT_TYPES(s):
177
+ return {"required":
178
+ {
179
+ "model": ("MODEL",),
180
+ "vae": ("VAE", ),
181
+ "image": ("IMAGE",),
182
+ "mask": ("MASK",),
183
+ "powerpaint": ("BRMODEL", ),
184
+ "clip": ("CLIP", ),
185
+ "positive": ("CONDITIONING", ),
186
+ "negative": ("CONDITIONING", ),
187
+ "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
188
+ "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
189
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
190
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
191
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
192
+ "save_memory": (['none', 'auto', 'max'], ),
193
+ },
194
+ }
195
+
196
+ CATEGORY = "inpaint"
197
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
198
+ RETURN_NAMES = ("model","positive","negative","latent",)
199
+
200
+ FUNCTION = "model_update"
201
+
202
+ def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
203
+
204
+ is_SDXL, is_PP = check_compatibilty(model, powerpaint)
205
+ if not is_PP:
206
+ raise Exception("BrushNet model was loaded, please use BrushNet node")
207
+
208
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
209
+ model = model.clone()
210
+
211
+ # prepare image and mask
212
+ # no batches for original image and mask
213
+ masked_image, mask = prepare_image(image, mask)
214
+
215
+ batch = masked_image.shape[0]
216
+ #width = masked_image.shape[2]
217
+ #height = masked_image.shape[1]
218
+
219
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
220
+ scaling_factor = model.model.model_config.latent_format.scale_factor
221
+ else:
222
+ scaling_factor = sd15_scaling_factor
223
+
224
+ torch_dtype = powerpaint['dtype']
225
+
226
+ # prepare conditioning latents
227
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
228
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
229
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
230
+
231
+ # prepare embeddings
232
+
233
+ if function == "object removal":
234
+ promptA = "P_ctxt"
235
+ promptB = "P_ctxt"
236
+ negative_promptA = "P_obj"
237
+ negative_promptB = "P_obj"
238
+ print('You should add to positive prompt: "empty scene blur"')
239
+ #positive = positive + " empty scene blur"
240
+ elif function == "context aware":
241
+ promptA = "P_ctxt"
242
+ promptB = "P_ctxt"
243
+ negative_promptA = ""
244
+ negative_promptB = ""
245
+ #positive = positive + " empty scene"
246
+ print('You should add to positive prompt: "empty scene"')
247
+ elif function == "shape guided":
248
+ promptA = "P_shape"
249
+ promptB = "P_ctxt"
250
+ negative_promptA = "P_shape"
251
+ negative_promptB = "P_ctxt"
252
+ elif function == "image outpainting":
253
+ promptA = "P_ctxt"
254
+ promptB = "P_ctxt"
255
+ negative_promptA = "P_obj"
256
+ negative_promptB = "P_obj"
257
+ #positive = positive + " empty scene"
258
+ print('You should add to positive prompt: "empty scene"')
259
+ else:
260
+ promptA = "P_obj"
261
+ promptB = "P_obj"
262
+ negative_promptA = "P_obj"
263
+ negative_promptB = "P_obj"
264
+
265
+ tokens = clip.tokenize(promptA)
266
+ prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
267
+
268
+ tokens = clip.tokenize(negative_promptA)
269
+ negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
270
+
271
+ tokens = clip.tokenize(promptB)
272
+ prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
273
+
274
+ tokens = clip.tokenize(negative_promptB)
275
+ negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
276
+
277
+ prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
278
+ negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
279
+
280
+ # unload vae and CLIPs
281
+ del vae
282
+ del clip
283
+ for loaded_model in comfy.model_management.current_loaded_models:
284
+ if type(loaded_model.model.model) in ModelsToUnload:
285
+ comfy.model_management.current_loaded_models.remove(loaded_model)
286
+ loaded_model.model_unload()
287
+ del loaded_model
288
+
289
+ # apply patch to model
290
+
291
+ brushnet_conditioning_scale = scale
292
+ control_guidance_start = start_at
293
+ control_guidance_end = end_at
294
+
295
+ if save_memory != 'none':
296
+ powerpaint['brushnet'].set_attention_slice(save_memory)
297
+
298
+ add_brushnet_patch(model,
299
+ powerpaint['brushnet'],
300
+ torch_dtype,
301
+ conditioning_latents,
302
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
303
+ negative_prompt_embeds_pp, prompt_embeds_pp,
304
+ None, None, None,
305
+ False)
306
+
307
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
308
+
309
+ return (model, positive, negative, {"samples":latent},)
310
+
311
+
312
+ class BrushNet:
313
+
314
+ @classmethod
315
+ def INPUT_TYPES(s):
316
+ return {"required":
317
+ {
318
+ "model": ("MODEL",),
319
+ "vae": ("VAE", ),
320
+ "image": ("IMAGE",),
321
+ "mask": ("MASK",),
322
+ "brushnet": ("BRMODEL", ),
323
+ "positive": ("CONDITIONING", ),
324
+ "negative": ("CONDITIONING", ),
325
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
326
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
327
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
328
+ },
329
+ }
330
+
331
+ CATEGORY = "inpaint"
332
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
333
+ RETURN_NAMES = ("model","positive","negative","latent",)
334
+
335
+ FUNCTION = "model_update"
336
+
337
+ def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
338
+
339
+ is_SDXL, is_PP = check_compatibilty(model, brushnet)
340
+
341
+ if is_PP:
342
+ raise Exception("PowerPaint model was loaded, please use PowerPaint node")
343
+
344
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
345
+ model = model.clone()
346
+
347
+ # prepare image and mask
348
+ # no batches for original image and mask
349
+ masked_image, mask = prepare_image(image, mask)
350
+
351
+ batch = masked_image.shape[0]
352
+ width = masked_image.shape[2]
353
+ height = masked_image.shape[1]
354
+
355
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
356
+ scaling_factor = model.model.model_config.latent_format.scale_factor
357
+ elif is_SDXL:
358
+ scaling_factor = sdxl_scaling_factor
359
+ else:
360
+ scaling_factor = sd15_scaling_factor
361
+
362
+ torch_dtype = brushnet['dtype']
363
+
364
+ # prepare conditioning latents
365
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
366
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
367
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
368
+
369
+ # unload vae
370
+ del vae
371
+ for loaded_model in comfy.model_management.current_loaded_models:
372
+ if type(loaded_model.model.model) in ModelsToUnload:
373
+ comfy.model_management.current_loaded_models.remove(loaded_model)
374
+ loaded_model.model_unload()
375
+ del loaded_model
376
+
377
+ # prepare embeddings
378
+
379
+ prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
380
+ negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
381
+
382
+ max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
383
+ if prompt_embeds.shape[1] < max_tokens:
384
+ multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
385
+ prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
386
+ print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
387
+ if negative_prompt_embeds.shape[1] < max_tokens:
388
+ multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
389
+ negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
390
+ print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
391
+
392
+ if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
393
+ pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
394
+ else:
395
+ print('BrushNet: positive conditioning has not pooled_output')
396
+ if is_SDXL:
397
+ print('BrushNet will not produce correct results')
398
+ pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
399
+
400
+ if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
401
+ negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
402
+ else:
403
+ print('BrushNet: negative conditioning has not pooled_output')
404
+ if is_SDXL:
405
+ print('BrushNet will not produce correct results')
406
+ negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
407
+
408
+ time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
409
+
410
+ if not is_SDXL:
411
+ pooled_prompt_embeds = None
412
+ negative_pooled_prompt_embeds = None
413
+ time_ids = None
414
+
415
+ # apply patch to model
416
+
417
+ brushnet_conditioning_scale = scale
418
+ control_guidance_start = start_at
419
+ control_guidance_end = end_at
420
+
421
+ add_brushnet_patch(model,
422
+ brushnet['brushnet'],
423
+ torch_dtype,
424
+ conditioning_latents,
425
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
426
+ prompt_embeds, negative_prompt_embeds,
427
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
428
+ False)
429
+
430
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
431
+
432
+ return (model, positive, negative, {"samples":latent},)
433
+
434
+
435
+ class BlendInpaint:
436
+
437
+ @classmethod
438
+ def INPUT_TYPES(s):
439
+ return {"required":
440
+ {
441
+ "inpaint": ("IMAGE",),
442
+ "original": ("IMAGE",),
443
+ "mask": ("MASK",),
444
+ "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
445
+ "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
446
+ },
447
+ "optional":
448
+ {
449
+ "origin": ("VECTOR",),
450
+ },
451
+ }
452
+
453
+ CATEGORY = "inpaint"
454
+ RETURN_TYPES = ("IMAGE","MASK",)
455
+ RETURN_NAMES = ("image","MASK",)
456
+
457
+ FUNCTION = "blend_inpaint"
458
+
459
+ def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
460
+
461
+ original, mask = check_image_mask(original, mask, 'Blend Inpaint')
462
+
463
+ if len(inpaint.shape) < 4:
464
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
465
+ inpaint = inpaint[None,:,:,:]
466
+
467
+ if inpaint.shape[0] < original.shape[0]:
468
+ print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
469
+ original= original[:inpaint.shape[0],:,:]
470
+ mask = mask[:inpaint.shape[0],:,:]
471
+
472
+ if inpaint.shape[0] > original.shape[0]:
473
+ # batch over inpaint
474
+ count = 0
475
+ original_list = []
476
+ mask_list = []
477
+ origin_list = []
478
+ while (count < inpaint.shape[0]):
479
+ for i in range(original.shape[0]):
480
+ original_list.append(original[i][None,:,:,:])
481
+ mask_list.append(mask[i][None,:,:])
482
+ if origin is not None:
483
+ origin_list.append(origin[i][None,:])
484
+ count += 1
485
+ if count >= inpaint.shape[0]:
486
+ break
487
+ original = torch.concat(original_list, dim=0)
488
+ mask = torch.concat(mask_list, dim=0)
489
+ if origin is not None:
490
+ origin = torch.concat(origin_list, dim=0)
491
+
492
+ if kernel % 2 == 0:
493
+ kernel += 1
494
+ transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
495
+
496
+ ret = []
497
+ blurred = []
498
+ for i in range(inpaint.shape[0]):
499
+ if origin is None:
500
+ blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
501
+ blurred.append(blurred_mask[0])
502
+
503
+ result = torch.nn.functional.interpolate(
504
+ inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
505
+ size=(
506
+ original[i].shape[0],
507
+ original[i].shape[1],
508
+ )
509
+ ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
510
+ else:
511
+ # got mask from CutForInpaint
512
+ height, width, _ = original[i].shape
513
+ x0 = origin[i][0].item()
514
+ y0 = origin[i][1].item()
515
+
516
+ if mask[i].shape[0] < height or mask[i].shape[1] < width:
517
+ padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
518
+ y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
519
+ else:
520
+ padded_mask = mask[i]
521
+ blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
522
+ blurred.append(blurred_mask[0][0])
523
+
524
+ result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
525
+ y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
526
+ result = result[None,:,:,:].to(original.device).to(original.dtype)
527
+
528
+ ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
529
+
530
+ return (torch.stack(ret), torch.stack(blurred), )
531
+
532
+
533
+ class CutForInpaint:
534
+
535
+ @classmethod
536
+ def INPUT_TYPES(s):
537
+ return {"required":
538
+ {
539
+ "image": ("IMAGE",),
540
+ "mask": ("MASK",),
541
+ "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
542
+ "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
543
+ },
544
+ }
545
+
546
+ CATEGORY = "inpaint"
547
+ RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
548
+ RETURN_NAMES = ("image","mask","origin",)
549
+
550
+ FUNCTION = "cut_for_inpaint"
551
+
552
+ def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
553
+
554
+ image, mask = check_image_mask(image, mask, 'BrushNet')
555
+
556
+ ret = []
557
+ msk = []
558
+ org = []
559
+ for i in range(image.shape[0]):
560
+ x0, y0, w, h = cut_with_mask(mask[i], width, height)
561
+ ret.append((image[i][y0:y0+h,x0:x0+w,:]))
562
+ msk.append((mask[i][y0:y0+h,x0:x0+w]))
563
+ org.append(torch.IntTensor([x0,y0]))
564
+
565
+ return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
566
+
567
+
568
+ #### Utility function
569
+
570
+ def get_files_with_extension(folder_name, extension=['.safetensors']):
571
+
572
+ try:
573
+ folders = folder_paths.get_folder_paths(folder_name)
574
+ except:
575
+ folders = []
576
+
577
+ if not folders:
578
+ folders = [os.path.join(folder_paths.models_dir, folder_name)]
579
+ if not os.path.isdir(folders[0]):
580
+ folders = [os.path.join(folder_paths.base_path, folder_name)]
581
+ if not os.path.isdir(folders[0]):
582
+ return {}
583
+
584
+ filtered_folders = []
585
+ for x in folders:
586
+ if not os.path.isdir(x):
587
+ continue
588
+ the_same = False
589
+ for y in filtered_folders:
590
+ if os.path.samefile(x, y):
591
+ the_same = True
592
+ break
593
+ if not the_same:
594
+ filtered_folders.append(x)
595
+
596
+ if not filtered_folders:
597
+ return {}
598
+
599
+ output = {}
600
+ for x in filtered_folders:
601
+ files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
602
+ filtered_files = folder_paths.filter_files_extensions(files, extension)
603
+
604
+ for f in filtered_files:
605
+ output[f] = x
606
+
607
+ return output
608
+
609
+
610
+ # get blocks from state_dict so we could know which model it is
611
+ def brushnet_blocks(sd):
612
+ brushnet_down_block = 0
613
+ brushnet_mid_block = 0
614
+ brushnet_up_block = 0
615
+ for key in sd:
616
+ if 'brushnet_down_block' in key:
617
+ brushnet_down_block += 1
618
+ if 'brushnet_mid_block' in key:
619
+ brushnet_mid_block += 1
620
+ if 'brushnet_up_block' in key:
621
+ brushnet_up_block += 1
622
+ return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
623
+
624
+
625
+ # Check models compatibility
626
+ def check_compatibilty(model, brushnet):
627
+ is_SDXL = False
628
+ is_PP = False
629
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
630
+ print('Base model type: SD1.5')
631
+ is_SDXL = False
632
+ if brushnet["SDXL"]:
633
+ raise Exception("Base model is SD15, but BrushNet is SDXL type")
634
+ if brushnet["PP"]:
635
+ is_PP = True
636
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
637
+ print('Base model type: SDXL')
638
+ is_SDXL = True
639
+ if not brushnet["SDXL"]:
640
+ raise Exception("Base model is SDXL, but BrushNet is SD15 type")
641
+ else:
642
+ print('Base model type: ', type(model.model.model_config))
643
+ raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
644
+
645
+ return (is_SDXL, is_PP)
646
+
647
+
648
+ def check_image_mask(image, mask, name):
649
+ if len(image.shape) < 4:
650
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
651
+ image = image[None,:,:,:]
652
+
653
+ if len(mask.shape) > 3:
654
+ # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
655
+ # take first mask, red channel
656
+ mask = (mask[:,:,:,0])[:,:,:]
657
+ elif len(mask.shape) < 3:
658
+ # mask tensor shape should be [B, H, W] but batch somehow is missing
659
+ mask = mask[None,:,:]
660
+
661
+ if image.shape[0] > mask.shape[0]:
662
+ print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
663
+ if mask.shape[0] == 1:
664
+ print(name, "will copy the mask to fill batch")
665
+ mask = torch.cat([mask] * image.shape[0], dim=0)
666
+ else:
667
+ print(name, "will add empty masks to fill batch")
668
+ empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
669
+ mask = torch.cat([mask, empty_mask], dim=0)
670
+ elif image.shape[0] < mask.shape[0]:
671
+ print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
672
+ mask = mask[:image.shape[0],:,:]
673
+
674
+ return (image, mask)
675
+
676
+
677
+ # Prepare image and mask
678
+ def prepare_image(image, mask):
679
+
680
+ image, mask = check_image_mask(image, mask, 'BrushNet')
681
+
682
+ print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
683
+
684
+ if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
685
+ raise Exception("Image and mask should be the same size")
686
+
687
+ # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
688
+ mask = mask.round()
689
+
690
+ masked_image = image * (1.0 - mask[:,:,:,None])
691
+
692
+ return (masked_image, mask)
693
+
694
+
695
+ # Get origin of the mask
696
+ def cut_with_mask(mask, width, height):
697
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
698
+
699
+ h0, w0 = mask.shape
700
+
701
+ if iy.numel() == 0:
702
+ x_c = w0 / 2.0
703
+ y_c = h0 / 2.0
704
+ else:
705
+ x_min = ix.min().item()
706
+ x_max = ix.max().item()
707
+ y_min = iy.min().item()
708
+ y_max = iy.max().item()
709
+
710
+ if x_max - x_min > width or y_max - y_min > height:
711
+ raise Exception("Masked area is bigger than provided dimensions")
712
+
713
+ x_c = (x_min + x_max) / 2.0
714
+ y_c = (y_min + y_max) / 2.0
715
+
716
+ width2 = width / 2.0
717
+ height2 = height / 2.0
718
+
719
+ if w0 <= width:
720
+ x0 = 0
721
+ w = w0
722
+ else:
723
+ x0 = max(0, x_c - width2)
724
+ w = width
725
+ if x0 + width > w0:
726
+ x0 = w0 - width
727
+
728
+ if h0 <= height:
729
+ y0 = 0
730
+ h = h0
731
+ else:
732
+ y0 = max(0, y_c - height2)
733
+ h = height
734
+ if y0 + height > h0:
735
+ y0 = h0 - height
736
+
737
+ return (int(x0), int(y0), int(w), int(h))
738
+
739
+
740
+ # Prepare conditioning_latents
741
+ @torch.inference_mode()
742
+ def get_image_latents(masked_image, mask, vae, scaling_factor):
743
+ processed_image = masked_image.to(vae.device)
744
+ image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
745
+ processed_mask = 1. - mask[:,None,:,:]
746
+ interpolated_mask = torch.nn.functional.interpolate(
747
+ processed_mask,
748
+ size=(
749
+ image_latents.shape[-2],
750
+ image_latents.shape[-1]
751
+ )
752
+ )
753
+ interpolated_mask = interpolated_mask.to(image_latents.device)
754
+
755
+ conditioning_latents = [image_latents, interpolated_mask]
756
+
757
+ print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
758
+
759
+ return conditioning_latents
760
+
761
+
762
+ # Main function where magic happens
763
+ @torch.inference_mode()
764
+ def brushnet_inference(x, timesteps, transformer_options, debug):
765
+ if 'model_patch' not in transformer_options:
766
+ print('BrushNet inference: there is no model_patch key in transformer_options')
767
+ return ([], 0, [])
768
+ mp = transformer_options['model_patch']
769
+ if 'brushnet' not in mp:
770
+ print('BrushNet inference: there is no brushnet key in mdel_patch')
771
+ return ([], 0, [])
772
+ bo = mp['brushnet']
773
+ if 'model' not in bo:
774
+ print('BrushNet inference: there is no model key in brushnet')
775
+ return ([], 0, [])
776
+ brushnet = bo['model']
777
+ if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
778
+ print('BrushNet model is not a BrushNetModel class')
779
+ return ([], 0, [])
780
+
781
+ torch_dtype = bo['dtype']
782
+ cl_list = bo['latents']
783
+ brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
784
+ pe = bo['prompt_embeds']
785
+ npe = bo['negative_prompt_embeds']
786
+ ppe, nppe, time_ids = bo['add_embeds']
787
+
788
+ #do_classifier_free_guidance = mp['free_guidance']
789
+ do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
790
+
791
+ x = x.detach().clone()
792
+ x = x.to(torch_dtype).to(brushnet.device)
793
+
794
+ timesteps = timesteps.detach().clone()
795
+ timesteps = timesteps.to(torch_dtype).to(brushnet.device)
796
+
797
+ total_steps = mp['total_steps']
798
+ step = mp['step']
799
+
800
+ added_cond_kwargs = {}
801
+
802
+ if do_classifier_free_guidance and step == 0:
803
+ print('BrushNet inference: do_classifier_free_guidance is True')
804
+
805
+ sub_idx = None
806
+ if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
807
+ sub_idx = transformer_options['ad_params']['sub_idxs']
808
+
809
+ # we have batch input images
810
+ batch = cl_list[0].shape[0]
811
+ # we have incoming latents
812
+ latents_incoming = x.shape[0]
813
+ # and we already got some
814
+ latents_got = bo['latent_id']
815
+ if step == 0 or batch > 1:
816
+ print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
817
+ % (step, batch, latents_incoming, latents_got))
818
+
819
+ image_latents = []
820
+ masks = []
821
+ prompt_embeds = []
822
+ negative_prompt_embeds = []
823
+ pooled_prompt_embeds = []
824
+ negative_pooled_prompt_embeds = []
825
+ if sub_idx:
826
+ # AnimateDiff indexes detected
827
+ if step == 0:
828
+ print('BrushNet inference: AnimateDiff indexes detected and applied')
829
+
830
+ batch = len(sub_idx)
831
+
832
+ if do_classifier_free_guidance:
833
+ for i in sub_idx:
834
+ image_latents.append(cl_list[0][i][None,:,:,:])
835
+ masks.append(cl_list[1][i][None,:,:,:])
836
+ prompt_embeds.append(pe)
837
+ negative_prompt_embeds.append(npe)
838
+ pooled_prompt_embeds.append(ppe)
839
+ negative_pooled_prompt_embeds.append(nppe)
840
+ for i in sub_idx:
841
+ image_latents.append(cl_list[0][i][None,:,:,:])
842
+ masks.append(cl_list[1][i][None,:,:,:])
843
+ else:
844
+ for i in sub_idx:
845
+ image_latents.append(cl_list[0][i][None,:,:,:])
846
+ masks.append(cl_list[1][i][None,:,:,:])
847
+ prompt_embeds.append(pe)
848
+ pooled_prompt_embeds.append(ppe)
849
+ else:
850
+ # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
851
+ continue_batch = True
852
+ for i in range(latents_incoming):
853
+ number = latents_got + i
854
+ if number < batch:
855
+ # 1st pass, cond
856
+ image_latents.append(cl_list[0][number][None,:,:,:])
857
+ masks.append(cl_list[1][number][None,:,:,:])
858
+ prompt_embeds.append(pe)
859
+ pooled_prompt_embeds.append(ppe)
860
+ elif do_classifier_free_guidance and number < batch * 2:
861
+ # 2nd pass, uncond
862
+ image_latents.append(cl_list[0][number-batch][None,:,:,:])
863
+ masks.append(cl_list[1][number-batch][None,:,:,:])
864
+ negative_prompt_embeds.append(npe)
865
+ negative_pooled_prompt_embeds.append(nppe)
866
+ else:
867
+ # latent batch
868
+ image_latents.append(cl_list[0][0][None,:,:,:])
869
+ masks.append(cl_list[1][0][None,:,:,:])
870
+ prompt_embeds.append(pe)
871
+ pooled_prompt_embeds.append(ppe)
872
+ latents_got = -i
873
+ continue_batch = False
874
+
875
+ if continue_batch:
876
+ # we don't have full batch yet
877
+ if do_classifier_free_guidance:
878
+ if number < batch * 2 - 1:
879
+ bo['latent_id'] = number + 1
880
+ else:
881
+ bo['latent_id'] = 0
882
+ else:
883
+ if number < batch - 1:
884
+ bo['latent_id'] = number + 1
885
+ else:
886
+ bo['latent_id'] = 0
887
+ else:
888
+ bo['latent_id'] = 0
889
+
890
+ cl = []
891
+ for il, m in zip(image_latents, masks):
892
+ cl.append(torch.concat([il, m], dim=1))
893
+ cl2apply = torch.concat(cl, dim=0)
894
+
895
+ conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
896
+
897
+ # print("BrushNet CL: conditioning_latents shape =", conditioning_latents.shape)
898
+ # print("BrushNet CL: x shape =", x.shape)
899
+
900
+ prompt_embeds.extend(negative_prompt_embeds)
901
+ prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
902
+
903
+ if ppe is not None:
904
+ added_cond_kwargs = {}
905
+ added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
906
+
907
+ pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
908
+ pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
909
+ added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
910
+ else:
911
+ added_cond_kwargs = None
912
+
913
+ if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
914
+ if step == 0:
915
+ print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
916
+ conditioning_latents = torch.nn.functional.interpolate(
917
+ conditioning_latents, size=(
918
+ x.shape[2],
919
+ x.shape[3],
920
+ ), mode='bicubic',
921
+ ).to(torch_dtype).to(brushnet.device)
922
+
923
+ if step == 0:
924
+ print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
925
+
926
+ if debug: print('BrushNet: step =', step)
927
+
928
+ if step < control_guidance_start or step > control_guidance_end:
929
+ cond_scale = 0.0
930
+ else:
931
+ cond_scale = brushnet_conditioning_scale
932
+
933
+ return brushnet(x,
934
+ encoder_hidden_states=prompt_embeds,
935
+ brushnet_cond=conditioning_latents,
936
+ timestep = timesteps,
937
+ conditioning_scale=cond_scale,
938
+ guess_mode=False,
939
+ added_cond_kwargs=added_cond_kwargs,
940
+ return_dict=False,
941
+ debug=debug,
942
+ )
943
+
944
+
945
+ # This is main patch function
946
+ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
947
+ controls,
948
+ prompt_embeds, negative_prompt_embeds,
949
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
950
+ debug):
951
+
952
+ is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
953
+
954
+ if is_SDXL:
955
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
956
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
957
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
958
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
959
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
960
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
961
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
962
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
963
+ [8, comfy.ldm.modules.attention.SpatialTransformer]]
964
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
965
+ output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
966
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
967
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
968
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
969
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
970
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
971
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
972
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
973
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
974
+ [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
975
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
976
+ else:
977
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
978
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
979
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
980
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
981
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
982
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
983
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
984
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
985
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
986
+ [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
987
+ [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
988
+ [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
989
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
990
+ output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
991
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
992
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
993
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
994
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
995
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
996
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
997
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
998
+ [6, comfy.ldm.modules.attention.SpatialTransformer],
999
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
1000
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
1001
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
1002
+ [9, comfy.ldm.modules.attention.SpatialTransformer],
1003
+ [10, comfy.ldm.modules.attention.SpatialTransformer],
1004
+ [11, comfy.ldm.modules.attention.SpatialTransformer]]
1005
+
1006
+ def last_layer_index(block, tp):
1007
+ layer_list = []
1008
+ for layer in block:
1009
+ layer_list.append(type(layer))
1010
+ layer_list.reverse()
1011
+ if tp not in layer_list:
1012
+ return -1, layer_list.reverse()
1013
+ return len(layer_list) - 1 - layer_list.index(tp), layer_list
1014
+
1015
+ def brushnet_forward(model, x, timesteps, transformer_options, control):
1016
+ if 'brushnet' not in transformer_options['model_patch']:
1017
+ input_samples = []
1018
+ mid_sample = 0
1019
+ output_samples = []
1020
+ else:
1021
+ # brushnet inference
1022
+ input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
1023
+
1024
+ # give additional samples to blocks
1025
+ for i, tp in input_blocks:
1026
+ idx, layer_list = last_layer_index(model.input_blocks[i], tp)
1027
+ if idx < 0:
1028
+ print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
1029
+ continue
1030
+ model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
1031
+
1032
+ idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
1033
+ if idx < 0:
1034
+ print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
1035
+ model.middle_block[idx].add_sample_after = mid_sample
1036
+
1037
+ for i, tp in output_blocks:
1038
+ idx, layer_list = last_layer_index(model.output_blocks[i], tp)
1039
+ if idx < 0:
1040
+ print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
1041
+ continue
1042
+ model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
1043
+
1044
+ patch_model_function_wrapper(model, brushnet_forward)
1045
+
1046
+ to = add_model_patch_option(model)
1047
+ mp = to['model_patch']
1048
+ if 'brushnet' not in mp:
1049
+ mp['brushnet'] = {}
1050
+ bo = mp['brushnet']
1051
+
1052
+ bo['model'] = brushnet
1053
+ bo['dtype'] = torch_dtype
1054
+ bo['latents'] = conditioning_latents
1055
+ bo['controls'] = controls
1056
+ bo['prompt_embeds'] = prompt_embeds
1057
+ bo['negative_prompt_embeds'] = negative_prompt_embeds
1058
+ bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
1059
+ bo['latent_id'] = 0
1060
+
1061
+ # patch layers `forward` so we can apply brushnet
1062
+ def forward_patched_by_brushnet(self, x, *args, **kwargs):
1063
+ h = self.original_forward(x, *args, **kwargs)
1064
+ if hasattr(self, 'add_sample_after') and type(self):
1065
+ to_add = self.add_sample_after
1066
+ if torch.is_tensor(to_add):
1067
+ # interpolate due to RAUNet
1068
+ if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1069
+ to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1070
+ h += to_add.to(h.dtype).to(h.device)
1071
+ else:
1072
+ h += self.add_sample_after
1073
+ self.add_sample_after = 0
1074
+ return h
1075
+
1076
+ for i, block in enumerate(model.model.diffusion_model.input_blocks):
1077
+ for j, layer in enumerate(block):
1078
+ if not hasattr(layer, 'original_forward'):
1079
+ layer.original_forward = layer.forward
1080
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1081
+ layer.add_sample_after = 0
1082
+
1083
+ for j, layer in enumerate(model.model.diffusion_model.middle_block):
1084
+ if not hasattr(layer, 'original_forward'):
1085
+ layer.original_forward = layer.forward
1086
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1087
+ layer.add_sample_after = 0
1088
+
1089
+ for i, block in enumerate(model.model.diffusion_model.output_blocks):
1090
+ for j, layer in enumerate(block):
1091
+ if not hasattr(layer, 'original_forward'):
1092
+ layer.original_forward = layer.forward
1093
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1094
+ layer.add_sample_after = 0
MagicQuill/comfy/.DS_Store ADDED
Binary file (6.15 kB). View file
 
MagicQuill/comfy/__pycache__/checkpoint_pickle.cpython-311.pyc ADDED
Binary file (1.09 kB). View file
 
MagicQuill/comfy/__pycache__/cli_args.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
MagicQuill/comfy/__pycache__/clip_model.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
MagicQuill/comfy/__pycache__/clip_vision.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
MagicQuill/comfy/__pycache__/conds.cpython-311.pyc ADDED
Binary file (5.5 kB). View file
 
MagicQuill/comfy/__pycache__/controlnet.cpython-311.pyc ADDED
Binary file (34.7 kB). View file
 
MagicQuill/comfy/__pycache__/diffusers_convert.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
MagicQuill/comfy/__pycache__/diffusers_load.cpython-311.pyc ADDED
Binary file (2.37 kB). View file
 
MagicQuill/comfy/__pycache__/gligen.cpython-311.pyc ADDED
Binary file (22.1 kB). View file
 
MagicQuill/comfy/__pycache__/latent_formats.cpython-311.pyc ADDED
Binary file (7.88 kB). View file
 
MagicQuill/comfy/__pycache__/lora.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
MagicQuill/comfy/__pycache__/model_base.cpython-311.pyc ADDED
Binary file (48 kB). View file
 
MagicQuill/comfy/__pycache__/model_detection.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
MagicQuill/comfy/__pycache__/model_management.cpython-311.pyc ADDED
Binary file (40.5 kB). View file
 
MagicQuill/comfy/__pycache__/model_patcher.cpython-311.pyc ADDED
Binary file (32.6 kB). View file
 
MagicQuill/comfy/__pycache__/model_sampling.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
MagicQuill/comfy/__pycache__/ops.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
MagicQuill/comfy/__pycache__/options.cpython-311.pyc ADDED
Binary file (332 Bytes). View file
 
MagicQuill/comfy/__pycache__/sa_t5.cpython-311.pyc ADDED
Binary file (3.9 kB). View file
 
MagicQuill/comfy/__pycache__/sample.cpython-311.pyc ADDED
Binary file (4.75 kB). View file
 
MagicQuill/comfy/__pycache__/sampler_helpers.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
MagicQuill/comfy/__pycache__/samplers.cpython-311.pyc ADDED
Binary file (43.9 kB). View file
 
MagicQuill/comfy/__pycache__/sd.cpython-311.pyc ADDED
Binary file (46.2 kB). View file