YoadTew commited on
Commit
504c7e8
·
1 Parent(s): 47481b9

Add application file

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NVIDIA License
2
+
3
+ 1. Definitions
4
+
5
+ “Licensor” means any person or entity that distributes its Work.
6
+ “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7
+ The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8
+ Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9
+
10
+ 2. License Grant
11
+
12
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13
+
14
+ 3. Limitations
15
+
16
+ 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17
+
18
+ 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19
+
20
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
21
+
22
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23
+
24
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25
+
26
+ 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27
+
28
+ 4. Disclaimer of Warranty.
29
+
30
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32
+
33
+ 5. Limitation of Liability.
34
+
35
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
README.md CHANGED
@@ -1,13 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Addit
3
- emoji:
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.36.2
8
- app_file: app.py
9
- pinned: false
10
- license: other
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ # 🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models
2
+
3
+ <div align="center">
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2411.07232-b31b1b.svg)](https://arxiv.org/abs/2411.07232)
6
+ [![Project Website](https://img.shields.io/badge/🌐-Project%20Website-blue)](https://research.nvidia.com/labs/par/addit/)
7
+
8
+ </div>
9
+
10
+ ## 👥 Authors
11
+
12
+ **Yoad Tewel**<sup>1,2</sup>, **Rinon Gal**<sup>1,2</sup>, **Dvir Samuel**<sup>3</sup>, **Yuval Atzmon**<sup>1</sup>, **Lior Wolf**<sup>2</sup>, **Gal Chechik**<sup>1</sup>
13
+
14
+ <sup>1</sup>NVIDIA • <sup>2</sup>Tel Aviv University • <sup>3</sup>Bar-Ilan University
15
+
16
+ <div align="center">
17
+ <img src="https://research.nvidia.com/labs/par/addit/static/images/Teaser.png" alt="Add-it Teaser" width="800"/>
18
+ </div>
19
+
20
+ ## 📄 Abstract
21
+
22
+ Adding objects into images based on text instructions is a challenging task in semantic image editing, requiring a balance between preserving the original scene and seamlessly integrating the new object in a fitting location. Despite extensive efforts, existing models often struggle with this balance, particularly with finding a natural location for adding an object in complex scenes.
23
+
24
+ We introduce **Add-it**, a training-free approach that extends diffusion models' attention mechanisms to incorporate information from three key sources: the scene image, the text prompt, and the generated image itself. Our weighted extended-attention mechanism maintains structural consistency and fine details while ensuring natural object placement.
25
+
26
+ Without task-specific fine-tuning, Add-it achieves state-of-the-art results on both real and generated image insertion benchmarks, including our newly constructed "Additing Affordance Benchmark" for evaluating object placement plausibility, outperforming supervised methods. Human evaluations show that Add-it is preferred in over 80% of cases, and it also demonstrates improvements in various automated metrics.
27
+
28
+ ---
29
+
30
+ ## 📋 Description
31
+
32
+ This repository contains the official implementation of the Add-it paper, providing tools for seamless object insertion into images using pretrained diffusion models.
33
+
34
+ ---
35
+
36
+ ## 🛠️ Setup
37
+
38
+ ```bash
39
+ conda env create -f environment.yml
40
+ conda activate addit
41
+ ```
42
+
43
  ---
44
+
45
+ ## 🚀 Usage
46
+
47
+ ### 💻 Command Line Interface (CLI)
48
+
49
+ Add-it provides two CLI scripts for different use cases:
50
+
51
+ #### 1. 🎭 Adding Objects to Generated Images
52
+
53
+ Use `run_CLI_addit_generated.py` to add objects to AI-generated images:
54
+
55
+ ```bash
56
+ python run_CLI_addit_generated.py \
57
+ --prompt_source "A photo of a cat sitting on the couch" \
58
+ --prompt_target "A photo of a cat wearing a red hat sitting on the couch" \
59
+ --subject_token "hat"
60
+ ```
61
+
62
+ ##### ⚙️ Options for Generated Images
63
+
64
+ **🔴 Required Arguments:**
65
+ - `--prompt_source`: Source prompt for generating the base image
66
+ - `--prompt_target`: Target prompt describing the desired edited image
67
+ - `--subject_token`: Single token representing the subject to add (must appear in prompt_target)
68
+
69
+ **🔵 Optional Arguments:**
70
+ - `--output_dir`: Directory to save output images (default: "outputs")
71
+ - `--seed_src`: Seed for source generation (default: 6311)
72
+ - `--seed_obj`: Seed for edited image generation (default: 1)
73
+ - `--extended_scale`: Extended attention scale (default: 1.05)
74
+ - `--structure_transfer_step`: Structure transfer step (default: 2)
75
+ - `--blend_steps`: Blend steps (default: [15]). To allow for changes in the input image pass `--blend_steps` with empty value.
76
+ - `--localization_model`: Localization model (default: "attention_points_sam")
77
+ - **Options:** `attention_points_sam`, `attention`, `attention_box_sam`, `attention_mask_sam`, `grounding_sam`
78
+ - `--show_attention`: Show attention maps using pyplot (flag), will be saved to `attn_vis.png`.
79
+
80
+ #### 2. 📸 Adding Objects to Real Images
81
+
82
+ Use `run_CLI_addit_real.py` to add objects to existing images:
83
+
84
+ ```bash
85
+ python run_CLI_addit_real.py \
86
+ --source_image "images/bed_dark_room.jpg" \
87
+ --prompt_source "A photo of a bed in a dark room" \
88
+ --prompt_target "A photo of a dog lying on a bed in a dark room" \
89
+ --subject_token "dog"
90
+ ```
91
+
92
+ ##### ⚙️ Options for Real Images
93
+
94
+ **🔴 Required Arguments:**
95
+ - `--source_image`: Path to the source image (default: "images/bed_dark_room.jpg")
96
+ - `--prompt_source`: Source prompt describing the original image
97
+ - `--prompt_target`: Target prompt describing the desired edited image
98
+ - `--subject_token`: Subject token to add to the image (must appear in prompt_target)
99
+
100
+ **🔵 Optional Arguments:**
101
+ - `--output_dir`: Directory to save output images (default: "outputs")
102
+ - `--seed_src`: Seed for source generation (default: 6311)
103
+ - `--seed_obj`: Seed for edited image generation (default: 1)
104
+ - `--extended_scale`: Extended attention scale (default: 1.1)
105
+ - `--structure_transfer_step`: Structure transfer step (default: 4)
106
+ - `--blend_steps`: Blend steps (default: [18]). To allow for changes in the input image pass `--blend_steps` with empty value.
107
+ - `--localization_model`: Localization model (default: "attention")
108
+ - **Options:** `attention_points_sam`, `attention`, `attention_box_sam`, `attention_mask_sam`, `grounding_sam`
109
+ - `--use_offset`: Use offset in processing (flag)
110
+ - `--show_attention`: Show attention maps using pyplot (flag), will be saved to `attn_vis.png`.
111
+ - `--disable_inversion`: Disable source image inversion (flag)
112
+
113
+ ---
114
+
115
+ ### 📓 Jupyter Notebooks
116
+
117
+ You can run Add-it in two interactive modes:
118
+
119
+ | Mode | Notebook | Description |
120
+ |------|----------|-------------|
121
+ | 🎭 **Generated Images** | `run_addit_generated.ipynb` | Adding objects to AI-generated images |
122
+ | 📸 **Real Images** | `run_addit_real.ipynb` | Adding objects to existing real images |
123
+
124
+ The notebooks contain examples of different prompts and parameters that can be adjusted to control the object insertion process.
125
+
126
+ ---
127
+
128
+ ## 💡 Tips for Better Results
129
+
130
+ - **Prompt Design**: The `--prompt_target` should be similar to the `--prompt_source`, but include a description of the new object to insert
131
+ - **Seed Variation**: Try different values for `--seed_obj` - some prompts may require a few attempts to get satisfying results
132
+ - **Localization Models**: The most effective `--localization_model` options are `attention_points_sam` and `attention`. Use the `--show_attention` flag to visualize localization performance
133
+ - **Object Placement Issues**: If the object is not added to the image:
134
+ - Try **decreasing** `--structure_transfer_step`
135
+ - Try **increasing** `--extended_scale`
136
+ - **Flexibility**: To allow more flexibility in modifying the source image, set `--blend_steps` to an empty value to send an empty list: `[]`
137
+
138
+ ---
139
+
140
+ ## 📰 News
141
+
142
+ - **🎉 2025 JUL**: Official Add-it implementation is released!
143
+
144
+ ---
145
+
146
+ ## 📝 TODO
147
+
148
+ - [x] Release code
149
+
150
+ ---
151
+
152
+ ## 📚 Citation
153
+
154
+ If you make use of our work, please cite our paper:
155
+
156
+ ```bibtex
157
+ @misc{tewel2024addit,
158
+ title={Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models},
159
+ author={Yoad Tewel and Rinon Gal and Dvir Samuel and Yuval Atzmon and Lior Wolf and Gal Chechik},
160
+ year={2024},
161
+ eprint={2411.07232},
162
+ archivePrefix={arXiv},
163
+ primaryClass={cs.CV}
164
+ }
165
+ ```
166
+
167
  ---
168
 
169
+ <div align="center">
170
+ <strong>🌟 Star this repo if you find it useful! 🌟</strong>
171
+ </div>
addit_attention_processors.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
16
+ #
17
+ # This work is licensed under the LICENSE file
18
+ # located at the root directory.
19
+
20
+ from collections import defaultdict
21
+ from diffusers.models.attention_processor import Attention, apply_rope
22
+ from typing import Callable, List, Optional, Tuple, Union
23
+
24
+ from addit_attention_store import AttentionStore
25
+ from visualization_utils import show_tensors
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import numpy as np
30
+ from scipy.optimize import brentq
31
+
32
+ def apply_standard_attention(query, key, value, attn, attention_probs=None):
33
+ batch_size, attn_heads, _, head_dim = query.shape
34
+
35
+ # Do normal attention, to cache the attention scores
36
+ query = query.reshape(batch_size*attn_heads, -1, head_dim)
37
+ key = key.reshape(batch_size*attn_heads, -1, head_dim)
38
+ value = value.reshape(batch_size*attn_heads, -1, head_dim)
39
+
40
+ if attention_probs is None:
41
+ attention_probs = attn.get_attention_scores(query, key)
42
+
43
+ hidden_states = torch.bmm(attention_probs, value)
44
+ hidden_states = hidden_states.view(batch_size, attn_heads, -1, head_dim)
45
+
46
+ return hidden_states, attention_probs
47
+
48
+ def apply_extended_attention(query, key, value, attention_store, attn, layer_name, step_index, extend_type="pixels",
49
+ extended_scale=1., record_attention=False):
50
+ batch_size = query.size(0)
51
+ extend_query = query[1:]
52
+
53
+ if extend_type == "full":
54
+ added_key = key[0] * extended_scale
55
+ added_value = value[0]
56
+ elif extend_type == "text":
57
+ added_key = key[0, :, :512] * extended_scale
58
+ added_value = value[0, :, :512]
59
+ elif extend_type == "pixels":
60
+ added_key = key[0, :, 512:]
61
+ added_value = value[0, :, 512:]
62
+
63
+ key[1] = key[1] * extended_scale
64
+
65
+ extend_key = torch.cat([added_key, key[1]], dim=1).unsqueeze(0)
66
+ extend_value = torch.cat([added_value, value[1]], dim=1).unsqueeze(0)
67
+
68
+ hidden_states_0 = F.scaled_dot_product_attention(query[:1], key[:1], value[:1], dropout_p=0.0, is_causal=False)
69
+
70
+ if record_attention or attention_store.is_cache_attn_ratio(step_index):
71
+ hidden_states_1, attention_probs_1 = apply_standard_attention(extend_query, extend_key, extend_value, attn)
72
+ else:
73
+ hidden_states_1 = F.scaled_dot_product_attention(extend_query, extend_key, extend_value, dropout_p=0.0, is_causal=False)
74
+
75
+ if record_attention:
76
+ # Store Attention
77
+ seq_len = attention_probs_1.size(2) - attention_probs_1.size(1)
78
+ self_attention_probs_1 = attention_probs_1[:,:,seq_len:]
79
+ attention_store.store_attention(self_attention_probs_1, layer_name, 1, attn.heads)
80
+
81
+ if attention_store.is_cache_attn_ratio(step_index):
82
+ attention_store.store_attention_ratios(attention_probs_1, step_index, layer_name)
83
+
84
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0)
85
+
86
+ return hidden_states
87
+
88
+ def apply_attention(query, key, value, attention_store, attn, layer_name, step_index,
89
+ record_attention, extended_attention, extended_scale):
90
+ if extended_attention:
91
+ hidden_states = apply_extended_attention(query, key, value, attention_store, attn, layer_name, step_index,
92
+ extended_scale=extended_scale,
93
+ record_attention=record_attention)
94
+ else:
95
+ if record_attention:
96
+ hidden_states_0 = F.scaled_dot_product_attention(query[:1], key[:1], value[:1], dropout_p=0.0, is_causal=False)
97
+ hidden_states_1, attention_probs_1 = apply_standard_attention(query[1:], key[1:], value[1:], attn)
98
+ attention_store.store_attention(attention_probs_1, layer_name, 1, attn.heads)
99
+
100
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0)
101
+ else:
102
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
103
+
104
+ return hidden_states
105
+
106
+ class AdditFluxAttnProcessor2_0:
107
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
108
+
109
+ def __init__(self, layer_name: str, attention_store: AttentionStore,
110
+ extended_steps: Tuple[int, int] = (0, 30), **kwargs):
111
+ if not hasattr(F, "scaled_dot_product_attention"):
112
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
113
+
114
+ self.layer_name = layer_name
115
+ self.layer_idx = int(layer_name.split(".")[-1])
116
+ self.attention_store = attention_store
117
+
118
+ self.extended_steps = (0, extended_steps) if isinstance(extended_steps, int) else extended_steps
119
+
120
+ def __call__(
121
+ self,
122
+ attn: Attention,
123
+ hidden_states: torch.FloatTensor,
124
+ encoder_hidden_states: torch.FloatTensor = None,
125
+ attention_mask: Optional[torch.FloatTensor] = None,
126
+ image_rotary_emb: Optional[torch.Tensor] = None,
127
+
128
+ step_index: Optional[int] = None,
129
+ extended_scale: Optional[float] = 1.0,
130
+ ) -> torch.FloatTensor:
131
+ input_ndim = hidden_states.ndim
132
+ if input_ndim == 4:
133
+ batch_size, channel, height, width = hidden_states.shape
134
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
135
+ context_input_ndim = encoder_hidden_states.ndim
136
+ if context_input_ndim == 4:
137
+ batch_size, channel, height, width = encoder_hidden_states.shape
138
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
139
+
140
+ batch_size = encoder_hidden_states.shape[0]
141
+
142
+ # `sample` projections.
143
+ query = attn.to_q(hidden_states)
144
+ key = attn.to_k(hidden_states)
145
+ value = attn.to_v(hidden_states)
146
+
147
+ inner_dim = key.shape[-1]
148
+ head_dim = inner_dim // attn.heads
149
+
150
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
151
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
152
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
153
+
154
+ if attn.norm_q is not None:
155
+ query = attn.norm_q(query)
156
+ if attn.norm_k is not None:
157
+ key = attn.norm_k(key)
158
+
159
+ # `context` projections.
160
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
161
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
162
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
163
+
164
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
165
+ batch_size, -1, attn.heads, head_dim
166
+ ).transpose(1, 2)
167
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
168
+ batch_size, -1, attn.heads, head_dim
169
+ ).transpose(1, 2)
170
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
171
+ batch_size, -1, attn.heads, head_dim
172
+ ).transpose(1, 2)
173
+
174
+ if attn.norm_added_q is not None:
175
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
176
+ if attn.norm_added_k is not None:
177
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
178
+
179
+ # attention
180
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
181
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
182
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
183
+
184
+ if image_rotary_emb is not None:
185
+ # YiYi to-do: update uising apply_rotary_emb
186
+ # from ..embeddings import apply_rotary_emb
187
+ # query = apply_rotary_emb(query, image_rotary_emb)
188
+ # key = apply_rotary_emb(key, image_rotary_emb)
189
+ query, key = apply_rope(query, key, image_rotary_emb)
190
+
191
+ record_attention = self.attention_store.is_record_attention(self.layer_name, step_index)
192
+ extend_start, extend_end = self.extended_steps
193
+ extended_attention = extend_start <= step_index <= extend_end
194
+
195
+ hidden_states = apply_attention(query, key, value, self.attention_store, attn, self.layer_name, step_index,
196
+ record_attention, extended_attention, extended_scale)
197
+
198
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
199
+ hidden_states = hidden_states.to(query.dtype)
200
+
201
+ encoder_hidden_states, hidden_states = (
202
+ hidden_states[:, : encoder_hidden_states.shape[1]],
203
+ hidden_states[:, encoder_hidden_states.shape[1] :],
204
+ )
205
+
206
+ # linear proj
207
+ hidden_states = attn.to_out[0](hidden_states)
208
+ # dropout
209
+ hidden_states = attn.to_out[1](hidden_states)
210
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
211
+
212
+ if input_ndim == 4:
213
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
214
+ if context_input_ndim == 4:
215
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
216
+
217
+ return hidden_states, encoder_hidden_states
218
+
219
+ class AdditFluxSingleAttnProcessor2_0:
220
+ r"""
221
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
222
+ """
223
+
224
+ def __init__(self, layer_name: str, attention_store: AttentionStore,
225
+ extended_steps: Tuple[int, int] = (0, 30), **kwargs):
226
+ if not hasattr(F, "scaled_dot_product_attention"):
227
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
228
+
229
+ self.layer_name = layer_name
230
+ self.layer_idx = int(layer_name.split(".")[-1])
231
+ self.attention_store = attention_store
232
+
233
+ self.extended_steps = (0, extended_steps) if isinstance(extended_steps, int) else extended_steps
234
+
235
+ def __call__(
236
+ self,
237
+ attn: Attention,
238
+ hidden_states: torch.Tensor,
239
+ encoder_hidden_states: Optional[torch.Tensor] = None,
240
+ attention_mask: Optional[torch.FloatTensor] = None,
241
+ image_rotary_emb: Optional[torch.Tensor] = None,
242
+ step_index: Optional[int] = None,
243
+ extended_scale: Optional[float] = 1.0,
244
+ ) -> torch.Tensor:
245
+ input_ndim = hidden_states.ndim
246
+
247
+ if input_ndim == 4:
248
+ batch_size, channel, height, width = hidden_states.shape
249
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
250
+
251
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
252
+
253
+ query = attn.to_q(hidden_states)
254
+ if encoder_hidden_states is None:
255
+ encoder_hidden_states = hidden_states
256
+
257
+ key = attn.to_k(encoder_hidden_states)
258
+ value = attn.to_v(encoder_hidden_states)
259
+
260
+ inner_dim = key.shape[-1]
261
+ head_dim = inner_dim // attn.heads
262
+
263
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
264
+
265
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
266
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
267
+
268
+ if attn.norm_q is not None:
269
+ query = attn.norm_q(query)
270
+ if attn.norm_k is not None:
271
+ key = attn.norm_k(key)
272
+
273
+ # Apply RoPE if needed
274
+ if image_rotary_emb is not None:
275
+ # YiYi to-do: update uising apply_rotary_emb
276
+ # from ..embeddings import apply_rotary_emb
277
+ # query = apply_rotary_emb(query, image_rotary_emb)
278
+ # key = apply_rotary_emb(key, image_rotary_emb)
279
+ query, key = apply_rope(query, key, image_rotary_emb)
280
+
281
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
282
+ # TODO: add support for attn.scale when we move to Torch 2.1
283
+
284
+ record_attention = self.attention_store.is_record_attention(self.layer_name, step_index)
285
+ extend_start, extend_end = self.extended_steps
286
+ extended_attention = extend_start <= step_index <= extend_end
287
+
288
+ hidden_states = apply_attention(query, key, value, self.attention_store, attn, self.layer_name, step_index,
289
+ record_attention, extended_attention, extended_scale)
290
+
291
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
292
+ hidden_states = hidden_states.to(query.dtype)
293
+
294
+ if input_ndim == 4:
295
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
296
+
297
+ return hidden_states
addit_attention_store.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is licensed under the LICENSE file
4
+ # located at the root directory.
5
+
6
+ import torch
7
+ from skimage import filters
8
+ import cv2
9
+ import torch.nn.functional as F
10
+ from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu
11
+ import numpy as np
12
+ from visualization_utils import show_tensors
13
+ import matplotlib.pyplot as plt
14
+
15
+ def text_to_tokens(text, tokenizer):
16
+ return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]]
17
+
18
+ def flatten_list(l):
19
+ return [item for sublist in l for item in sublist]
20
+
21
+ def gaussian_blur(heatmap, kernel_size=7, sigma=0):
22
+ # Shape of heatmap: (H, W)
23
+ heatmap = heatmap.cpu().numpy()
24
+ heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma)
25
+ heatmap = torch.tensor(heatmap)
26
+
27
+ return heatmap
28
+
29
+ def min_max_norm(x):
30
+ return (x - x.min()) / (x.max() - x.min())
31
+
32
+ class AttentionStore:
33
+ def __init__(self, prompts, tokenizer,
34
+ subject_token=None, record_attention_steps=[],
35
+ is_cache_attn_ratio=False, attn_ratios_steps=[5]):
36
+
37
+ self.text2image_store = {}
38
+ self.image2text_store = {}
39
+ self.count_per_layer = {}
40
+
41
+ self.record_attention_steps = record_attention_steps
42
+ self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"]
43
+
44
+ self.attention_ratios = {}
45
+ self._is_cache_attn_ratio = is_cache_attn_ratio
46
+ self.attn_ratios_steps = attn_ratios_steps
47
+ self.ratio_source = 'text'
48
+
49
+ self.max_tokens_to_record = 10
50
+
51
+ if isinstance(prompts, str):
52
+ prompts = [prompts]
53
+ batch_size = 1
54
+ else:
55
+ batch_size = len(prompts)
56
+
57
+ tokens_per_prompt = []
58
+
59
+ for prompt in prompts:
60
+ tokens = text_to_tokens(prompt, tokenizer)
61
+ tokens_per_prompt.append(tokens)
62
+
63
+ self.tokens_to_record = []
64
+ self.token_idxs_to_record = []
65
+
66
+ if len(record_attention_steps) > 0:
67
+ self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]])
68
+ self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens]
69
+ self.add_token_idx = self.subject_tokens_idx[-1]
70
+
71
+ def is_record_attention(self, layer_name, step_index):
72
+ is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers)
73
+
74
+ record_attention = (step_index in self.record_attention_steps) and (is_correct_layer)
75
+
76
+ return record_attention
77
+
78
+ def store_attention(self, attention_probs, layer_name, batch_size, num_heads):
79
+ text_len = 512
80
+ timesteps = len(self.record_attention_steps)
81
+
82
+ # Split batch and heads
83
+ attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:])
84
+
85
+ # Mean over the heads
86
+ attention_probs = attention_probs.mean(dim=1)
87
+
88
+ # Attention: text -> image
89
+ attention_probs_text2image = attention_probs[:, :text_len, text_len:]
90
+ attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]]
91
+
92
+ # Attention: image -> text
93
+ attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2)
94
+ attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]]
95
+
96
+ if layer_name not in self.text2image_store:
97
+ self.text2image_store[layer_name] = [x for x in attention_probs_text2image]
98
+ self.image2text_store[layer_name] = [x for x in attention_probs_image2text]
99
+ else:
100
+ self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)]
101
+ self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)]
102
+
103
+ def is_cache_attn_ratio(self, step_index):
104
+ return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps)
105
+
106
+ def store_attention_ratios(self, attention_probs, step_index, layer_name):
107
+ layer_prefix = layer_name.split(".")[0]
108
+
109
+ if self.ratio_source == 'pixels':
110
+ extended_attention_probs = attention_probs.mean(dim=0)[512:, :]
111
+ extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu()
112
+ extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu()
113
+ extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu()
114
+ token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu()
115
+
116
+ stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1)
117
+ elif self.ratio_source == 'text':
118
+ extended_attention_probs = attention_probs.mean(dim=0)[:512, :]
119
+ extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu()
120
+ extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu()
121
+
122
+ stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1)
123
+
124
+ if step_index not in self.attention_ratios:
125
+ self.attention_ratios[step_index] = {}
126
+
127
+ if layer_prefix not in self.attention_ratios[step_index]:
128
+ self.attention_ratios[step_index][layer_prefix] = []
129
+
130
+ self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios)
131
+
132
+ def get_attention_ratios(self, step_indices=None, display_imgs=False):
133
+ ratios = []
134
+
135
+ if step_indices is None:
136
+ step_indices = list(self.attention_ratios.keys())
137
+
138
+ if len(step_indices) == 1:
139
+ steps = f"Step: {step_indices[0]}"
140
+ else:
141
+ steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]"
142
+
143
+ layer_prefixes = list(self.attention_ratios[step_indices[0]].keys())
144
+ scores_per_layer = {}
145
+
146
+ for layer_prefix in layer_prefixes:
147
+ ratios = []
148
+
149
+ for step_index in step_indices:
150
+ if layer_prefix in self.attention_ratios[step_index]:
151
+ step_ratios = self.attention_ratios[step_index][layer_prefix]
152
+ step_ratios = torch.stack(step_ratios).mean(dim=0)
153
+ ratios.append(step_ratios)
154
+
155
+ # Mean over the steps
156
+ ratios = torch.stack(ratios).mean(dim=0)
157
+
158
+ if self.ratio_source == 'pixels':
159
+ source, text, target, token = torch.split(ratios, 64, dim=1)
160
+ title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}"
161
+ ratios = min_max_norm(torch.cat([source, text, target], dim=1))
162
+ token = min_max_norm(token)
163
+ ratios = torch.cat([ratios, token], dim=1)
164
+ elif self.ratio_source == 'text':
165
+ source, target = torch.split(ratios, 64, dim=1)
166
+ source_sum = source.sum().item()
167
+ target_sum = target.sum().item()
168
+ text_sum = 512 - (source_sum + target_sum)
169
+
170
+ title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}"
171
+ ratios = min_max_norm(torch.cat([source, target], dim=1))
172
+
173
+ if display_imgs:
174
+ print(f"Layer: {layer_prefix}")
175
+ show_tensors([ratios], [title])
176
+
177
+ scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum)
178
+
179
+ return scores_per_layer
180
+
181
+ def plot_attention_ratios(self, step_indices=None):
182
+ steps = list(self.attention_ratios.keys())
183
+ score_per_layer = {
184
+ 'transformer_blocks': {},
185
+ 'single_transformer_blocks': {}
186
+ }
187
+
188
+ for i in steps:
189
+ scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False)
190
+
191
+ for layer in self.attention_ratios[i]:
192
+ source, text, target = scores_per_layer[layer]
193
+ score_per_layer[layer][i] = (source, text, target)
194
+
195
+ for layer_type in score_per_layer:
196
+ x = list(score_per_layer[layer_type].keys())
197
+ source_sums = [x[0] for x in score_per_layer[layer_type].values()]
198
+ text_sums = [x[1] for x in score_per_layer[layer_type].values()]
199
+ target_sums = [x[2] for x in score_per_layer[layer_type].values()]
200
+
201
+ # Calculate the total sums for each stack (source + text + target)
202
+ total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))]
203
+
204
+ # Create stacked bar plots
205
+ fig, ax = plt.subplots(figsize=(10, 6))
206
+ indices = np.arange(len(x))
207
+
208
+ # Plot source at the bottom
209
+ ax.bar(indices, source_sums, label='Source', color='#6A2C70')
210
+
211
+ # Plot text stacked on source
212
+ ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums)
213
+
214
+ # Plot target stacked on text + source
215
+ target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))]
216
+ ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom)
217
+
218
+ # Annotate bars with percentage values
219
+ for j, index in enumerate(indices):
220
+
221
+ font_size = 12
222
+
223
+ # Source percentage
224
+ source_percentage = 100 * source_sums[j] / total_sums[j]
225
+ ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%',
226
+ ha='center', va='center', rotation=90, color='white',
227
+ fontsize=font_size, fontweight='bold')
228
+
229
+ # Text percentage
230
+ text_percentage = 100 * text_sums[j] / total_sums[j]
231
+ ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%',
232
+ ha='center', va='center', rotation=90, color='white',
233
+ fontsize=font_size, fontweight='bold')
234
+
235
+ # Target percentage
236
+ target_percentage = 100 * target_sums[j] / total_sums[j]
237
+ ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%',
238
+ ha='center', va='center', rotation=90, color='white',
239
+ fontsize=font_size, fontweight='bold')
240
+
241
+
242
+ ax.set_xlabel('Step Index')
243
+ ax.set_ylabel('Attention Ratio')
244
+ ax.set_title(f'Attention Ratios for {layer_type}')
245
+ ax.set_xticks(indices)
246
+ ax.set_xticklabels(x)
247
+
248
+ plt.legend()
249
+ plt.show()
250
+
251
+ def aggregate_attention(self, store, target_layers=None, resolution=None,
252
+ gaussian_kernel=3, thr_type='otsu', thr_number=0.5):
253
+ if target_layers is None:
254
+ store_vals = list(store.values())
255
+ elif isinstance(target_layers, list):
256
+ store_vals = [store[x] for x in target_layers]
257
+ else:
258
+ raise ValueError("target_layers must be a list of layer names or None.")
259
+
260
+ # store vals = List[layers] of Tensor[batch_size, text_tokens, image_tokens]
261
+ batch_size = len(store_vals[0])
262
+
263
+ attention_maps = []
264
+ attention_masks = []
265
+
266
+ for i in range(batch_size):
267
+ # Average over the layers
268
+ agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0)
269
+
270
+ if resolution is None:
271
+ size = int(agg_vals.shape[-1] ** 0.5)
272
+ resolution = (size, size)
273
+
274
+ agg_vals = agg_vals.view(agg_vals.shape[0], *resolution)
275
+
276
+ if gaussian_kernel > 0:
277
+ agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype)
278
+
279
+ mask_vals = agg_vals.clone()
280
+
281
+ for j in range(mask_vals.shape[0]):
282
+ mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min())
283
+ np_vals = mask_vals[j].float().cpu().numpy()
284
+
285
+ otsu_thr = filters.threshold_otsu(np_vals)
286
+ li_thr = threshold_li(np_vals, initial_guess=otsu_thr)
287
+ yen_thr = threshold_yen(np_vals)
288
+
289
+ if thr_type == 'otsu':
290
+ thr = otsu_thr
291
+ elif thr_type == 'yen':
292
+ thr = yen_thr
293
+ elif thr_type == 'li':
294
+ thr = li_thr
295
+ elif thr_type == 'number':
296
+ thr = thr_number
297
+ elif thr_type == 'multiotsu':
298
+ thrs = threshold_multiotsu(np_vals, classes=3)
299
+
300
+ if thrs[1] > thrs[0] * 3.5:
301
+ thr = thrs[1]
302
+ else:
303
+ thr = thrs[0]
304
+
305
+ # Take the closest threshold to otsu_thr
306
+ # thr = thrs[np.argmin(np.abs(thrs - otsu_thr))]
307
+
308
+ # alpha = 0.8
309
+ # thr = (alpha * thr + (1-alpha) * mask_vals[j].max())
310
+
311
+ mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype)
312
+
313
+ attention_maps.append(agg_vals)
314
+ attention_masks.append(mask_vals)
315
+
316
+ return attention_maps, attention_masks, self.tokens_to_record
addit_blending_utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is licensed under the LICENSE file
4
+ # located at the root directory.
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+ from skimage import filters
10
+ import matplotlib.pyplot as plt
11
+ from scipy.ndimage import maximum_filter, label, find_objects
12
+
13
+ def dilate_mask(latents_mask, k, latents_dtype):
14
+ # Reshape the mask to 2D (64x64)
15
+ mask_2d = latents_mask.view(64, 64)
16
+
17
+ # Create a square kernel for dilation
18
+ kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype)
19
+
20
+ # Add two dimensions to make it compatible with conv2d
21
+ mask_4d = mask_2d.unsqueeze(0).unsqueeze(0)
22
+
23
+ # Perform dilation using conv2d
24
+ dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k)
25
+
26
+ # Threshold the result to get a binary mask
27
+ dilated_mask = (dilated_mask > 0).to(mask_2d.dtype)
28
+
29
+ # Reshape back to the original shape and convert to the desired dtype
30
+ dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype)
31
+
32
+ return dilated_mask
33
+
34
+ def clipseg_predict(model, processor, image, text, device):
35
+ inputs = processor(text=text, images=image, return_tensors="pt")
36
+ inputs = {k: v.to(device) for k, v in inputs.items()}
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+ preds = outputs.logits.unsqueeze(1)
41
+ preds = torch.sigmoid(preds)
42
+
43
+ otsu_thr = filters.threshold_otsu(preds.cpu().numpy())
44
+ subject_mask = (preds > otsu_thr).float()
45
+
46
+ return subject_mask
47
+
48
+ def grounding_sam_predict(model, processor, sam_predictor, image, text, device):
49
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device)
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ results = processor.post_process_grounded_object_detection(
54
+ outputs,
55
+ inputs.input_ids,
56
+ box_threshold=0.4,
57
+ text_threshold=0.3,
58
+ target_sizes=[image.size[::-1]]
59
+ )
60
+
61
+ input_boxes = results[0]["boxes"].cpu().numpy()
62
+
63
+ if input_boxes.shape[0] == 0:
64
+ return torch.ones((64, 64), device=device)
65
+
66
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
67
+ sam_predictor.set_image(image)
68
+ masks, scores, logits = sam_predictor.predict(
69
+ point_coords=None,
70
+ point_labels=None,
71
+ box=input_boxes,
72
+ multimask_output=False,
73
+ )
74
+
75
+ subject_mask = torch.tensor(masks[0], device=device)
76
+
77
+ return subject_mask
78
+
79
+ def mask_to_box_sam_predict(mask, sam_predictor, image, text, device):
80
+ H, W = image.size
81
+
82
+ # Resize clipseg mask to image size
83
+ mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
84
+ mask_indices = torch.nonzero(mask)
85
+ top_left = mask_indices.min(dim=0)[0]
86
+ bottom_right = mask_indices.max(dim=0)[0]
87
+
88
+ # numpy shape [1,4]
89
+ input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
90
+
91
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
92
+ sam_predictor.set_image(image)
93
+ masks, scores, logits = sam_predictor.predict(
94
+ point_coords=None,
95
+ point_labels=None,
96
+ box=input_boxes,
97
+ multimask_output=True,
98
+ )
99
+
100
+ # subject_mask = torch.tensor(masks[0], device=device)
101
+ subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
102
+
103
+ return subject_mask, input_boxes[0]
104
+
105
+ def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device):
106
+ H, W = (256, 256)
107
+
108
+ # Resize clipseg mask to image size
109
+ mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W)
110
+ mask_input = mask.float().cpu().numpy()
111
+
112
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
113
+ sam_predictor.set_image(image)
114
+ masks, scores, logits = sam_predictor.predict(
115
+ point_coords=None,
116
+ point_labels=None,
117
+ mask_input=mask_input,
118
+ multimask_output=False,
119
+ )
120
+
121
+ subject_mask = torch.tensor(masks[0], device=device)
122
+
123
+ return subject_mask
124
+
125
+ def mask_to_points_sam_predict(mask, sam_predictor, image, text, device):
126
+ H, W = image.size
127
+
128
+ # Resize clipseg mask to image size
129
+ mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
130
+ mask_indices = torch.nonzero(mask)
131
+
132
+ # Randomly sample 10 points from the mask
133
+ n_points = 2
134
+ point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy()
135
+ point_labels = torch.ones((n_points,)).float().cpu().numpy()
136
+
137
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
138
+ sam_predictor.set_image(image)
139
+ masks, scores, logits = sam_predictor.predict(
140
+ point_coords=point_coords,
141
+ point_labels=point_labels,
142
+ multimask_output=False,
143
+ )
144
+
145
+ subject_mask = torch.tensor(masks[0], device=device)
146
+
147
+ return subject_mask
148
+
149
+ def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device):
150
+ H, W = image.size
151
+
152
+ # Resize clipseg mask to image size
153
+ subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
154
+ subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
155
+
156
+ # Get mask_bbox
157
+ subject_mask_indices = torch.nonzero(subject_mask)
158
+ top_left = subject_mask_indices.min(dim=0)[0]
159
+ bottom_right = subject_mask_indices.max(dim=0)[0]
160
+ box_width = bottom_right[1] - top_left[1]
161
+ box_height = bottom_right[0] - top_left[0]
162
+
163
+ # Define the number of points and minimum distance between points
164
+ n_points = 3
165
+ max_thr = 0.35
166
+ max_attention = torch.max(subject_attention)
167
+ min_distance = max(box_width, box_height) // (n_points + 1) # Adjust this value to control spread
168
+ # min_distance = max(min_distance, 75)
169
+
170
+ # Initialize list to store selected points
171
+ selected_points = []
172
+
173
+ # Create a copy of the attention map
174
+ remaining_attention = subject_attention.clone()
175
+
176
+ for _ in range(n_points):
177
+ if remaining_attention.max() < max_thr * max_attention:
178
+ break
179
+
180
+ # Find the highest attention point
181
+ point = torch.argmax(remaining_attention)
182
+ y, x = torch.unravel_index(point, remaining_attention.shape)
183
+ y, x = y.item(), x.item()
184
+
185
+ # Add the point to our list
186
+ selected_points.append((x, y))
187
+
188
+ # Zero out the area around the selected point
189
+ y_min = max(0, y - min_distance)
190
+ y_max = min(H, y + min_distance + 1)
191
+ x_min = max(0, x - min_distance)
192
+ x_max = min(W, x + min_distance + 1)
193
+ remaining_attention[y_min:y_max, x_min:x_max] = 0
194
+
195
+ # Convert selected points to numpy array
196
+ point_coords = np.array(selected_points)
197
+ point_labels = np.ones(point_coords.shape[0], dtype=int)
198
+
199
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
200
+ sam_predictor.set_image(image)
201
+ masks, scores, logits = sam_predictor.predict(
202
+ point_coords=point_coords,
203
+ point_labels=point_labels,
204
+ multimask_output=False,
205
+ )
206
+
207
+ subject_mask = torch.tensor(masks[0], device=device)
208
+
209
+ return subject_mask, point_coords
210
+
211
+ def sam_refine_step(mask, sam_predictor, image, device):
212
+ mask_indices = torch.nonzero(mask)
213
+ top_left = mask_indices.min(dim=0)[0]
214
+ bottom_right = mask_indices.max(dim=0)[0]
215
+
216
+ # numpy shape [1,4]
217
+ input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
218
+
219
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
220
+ sam_predictor.set_image(image)
221
+ masks, scores, logits = sam_predictor.predict(
222
+ point_coords=None,
223
+ point_labels=None,
224
+ box=input_boxes,
225
+ multimask_output=True,
226
+ )
227
+
228
+ # subject_mask = torch.tensor(masks[0], device=device)
229
+ subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
230
+
231
+ return subject_mask, input_boxes[0]
232
+
addit_flux_pipeline.py ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
16
+ #
17
+ # This work is licensed under the LICENSE file
18
+ # located at the root directory.
19
+
20
+ from typing import Any, Callable, Dict, List, Optional, Union
21
+ import torch
22
+ import numpy as np
23
+ from PIL import Image
24
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ import matplotlib.pyplot as plt
29
+
30
+ import torch.fft
31
+ import torch.nn.functional as F
32
+
33
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
34
+ from addit_attention_processors import AdditFluxAttnProcessor2_0, AdditFluxSingleAttnProcessor2_0
35
+ from addit_attention_store import AttentionStore
36
+
37
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
38
+ from skimage import filters
39
+ from visualization_utils import show_image_and_heatmap, show_images, draw_points_on_pil_image, draw_bboxes_on_image
40
+ from addit_blending_utils import clipseg_predict, grounding_sam_predict, mask_to_box_sam_predict, \
41
+ mask_to_mask_sam_predict, attention_to_points_sam_predict
42
+
43
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
44
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
45
+
46
+ from scipy.optimize import brentq
47
+ from scipy.optimize import root_scalar
48
+
49
+ def register_my_attention_processors(transformer, attention_store, extended_steps_multi, extended_steps_single):
50
+ attn_procs = {}
51
+
52
+ for i, (name, processor) in enumerate(transformer.attn_processors.items()):
53
+ layer_name = ".".join(name.split(".")[:2])
54
+
55
+ if layer_name.startswith("transformer_blocks"):
56
+ attn_procs[name] = AdditFluxAttnProcessor2_0(layer_name=layer_name,
57
+ attention_store=attention_store,
58
+ extended_steps=extended_steps_multi)
59
+ elif layer_name.startswith("single_transformer_blocks"):
60
+ attn_procs[name] = AdditFluxSingleAttnProcessor2_0(layer_name=layer_name,
61
+ attention_store=attention_store,
62
+ extended_steps=extended_steps_single)
63
+
64
+ transformer.set_attn_processor(attn_procs)
65
+
66
+ def register_regular_attention_processors(transformer):
67
+ attn_procs = {}
68
+
69
+ for i, (name, processor) in enumerate(transformer.attn_processors.items()):
70
+ layer_name = ".".join(name.split(".")[:2])
71
+
72
+ if layer_name.startswith("transformer_blocks"):
73
+ attn_procs[name] = FluxAttnProcessor2_0()
74
+ elif layer_name.startswith("single_transformer_blocks"):
75
+ attn_procs[name] = FluxSingleAttnProcessor2_0()
76
+
77
+ transformer.set_attn_processor(attn_procs)
78
+
79
+ def img2img_retrieve_latents(
80
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
81
+ ):
82
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
83
+ return encoder_output.latent_dist.sample(generator)
84
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
85
+ return encoder_output.latent_dist.mode()
86
+ elif hasattr(encoder_output, "latents"):
87
+ return encoder_output.latents
88
+ else:
89
+ raise AttributeError("Could not access latents of provided encoder_output")
90
+
91
+ class AdditFluxPipeline(FluxPipeline):
92
+ def prepare_latents(
93
+ self,
94
+ batch_size,
95
+ num_channels_latents,
96
+ height,
97
+ width,
98
+ dtype,
99
+ device,
100
+ generator,
101
+ latents=None,
102
+ ):
103
+ height = 2 * (int(height) // self.vae_scale_factor)
104
+ width = 2 * (int(width) // self.vae_scale_factor)
105
+
106
+ shape = (batch_size, num_channels_latents, height, width)
107
+
108
+ if latents is not None:
109
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
110
+ return latents.to(device=device, dtype=dtype), latent_image_ids
111
+
112
+ if isinstance(generator, list) and len(generator) != batch_size:
113
+ raise ValueError(
114
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
115
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
116
+ )
117
+
118
+ if isinstance(generator, list):
119
+ latents = torch.empty(shape, device=device, dtype=dtype)
120
+
121
+ latents_list = [randn_tensor(shape, generator=g, device=device, dtype=dtype) for g in generator]
122
+
123
+ for i, l_i in enumerate(latents_list):
124
+ latents[i] = l_i[i]
125
+ else:
126
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
127
+
128
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
129
+
130
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
131
+
132
+ return latents, latent_image_ids
133
+
134
+ @torch.no_grad()
135
+ def __call__(
136
+ self,
137
+ prompt: Union[str, List[str]] = None,
138
+ prompt_2: Optional[Union[str, List[str]]] = None,
139
+ height: Optional[int] = None,
140
+ width: Optional[int] = None,
141
+ num_inference_steps: int = 28,
142
+ timesteps: List[int] = None,
143
+ guidance_scale: Union[float, List[float]] = 7.0,
144
+ num_images_per_prompt: Optional[int] = 1,
145
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
146
+ latents: Optional[torch.FloatTensor] = None,
147
+ prompt_embeds: Optional[torch.FloatTensor] = None,
148
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
149
+ output_type: Optional[str] = "pil",
150
+ return_dict: bool = True,
151
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
152
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
153
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
154
+ max_sequence_length: int = 512,
155
+
156
+ seed: Optional[Union[int, List[int]]] = None,
157
+ same_latent_for_all_prompts: bool = False,
158
+
159
+ # Extended Attention
160
+ extended_steps_multi: Optional[int] = -1,
161
+ extended_steps_single: Optional[int] = -1,
162
+ extended_scale: Optional[Union[float, str]] = 1.0,
163
+
164
+ # Structure Transfer
165
+ source_latents: Optional[torch.FloatTensor] = None,
166
+ structure_transfer_step: int = 5,
167
+
168
+ # Latent Blending
169
+ subject_token: Optional[str] = None,
170
+ localization_model: Optional[str] = "attention_points_sam",
171
+ blend_steps: List[int] = [],
172
+ show_attention: bool = False,
173
+
174
+ # Real Image Source
175
+ is_img_src: bool = False,
176
+ use_offset: bool = False,
177
+ img_src_latents: Optional[List[torch.FloatTensor]] = None,
178
+ ):
179
+ r"""
180
+ Function invoked when calling the pipeline for generation.
181
+
182
+ Args:
183
+ prompt (`str` or `List[str]`, *optional*):
184
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
185
+ instead.
186
+ prompt_2 (`str` or `List[str]`, *optional*):
187
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
188
+ will be used instead
189
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
190
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
191
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
192
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
193
+ num_inference_steps (`int`, *optional*, defaults to 50):
194
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
195
+ expense of slower inference.
196
+ timesteps (`List[int]`, *optional*):
197
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
198
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
199
+ passed will be used. Must be in descending order.
200
+ guidance_scale (`float`, *optional*, defaults to 7.0):
201
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
202
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
203
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
204
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
205
+ usually at the expense of lower image quality.
206
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
207
+ The number of images to generate per prompt.
208
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
209
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
210
+ to make generation deterministic.
211
+ latents (`torch.FloatTensor`, *optional*):
212
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
213
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
214
+ tensor will ge generated by sampling using the supplied random `generator`.
215
+ prompt_embeds (`torch.FloatTensor`, *optional*):
216
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
217
+ provided, text embeddings will be generated from `prompt` input argument.
218
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
219
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
220
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
221
+ output_type (`str`, *optional*, defaults to `"pil"`):
222
+ The output format of the generate image. Choose between
223
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
224
+ return_dict (`bool`, *optional*, defaults to `True`):
225
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
226
+ joint_attention_kwargs (`dict`, *optional*):
227
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
228
+ `self.processor` in
229
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
230
+ callback_on_step_end (`Callable`, *optional*):
231
+ A function that calls at the end of each denoising steps during the inference. The function is called
232
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
233
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
234
+ `callback_on_step_end_tensor_inputs`.
235
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
236
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
237
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
238
+ `._callback_tensor_inputs` attribute of your pipeline class.
239
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
240
+
241
+ Examples:
242
+
243
+ Returns:
244
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
245
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
246
+ images.
247
+ """
248
+
249
+ device = self._execution_device
250
+
251
+ # Blend Steps
252
+ blend_models = {}
253
+ if len(blend_steps) > 0:
254
+ if localization_model == "clipseg":
255
+ blend_models["clipseg_processor"] = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
256
+ blend_models["clipseg_model"] = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
257
+ elif localization_model == "grounding_sam":
258
+ grounding_dino_model_id = "IDEA-Research/grounding-dino-base"
259
+ blend_models["grounding_processor"] = AutoProcessor.from_pretrained(grounding_dino_model_id)
260
+ blend_models["grounding_model"] = AutoModelForZeroShotObjectDetection.from_pretrained(grounding_dino_model_id).to(device)
261
+ blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
262
+ elif localization_model == "clipseg_sam":
263
+ blend_models["clipseg_processor"] = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
264
+ blend_models["clipseg_model"] = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
265
+ blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
266
+ elif localization_model == "attention":
267
+ pass
268
+ elif localization_model in ["attention_box_sam", "attention_mask_sam", "attention_points_sam"]:
269
+ blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
270
+
271
+ height = height or self.default_sample_size * self.vae_scale_factor
272
+ width = width or self.default_sample_size * self.vae_scale_factor
273
+
274
+ # 1. Check inputs. Raise error if not correct
275
+ self.check_inputs(
276
+ prompt,
277
+ prompt_2,
278
+ height,
279
+ width,
280
+ prompt_embeds=prompt_embeds,
281
+ pooled_prompt_embeds=pooled_prompt_embeds,
282
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
283
+ max_sequence_length=max_sequence_length,
284
+ )
285
+
286
+ self._guidance_scale = guidance_scale
287
+ self._joint_attention_kwargs = joint_attention_kwargs
288
+ self._interrupt = False
289
+
290
+ # 2. Define call parameters
291
+ if prompt is not None and isinstance(prompt, str):
292
+ batch_size = 1
293
+ elif prompt is not None and isinstance(prompt, list):
294
+ batch_size = len(prompt)
295
+ else:
296
+ batch_size = prompt_embeds.shape[0]
297
+
298
+ device = self._execution_device
299
+
300
+ lora_scale = (
301
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
302
+ )
303
+ (
304
+ prompt_embeds,
305
+ pooled_prompt_embeds,
306
+ text_ids,
307
+ ) = self.encode_prompt(
308
+ prompt=prompt,
309
+ prompt_2=prompt_2,
310
+ prompt_embeds=prompt_embeds,
311
+ pooled_prompt_embeds=pooled_prompt_embeds,
312
+ device=device,
313
+ num_images_per_prompt=num_images_per_prompt,
314
+ max_sequence_length=max_sequence_length,
315
+ lora_scale=lora_scale,
316
+ )
317
+
318
+ # 4. Prepare latent variables
319
+ if (generator is None) and seed is not None:
320
+ if isinstance(seed, int):
321
+ generator = torch.Generator(device=device).manual_seed(seed)
322
+ else:
323
+ assert len(seed) == batch_size, "The number of seeds must match the batch size"
324
+ generator = [torch.Generator(device=device).manual_seed(s) for s in seed]
325
+
326
+ num_channels_latents = self.transformer.config.in_channels // 4
327
+ latents, latent_image_ids = self.prepare_latents(
328
+ batch_size * num_images_per_prompt,
329
+ num_channels_latents,
330
+ height,
331
+ width,
332
+ prompt_embeds.dtype,
333
+ device,
334
+ generator,
335
+ latents,
336
+ )
337
+
338
+ if same_latent_for_all_prompts:
339
+ latents = latents[:1].repeat(batch_size * num_images_per_prompt, 1, 1)
340
+
341
+ noise = latents.clone()
342
+
343
+ attention_store_kwargs = {}
344
+
345
+ if extended_scale == "auto":
346
+ is_auto_extend_scale = True
347
+ extended_scale = 1.05
348
+ attention_store_kwargs["is_cache_attn_ratio"] = True
349
+ auto_extended_step = 5
350
+ target_auto_ratio = 1.05
351
+ else:
352
+ is_auto_extend_scale = False
353
+
354
+ if len(blend_steps) > 0:
355
+ attn_steps = range(blend_steps[0] - 2, blend_steps[0] + 1)
356
+ attention_store_kwargs["record_attention_steps"] = attn_steps
357
+
358
+ self.attention_store = AttentionStore(prompts=prompt, tokenizer=self.tokenizer_2, subject_token=subject_token, **attention_store_kwargs)
359
+ register_my_attention_processors(self.transformer, self.attention_store, extended_steps_multi, extended_steps_single)
360
+
361
+ # 5. Prepare timesteps
362
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
363
+ image_seq_len = latents.shape[1]
364
+ mu = calculate_shift(
365
+ image_seq_len,
366
+ self.scheduler.config.base_image_seq_len,
367
+ self.scheduler.config.max_image_seq_len,
368
+ self.scheduler.config.base_shift,
369
+ self.scheduler.config.max_shift,
370
+ )
371
+ timesteps, num_inference_steps = retrieve_timesteps(
372
+ self.scheduler,
373
+ num_inference_steps,
374
+ device,
375
+ timesteps,
376
+ sigmas,
377
+ mu=mu,
378
+ )
379
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
380
+ self._num_timesteps = len(timesteps)
381
+
382
+ # handle guidance
383
+ if self.transformer.config.guidance_embeds:
384
+ if isinstance(guidance_scale, float):
385
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
386
+ guidance = guidance.expand(latents.shape[0])
387
+ elif isinstance(guidance_scale, list):
388
+ assert len(guidance_scale) == latents.shape[0], "The number of guidance scales must match the batch size"
389
+ guidance = torch.tensor(guidance_scale, device=device, dtype=torch.float32)
390
+ else:
391
+ guidance = None
392
+
393
+ if is_img_src and img_src_latents is None:
394
+ assert source_latents is not None, "source_latents must be provided when is_img_src is True"
395
+
396
+ rand_noise = noise[0].clone()
397
+ img_src_latents = []
398
+
399
+ for i in range(timesteps.shape[0]):
400
+ sigma = self.scheduler.sigmas[i]
401
+ img_src_latents.append((1.0 - sigma) * source_latents[0] + sigma * rand_noise)
402
+
403
+ # 6. Denoising loop
404
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
405
+ for i, t in enumerate(timesteps):
406
+ if self.interrupt:
407
+ continue
408
+
409
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
410
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
411
+
412
+ # For denoising from source image
413
+ if is_img_src:
414
+ latents[0] = img_src_latents[i]
415
+
416
+ # For Structure Transfer
417
+ if (source_latents is not None) and i == structure_transfer_step:
418
+ sigma = self.scheduler.sigmas[i]
419
+ latents[1] = (1.0 - sigma) * source_latents[0] + sigma * noise[1]
420
+
421
+ if is_auto_extend_scale and i == auto_extended_step:
422
+ def f(gamma):
423
+ self.attention_store.attention_ratios[i] = {}
424
+ noise_pred = self.transformer(
425
+ hidden_states=latents,
426
+ timestep=timestep / 1000,
427
+ guidance=guidance,
428
+ pooled_projections=pooled_prompt_embeds,
429
+ encoder_hidden_states=prompt_embeds,
430
+ txt_ids=text_ids,
431
+ img_ids=latent_image_ids,
432
+ joint_attention_kwargs=self.joint_attention_kwargs,
433
+ return_dict=False,
434
+ proccesor_kwargs={"step_index": i, "extended_scale": gamma},
435
+ )[0]
436
+
437
+ scores_per_layer = self.attention_store.get_attention_ratios(step_indices=[i], display_imgs=False)
438
+ source_sum, text_sum, target_sum = scores_per_layer['transformer_blocks']
439
+
440
+ # We want to find the gamma that makes the ratio equal to K
441
+ ratio = (target_sum / source_sum)
442
+ return (ratio - target_auto_ratio)
443
+
444
+ gamma_sol = brentq(f, 1.0, 1.2, xtol=0.01)
445
+
446
+ print('Chosen gamma:', gamma_sol)
447
+ extended_scale = gamma_sol
448
+ else:
449
+ noise_pred = self.transformer(
450
+ hidden_states=latents,
451
+ timestep=timestep / 1000,
452
+ guidance=guidance,
453
+ pooled_projections=pooled_prompt_embeds,
454
+ encoder_hidden_states=prompt_embeds,
455
+ txt_ids=text_ids,
456
+ img_ids=latent_image_ids,
457
+ joint_attention_kwargs=self.joint_attention_kwargs,
458
+ return_dict=False,
459
+ proccesor_kwargs={"step_index": i, "extended_scale": extended_scale},
460
+ )[0]
461
+
462
+ # compute the previous noisy sample x_t -> x_t-1
463
+ latents_dtype = latents.dtype
464
+ latents, x0 = self.scheduler.step(noise_pred, t, latents, return_dict=False, step_index=i)
465
+
466
+ if use_offset and is_img_src and (i+1 < len(img_src_latents)):
467
+ next_latent = img_src_latents[i+1]
468
+ offset = (next_latent - latents[0])
469
+ latents[1] = latents[1] + offset
470
+
471
+ # blend latents
472
+ if i in blend_steps and (subject_token is not None) and (localization_model is not None):
473
+ x0 = self._unpack_latents(x0, height, width, self.vae_scale_factor)
474
+ x0 = (x0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor
475
+ images = self.vae.decode(x0, return_dict=False)[0]
476
+ images = self.image_processor.postprocess(images, output_type="pil")
477
+
478
+ self.do_step_blend(images, latents, subject_token, localization_model, show_attention, i, blend_models)
479
+
480
+ if latents.dtype != latents_dtype:
481
+ if torch.backends.mps.is_available():
482
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
483
+ latents = latents.to(latents_dtype)
484
+
485
+ if callback_on_step_end is not None:
486
+ callback_kwargs = {}
487
+ for k in callback_on_step_end_tensor_inputs:
488
+ callback_kwargs[k] = locals()[k]
489
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
490
+
491
+ latents = callback_outputs.pop("latents", latents)
492
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
493
+
494
+ # call the callback, if provided
495
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
496
+ progress_bar.update()
497
+
498
+ # if XLA_AVAILABLE:
499
+ # xm.mark_step()
500
+
501
+ if output_type == "latent":
502
+ image = latents
503
+ elif output_type == "both":
504
+ return_latents = latents
505
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
506
+
507
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
508
+ image = self.vae.decode(latents, return_dict=False)[0]
509
+ image = self.image_processor.postprocess(image, output_type="pil")
510
+
511
+ return (image, return_latents)
512
+ else:
513
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
514
+
515
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
516
+ image = self.vae.decode(latents, return_dict=False)[0]
517
+ image = self.image_processor.postprocess(image, output_type=output_type)
518
+
519
+ # Offload all models
520
+ self.maybe_free_model_hooks()
521
+
522
+ if not return_dict:
523
+ return (image,)
524
+
525
+ return FluxPipelineOutput(images=image)
526
+
527
+ def do_step_blend(self, images, latents, subject_token, localization_model,
528
+ show_attention, i, blend_models):
529
+
530
+ device = latents.device
531
+ latents_dtype = latents.dtype
532
+
533
+ clipseg_processor = blend_models.get("clipseg_processor", None)
534
+ clipseg_model = blend_models.get("clipseg_model", None)
535
+ grounding_processor = blend_models.get("grounding_processor", None)
536
+ grounding_model = blend_models.get("grounding_model", None)
537
+ sam_predictor = blend_models.get("sam_predictor", None)
538
+
539
+ image_to_display = []
540
+ titles_to_display = []
541
+
542
+ if show_attention:
543
+ image_to_display += [images[0], images[1]]
544
+ titles_to_display += ["Source X0", "Target X0"]
545
+
546
+ if localization_model == "clipseg":
547
+ subject_mask = clipseg_predict(clipseg_model, clipseg_processor, [images[-1]], f"A photo of {subject_token}", device)
548
+ elif localization_model == "grounding_sam":
549
+ subject_mask = grounding_sam_predict(grounding_model, grounding_processor, sam_predictor, images[-1], f"A {subject_token}.", device)
550
+ elif localization_model == "clipseg_sam":
551
+ subject_mask = clipseg_predict(clipseg_model, clipseg_processor, [images[-1]], f"A photo of {subject_token}", device)
552
+ subject_mask = mask_to_box_sam_predict(subject_mask, sam_predictor, images[-1], None, device)
553
+ elif localization_model == "attention":
554
+ store = self.attention_store.image2text_store
555
+ attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
556
+
557
+ subject_mask = attention_masks[0][-1].to(device)
558
+ subject_attention = attention_maps[0][-1].to(device)
559
+
560
+ if show_attention:
561
+ attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
562
+ attention_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=512)
563
+
564
+ image_to_display += [attentioned_image, attention_masked_image]
565
+ titles_to_display += ["Attention", "Attention Mask"]
566
+
567
+ elif localization_model == "attention_box_sam":
568
+ store = self.attention_store.image2text_store
569
+ attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
570
+
571
+ attention_mask = attention_masks[0][-1].to(device)
572
+ subject_attention = attention_maps[0][-1].to(device)
573
+
574
+ subject_mask, bbox = mask_to_box_sam_predict(attention_mask, sam_predictor, images[-1], None, device)
575
+
576
+ if show_attention:
577
+ attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
578
+ attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
579
+
580
+ sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
581
+ sam_masked_image = draw_bboxes_on_image(sam_masked_image, [bbox.tolist()], color="green", thickness=5)
582
+
583
+ image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
584
+ titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
585
+
586
+ elif localization_model == "attention_mask_sam":
587
+ store = self.attention_store.image2text_store
588
+ attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
589
+
590
+ attention_mask = attention_masks[0][-1].to(device)
591
+ subject_attention = attention_maps[0][-1].to(device)
592
+
593
+ subject_mask = mask_to_mask_sam_predict(attention_mask, sam_predictor, images[-1], None, device)
594
+
595
+ if show_attention:
596
+ print('Attention:')
597
+ attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
598
+ attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
599
+ sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
600
+
601
+ image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
602
+ titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
603
+
604
+ elif localization_model == "attention_points_sam":
605
+ store = self.attention_store.image2text_store
606
+ attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
607
+
608
+ attention_mask = attention_masks[0][-1].to(device)
609
+ subject_attention = attention_maps[0][-1].to(device)
610
+
611
+ subject_mask, point_coords = attention_to_points_sam_predict(subject_attention, attention_mask, sam_predictor, images[1], None, device)
612
+
613
+ if show_attention:
614
+ print('Attention:')
615
+ attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
616
+ attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
617
+
618
+ sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
619
+ sam_masked_image = draw_points_on_pil_image(sam_masked_image, point_coords, point_color="green", radius=10)
620
+
621
+ image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
622
+ titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
623
+
624
+ if show_attention:
625
+ show_images(image_to_display, titles_to_display, size=512, save_path="attn_vis.png")
626
+
627
+ # Resize the mask to latents size
628
+ latents_mask = torch.nn.functional.interpolate(subject_mask.view(1,1,subject_mask.shape[-2],subject_mask.shape[-1]), size=64, mode='bilinear').view(4096, 1).to(latents_dtype)
629
+ latents_mask[latents_mask > 0.01] = 1
630
+
631
+ latents[1] = latents[1] * latents_mask + latents[0] * (1 - latents_mask)
632
+
633
+ ############# Image to Image Methods #############
634
+ def img2img_encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
635
+ if isinstance(generator, list):
636
+ image_latents = [
637
+ img2img_retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
638
+ for i in range(image.shape[0])
639
+ ]
640
+ image_latents = torch.cat(image_latents, dim=0)
641
+ else:
642
+ image_latents = img2img_retrieve_latents(self.vae.encode(image), generator=generator)
643
+
644
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
645
+
646
+ return image_latents
647
+
648
+ def img2img_prepare_latents(
649
+ self,
650
+ image,
651
+ timestep,
652
+ batch_size,
653
+ num_channels_latents,
654
+ height,
655
+ width,
656
+ dtype,
657
+ device,
658
+ generator,
659
+ latents=None,
660
+ ):
661
+ if isinstance(generator, list) and len(generator) != batch_size:
662
+ raise ValueError(
663
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
664
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
665
+ )
666
+
667
+ height = 2 * (int(height) // self.vae_scale_factor)
668
+ width = 2 * (int(width) // self.vae_scale_factor)
669
+
670
+ shape = (batch_size, num_channels_latents, height, width)
671
+ latent_image_ids = self.img2img_prepare_latent_image_ids(batch_size, height, width, device, dtype)
672
+
673
+ if latents is not None:
674
+ return latents.to(device=device, dtype=dtype), latent_image_ids
675
+
676
+ image = image.to(device=device, dtype=dtype)
677
+ image_latents = self.img2img_encode_vae_image(image=image, generator=generator)
678
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
679
+ # expand init_latents for batch_size
680
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
681
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
682
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
683
+ raise ValueError(
684
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
685
+ )
686
+ else:
687
+ image_latents = torch.cat([image_latents], dim=0)
688
+
689
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
690
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
691
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
692
+ return latents, latent_image_ids
693
+
694
+ def img2img_check_inputs(
695
+ self,
696
+ prompt,
697
+ prompt_2,
698
+ strength,
699
+ height,
700
+ width,
701
+ prompt_embeds=None,
702
+ pooled_prompt_embeds=None,
703
+ callback_on_step_end_tensor_inputs=None,
704
+ max_sequence_length=None,
705
+ ):
706
+ if strength < 0 or strength > 1:
707
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
708
+
709
+ if height % 8 != 0 or width % 8 != 0:
710
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
711
+
712
+ if callback_on_step_end_tensor_inputs is not None and not all(
713
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
714
+ ):
715
+ raise ValueError(
716
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
717
+ )
718
+
719
+ if prompt is not None and prompt_embeds is not None:
720
+ raise ValueError(
721
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
722
+ " only forward one of the two."
723
+ )
724
+ elif prompt_2 is not None and prompt_embeds is not None:
725
+ raise ValueError(
726
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
727
+ " only forward one of the two."
728
+ )
729
+ elif prompt is None and prompt_embeds is None:
730
+ raise ValueError(
731
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
732
+ )
733
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
734
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
735
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
736
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
737
+
738
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
739
+ raise ValueError(
740
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
741
+ )
742
+
743
+ if max_sequence_length is not None and max_sequence_length > 512:
744
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
745
+
746
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
747
+ def img2img_get_timesteps(self, num_inference_steps, strength, device):
748
+ # get the original timestep using init_timestep
749
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
750
+
751
+ t_start = int(max(num_inference_steps - init_timestep, 0))
752
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
753
+ if hasattr(self.scheduler, "set_begin_index"):
754
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
755
+
756
+ return timesteps, num_inference_steps - t_start
757
+
758
+ @staticmethod
759
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
760
+ def img2img_prepare_latent_image_ids(batch_size, height, width, device, dtype):
761
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
762
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
763
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
764
+
765
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
766
+
767
+ latent_image_ids = latent_image_ids.reshape(
768
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
769
+ )
770
+
771
+ return latent_image_ids.to(device=device, dtype=dtype)
772
+
773
+ @torch.no_grad()
774
+ def call_img2img(
775
+ self,
776
+ prompt: Union[str, List[str]] = None,
777
+ prompt_2: Optional[Union[str, List[str]]] = None,
778
+ image: PipelineImageInput = None,
779
+ height: Optional[int] = None,
780
+ width: Optional[int] = None,
781
+ strength: float = 0.6,
782
+ num_inference_steps: int = 28,
783
+ timesteps: List[int] = None,
784
+ guidance_scale: float = 7.0,
785
+ num_images_per_prompt: Optional[int] = 1,
786
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
787
+ latents: Optional[torch.FloatTensor] = None,
788
+ prompt_embeds: Optional[torch.FloatTensor] = None,
789
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
790
+ output_type: Optional[str] = "pil",
791
+ return_dict: bool = True,
792
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
793
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
794
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
795
+ max_sequence_length: int = 512,
796
+ ):
797
+ r"""
798
+ Function invoked when calling the pipeline for generation.
799
+
800
+ Args:
801
+ prompt (`str` or `List[str]`, *optional*):
802
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
803
+ instead.
804
+ prompt_2 (`str` or `List[str]`, *optional*):
805
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
806
+ will be used instead
807
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
808
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
809
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
810
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
811
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
812
+ latents as `image`, but if passing latents directly it is not encoded again.
813
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
814
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
815
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
816
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
817
+ strength (`float`, *optional*, defaults to 1.0):
818
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
819
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
820
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
821
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
822
+ essentially ignores `image`.
823
+ num_inference_steps (`int`, *optional*, defaults to 50):
824
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
825
+ expense of slower inference.
826
+ timesteps (`List[int]`, *optional*):
827
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
828
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
829
+ passed will be used. Must be in descending order.
830
+ guidance_scale (`float`, *optional*, defaults to 7.0):
831
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
832
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
833
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
834
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
835
+ usually at the expense of lower image quality.
836
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
837
+ The number of images to generate per prompt.
838
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
839
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
840
+ to make generation deterministic.
841
+ latents (`torch.FloatTensor`, *optional*):
842
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
843
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
844
+ tensor will ge generated by sampling using the supplied random `generator`.
845
+ prompt_embeds (`torch.FloatTensor`, *optional*):
846
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
847
+ provided, text embeddings will be generated from `prompt` input argument.
848
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
849
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
850
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
851
+ output_type (`str`, *optional*, defaults to `"pil"`):
852
+ The output format of the generate image. Choose between
853
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
856
+ joint_attention_kwargs (`dict`, *optional*):
857
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
858
+ `self.processor` in
859
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
860
+ callback_on_step_end (`Callable`, *optional*):
861
+ A function that calls at the end of each denoising steps during the inference. The function is called
862
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
863
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
864
+ `callback_on_step_end_tensor_inputs`.
865
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
866
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
867
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
868
+ `._callback_tensor_inputs` attribute of your pipeline class.
869
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
870
+
871
+ Examples:
872
+
873
+ Returns:
874
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
875
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
876
+ images.
877
+ """
878
+
879
+ height = height or self.default_sample_size * self.vae_scale_factor
880
+ width = width or self.default_sample_size * self.vae_scale_factor
881
+
882
+ # 1. Check inputs. Raise error if not correct
883
+ self.img2img_check_inputs(
884
+ prompt,
885
+ prompt_2,
886
+ strength,
887
+ height,
888
+ width,
889
+ prompt_embeds=prompt_embeds,
890
+ pooled_prompt_embeds=pooled_prompt_embeds,
891
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
892
+ max_sequence_length=max_sequence_length,
893
+ )
894
+
895
+ self._guidance_scale = guidance_scale
896
+ self._joint_attention_kwargs = joint_attention_kwargs
897
+ self._interrupt = False
898
+
899
+ # 2. Preprocess image
900
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
901
+ init_image = init_image.to(dtype=torch.float32)
902
+
903
+ # 3. Define call parameters
904
+ if prompt is not None and isinstance(prompt, str):
905
+ batch_size = 1
906
+ elif prompt is not None and isinstance(prompt, list):
907
+ batch_size = len(prompt)
908
+ else:
909
+ batch_size = prompt_embeds.shape[0]
910
+
911
+ device = self._execution_device
912
+
913
+ lora_scale = (
914
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
915
+ )
916
+ (
917
+ prompt_embeds,
918
+ pooled_prompt_embeds,
919
+ text_ids,
920
+ ) = self.encode_prompt(
921
+ prompt=prompt,
922
+ prompt_2=prompt_2,
923
+ prompt_embeds=prompt_embeds,
924
+ pooled_prompt_embeds=pooled_prompt_embeds,
925
+ device=device,
926
+ num_images_per_prompt=num_images_per_prompt,
927
+ max_sequence_length=max_sequence_length,
928
+ lora_scale=lora_scale,
929
+ )
930
+
931
+ register_regular_attention_processors(self.transformer)
932
+
933
+ # 4.Prepare timesteps
934
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
935
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
936
+ mu = calculate_shift(
937
+ image_seq_len,
938
+ self.scheduler.config.base_image_seq_len,
939
+ self.scheduler.config.max_image_seq_len,
940
+ self.scheduler.config.base_shift,
941
+ self.scheduler.config.max_shift,
942
+ )
943
+ timesteps, num_inference_steps = retrieve_timesteps(
944
+ self.scheduler,
945
+ num_inference_steps,
946
+ device,
947
+ timesteps,
948
+ sigmas,
949
+ mu=mu,
950
+ )
951
+ timesteps, num_inference_steps = self.img2img_get_timesteps(num_inference_steps, strength, device)
952
+
953
+ if num_inference_steps < 1:
954
+ raise ValueError(
955
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
956
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
957
+ )
958
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
959
+
960
+ # 5. Prepare latent variables
961
+ num_channels_latents = self.transformer.config.in_channels // 4
962
+
963
+ latents, latent_image_ids = self.img2img_prepare_latents(
964
+ init_image,
965
+ latent_timestep,
966
+ batch_size * num_images_per_prompt,
967
+ num_channels_latents,
968
+ height,
969
+ width,
970
+ prompt_embeds.dtype,
971
+ device,
972
+ generator,
973
+ latents,
974
+ )
975
+
976
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
977
+ self._num_timesteps = len(timesteps)
978
+
979
+ # handle guidance
980
+ if self.transformer.config.guidance_embeds:
981
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
982
+ guidance = guidance.expand(latents.shape[0])
983
+ else:
984
+ guidance = None
985
+
986
+ text_ids = text_ids.expand(latents.shape[0], -1, -1)
987
+ latent_image_ids = latent_image_ids.expand(latents.shape[0], -1, -1)
988
+
989
+ # 6. Denoising loop
990
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
991
+ for i, t in enumerate(timesteps):
992
+ if self.interrupt:
993
+ continue
994
+
995
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
996
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
997
+ noise_pred = self.transformer(
998
+ hidden_states=latents,
999
+ timestep=timestep / 1000,
1000
+ guidance=guidance,
1001
+ pooled_projections=pooled_prompt_embeds,
1002
+ encoder_hidden_states=prompt_embeds,
1003
+ txt_ids=text_ids,
1004
+ img_ids=latent_image_ids,
1005
+ joint_attention_kwargs=self.joint_attention_kwargs,
1006
+ return_dict=False,
1007
+ )[0]
1008
+
1009
+ # compute the previous noisy sample x_t -> x_t-1
1010
+ latents_dtype = latents.dtype
1011
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1012
+
1013
+ if latents.dtype != latents_dtype:
1014
+ if torch.backends.mps.is_available():
1015
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1016
+ latents = latents.to(latents_dtype)
1017
+
1018
+ if callback_on_step_end is not None:
1019
+ callback_kwargs = {}
1020
+ for k in callback_on_step_end_tensor_inputs:
1021
+ callback_kwargs[k] = locals()[k]
1022
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1023
+
1024
+ latents = callback_outputs.pop("latents", latents)
1025
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1026
+
1027
+ # call the callback, if provided
1028
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1029
+ progress_bar.update()
1030
+
1031
+ # if XLA_AVAILABLE:
1032
+ # xm.mark_step()
1033
+
1034
+ if output_type == "latent":
1035
+ image = latents
1036
+
1037
+ else:
1038
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1039
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1040
+ image = self.vae.decode(latents, return_dict=False)[0]
1041
+ image = self.image_processor.postprocess(image, output_type=output_type)
1042
+
1043
+ # Offload all models
1044
+ self.maybe_free_model_hooks()
1045
+
1046
+ if not return_dict:
1047
+ return (image,)
1048
+
1049
+ return FluxPipelineOutput(images=image)
1050
+
1051
+ ############# Invert Methods #############
1052
+ def invert_prepare_latents(
1053
+ self,
1054
+ image,
1055
+ timestep,
1056
+ batch_size,
1057
+ num_channels_latents,
1058
+ height,
1059
+ width,
1060
+ dtype,
1061
+ device,
1062
+ generator,
1063
+ latents=None,
1064
+ add_noise=False,
1065
+ ):
1066
+ if isinstance(generator, list) and len(generator) != batch_size:
1067
+ raise ValueError(
1068
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1069
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1070
+ )
1071
+
1072
+ height = 2 * (int(height) // self.vae_scale_factor)
1073
+ width = 2 * (int(width) // self.vae_scale_factor)
1074
+
1075
+ shape = (batch_size, num_channels_latents, height, width)
1076
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
1077
+
1078
+ if latents is not None:
1079
+ return latents.to(device=device, dtype=dtype), latent_image_ids
1080
+
1081
+ image = image.to(device=device, dtype=dtype)
1082
+ image_latents = self.img2img_encode_vae_image(image=image, generator=generator)
1083
+
1084
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
1085
+ # expand init_latents for batch_size
1086
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
1087
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
1088
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
1089
+ raise ValueError(
1090
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
1091
+ )
1092
+ else:
1093
+ image_latents = torch.cat([image_latents], dim=0)
1094
+
1095
+ if add_noise:
1096
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1097
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
1098
+ else:
1099
+ latents = image_latents
1100
+
1101
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
1102
+
1103
+ return latents, latent_image_ids
1104
+
1105
+ @torch.no_grad()
1106
+ def call_invert(
1107
+ self,
1108
+ prompt: Union[str, List[str]] = None,
1109
+ prompt_2: Optional[Union[str, List[str]]] = None,
1110
+ image: PipelineImageInput = None,
1111
+ height: Optional[int] = None,
1112
+ width: Optional[int] = None,
1113
+ num_inference_steps: int = 28,
1114
+ timesteps: List[int] = None,
1115
+ guidance_scale: float = 7.0,
1116
+ num_images_per_prompt: Optional[int] = 1,
1117
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1118
+ latents: Optional[torch.FloatTensor] = None,
1119
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1120
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1121
+ output_type: Optional[str] = "pil",
1122
+ return_dict: bool = True,
1123
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1124
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1125
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1126
+ max_sequence_length: int = 512,
1127
+
1128
+ fixed_point_iterations: int = 1,
1129
+ ):
1130
+ r"""
1131
+ Function invoked when calling the pipeline for generation.
1132
+
1133
+ Args:
1134
+ prompt (`str` or `List[str]`, *optional*):
1135
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1136
+ instead.
1137
+ prompt_2 (`str` or `List[str]`, *optional*):
1138
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1139
+ will be used instead
1140
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1141
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1142
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1143
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1144
+ num_inference_steps (`int`, *optional*, defaults to 50):
1145
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1146
+ expense of slower inference.
1147
+ timesteps (`List[int]`, *optional*):
1148
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1149
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1150
+ passed will be used. Must be in descending order.
1151
+ guidance_scale (`float`, *optional*, defaults to 7.0):
1152
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1153
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1154
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1155
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1156
+ usually at the expense of lower image quality.
1157
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1158
+ The number of images to generate per prompt.
1159
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1160
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1161
+ to make generation deterministic.
1162
+ latents (`torch.FloatTensor`, *optional*):
1163
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1164
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1165
+ tensor will ge generated by sampling using the supplied random `generator`.
1166
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1167
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1168
+ provided, text embeddings will be generated from `prompt` input argument.
1169
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1170
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1171
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1172
+ output_type (`str`, *optional*, defaults to `"pil"`):
1173
+ The output format of the generate image. Choose between
1174
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1175
+ return_dict (`bool`, *optional*, defaults to `True`):
1176
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
1177
+ joint_attention_kwargs (`dict`, *optional*):
1178
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1179
+ `self.processor` in
1180
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1181
+ callback_on_step_end (`Callable`, *optional*):
1182
+ A function that calls at the end of each denoising steps during the inference. The function is called
1183
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1184
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1185
+ `callback_on_step_end_tensor_inputs`.
1186
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1187
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1188
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1189
+ `._callback_tensor_inputs` attribute of your pipeline class.
1190
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
1191
+
1192
+ Examples:
1193
+
1194
+ Returns:
1195
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
1196
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
1197
+ images.
1198
+ """
1199
+ height = height or self.default_sample_size * self.vae_scale_factor
1200
+ width = width or self.default_sample_size * self.vae_scale_factor
1201
+
1202
+ # 1. Check inputs. Raise error if not correct
1203
+ self.check_inputs(
1204
+ prompt,
1205
+ prompt_2,
1206
+ height,
1207
+ width,
1208
+ prompt_embeds=prompt_embeds,
1209
+ pooled_prompt_embeds=pooled_prompt_embeds,
1210
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1211
+ max_sequence_length=max_sequence_length,
1212
+ )
1213
+
1214
+ self._guidance_scale = guidance_scale
1215
+ self._joint_attention_kwargs = joint_attention_kwargs
1216
+ self._interrupt = False
1217
+
1218
+ # 1.5. Preprocess image
1219
+ if isinstance(image, Image.Image):
1220
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
1221
+ elif isinstance(image, torch.Tensor):
1222
+ init_image = image
1223
+ latents = image
1224
+ else:
1225
+ raise ValueError("Image must be of type `PIL.Image.Image` or `torch.Tensor`")
1226
+
1227
+ init_image = init_image.to(dtype=torch.float32)
1228
+
1229
+ # 2. Define call parameters
1230
+ if prompt is not None and isinstance(prompt, str):
1231
+ batch_size = 1
1232
+ elif prompt is not None and isinstance(prompt, list):
1233
+ batch_size = len(prompt)
1234
+ else:
1235
+ batch_size = prompt_embeds.shape[0]
1236
+
1237
+ device = self._execution_device
1238
+
1239
+ lora_scale = (
1240
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1241
+ )
1242
+ (
1243
+ prompt_embeds,
1244
+ pooled_prompt_embeds,
1245
+ text_ids,
1246
+ ) = self.encode_prompt(
1247
+ prompt=prompt,
1248
+ prompt_2=prompt_2,
1249
+ prompt_embeds=prompt_embeds,
1250
+ pooled_prompt_embeds=pooled_prompt_embeds,
1251
+ device=device,
1252
+ num_images_per_prompt=num_images_per_prompt,
1253
+ max_sequence_length=max_sequence_length,
1254
+ lora_scale=lora_scale,
1255
+ )
1256
+
1257
+ # 4. Prepare latent variables
1258
+ num_channels_latents = self.transformer.config.in_channels // 4
1259
+ # latents, latent_image_ids = self.prepare_latents(
1260
+ # batch_size * num_images_per_prompt,
1261
+ # num_channels_latents,
1262
+ # height,
1263
+ # width,
1264
+ # prompt_embeds.dtype,
1265
+ # device,
1266
+ # generator,
1267
+ # latents,
1268
+ # )
1269
+ latents, latent_image_ids = self.invert_prepare_latents(
1270
+ init_image,
1271
+ None,
1272
+ batch_size * num_images_per_prompt,
1273
+ num_channels_latents,
1274
+ height,
1275
+ width,
1276
+ prompt_embeds.dtype,
1277
+ device,
1278
+ generator,
1279
+ latents,
1280
+ False
1281
+ )
1282
+
1283
+ register_regular_attention_processors(self.transformer)
1284
+
1285
+ # 5. Prepare timesteps
1286
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1287
+ image_seq_len = latents.shape[1]
1288
+ mu = calculate_shift(
1289
+ image_seq_len,
1290
+ self.scheduler.config.base_image_seq_len,
1291
+ self.scheduler.config.max_image_seq_len,
1292
+ self.scheduler.config.base_shift,
1293
+ self.scheduler.config.max_shift,
1294
+ )
1295
+
1296
+ # For Inversion, reverse the sigmas
1297
+ # sigmas = sigmas[::-1]
1298
+
1299
+ timesteps, num_inference_steps = retrieve_timesteps(
1300
+ self.scheduler,
1301
+ num_inference_steps,
1302
+ device,
1303
+ timesteps,
1304
+ sigmas,
1305
+ mu=mu,
1306
+ )
1307
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1308
+ self._num_timesteps = len(timesteps)
1309
+
1310
+ # handle guidance
1311
+ if self.transformer.config.guidance_embeds:
1312
+ guidance = torch.tensor([guidance_scale], device=device)
1313
+ guidance = guidance.expand(latents.shape[0])
1314
+ else:
1315
+ guidance = None
1316
+
1317
+ self.scheduler.sigmas = reversed(self.scheduler.sigmas)
1318
+
1319
+ timesteps_zero_start = reversed(torch.cat([self.scheduler.timesteps[1:], torch.tensor([0], device=device)]))
1320
+ timesteps_one_start = reversed(self.scheduler.timesteps)
1321
+
1322
+ self.scheduler.timesteps = timesteps_zero_start
1323
+ # self.scheduler.timesteps = timesteps_one_start
1324
+
1325
+ timesteps = self.scheduler.timesteps
1326
+
1327
+ latents_list = []
1328
+ latents_list.append(latents)
1329
+
1330
+ # 6. Denoising loop
1331
+ with self.progress_bar(total=num_inference_steps * fixed_point_iterations) as progress_bar:
1332
+ for i, t in enumerate(timesteps):
1333
+ original_latents = latents.clone()
1334
+ for j in range(fixed_point_iterations):
1335
+ if self.interrupt:
1336
+ continue
1337
+
1338
+ if j == 0:
1339
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1340
+ timestep = timesteps[i].expand(latents.shape[0]).to(latents.dtype)
1341
+ else:
1342
+ timestep = timesteps_one_start[i].expand(latents.shape[0]).to(latents.dtype)
1343
+
1344
+ noise_pred = self.transformer(
1345
+ hidden_states=latents,
1346
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1347
+ timestep=timestep / 1000,
1348
+ guidance=guidance,
1349
+ pooled_projections=pooled_prompt_embeds,
1350
+ encoder_hidden_states=prompt_embeds,
1351
+ txt_ids=text_ids,
1352
+ img_ids=latent_image_ids,
1353
+ joint_attention_kwargs=self.joint_attention_kwargs,
1354
+ return_dict=False,
1355
+ )[0]
1356
+
1357
+ # compute the previous noisy sample x_t -> x_t-1
1358
+ latents_dtype = latents.dtype
1359
+
1360
+ # noise_pred = -noise_pred
1361
+ latents = self.scheduler.step(noise_pred, t, original_latents, return_dict=False, step_index=i)[0]
1362
+
1363
+ if latents.dtype != latents_dtype:
1364
+ if torch.backends.mps.is_available():
1365
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1366
+ latents = latents.to(latents_dtype)
1367
+
1368
+ if callback_on_step_end is not None:
1369
+ callback_kwargs = {}
1370
+ for k in callback_on_step_end_tensor_inputs:
1371
+ callback_kwargs[k] = locals()[k]
1372
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1373
+
1374
+ latents = callback_outputs.pop("latents", latents)
1375
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1376
+
1377
+ # call the callback, if provided
1378
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1379
+ progress_bar.update()
1380
+
1381
+ # if XLA_AVAILABLE:
1382
+ # xm.mark_step()
1383
+
1384
+ latents_list.append(latents)
1385
+
1386
+ # Offload all models
1387
+ self.maybe_free_model_hooks()
1388
+
1389
+ return latents_list
addit_flux_transformer.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from diffusers.models.attention import FeedForward
24
+ from diffusers.models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
29
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
+
32
+ from addit_attention_processors import AdditFluxAttnProcessor2_0, AdditFluxSingleAttnProcessor2_0
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ # YiYi to-do: refactor rope related functions/classes
38
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
39
+ assert dim % 2 == 0, "The dimension must be even."
40
+
41
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
42
+ omega = 1.0 / (theta**scale)
43
+
44
+ batch_size, seq_length = pos.shape
45
+ out = torch.einsum("...n,d->...nd", pos, omega)
46
+ cos_out = torch.cos(out)
47
+ sin_out = torch.sin(out)
48
+
49
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
50
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
51
+ return out.float()
52
+
53
+
54
+ # YiYi to-do: refactor rope related functions/classes
55
+ class EmbedND(nn.Module):
56
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
57
+ super().__init__()
58
+ self.dim = dim
59
+ self.theta = theta
60
+ self.axes_dim = axes_dim
61
+
62
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
63
+ n_axes = ids.shape[-1]
64
+ emb = torch.cat(
65
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
66
+ dim=-3,
67
+ )
68
+ return emb.unsqueeze(1)
69
+
70
+
71
+ @maybe_allow_in_graph
72
+ class AdditFluxSingleTransformerBlock(nn.Module):
73
+ r"""
74
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
75
+
76
+ Reference: https://arxiv.org/abs/2403.03206
77
+
78
+ Parameters:
79
+ dim (`int`): The number of channels in the input and output.
80
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
81
+ attention_head_dim (`int`): The number of channels in each head.
82
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
83
+ processing of `context` conditions.
84
+ """
85
+
86
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
87
+ super().__init__()
88
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
89
+
90
+ self.norm = AdaLayerNormZeroSingle(dim)
91
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
92
+ self.act_mlp = nn.GELU(approximate="tanh")
93
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
94
+
95
+ processor = FluxSingleAttnProcessor2_0()
96
+ self.attn = Attention(
97
+ query_dim=dim,
98
+ cross_attention_dim=None,
99
+ dim_head=attention_head_dim,
100
+ heads=num_attention_heads,
101
+ out_dim=dim,
102
+ bias=True,
103
+ processor=processor,
104
+ qk_norm="rms_norm",
105
+ eps=1e-6,
106
+ pre_only=True,
107
+ )
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.FloatTensor,
112
+ temb: torch.FloatTensor,
113
+ image_rotary_emb=None,
114
+ proccesor_kwargs=None,
115
+ ):
116
+ residual = hidden_states
117
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
118
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
119
+
120
+ attn_output = self.attn(
121
+ hidden_states=norm_hidden_states,
122
+ image_rotary_emb=image_rotary_emb,
123
+ **(proccesor_kwargs or {}),
124
+ )
125
+
126
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
127
+ gate = gate.unsqueeze(1)
128
+ hidden_states = gate * self.proj_out(hidden_states)
129
+ hidden_states = residual + hidden_states
130
+ if hidden_states.dtype == torch.float16:
131
+ hidden_states = hidden_states.clip(-65504, 65504)
132
+
133
+ return hidden_states
134
+
135
+
136
+ @maybe_allow_in_graph
137
+ class AdditFluxTransformerBlock(nn.Module):
138
+ r"""
139
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
140
+
141
+ Reference: https://arxiv.org/abs/2403.03206
142
+
143
+ Parameters:
144
+ dim (`int`): The number of channels in the input and output.
145
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
146
+ attention_head_dim (`int`): The number of channels in each head.
147
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
148
+ processing of `context` conditions.
149
+ """
150
+
151
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
152
+ super().__init__()
153
+
154
+ self.norm1 = AdaLayerNormZero(dim)
155
+
156
+ self.norm1_context = AdaLayerNormZero(dim)
157
+
158
+ if hasattr(F, "scaled_dot_product_attention"):
159
+ processor = FluxAttnProcessor2_0()
160
+ else:
161
+ raise ValueError(
162
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
163
+ )
164
+ self.attn = Attention(
165
+ query_dim=dim,
166
+ cross_attention_dim=None,
167
+ added_kv_proj_dim=dim,
168
+ dim_head=attention_head_dim,
169
+ heads=num_attention_heads,
170
+ out_dim=dim,
171
+ context_pre_only=False,
172
+ bias=True,
173
+ processor=processor,
174
+ qk_norm=qk_norm,
175
+ eps=eps,
176
+ )
177
+
178
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
179
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
180
+
181
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
182
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
183
+
184
+ # let chunk size default to None
185
+ self._chunk_size = None
186
+ self._chunk_dim = 0
187
+
188
+ def forward(
189
+ self,
190
+ hidden_states: torch.FloatTensor,
191
+ encoder_hidden_states: torch.FloatTensor,
192
+ temb: torch.FloatTensor,
193
+ image_rotary_emb=None,
194
+ proccesor_kwargs=None,
195
+ ):
196
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
197
+
198
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
199
+ encoder_hidden_states, emb=temb
200
+ )
201
+
202
+ # Attention.
203
+ attn_output, context_attn_output = self.attn(
204
+ hidden_states=norm_hidden_states,
205
+ encoder_hidden_states=norm_encoder_hidden_states,
206
+ image_rotary_emb=image_rotary_emb,
207
+ **(proccesor_kwargs or {}),
208
+ )
209
+
210
+ # Process attention outputs for the `hidden_states`.
211
+ attn_output = gate_msa.unsqueeze(1) * attn_output
212
+ hidden_states = hidden_states + attn_output
213
+
214
+ norm_hidden_states = self.norm2(hidden_states)
215
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
216
+
217
+ ff_output = self.ff(norm_hidden_states)
218
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
219
+
220
+ hidden_states = hidden_states + ff_output
221
+
222
+ # Process attention outputs for the `encoder_hidden_states`.
223
+
224
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
225
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
226
+
227
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
228
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
229
+
230
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
231
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
232
+ if encoder_hidden_states.dtype == torch.float16:
233
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
234
+
235
+ return encoder_hidden_states, hidden_states
236
+
237
+
238
+ class AdditFluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
239
+ """
240
+ The Transformer model introduced in Flux.
241
+
242
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
243
+
244
+ Parameters:
245
+ patch_size (`int`): Patch size to turn the input data into small patches.
246
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
247
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
248
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
249
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
250
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
251
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
252
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
253
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
254
+ """
255
+
256
+ _supports_gradient_checkpointing = True
257
+
258
+ @register_to_config
259
+ def __init__(
260
+ self,
261
+ patch_size: int = 1,
262
+ in_channels: int = 64,
263
+ num_layers: int = 19,
264
+ num_single_layers: int = 38,
265
+ attention_head_dim: int = 128,
266
+ num_attention_heads: int = 24,
267
+ joint_attention_dim: int = 4096,
268
+ pooled_projection_dim: int = 768,
269
+ guidance_embeds: bool = False,
270
+ axes_dims_rope: List[int] = [16, 56, 56],
271
+ ):
272
+ super().__init__()
273
+ self.out_channels = in_channels
274
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
275
+
276
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
277
+ text_time_guidance_cls = (
278
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
279
+ )
280
+ self.time_text_embed = text_time_guidance_cls(
281
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
282
+ )
283
+
284
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
285
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
286
+
287
+ self.transformer_blocks = nn.ModuleList(
288
+ [
289
+ AdditFluxTransformerBlock(
290
+ dim=self.inner_dim,
291
+ num_attention_heads=self.config.num_attention_heads,
292
+ attention_head_dim=self.config.attention_head_dim,
293
+ )
294
+ for i in range(self.config.num_layers)
295
+ ]
296
+ )
297
+
298
+ self.single_transformer_blocks = nn.ModuleList(
299
+ [
300
+ AdditFluxSingleTransformerBlock(
301
+ dim=self.inner_dim,
302
+ num_attention_heads=self.config.num_attention_heads,
303
+ attention_head_dim=self.config.attention_head_dim,
304
+ )
305
+ for i in range(self.config.num_single_layers)
306
+ ]
307
+ )
308
+
309
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
310
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
311
+
312
+ self.gradient_checkpointing = False
313
+
314
+ def _set_gradient_checkpointing(self, module, value=False):
315
+ if hasattr(module, "gradient_checkpointing"):
316
+ module.gradient_checkpointing = value
317
+
318
+ @property
319
+ def attn_processors(self):
320
+ r"""
321
+ Returns:
322
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
323
+ indexed by its weight name.
324
+ """
325
+ # set recursively
326
+ processors = {}
327
+
328
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
329
+ if hasattr(module, "get_processor"):
330
+ processors[f"{name}.processor"] = module.get_processor()
331
+
332
+ for sub_name, child in module.named_children():
333
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
334
+
335
+ return processors
336
+
337
+ for name, module in self.named_children():
338
+ fn_recursive_add_processors(name, module, processors)
339
+
340
+ return processors
341
+
342
+ def set_attn_processor(
343
+ self, processor
344
+ ):
345
+ r"""
346
+ Sets the attention processor to use to compute attention.
347
+
348
+ Parameters:
349
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
350
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
351
+ for **all** `Attention` layers.
352
+
353
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
354
+ processor. This is strongly recommended when setting trainable attention processors.
355
+ """
356
+ count = len(self.attn_processors.keys())
357
+
358
+ if isinstance(processor, dict) and len(processor) != count:
359
+ raise ValueError(
360
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
361
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
362
+ )
363
+
364
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
365
+ if hasattr(module, "set_processor"):
366
+ if not isinstance(processor, dict):
367
+ module.set_processor(processor)
368
+ else:
369
+ module.set_processor(processor.pop(f"{name}.processor"))
370
+
371
+ for sub_name, child in module.named_children():
372
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
373
+
374
+ for name, module in self.named_children():
375
+ fn_recursive_attn_processor(name, module, processor)
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ encoder_hidden_states: torch.Tensor = None,
381
+ pooled_projections: torch.Tensor = None,
382
+ timestep: torch.LongTensor = None,
383
+ img_ids: torch.Tensor = None,
384
+ txt_ids: torch.Tensor = None,
385
+ guidance: torch.Tensor = None,
386
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
387
+ return_dict: bool = True,
388
+ proccesor_kwargs: Optional[Dict[str, Any]] = None,
389
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
390
+ """
391
+ The [`FluxTransformer2DModel`] forward method.
392
+
393
+ Args:
394
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
395
+ Input `hidden_states`.
396
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
397
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
398
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
399
+ from the embeddings of input conditions.
400
+ timestep ( `torch.LongTensor`):
401
+ Used to indicate denoising step.
402
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
403
+ A list of tensors that if specified are added to the residuals of transformer blocks.
404
+ joint_attention_kwargs (`dict`, *optional*):
405
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
406
+ `self.processor` in
407
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
408
+ return_dict (`bool`, *optional*, defaults to `True`):
409
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
410
+ tuple.
411
+
412
+ Returns:
413
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
414
+ `tuple` where the first element is the sample tensor.
415
+ """
416
+ if joint_attention_kwargs is not None:
417
+ joint_attention_kwargs = joint_attention_kwargs.copy()
418
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
419
+ else:
420
+ lora_scale = 1.0
421
+
422
+ if USE_PEFT_BACKEND:
423
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
424
+ scale_lora_layers(self, lora_scale)
425
+ else:
426
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
427
+ logger.warning(
428
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
429
+ )
430
+ hidden_states = self.x_embedder(hidden_states)
431
+
432
+ timestep = timestep.to(hidden_states.dtype) * 1000
433
+ if guidance is not None:
434
+ guidance = guidance.to(hidden_states.dtype) * 1000
435
+ else:
436
+ guidance = None
437
+ temb = (
438
+ self.time_text_embed(timestep, pooled_projections)
439
+ if guidance is None
440
+ else self.time_text_embed(timestep, guidance, pooled_projections)
441
+ )
442
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
443
+
444
+ ids = torch.cat((txt_ids, img_ids), dim=1)
445
+ image_rotary_emb = self.pos_embed(ids)
446
+
447
+ for index_block, block in enumerate(self.transformer_blocks):
448
+ if self.training and self.gradient_checkpointing:
449
+
450
+ def create_custom_forward(module, return_dict=None):
451
+ def custom_forward(*inputs):
452
+ if return_dict is not None:
453
+ return module(*inputs, return_dict=return_dict)
454
+ else:
455
+ return module(*inputs)
456
+
457
+ return custom_forward
458
+
459
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
460
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(block),
462
+ hidden_states,
463
+ encoder_hidden_states,
464
+ temb,
465
+ image_rotary_emb,
466
+ **ckpt_kwargs,
467
+ )
468
+
469
+ else:
470
+ encoder_hidden_states, hidden_states = block(
471
+ hidden_states=hidden_states,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ temb=temb,
474
+ image_rotary_emb=image_rotary_emb,
475
+ proccesor_kwargs=proccesor_kwargs,
476
+ )
477
+
478
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
479
+
480
+ for index_block, block in enumerate(self.single_transformer_blocks):
481
+ if self.training and self.gradient_checkpointing:
482
+
483
+ def create_custom_forward(module, return_dict=None):
484
+ def custom_forward(*inputs):
485
+ if return_dict is not None:
486
+ return module(*inputs, return_dict=return_dict)
487
+ else:
488
+ return module(*inputs)
489
+
490
+ return custom_forward
491
+
492
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
493
+ hidden_states = torch.utils.checkpoint.checkpoint(
494
+ create_custom_forward(block),
495
+ hidden_states,
496
+ temb,
497
+ image_rotary_emb,
498
+ **ckpt_kwargs,
499
+ )
500
+
501
+ else:
502
+ hidden_states = block(
503
+ hidden_states=hidden_states,
504
+ temb=temb,
505
+ image_rotary_emb=image_rotary_emb,
506
+ proccesor_kwargs=proccesor_kwargs,
507
+ )
508
+
509
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
510
+
511
+ hidden_states = self.norm_out(hidden_states, temb)
512
+ output = self.proj_out(hidden_states)
513
+
514
+ if USE_PEFT_BACKEND:
515
+ # remove `lora_scale` from each PEFT layer
516
+ unscale_lora_layers(self, lora_scale)
517
+
518
+ if not return_dict:
519
+ return (output,)
520
+
521
+ return Transformer2DModelOutput(sample=output)
addit_methods.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is licensed under the LICENSE file
4
+ # located at the root directory.
5
+
6
+ import gc
7
+ import torch
8
+ from visualization_utils import show_images
9
+
10
+ def _add_object(
11
+ pipe,
12
+ prompts,
13
+ seed_src,
14
+ seed_obj,
15
+ extended_scale,
16
+ source_latents,
17
+ structure_transfer_step,
18
+ subject_token,
19
+ blend_steps,
20
+ show_attention=False,
21
+ localization_model="attention_points_sam",
22
+ is_img_src=False,
23
+ img_src_latents=None,
24
+ use_offset=False,
25
+ display_output=False,
26
+ ):
27
+ gc.collect()
28
+ torch.cuda.empty_cache()
29
+
30
+ out = pipe(
31
+ prompt=prompts,
32
+ guidance_scale=3.5 if (not is_img_src) else [1,3.5],
33
+ height=1024,
34
+ width=1024,
35
+ max_sequence_length=512,
36
+ num_inference_steps=30,
37
+ seed=[seed_src, seed_obj],
38
+
39
+ # Extended Attention
40
+ extended_scale=extended_scale,
41
+ extended_steps_multi=10,
42
+ extended_steps_single=20,
43
+
44
+ # Structure Transfer
45
+ source_latents=source_latents,
46
+ structure_transfer_step=structure_transfer_step,
47
+
48
+ # Latent Blending
49
+ subject_token=subject_token,
50
+ localization_model=localization_model,
51
+ blend_steps=blend_steps,
52
+ show_attention=show_attention,
53
+
54
+ # Real Image Source
55
+ is_img_src=is_img_src,
56
+ img_src_latents=img_src_latents,
57
+ use_offset=use_offset,
58
+ )
59
+
60
+ if display_output:
61
+ show_images(out.images)
62
+
63
+ return out.images
64
+
65
+ def add_object_generated(
66
+ pipe,
67
+ prompt_source,
68
+ prompt_object,
69
+ subject_token,
70
+ seed_src,
71
+ seed_obj,
72
+ show_attention=False,
73
+ extended_scale=1.05,
74
+ structure_transfer_step=2,
75
+ blend_steps=[15],
76
+ localization_model="attention_points_sam",
77
+ display_output=False
78
+ ):
79
+ gc.collect()
80
+ torch.cuda.empty_cache()
81
+
82
+ # Generate source image and latents for each seed1
83
+ print('Generating source image...')
84
+ source_image, source_latents = pipe(
85
+ prompt=[prompt_source],
86
+ guidance_scale=3.5,
87
+ height=1024,
88
+ width=1024,
89
+ max_sequence_length=512,
90
+ num_inference_steps=30,
91
+ seed=[seed_src],
92
+ output_type="both",
93
+ )
94
+ source_image = source_image[0]
95
+
96
+ # Run the core combination logic
97
+ print('Running Addit...')
98
+ src_image, edited_image = _add_object(
99
+ pipe=pipe,
100
+ prompts=[prompt_source, prompt_object],
101
+ subject_token=subject_token,
102
+ seed_src=seed_src,
103
+ seed_obj=seed_obj,
104
+ source_latents=source_latents,
105
+ structure_transfer_step=structure_transfer_step,
106
+ extended_scale=extended_scale,
107
+ blend_steps=blend_steps,
108
+ show_attention=show_attention,
109
+ localization_model=localization_model,
110
+ display_output=display_output
111
+ )
112
+
113
+ return src_image, edited_image
114
+
115
+ def add_object_real(
116
+ pipe,
117
+ source_image,
118
+ prompt_source,
119
+ prompt_object,
120
+ subject_token,
121
+ seed_src,
122
+ seed_obj,
123
+ localization_model="attention_points_sam",
124
+ extended_scale=1.05,
125
+ structure_transfer_step=4,
126
+ blend_steps=[20],
127
+ use_offset=False,
128
+ show_attention=False,
129
+ use_inversion=False,
130
+ display_output=False
131
+ ):
132
+ print('Noising-Denoising Original Image')
133
+ gc.collect()
134
+ torch.cuda.empty_cache()
135
+
136
+ # Get initial latents
137
+ source_latents = pipe.call_img2img(
138
+ prompt=prompt_source,
139
+ image=source_image,
140
+ num_inference_steps=30,
141
+ strength=0.1,
142
+ guidance_scale=3.5,
143
+ output_type="latent",
144
+ generator=torch.Generator(device=pipe.device).manual_seed(0)
145
+ ).images
146
+
147
+ # Optional inversion step
148
+ img_src_latents = None
149
+ if use_inversion:
150
+ print('Inverting Image')
151
+ gc.collect()
152
+ torch.cuda.empty_cache()
153
+
154
+ latents_list = pipe.call_invert(
155
+ prompt=prompt_source,
156
+ image=source_latents,
157
+ num_inference_steps=30,
158
+ guidance_scale=1,
159
+ fixed_point_iterations=2,
160
+ generator=torch.Generator(device=pipe.device).manual_seed(0)
161
+ )
162
+ img_src_latents = [x[0] for x in latents_list][::-1]
163
+
164
+ print('Running Addit')
165
+ gc.collect()
166
+ torch.cuda.empty_cache()
167
+
168
+ src_image, edited_image = _add_object(
169
+ pipe,
170
+ prompts=[prompt_source, prompt_object],
171
+ seed_src=seed_src,
172
+ seed_obj=seed_obj,
173
+ extended_scale=extended_scale,
174
+ source_latents=source_latents,
175
+ structure_transfer_step=structure_transfer_step,
176
+ subject_token=subject_token,
177
+ blend_steps=blend_steps,
178
+ show_attention=show_attention,
179
+ localization_model=localization_model,
180
+ is_img_src=True,
181
+ img_src_latents=img_src_latents,
182
+ use_offset=use_offset,
183
+ display_output=display_output,
184
+ )
185
+
186
+ return src_image, edited_image
addit_scheduler.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
16
+ from typing import Union, Optional, Tuple
17
+ import torch
18
+
19
+ class AdditFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
20
+ def step(
21
+ self,
22
+ model_output: torch.FloatTensor,
23
+ timestep: Union[float, torch.FloatTensor],
24
+ sample: torch.FloatTensor,
25
+ s_churn: float = 0.0,
26
+ s_tmin: float = 0.0,
27
+ s_tmax: float = float("inf"),
28
+ s_noise: float = 1.0,
29
+ generator: Optional[torch.Generator] = None,
30
+ return_dict: bool = True,
31
+ step_index: Optional[int] = None,
32
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
33
+ """
34
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
35
+ process from the learned model outputs (most often the predicted noise).
36
+
37
+ Args:
38
+ model_output (`torch.FloatTensor`):
39
+ The direct output from learned diffusion model.
40
+ timestep (`float`):
41
+ The current discrete timestep in the diffusion chain.
42
+ sample (`torch.FloatTensor`):
43
+ A current instance of a sample created by the diffusion process.
44
+ s_churn (`float`):
45
+ s_tmin (`float`):
46
+ s_tmax (`float`):
47
+ s_noise (`float`, defaults to 1.0):
48
+ Scaling factor for noise added to the sample.
49
+ generator (`torch.Generator`, *optional*):
50
+ A random number generator.
51
+ return_dict (`bool`):
52
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
53
+ tuple.
54
+
55
+ Returns:
56
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
57
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
58
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
59
+ """
60
+
61
+ if (
62
+ isinstance(timestep, int)
63
+ or isinstance(timestep, torch.IntTensor)
64
+ or isinstance(timestep, torch.LongTensor)
65
+ ):
66
+ raise ValueError(
67
+ (
68
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
69
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
70
+ " one of the `scheduler.timesteps` as a timestep."
71
+ ),
72
+ )
73
+
74
+ if step_index is not None:
75
+ self._step_index = step_index
76
+
77
+ if self.step_index is None:
78
+ self._init_step_index(timestep)
79
+
80
+ # Upcast to avoid precision issues when computing prev_sample
81
+ sample = sample.to(torch.float32)
82
+
83
+ sigma = self.sigmas[self.step_index]
84
+ sigma_next = self.sigmas[self.step_index + 1]
85
+
86
+ prev_sample = sample + (sigma_next - sigma) * model_output
87
+
88
+ # Calculate X_0
89
+ x_0 = sample - sigma * model_output
90
+
91
+ # Cast sample back to model compatible dtype
92
+ prev_sample = prev_sample.to(model_output.dtype)
93
+ x_0 = x_0.to(model_output.dtype)
94
+
95
+ # upon completion increase step index by one
96
+ self._step_index += 1
97
+
98
+ if not return_dict:
99
+ return (prev_sample, x_0)
100
+
101
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
app.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the LICENSE file
5
+ # located at the root directory.
6
+
7
+ import os
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ import tempfile
14
+ import gc
15
+
16
+ from addit_flux_pipeline import AdditFluxPipeline
17
+ from addit_flux_transformer import AdditFluxTransformer2DModel
18
+ from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
19
+ from addit_methods import add_object_generated, add_object_real
20
+
21
+ # Global variables for model
22
+ pipe = None
23
+ device = None
24
+
25
+ # Initialize model at startup
26
+ print("Initializing ADDIT model...")
27
+ try:
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"Using device: {device}")
30
+
31
+ # Load transformer
32
+ my_transformer = AdditFluxTransformer2DModel.from_pretrained(
33
+ "black-forest-labs/FLUX.1-dev",
34
+ subfolder="transformer",
35
+ torch_dtype=torch.bfloat16
36
+ )
37
+
38
+ # Load pipeline
39
+ pipe = AdditFluxPipeline.from_pretrained(
40
+ "black-forest-labs/FLUX.1-dev",
41
+ transformer=my_transformer,
42
+ torch_dtype=torch.bfloat16
43
+ ).to(device)
44
+
45
+ # Set scheduler
46
+ pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+
48
+ print("Model initialized successfully!")
49
+
50
+ except Exception as e:
51
+ print(f"Error initializing model: {str(e)}")
52
+ print("The application will start but model functionality will be unavailable.")
53
+
54
+ def validate_inputs(prompt_source, prompt_target, subject_token):
55
+ """Validate user inputs"""
56
+ if not prompt_source.strip():
57
+ return "Source prompt cannot be empty"
58
+ if not prompt_target.strip():
59
+ return "Target prompt cannot be empty"
60
+ if not subject_token.strip():
61
+ return "Subject token cannot be empty"
62
+ if subject_token not in prompt_target:
63
+ return f"Subject token '{subject_token}' must appear in the target prompt"
64
+ return None
65
+
66
+ @spaces.GPU
67
+ def process_generated_image(
68
+ prompt_source,
69
+ prompt_target,
70
+ subject_token,
71
+ seed_src,
72
+ seed_obj,
73
+ extended_scale,
74
+ structure_transfer_step,
75
+ blend_steps,
76
+ localization_model,
77
+ progress=gr.Progress(track_tqdm=True)
78
+ ):
79
+ """Process generated image with ADDIT"""
80
+ global pipe
81
+
82
+ if pipe is None:
83
+ return None, None, "Model not initialized. Please restart the application."
84
+
85
+ # Validate inputs
86
+ error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
87
+ if error_msg:
88
+ return None, None, error_msg
89
+
90
+ try:
91
+ # Parse blend steps
92
+ if blend_steps.strip():
93
+ blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
94
+ else:
95
+ blend_steps_list = []
96
+
97
+ # Generate images
98
+ src_image, edited_image = add_object_generated(
99
+ pipe=pipe,
100
+ prompt_source=prompt_source,
101
+ prompt_object=prompt_target,
102
+ subject_token=subject_token,
103
+ seed_src=seed_src,
104
+ seed_obj=seed_obj,
105
+ show_attention=False,
106
+ extended_scale=extended_scale,
107
+ structure_transfer_step=structure_transfer_step,
108
+ blend_steps=blend_steps_list,
109
+ localization_model=localization_model,
110
+ display_output=False
111
+ )
112
+
113
+ return src_image, edited_image, "Images generated successfully!"
114
+
115
+ except Exception as e:
116
+ error_msg = f"Error generating images: {str(e)}"
117
+ print(error_msg)
118
+ return None, None, error_msg
119
+
120
+ @spaces.GPU
121
+ def process_real_image(
122
+ source_image,
123
+ prompt_source,
124
+ prompt_target,
125
+ subject_token,
126
+ seed_src,
127
+ seed_obj,
128
+ extended_scale,
129
+ structure_transfer_step,
130
+ blend_steps,
131
+ localization_model,
132
+ use_offset,
133
+ disable_inversion,
134
+ progress=gr.Progress(track_tqdm=True)
135
+ ):
136
+ """Process real image with ADDIT"""
137
+ global pipe
138
+
139
+ if pipe is None:
140
+ return None, None, "Model not initialized. Please restart the application."
141
+
142
+ if source_image is None:
143
+ return None, None, "Please upload a source image"
144
+
145
+ # Validate inputs
146
+ error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
147
+ if error_msg:
148
+ return None, None, error_msg
149
+
150
+ try:
151
+ # Resize source image
152
+ source_image = source_image.resize((1024, 1024))
153
+
154
+ # Parse blend steps
155
+ if blend_steps.strip():
156
+ blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
157
+ else:
158
+ blend_steps_list = []
159
+
160
+ # Process image
161
+ src_image, edited_image = add_object_real(
162
+ pipe=pipe,
163
+ source_image=source_image,
164
+ prompt_source=prompt_source,
165
+ prompt_object=prompt_target,
166
+ subject_token=subject_token,
167
+ seed_src=seed_src,
168
+ seed_obj=seed_obj,
169
+ extended_scale=extended_scale,
170
+ structure_transfer_step=structure_transfer_step,
171
+ blend_steps=blend_steps_list,
172
+ localization_model=localization_model,
173
+ use_offset=use_offset,
174
+ show_attention=False,
175
+ use_inversion=not disable_inversion,
176
+ display_output=False
177
+ )
178
+
179
+ return src_image, edited_image, "Image edited successfully!"
180
+
181
+ except Exception as e:
182
+ error_msg = f"Error processing image: {str(e)}"
183
+ print(error_msg)
184
+ return None, None, error_msg
185
+
186
+ def create_interface():
187
+ """Create the Gradio interface"""
188
+
189
+ # Show model status in the interface
190
+ model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable"
191
+
192
+ with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo:
193
+ gr.HTML(f"""
194
+ <div style="text-align: center; margin-bottom: 20px;">
195
+ <h1>🎨 Add-it: Training-Free Object Insertion</h1>
196
+ <p>Add objects to images using pretrained diffusion models</p>
197
+ <p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> |
198
+ <a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> |
199
+ <a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p>
200
+ <p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p>
201
+ </div>
202
+ """)
203
+
204
+ # Main interface
205
+ with gr.Tabs():
206
+ # Generated Images Tab
207
+ with gr.TabItem("🎭 Generated Images"):
208
+ gr.Markdown("### Generate a base image and add objects to it")
209
+
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ gen_prompt_source = gr.Textbox(
213
+ label="Source Prompt",
214
+ placeholder="A photo of a cat sitting on the couch",
215
+ value="A photo of a cat sitting on the couch"
216
+ )
217
+ gen_prompt_target = gr.Textbox(
218
+ label="Target Prompt",
219
+ placeholder="A photo of a cat wearing a red hat sitting on the couch",
220
+ value="A photo of a cat wearing a red hat sitting on the couch"
221
+ )
222
+ gen_subject_token = gr.Textbox(
223
+ label="Subject Token",
224
+ placeholder="hat",
225
+ value="hat",
226
+ info="Single token representing the object to add **(must appear in target prompt)**"
227
+ )
228
+
229
+ with gr.Accordion("Advanced Settings", open=False):
230
+ gen_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
231
+ gen_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
232
+ gen_extended_scale = gr.Slider(
233
+ label="Extended Scale",
234
+ minimum=1.0,
235
+ maximum=1.3,
236
+ value=1.05,
237
+ step=0.01
238
+ )
239
+ gen_structure_transfer_step = gr.Slider(
240
+ label="Structure Transfer Step",
241
+ minimum=0,
242
+ maximum=10,
243
+ value=2,
244
+ step=1
245
+ )
246
+ gen_blend_steps = gr.Textbox(
247
+ label="Blend Steps",
248
+ value="15",
249
+ info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
250
+ )
251
+ gen_localization_model = gr.Dropdown(
252
+ label="Localization Model",
253
+ choices=[
254
+ "attention_points_sam",
255
+ "attention",
256
+ "attention_box_sam",
257
+ "attention_mask_sam",
258
+ "grounding_sam"
259
+ ],
260
+ value="attention_points_sam"
261
+ )
262
+
263
+ gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary")
264
+
265
+ with gr.Column(scale=2):
266
+ with gr.Row():
267
+ gen_src_output = gr.Image(label="Generated Source Image", type="pil")
268
+ gen_edited_output = gr.Image(label="Edited Image", type="pil")
269
+ gen_status = gr.Textbox(label="Status", interactive=False)
270
+
271
+ gen_submit_btn.click(
272
+ fn=process_generated_image,
273
+ inputs=[
274
+ gen_prompt_source, gen_prompt_target, gen_subject_token,
275
+ gen_seed_src, gen_seed_obj, gen_extended_scale,
276
+ gen_structure_transfer_step, gen_blend_steps,
277
+ gen_localization_model
278
+ ],
279
+ outputs=[gen_src_output, gen_edited_output, gen_status]
280
+ )
281
+
282
+ # Examples for generated images
283
+ gr.Examples(
284
+ examples=[
285
+ ["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"],
286
+ ["A photo of a cat sitting on the couch", "A photo of a cat wearing a red hat sitting on the couch", "hat"],
287
+ ["A car driving through an empty street", "A pink car driving through an empty street", "car"]
288
+ ],
289
+ inputs=[
290
+ gen_prompt_source, gen_prompt_target, gen_subject_token
291
+ ],
292
+ label="Example Prompts"
293
+ )
294
+
295
+ # Real Images Tab
296
+ with gr.TabItem("📸 Real Images"):
297
+ gr.Markdown("### Upload an image and add objects to it")
298
+
299
+ with gr.Row():
300
+ with gr.Column(scale=1):
301
+ real_source_image = gr.Image(label="Source Image", type="pil")
302
+ real_prompt_source = gr.Textbox(
303
+ label="Source Prompt",
304
+ placeholder="A photo of a bed in a dark room",
305
+ value="A photo of a bed in a dark room"
306
+ )
307
+ real_prompt_target = gr.Textbox(
308
+ label="Target Prompt",
309
+ placeholder="A photo of a dog lying on a bed in a dark room",
310
+ value="A photo of a dog lying on a bed in a dark room"
311
+ )
312
+ real_subject_token = gr.Textbox(
313
+ label="Subject Token",
314
+ placeholder="dog",
315
+ value="dog",
316
+ info="Single token representing the object to add **(must appear in target prompt)**"
317
+ )
318
+
319
+ with gr.Accordion("Advanced Settings", open=False):
320
+ real_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
321
+ real_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
322
+ real_extended_scale = gr.Slider(
323
+ label="Extended Scale",
324
+ minimum=1.0,
325
+ maximum=1.3,
326
+ value=1.1,
327
+ step=0.01
328
+ )
329
+ real_structure_transfer_step = gr.Slider(
330
+ label="Structure Transfer Step",
331
+ minimum=0,
332
+ maximum=10,
333
+ value=4,
334
+ step=1
335
+ )
336
+ real_blend_steps = gr.Textbox(
337
+ label="Blend Steps",
338
+ value="18",
339
+ info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
340
+ )
341
+ real_localization_model = gr.Dropdown(
342
+ label="Localization Model",
343
+ choices=[
344
+ "attention",
345
+ "attention_points_sam",
346
+ "attention_box_sam",
347
+ "attention_mask_sam",
348
+ "grounding_sam"
349
+ ],
350
+ value="attention"
351
+ )
352
+ real_use_offset = gr.Checkbox(label="Use Offset", value=False)
353
+ real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False)
354
+
355
+ real_submit_btn = gr.Button("🎨 Edit Image", variant="primary")
356
+
357
+ with gr.Column(scale=2):
358
+ with gr.Row():
359
+ real_src_output = gr.Image(label="Source Image", type="pil")
360
+ real_edited_output = gr.Image(label="Edited Image", type="pil")
361
+ real_status = gr.Textbox(label="Status", interactive=False)
362
+
363
+ real_submit_btn.click(
364
+ fn=process_real_image,
365
+ inputs=[
366
+ real_source_image, real_prompt_source, real_prompt_target, real_subject_token,
367
+ real_seed_src, real_seed_obj, real_extended_scale,
368
+ real_structure_transfer_step, real_blend_steps,
369
+ real_localization_model, real_use_offset,
370
+ real_disable_inversion
371
+ ],
372
+ outputs=[real_src_output, real_edited_output, real_status]
373
+ )
374
+
375
+ # Examples for real images
376
+ gr.Examples(
377
+ examples=[
378
+ [
379
+ "images/bed_dark_room.jpg",
380
+ "A photo of a bed in a dark room",
381
+ "A photo of a dog lying on a bed in a dark room",
382
+ "dog"
383
+ ],
384
+ [
385
+ "images/flower.jpg",
386
+ "A photo of a flower",
387
+ "A bee standing on a flower",
388
+ "bee"
389
+ ]
390
+ ],
391
+ inputs=[
392
+ real_source_image, real_prompt_source, real_prompt_target, real_subject_token
393
+ ],
394
+ label="Example Images & Prompts"
395
+ )
396
+
397
+ # Tips
398
+ with gr.Accordion("💡 Tips for Better Results", open=False):
399
+ gr.Markdown("""
400
+ - **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert
401
+ - **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results
402
+ - **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance
403
+ - **Object Placement Issues**: If the object is not added to the image:
404
+ - Try **decreasing** Structure Transfer Step
405
+ - Try **increasing** Extended Scale
406
+ - **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list
407
+ """)
408
+
409
+ return demo
410
+
411
+ demo = create_interface()
412
+ demo.launch(
413
+ server_name="0.0.0.0",
414
+ server_port=7860,
415
+ share=True
416
+ )
images/bed_dark_room.jpg ADDED

Git LFS Details

  • SHA256: dd0c09288fb6d87a6ca9c0e3687a2748d0c61c885dc928b24073d00929ee76cc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.26 MB
images/flower.jpg ADDED

Git LFS Details

  • SHA256: 72f485cdc855329da5a899f98dae55f51413099abc826857be780c2a4b9bcea7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.77 MB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ numpy==1.26.4
4
+ scipy==1.14.1
5
+ scikit-image==0.24.0
6
+ pandas==2.2.2
7
+ matplotlib
8
+ transformers==4.44.0
9
+ accelerate==0.33.0
10
+ diffusers @ git+https://github.com/huggingface/diffusers.git@15eb77bc4cf2ccb40781cb630b9a734b43cffcb8
11
+ opencv-python
12
+ pyarrow
13
+ fastparquet
14
+ ipykernel
15
+ sentencepiece==0.2.0
16
+ protobuf==5.27.3
17
+ python-dotenv
18
+ git+https://github.com/facebookresearch/sam2.git
19
+ gradio
20
+ spaces
run_CLI_addit_generated.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the LICENSE file
5
+ # located at the root directory.
6
+
7
+ import os
8
+ import argparse
9
+ import torch
10
+ import random
11
+
12
+ from visualization_utils import show_images
13
+ from addit_flux_pipeline import AdditFluxPipeline
14
+ from addit_flux_transformer import AdditFluxTransformer2DModel
15
+ from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
16
+ from addit_methods import add_object_generated
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description='Run ADDIT with generated images')
20
+
21
+ # Required arguments
22
+ parser.add_argument('--prompt_source', type=str, default="A photo of a cat sitting on the couch",
23
+ help='Source prompt for generating the base image')
24
+ parser.add_argument('--prompt_target', type=str, default="A photo of a cat wearing a red hat sitting on the couch",
25
+ help='Target prompt describing the desired edited image')
26
+ parser.add_argument('--subject_token', type=str, default="hat",
27
+ help='Single token representing the subject to add to the image, must appear in the prompt_target')
28
+
29
+ # Optional arguments
30
+ parser.add_argument('--output_dir', type=str, default='outputs',
31
+ help='Directory to save output images (default: outputs)')
32
+ parser.add_argument('--seed_src', type=int, default=6311,
33
+ help='Seed for source generation')
34
+ parser.add_argument('--seed_obj', type=int, default=1,
35
+ help='Seed for edited image generation')
36
+ parser.add_argument('--extended_scale', type=float, default=1.05,
37
+ help='Extended attention scale (default: 1.05)')
38
+ parser.add_argument('--structure_transfer_step', type=int, default=2,
39
+ help='Structure transfer step (default: 2)')
40
+ parser.add_argument('--blend_steps', type=int, nargs='*', default=[15],
41
+ help='Blend steps (default: [15])')
42
+ parser.add_argument('--localization_model', type=str, default="attention_points_sam",
43
+ help='Localization model (default: attention_points_sam, Options: [attention_points_sam, attention, attention_box_sam, attention_mask_sam, grounding_sam])')
44
+ parser.add_argument('--show_attention', action='store_true',
45
+ help='Show attention maps')
46
+ parser.add_argument('--display_output', action='store_true',
47
+ help='Display output images during processing')
48
+
49
+ args = parser.parse_args()
50
+
51
+ assert args.subject_token in args.prompt_target, "Subject token must appear in the prompt_target"
52
+
53
+ # Set up device and model
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ print(f"Using device: {device}")
56
+
57
+ my_transformer = AdditFluxTransformer2DModel.from_pretrained(
58
+ "black-forest-labs/FLUX.1-dev",
59
+ subfolder="transformer",
60
+ torch_dtype=torch.bfloat16
61
+ )
62
+
63
+ pipe = AdditFluxPipeline.from_pretrained(
64
+ "black-forest-labs/FLUX.1-dev",
65
+ transformer=my_transformer,
66
+ torch_dtype=torch.bfloat16
67
+ ).to(device)
68
+
69
+ pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
70
+
71
+ # Create output directory
72
+ os.makedirs(args.output_dir, exist_ok=True)
73
+
74
+ # Process the seeds
75
+ print(f"\nProcessing with source seed: {args.seed_src}, object seed: {args.seed_obj}")
76
+
77
+ src_image, edited_image = add_object_generated(
78
+ pipe,
79
+ args.prompt_source,
80
+ args.prompt_target,
81
+ args.subject_token,
82
+ args.seed_src,
83
+ args.seed_obj,
84
+ show_attention=args.show_attention,
85
+ extended_scale=args.extended_scale,
86
+ structure_transfer_step=args.structure_transfer_step,
87
+ blend_steps=args.blend_steps,
88
+ localization_model=args.localization_model,
89
+ display_output=args.display_output
90
+ )
91
+
92
+ # Save output images
93
+ src_filename = f"src_{args.prompt_source}_seed-src={args.seed_src}.png"
94
+ edited_filename = f"edited_{args.prompt_target}_seed-src={args.seed_src}_seed-obj={args.seed_obj}.png"
95
+
96
+ src_image.save(os.path.join(args.output_dir, src_filename))
97
+ edited_image.save(os.path.join(args.output_dir, edited_filename))
98
+
99
+ print(f"Saved images: {src_filename}, {edited_filename}")
100
+
101
+ if __name__ == "__main__":
102
+ main()
run_CLI_addit_real.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the LICENSE file
5
+ # located at the root directory.
6
+
7
+ import os
8
+ import argparse
9
+ import torch
10
+ import random
11
+ from PIL import Image
12
+
13
+ from visualization_utils import show_images
14
+ from addit_flux_pipeline import AdditFluxPipeline
15
+ from addit_flux_transformer import AdditFluxTransformer2DModel
16
+ from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
17
+ from addit_methods import add_object_real
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser(description='Run ADDIT with real images')
21
+
22
+ # Required arguments
23
+ parser.add_argument('--source_image', type=str, default="images/bed_dark_room.jpg",
24
+ help='Path to the source image')
25
+ parser.add_argument('--prompt_source', type=str, default="A photo of a bed in a dark room",
26
+ help='Source prompt describing the original image')
27
+ parser.add_argument('--prompt_target', type=str, default="A photo of a dog lying on a bed in a dark room",
28
+ help='Target prompt describing the desired edited image')
29
+ parser.add_argument('--subject_token', type=str, default="dog",
30
+ help='Subject token to add to the image')
31
+
32
+ # Optional arguments
33
+ parser.add_argument('--output_dir', type=str, default='outputs',
34
+ help='Directory to save output images (default: outputs)')
35
+ parser.add_argument('--seed_src', type=int, default=6311,
36
+ help='Seed for source generation')
37
+ parser.add_argument('--seed_obj', type=int, default=1,
38
+ help='Seed for edited image generation')
39
+ parser.add_argument('--extended_scale', type=float, default=1.1,
40
+ help='Extended attention scale (default: 1.1)')
41
+ parser.add_argument('--structure_transfer_step', type=int, default=4,
42
+ help='Structure transfer step (default: 4)')
43
+ parser.add_argument('--blend_steps', type=int, nargs='*', default=[18],
44
+ help='Blend steps (default: [18])')
45
+ parser.add_argument('--localization_model', type=str, default="attention",
46
+ help='Localization model (default: attention, Options: [attention_points_sam, attention, attention_box_sam, attention_mask_sam, grounding_sam])')
47
+ parser.add_argument('--use_offset', action='store_true',
48
+ help='Use offset in processing')
49
+ parser.add_argument('--show_attention', action='store_true',
50
+ help='Show attention maps')
51
+ parser.add_argument('--disable_inversion', action='store_true',
52
+ help='Disable source image inversion')
53
+ parser.add_argument('--display_output', action='store_true',
54
+ help='Display output images during processing')
55
+
56
+ args = parser.parse_args()
57
+
58
+ assert args.subject_token in args.prompt_target, "Subject token must appear in the prompt_target"
59
+
60
+ # Set up device and model
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ print(f"Using device: {device}")
63
+
64
+ my_transformer = AdditFluxTransformer2DModel.from_pretrained(
65
+ "black-forest-labs/FLUX.1-dev",
66
+ subfolder="transformer",
67
+ torch_dtype=torch.bfloat16
68
+ )
69
+
70
+ pipe = AdditFluxPipeline.from_pretrained(
71
+ "black-forest-labs/FLUX.1-dev",
72
+ transformer=my_transformer,
73
+ torch_dtype=torch.bfloat16
74
+ ).to(device)
75
+
76
+ pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
77
+
78
+ # Load and resize source image
79
+ source_image = Image.open(args.source_image).resize((1024, 1024))
80
+ print(f"Loaded source image: {args.source_image}")
81
+
82
+ # Set random seed
83
+ if args.seed_src is None:
84
+ random.seed(0)
85
+ args.seed_src = random.randint(0, 10000)
86
+
87
+ # Create output directory
88
+ os.makedirs(args.output_dir, exist_ok=True)
89
+
90
+ # Process the seeds
91
+ print(f"\nProcessing with source seed: {args.seed_src}, object seed: {args.seed_obj}")
92
+
93
+ src_image, edited_image = add_object_real(
94
+ pipe,
95
+ source_image=source_image,
96
+ prompt_source=args.prompt_source,
97
+ prompt_object=args.prompt_target,
98
+ subject_token=args.subject_token,
99
+ seed_src=args.seed_src,
100
+ seed_obj=args.seed_obj,
101
+ extended_scale=args.extended_scale,
102
+ structure_transfer_step=args.structure_transfer_step,
103
+ blend_steps=args.blend_steps,
104
+ localization_model=args.localization_model,
105
+ use_offset=args.use_offset,
106
+ show_attention=args.show_attention,
107
+ use_inversion=not args.disable_inversion,
108
+ display_output=args.display_output
109
+ )
110
+
111
+ # Save output images
112
+ src_filename = f"src_{args.prompt_source}_seed-src={args.seed_src}.png"
113
+ edited_filename = f"edited_{args.prompt_target}_seed-src={args.seed_src}_seed-obj={args.seed_obj}.png"
114
+
115
+ src_image.save(os.path.join(args.output_dir, src_filename))
116
+ edited_image.save(os.path.join(args.output_dir, edited_filename))
117
+
118
+ print(f"Saved images: {src_filename}, {edited_filename}")
119
+
120
+ if __name__ == "__main__":
121
+ main()
run_addit_generated.ipynb ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.\n",
10
+ "#\n",
11
+ "# This work is licensed under the LICENSE file\n",
12
+ "# located at the root directory.\n",
13
+ "import torch\n",
14
+ "import random\n",
15
+ "\n",
16
+ "from visualization_utils import show_images\n",
17
+ "from addit_flux_pipeline import AdditFluxPipeline\n",
18
+ "from addit_flux_transformer import AdditFluxTransformer2DModel\n",
19
+ "from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler\n",
20
+ "from addit_methods import add_object_generated\n",
21
+ "\n",
22
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
23
+ "my_transformer = AdditFluxTransformer2DModel.from_pretrained(\"black-forest-labs/FLUX.1-dev\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
24
+ "\n",
25
+ "pipe = AdditFluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", \n",
26
+ " transformer=my_transformer,\n",
27
+ " torch_dtype=torch.bfloat16).to(device)\n",
28
+ "pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "# Reset the GPU memory tracking\n",
38
+ "torch.cuda.reset_max_memory_allocated(0)\n",
39
+ "\n",
40
+ "(prompt1, prompt2), subject_token = [\"A photo of a man sitting on a bench\", \"A photo of a man sitting on a bench with a dog\"], \"dog\"\n",
41
+ "\n",
42
+ "\n",
43
+ "random.seed(0)\n",
44
+ "seeds_src = [663]\n",
45
+ "seeds_obj = [0,1,2]\n",
46
+ "\n",
47
+ "for seed_src in seeds_src:\n",
48
+ " for seed_obj in seeds_obj:\n",
49
+ " src_image, edited_image = add_object_generated(pipe, prompt1, prompt2, subject_token, seed_src, seed_obj, show_attention=True, \n",
50
+ " extended_scale=1.05, structure_transfer_step=2, blend_steps=[15], \n",
51
+ " localization_model=\"attention_points_sam\", display_output=True)\n",
52
+ "\n",
53
+ "# Report maximum GPU memory usage in GB\n",
54
+ "max_memory_used = torch.cuda.max_memory_allocated(0) / (1024**3) # Convert to GB\n",
55
+ "print(f\"Maximum GPU memory used: {max_memory_used:.2f} GB\")"
56
+ ]
57
+ }
58
+ ],
59
+ "metadata": {
60
+ "kernelspec": {
61
+ "display_name": "addit",
62
+ "language": "python",
63
+ "name": "python3"
64
+ },
65
+ "language_info": {
66
+ "codemirror_mode": {
67
+ "name": "ipython",
68
+ "version": 3
69
+ },
70
+ "file_extension": ".py",
71
+ "mimetype": "text/x-python",
72
+ "name": "python",
73
+ "nbconvert_exporter": "python",
74
+ "pygments_lexer": "ipython3",
75
+ "version": "3.11.9"
76
+ }
77
+ },
78
+ "nbformat": 4,
79
+ "nbformat_minor": 2
80
+ }
run_addit_real.ipynb ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.\n",
10
+ "#\n",
11
+ "# This work is licensed under the LICENSE file\n",
12
+ "# located at the root directory.\n",
13
+ "\n",
14
+ "import torch\n",
15
+ "import random\n",
16
+ "from PIL import Image\n",
17
+ "\n",
18
+ "from visualization_utils import show_images\n",
19
+ "from addit_flux_pipeline import AdditFluxPipeline\n",
20
+ "from addit_flux_transformer import AdditFluxTransformer2DModel\n",
21
+ "from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler\n",
22
+ "from addit_methods import add_object_real\n",
23
+ "\n",
24
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
25
+ "my_transformer = AdditFluxTransformer2DModel.from_pretrained(\"black-forest-labs/FLUX.1-dev\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
26
+ "\n",
27
+ "pipe = AdditFluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", \n",
28
+ " transformer=my_transformer,\n",
29
+ " torch_dtype=torch.bfloat16).to(device)\n",
30
+ "pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "# Reset the GPU memory tracking\n",
40
+ "torch.cuda.reset_max_memory_allocated(0)\n",
41
+ "\n",
42
+ "# source_image = Image.open(\"images/cat.jpg\").resize((1024, 1024))\n",
43
+ "# (prompt_src, prompt_tgt), subject_token = [\"A photo of a cat\", \"A photo of a cat wearing a scarf\"], \"scarf\"\n",
44
+ "\n",
45
+ "source_image = Image.open(\"images/bed_dark_room.jpg\").resize((1024, 1024))\n",
46
+ "(prompt_src, prompt_tgt), subject_token = [\"A photo of a bed in a dark room\", \"A photo of a dog lying on a bed in a dark room\"], \"dog\"\n",
47
+ "\n",
48
+ "random.seed(0)\n",
49
+ "seed_src = random.randint(0, 10000)\n",
50
+ "seeds_obj = [0,1,2]\n",
51
+ "\n",
52
+ "for seed_obj in seeds_obj:\n",
53
+ " images_list = add_object_real(pipe, source_image=source_image, prompt_source=prompt_src, prompt_object=prompt_tgt, \n",
54
+ " subject_token=subject_token, seed_src=seed_src, seed_obj=seed_obj, \n",
55
+ " extended_scale =1.1, structure_transfer_step=4, blend_steps = [18], #localization_model=\"attention\",\n",
56
+ " use_offset=False, show_attention=True, use_inversion=True, display_output=True)\n",
57
+ "\n",
58
+ "# Report maximum GPU memory usage in GB\n",
59
+ "max_memory_used = torch.cuda.max_memory_allocated(0) / (1024**3) # Convert to GB\n",
60
+ "print(f\"Maximum GPU memory used: {max_memory_used:.2f} GB\")"
61
+ ]
62
+ }
63
+ ],
64
+ "metadata": {
65
+ "kernelspec": {
66
+ "display_name": "addit",
67
+ "language": "python",
68
+ "name": "python3"
69
+ },
70
+ "language_info": {
71
+ "codemirror_mode": {
72
+ "name": "ipython",
73
+ "version": 3
74
+ },
75
+ "file_extension": ".py",
76
+ "mimetype": "text/x-python",
77
+ "name": "python",
78
+ "nbconvert_exporter": "python",
79
+ "pygments_lexer": "ipython3",
80
+ "version": "3.11.9"
81
+ }
82
+ },
83
+ "nbformat": 4,
84
+ "nbformat_minor": 2
85
+ }
visualization_utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is licensed under the LICENSE file
4
+ # located at the root directory.
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw
9
+ import torch
10
+ import matplotlib.pyplot as plt
11
+ from skimage import filters
12
+ from IPython.display import display
13
+
14
+ def gaussian_blur(heatmap, kernel_size=7):
15
+ # Shape of heatmap: (H, W)
16
+ heatmap = heatmap.cpu().numpy()
17
+ heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), 0)
18
+ heatmap = torch.tensor(heatmap)
19
+
20
+ return heatmap
21
+
22
+ def show_cam_on_image(img, mask):
23
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
24
+ heatmap = np.float32(heatmap) / 255
25
+ cam = heatmap + np.float32(img)
26
+ cam = cam / np.max(cam)
27
+ return cam
28
+
29
+ def show_image_and_heatmap(heatmap: torch.Tensor, image: Image.Image, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3):
30
+ image = image.resize((relevnace_res, relevnace_res))
31
+ image = np.array(image)
32
+ image = (image - image.min()) / (image.max() - image.min())
33
+
34
+ # Apply gaussian blur to heatmap
35
+ # heatmap = gaussian_blur(heatmap, kernel_size=gassussian_kernel_size)
36
+
37
+ # heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
38
+ # otsu_thr = filters.threshold_otsu(heatmap.cpu().numpy())
39
+ # heatmap = (heatmap > otsu_thr).to(heatmap.dtype)
40
+
41
+ heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1])
42
+ heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation)
43
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
44
+ heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu()
45
+
46
+ vis = show_cam_on_image(image, heatmap)
47
+ vis = np.uint8(255 * vis)
48
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
49
+
50
+ vis = vis.astype(np.uint8)
51
+ vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res))
52
+
53
+ return vis
54
+
55
+ def show_only_heatmap(heatmap: torch.Tensor, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3):
56
+ # Apply gaussian blur to heatmap
57
+ # heatmap = gaussian_blur(heatmap, kernel_size=gassussian_kernel_size)
58
+
59
+ heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1])
60
+ heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation)
61
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
62
+ heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu()
63
+
64
+ vis = heatmap
65
+ vis = np.uint8(255 * vis)
66
+
67
+ # Show in black and white
68
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_GRAY2BGR)
69
+
70
+ vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res))
71
+
72
+ return vis
73
+
74
+ def visualize_tokens_attentions(attention, tokens, image, heatmap_interpolation="nearest", show_on_image=True):
75
+ # Tokens: list of strings
76
+ # attention: tensor of shape (batch_size, num_tokens, width, height)
77
+ token_vis = []
78
+ for j, token in enumerate(tokens):
79
+ if j >= attention.shape[0]:
80
+ break
81
+
82
+ if show_on_image:
83
+ vis = show_image_and_heatmap(attention[j], image, relevnace_res=512, interpolation=heatmap_interpolation)
84
+ else:
85
+ vis = show_only_heatmap(attention[j], relevnace_res=512, interpolation=heatmap_interpolation)
86
+
87
+ token_vis.append((token, vis))
88
+
89
+ # Display the token and the attention map in a grid, with K tokens per row
90
+ K = 4
91
+ n_rows = (len(token_vis) + K - 1) // K # Ceiling division
92
+ fig, axs = plt.subplots(n_rows, K, figsize=(K*5, n_rows*5))
93
+
94
+ for i, (token, vis) in enumerate(token_vis):
95
+ row, col = divmod(i, K)
96
+ if n_rows > 1:
97
+ ax = axs[row, col]
98
+ elif K > 1:
99
+ ax = axs[col]
100
+ else:
101
+ ax = axs
102
+
103
+ ax.imshow(vis)
104
+ ax.set_title(token)
105
+ ax.axis("off")
106
+
107
+ # Hide unused subplots
108
+ for j in range(i + 1, n_rows * K):
109
+ row, col = divmod(j, K)
110
+ if n_rows > 1:
111
+ axs[row, col].axis('off')
112
+ elif K > 1:
113
+ axs[col].axis('off')
114
+
115
+ plt.tight_layout()
116
+
117
+ # We want to return the figure so that we can save it to a file
118
+ return fig
119
+
120
+ def show_images(images, titles=None, size=1024, max_row_length=5, figsize=None, col_height=10, save_path=None):
121
+ if isinstance(images, Image.Image):
122
+ images = [images]
123
+
124
+ if len(images) == 1:
125
+ img = images[0]
126
+ img = img.resize((size, size))
127
+ plt.imshow(img)
128
+ plt.axis('off')
129
+
130
+ if titles is not None:
131
+ plt.title(titles[0])
132
+
133
+ if save_path:
134
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
135
+
136
+ plt.show()
137
+ else:
138
+ images = [img.resize((size, size)) for img in images]
139
+
140
+ # Check if the number of titles matches the number of images
141
+ if titles is not None:
142
+ assert len(images) == len(titles), "Number of titles should match the number of images"
143
+
144
+ n_images = len(images)
145
+ n_cols = min(n_images, max_row_length)
146
+ n_rows = (n_images + n_cols - 1) // n_cols # Calculate the number of rows needed
147
+
148
+ if figsize is None:
149
+ figsize=(n_cols * col_height, n_rows * col_height)
150
+
151
+ fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
152
+ axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs]
153
+
154
+ # Display images in the subplots
155
+ for i, img in enumerate(images):
156
+ axs[i].imshow(img)
157
+ if titles is not None:
158
+ axs[i].set_title(titles[i])
159
+ axs[i].axis("off")
160
+
161
+ # Turn off any unused subplots
162
+ for ax in axs[len(images):]:
163
+ ax.axis("off")
164
+
165
+ if save_path:
166
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
167
+
168
+ plt.show()
169
+
170
+ def show_tensors(tensors, titles=None, size=None, max_row_length=5):
171
+ # Shape of tensors: List[Tensor[H, W]]
172
+ if size is not None:
173
+ tensors = [torch.nn.functional.interpolate(t.unsqueeze(0).unsqueeze(0), size=(size, size), mode='bilinear').squeeze() for t in tensors]
174
+
175
+ if len(tensors) == 1:
176
+ plt.imshow(tensors[0].cpu().numpy())
177
+ plt.axis('off')
178
+
179
+ if titles is not None:
180
+ plt.title(titles[0])
181
+
182
+ plt.show()
183
+ else:
184
+ # Check if the number of titles matches the number of images
185
+ if titles is not None:
186
+ assert len(tensors) == len(titles), "Number of titles should match the number of images"
187
+
188
+ n_tensors = len(tensors)
189
+ n_cols = min(n_tensors, max_row_length)
190
+ n_rows = (n_tensors + n_cols - 1) // n_cols
191
+
192
+ fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10))
193
+ axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs]
194
+
195
+ for i, tensor in enumerate(tensors):
196
+ axs[i].imshow(tensor.cpu().numpy())
197
+ if titles is not None:
198
+ axs[i].set_title(titles[i])
199
+ axs[i].axis("off")
200
+
201
+ for ax in axs[len(tensors):]:
202
+ ax.axis("off")
203
+
204
+ plt.show()
205
+
206
+ def draw_bboxes_on_image(image, bboxes, color="red", thickness=2):
207
+ image = image.copy()
208
+ draw = ImageDraw.Draw(image)
209
+ for bbox in bboxes:
210
+ draw.rectangle(bbox, outline=color, width=thickness)
211
+ return image
212
+
213
+ def draw_points_on_pil_image(pil_image, point_coords, point_color="red", radius=5):
214
+ """
215
+ Draw points (circles) on a PIL image and return the modified image.
216
+
217
+ :param pil_image: PIL Image (e.g., sam_masked_image)
218
+ :param point_coords: An array-like of shape (N, 2), with x,y coordinates
219
+ :param point_color: Color of the point (default 'red')
220
+ :param radius: Radius of the drawn circles
221
+ :return: PIL Image with points drawn
222
+ """
223
+ # Copy so we don't modify the original
224
+ out_img = pil_image.copy()
225
+ draw = ImageDraw.Draw(out_img)
226
+
227
+ # Draw each point
228
+ for x, y in point_coords:
229
+ # Calculate bounding box of the circle
230
+ left_up_point = (x - radius, y - radius)
231
+ right_down_point = (x + radius, y + radius)
232
+ # Draw the circle
233
+ draw.ellipse([left_up_point, right_down_point], fill=point_color, outline=point_color)
234
+
235
+ return out_img