Add application file
Browse files- .gitattributes +1 -0
- LICENSE +35 -0
- README.md +168 -10
- addit_attention_processors.py +297 -0
- addit_attention_store.py +316 -0
- addit_blending_utils.py +232 -0
- addit_flux_pipeline.py +1389 -0
- addit_flux_transformer.py +521 -0
- addit_methods.py +186 -0
- addit_scheduler.py +101 -0
- app.py +416 -0
- images/bed_dark_room.jpg +3 -0
- images/flower.jpg +3 -0
- requirements.txt +20 -0
- run_CLI_addit_generated.py +102 -0
- run_CLI_addit_real.py +121 -0
- run_addit_generated.ipynb +80 -0
- run_addit_real.ipynb +85 -0
- visualization_utils.py +235 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NVIDIA License
|
2 |
+
|
3 |
+
1. Definitions
|
4 |
+
|
5 |
+
“Licensor” means any person or entity that distributes its Work.
|
6 |
+
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
7 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
8 |
+
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
|
9 |
+
|
10 |
+
2. License Grant
|
11 |
+
|
12 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
13 |
+
|
14 |
+
3. Limitations
|
15 |
+
|
16 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
|
17 |
+
|
18 |
+
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
|
19 |
+
|
20 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
|
21 |
+
|
22 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
23 |
+
|
24 |
+
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
|
25 |
+
|
26 |
+
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
|
27 |
+
|
28 |
+
4. Disclaimer of Warranty.
|
29 |
+
|
30 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
31 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
32 |
+
|
33 |
+
5. Limitation of Liability.
|
34 |
+
|
35 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
README.md
CHANGED
@@ -1,13 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
1 |
+
# 🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
[](https://arxiv.org/abs/2411.07232)
|
6 |
+
[](https://research.nvidia.com/labs/par/addit/)
|
7 |
+
|
8 |
+
</div>
|
9 |
+
|
10 |
+
## 👥 Authors
|
11 |
+
|
12 |
+
**Yoad Tewel**<sup>1,2</sup>, **Rinon Gal**<sup>1,2</sup>, **Dvir Samuel**<sup>3</sup>, **Yuval Atzmon**<sup>1</sup>, **Lior Wolf**<sup>2</sup>, **Gal Chechik**<sup>1</sup>
|
13 |
+
|
14 |
+
<sup>1</sup>NVIDIA • <sup>2</sup>Tel Aviv University • <sup>3</sup>Bar-Ilan University
|
15 |
+
|
16 |
+
<div align="center">
|
17 |
+
<img src="https://research.nvidia.com/labs/par/addit/static/images/Teaser.png" alt="Add-it Teaser" width="800"/>
|
18 |
+
</div>
|
19 |
+
|
20 |
+
## 📄 Abstract
|
21 |
+
|
22 |
+
Adding objects into images based on text instructions is a challenging task in semantic image editing, requiring a balance between preserving the original scene and seamlessly integrating the new object in a fitting location. Despite extensive efforts, existing models often struggle with this balance, particularly with finding a natural location for adding an object in complex scenes.
|
23 |
+
|
24 |
+
We introduce **Add-it**, a training-free approach that extends diffusion models' attention mechanisms to incorporate information from three key sources: the scene image, the text prompt, and the generated image itself. Our weighted extended-attention mechanism maintains structural consistency and fine details while ensuring natural object placement.
|
25 |
+
|
26 |
+
Without task-specific fine-tuning, Add-it achieves state-of-the-art results on both real and generated image insertion benchmarks, including our newly constructed "Additing Affordance Benchmark" for evaluating object placement plausibility, outperforming supervised methods. Human evaluations show that Add-it is preferred in over 80% of cases, and it also demonstrates improvements in various automated metrics.
|
27 |
+
|
28 |
+
---
|
29 |
+
|
30 |
+
## 📋 Description
|
31 |
+
|
32 |
+
This repository contains the official implementation of the Add-it paper, providing tools for seamless object insertion into images using pretrained diffusion models.
|
33 |
+
|
34 |
+
---
|
35 |
+
|
36 |
+
## 🛠️ Setup
|
37 |
+
|
38 |
+
```bash
|
39 |
+
conda env create -f environment.yml
|
40 |
+
conda activate addit
|
41 |
+
```
|
42 |
+
|
43 |
---
|
44 |
+
|
45 |
+
## 🚀 Usage
|
46 |
+
|
47 |
+
### 💻 Command Line Interface (CLI)
|
48 |
+
|
49 |
+
Add-it provides two CLI scripts for different use cases:
|
50 |
+
|
51 |
+
#### 1. 🎭 Adding Objects to Generated Images
|
52 |
+
|
53 |
+
Use `run_CLI_addit_generated.py` to add objects to AI-generated images:
|
54 |
+
|
55 |
+
```bash
|
56 |
+
python run_CLI_addit_generated.py \
|
57 |
+
--prompt_source "A photo of a cat sitting on the couch" \
|
58 |
+
--prompt_target "A photo of a cat wearing a red hat sitting on the couch" \
|
59 |
+
--subject_token "hat"
|
60 |
+
```
|
61 |
+
|
62 |
+
##### ⚙️ Options for Generated Images
|
63 |
+
|
64 |
+
**🔴 Required Arguments:**
|
65 |
+
- `--prompt_source`: Source prompt for generating the base image
|
66 |
+
- `--prompt_target`: Target prompt describing the desired edited image
|
67 |
+
- `--subject_token`: Single token representing the subject to add (must appear in prompt_target)
|
68 |
+
|
69 |
+
**🔵 Optional Arguments:**
|
70 |
+
- `--output_dir`: Directory to save output images (default: "outputs")
|
71 |
+
- `--seed_src`: Seed for source generation (default: 6311)
|
72 |
+
- `--seed_obj`: Seed for edited image generation (default: 1)
|
73 |
+
- `--extended_scale`: Extended attention scale (default: 1.05)
|
74 |
+
- `--structure_transfer_step`: Structure transfer step (default: 2)
|
75 |
+
- `--blend_steps`: Blend steps (default: [15]). To allow for changes in the input image pass `--blend_steps` with empty value.
|
76 |
+
- `--localization_model`: Localization model (default: "attention_points_sam")
|
77 |
+
- **Options:** `attention_points_sam`, `attention`, `attention_box_sam`, `attention_mask_sam`, `grounding_sam`
|
78 |
+
- `--show_attention`: Show attention maps using pyplot (flag), will be saved to `attn_vis.png`.
|
79 |
+
|
80 |
+
#### 2. 📸 Adding Objects to Real Images
|
81 |
+
|
82 |
+
Use `run_CLI_addit_real.py` to add objects to existing images:
|
83 |
+
|
84 |
+
```bash
|
85 |
+
python run_CLI_addit_real.py \
|
86 |
+
--source_image "images/bed_dark_room.jpg" \
|
87 |
+
--prompt_source "A photo of a bed in a dark room" \
|
88 |
+
--prompt_target "A photo of a dog lying on a bed in a dark room" \
|
89 |
+
--subject_token "dog"
|
90 |
+
```
|
91 |
+
|
92 |
+
##### ⚙️ Options for Real Images
|
93 |
+
|
94 |
+
**🔴 Required Arguments:**
|
95 |
+
- `--source_image`: Path to the source image (default: "images/bed_dark_room.jpg")
|
96 |
+
- `--prompt_source`: Source prompt describing the original image
|
97 |
+
- `--prompt_target`: Target prompt describing the desired edited image
|
98 |
+
- `--subject_token`: Subject token to add to the image (must appear in prompt_target)
|
99 |
+
|
100 |
+
**🔵 Optional Arguments:**
|
101 |
+
- `--output_dir`: Directory to save output images (default: "outputs")
|
102 |
+
- `--seed_src`: Seed for source generation (default: 6311)
|
103 |
+
- `--seed_obj`: Seed for edited image generation (default: 1)
|
104 |
+
- `--extended_scale`: Extended attention scale (default: 1.1)
|
105 |
+
- `--structure_transfer_step`: Structure transfer step (default: 4)
|
106 |
+
- `--blend_steps`: Blend steps (default: [18]). To allow for changes in the input image pass `--blend_steps` with empty value.
|
107 |
+
- `--localization_model`: Localization model (default: "attention")
|
108 |
+
- **Options:** `attention_points_sam`, `attention`, `attention_box_sam`, `attention_mask_sam`, `grounding_sam`
|
109 |
+
- `--use_offset`: Use offset in processing (flag)
|
110 |
+
- `--show_attention`: Show attention maps using pyplot (flag), will be saved to `attn_vis.png`.
|
111 |
+
- `--disable_inversion`: Disable source image inversion (flag)
|
112 |
+
|
113 |
+
---
|
114 |
+
|
115 |
+
### 📓 Jupyter Notebooks
|
116 |
+
|
117 |
+
You can run Add-it in two interactive modes:
|
118 |
+
|
119 |
+
| Mode | Notebook | Description |
|
120 |
+
|------|----------|-------------|
|
121 |
+
| 🎭 **Generated Images** | `run_addit_generated.ipynb` | Adding objects to AI-generated images |
|
122 |
+
| 📸 **Real Images** | `run_addit_real.ipynb` | Adding objects to existing real images |
|
123 |
+
|
124 |
+
The notebooks contain examples of different prompts and parameters that can be adjusted to control the object insertion process.
|
125 |
+
|
126 |
+
---
|
127 |
+
|
128 |
+
## 💡 Tips for Better Results
|
129 |
+
|
130 |
+
- **Prompt Design**: The `--prompt_target` should be similar to the `--prompt_source`, but include a description of the new object to insert
|
131 |
+
- **Seed Variation**: Try different values for `--seed_obj` - some prompts may require a few attempts to get satisfying results
|
132 |
+
- **Localization Models**: The most effective `--localization_model` options are `attention_points_sam` and `attention`. Use the `--show_attention` flag to visualize localization performance
|
133 |
+
- **Object Placement Issues**: If the object is not added to the image:
|
134 |
+
- Try **decreasing** `--structure_transfer_step`
|
135 |
+
- Try **increasing** `--extended_scale`
|
136 |
+
- **Flexibility**: To allow more flexibility in modifying the source image, set `--blend_steps` to an empty value to send an empty list: `[]`
|
137 |
+
|
138 |
+
---
|
139 |
+
|
140 |
+
## 📰 News
|
141 |
+
|
142 |
+
- **🎉 2025 JUL**: Official Add-it implementation is released!
|
143 |
+
|
144 |
+
---
|
145 |
+
|
146 |
+
## 📝 TODO
|
147 |
+
|
148 |
+
- [x] Release code
|
149 |
+
|
150 |
+
---
|
151 |
+
|
152 |
+
## 📚 Citation
|
153 |
+
|
154 |
+
If you make use of our work, please cite our paper:
|
155 |
+
|
156 |
+
```bibtex
|
157 |
+
@misc{tewel2024addit,
|
158 |
+
title={Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models},
|
159 |
+
author={Yoad Tewel and Rinon Gal and Dvir Samuel and Yuval Atzmon and Lior Wolf and Gal Chechik},
|
160 |
+
year={2024},
|
161 |
+
eprint={2411.07232},
|
162 |
+
archivePrefix={arXiv},
|
163 |
+
primaryClass={cs.CV}
|
164 |
+
}
|
165 |
+
```
|
166 |
+
|
167 |
---
|
168 |
|
169 |
+
<div align="center">
|
170 |
+
<strong>🌟 Star this repo if you find it useful! 🌟</strong>
|
171 |
+
</div>
|
addit_attention_processors.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
16 |
+
#
|
17 |
+
# This work is licensed under the LICENSE file
|
18 |
+
# located at the root directory.
|
19 |
+
|
20 |
+
from collections import defaultdict
|
21 |
+
from diffusers.models.attention_processor import Attention, apply_rope
|
22 |
+
from typing import Callable, List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
from addit_attention_store import AttentionStore
|
25 |
+
from visualization_utils import show_tensors
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import numpy as np
|
30 |
+
from scipy.optimize import brentq
|
31 |
+
|
32 |
+
def apply_standard_attention(query, key, value, attn, attention_probs=None):
|
33 |
+
batch_size, attn_heads, _, head_dim = query.shape
|
34 |
+
|
35 |
+
# Do normal attention, to cache the attention scores
|
36 |
+
query = query.reshape(batch_size*attn_heads, -1, head_dim)
|
37 |
+
key = key.reshape(batch_size*attn_heads, -1, head_dim)
|
38 |
+
value = value.reshape(batch_size*attn_heads, -1, head_dim)
|
39 |
+
|
40 |
+
if attention_probs is None:
|
41 |
+
attention_probs = attn.get_attention_scores(query, key)
|
42 |
+
|
43 |
+
hidden_states = torch.bmm(attention_probs, value)
|
44 |
+
hidden_states = hidden_states.view(batch_size, attn_heads, -1, head_dim)
|
45 |
+
|
46 |
+
return hidden_states, attention_probs
|
47 |
+
|
48 |
+
def apply_extended_attention(query, key, value, attention_store, attn, layer_name, step_index, extend_type="pixels",
|
49 |
+
extended_scale=1., record_attention=False):
|
50 |
+
batch_size = query.size(0)
|
51 |
+
extend_query = query[1:]
|
52 |
+
|
53 |
+
if extend_type == "full":
|
54 |
+
added_key = key[0] * extended_scale
|
55 |
+
added_value = value[0]
|
56 |
+
elif extend_type == "text":
|
57 |
+
added_key = key[0, :, :512] * extended_scale
|
58 |
+
added_value = value[0, :, :512]
|
59 |
+
elif extend_type == "pixels":
|
60 |
+
added_key = key[0, :, 512:]
|
61 |
+
added_value = value[0, :, 512:]
|
62 |
+
|
63 |
+
key[1] = key[1] * extended_scale
|
64 |
+
|
65 |
+
extend_key = torch.cat([added_key, key[1]], dim=1).unsqueeze(0)
|
66 |
+
extend_value = torch.cat([added_value, value[1]], dim=1).unsqueeze(0)
|
67 |
+
|
68 |
+
hidden_states_0 = F.scaled_dot_product_attention(query[:1], key[:1], value[:1], dropout_p=0.0, is_causal=False)
|
69 |
+
|
70 |
+
if record_attention or attention_store.is_cache_attn_ratio(step_index):
|
71 |
+
hidden_states_1, attention_probs_1 = apply_standard_attention(extend_query, extend_key, extend_value, attn)
|
72 |
+
else:
|
73 |
+
hidden_states_1 = F.scaled_dot_product_attention(extend_query, extend_key, extend_value, dropout_p=0.0, is_causal=False)
|
74 |
+
|
75 |
+
if record_attention:
|
76 |
+
# Store Attention
|
77 |
+
seq_len = attention_probs_1.size(2) - attention_probs_1.size(1)
|
78 |
+
self_attention_probs_1 = attention_probs_1[:,:,seq_len:]
|
79 |
+
attention_store.store_attention(self_attention_probs_1, layer_name, 1, attn.heads)
|
80 |
+
|
81 |
+
if attention_store.is_cache_attn_ratio(step_index):
|
82 |
+
attention_store.store_attention_ratios(attention_probs_1, step_index, layer_name)
|
83 |
+
|
84 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0)
|
85 |
+
|
86 |
+
return hidden_states
|
87 |
+
|
88 |
+
def apply_attention(query, key, value, attention_store, attn, layer_name, step_index,
|
89 |
+
record_attention, extended_attention, extended_scale):
|
90 |
+
if extended_attention:
|
91 |
+
hidden_states = apply_extended_attention(query, key, value, attention_store, attn, layer_name, step_index,
|
92 |
+
extended_scale=extended_scale,
|
93 |
+
record_attention=record_attention)
|
94 |
+
else:
|
95 |
+
if record_attention:
|
96 |
+
hidden_states_0 = F.scaled_dot_product_attention(query[:1], key[:1], value[:1], dropout_p=0.0, is_causal=False)
|
97 |
+
hidden_states_1, attention_probs_1 = apply_standard_attention(query[1:], key[1:], value[1:], attn)
|
98 |
+
attention_store.store_attention(attention_probs_1, layer_name, 1, attn.heads)
|
99 |
+
|
100 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0)
|
101 |
+
else:
|
102 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
103 |
+
|
104 |
+
return hidden_states
|
105 |
+
|
106 |
+
class AdditFluxAttnProcessor2_0:
|
107 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
108 |
+
|
109 |
+
def __init__(self, layer_name: str, attention_store: AttentionStore,
|
110 |
+
extended_steps: Tuple[int, int] = (0, 30), **kwargs):
|
111 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
112 |
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
113 |
+
|
114 |
+
self.layer_name = layer_name
|
115 |
+
self.layer_idx = int(layer_name.split(".")[-1])
|
116 |
+
self.attention_store = attention_store
|
117 |
+
|
118 |
+
self.extended_steps = (0, extended_steps) if isinstance(extended_steps, int) else extended_steps
|
119 |
+
|
120 |
+
def __call__(
|
121 |
+
self,
|
122 |
+
attn: Attention,
|
123 |
+
hidden_states: torch.FloatTensor,
|
124 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
125 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
126 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
127 |
+
|
128 |
+
step_index: Optional[int] = None,
|
129 |
+
extended_scale: Optional[float] = 1.0,
|
130 |
+
) -> torch.FloatTensor:
|
131 |
+
input_ndim = hidden_states.ndim
|
132 |
+
if input_ndim == 4:
|
133 |
+
batch_size, channel, height, width = hidden_states.shape
|
134 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
135 |
+
context_input_ndim = encoder_hidden_states.ndim
|
136 |
+
if context_input_ndim == 4:
|
137 |
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
138 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
139 |
+
|
140 |
+
batch_size = encoder_hidden_states.shape[0]
|
141 |
+
|
142 |
+
# `sample` projections.
|
143 |
+
query = attn.to_q(hidden_states)
|
144 |
+
key = attn.to_k(hidden_states)
|
145 |
+
value = attn.to_v(hidden_states)
|
146 |
+
|
147 |
+
inner_dim = key.shape[-1]
|
148 |
+
head_dim = inner_dim // attn.heads
|
149 |
+
|
150 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
151 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
152 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
153 |
+
|
154 |
+
if attn.norm_q is not None:
|
155 |
+
query = attn.norm_q(query)
|
156 |
+
if attn.norm_k is not None:
|
157 |
+
key = attn.norm_k(key)
|
158 |
+
|
159 |
+
# `context` projections.
|
160 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
161 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
162 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
163 |
+
|
164 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
165 |
+
batch_size, -1, attn.heads, head_dim
|
166 |
+
).transpose(1, 2)
|
167 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
168 |
+
batch_size, -1, attn.heads, head_dim
|
169 |
+
).transpose(1, 2)
|
170 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
171 |
+
batch_size, -1, attn.heads, head_dim
|
172 |
+
).transpose(1, 2)
|
173 |
+
|
174 |
+
if attn.norm_added_q is not None:
|
175 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
176 |
+
if attn.norm_added_k is not None:
|
177 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
178 |
+
|
179 |
+
# attention
|
180 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
181 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
182 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
183 |
+
|
184 |
+
if image_rotary_emb is not None:
|
185 |
+
# YiYi to-do: update uising apply_rotary_emb
|
186 |
+
# from ..embeddings import apply_rotary_emb
|
187 |
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
188 |
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
189 |
+
query, key = apply_rope(query, key, image_rotary_emb)
|
190 |
+
|
191 |
+
record_attention = self.attention_store.is_record_attention(self.layer_name, step_index)
|
192 |
+
extend_start, extend_end = self.extended_steps
|
193 |
+
extended_attention = extend_start <= step_index <= extend_end
|
194 |
+
|
195 |
+
hidden_states = apply_attention(query, key, value, self.attention_store, attn, self.layer_name, step_index,
|
196 |
+
record_attention, extended_attention, extended_scale)
|
197 |
+
|
198 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
199 |
+
hidden_states = hidden_states.to(query.dtype)
|
200 |
+
|
201 |
+
encoder_hidden_states, hidden_states = (
|
202 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
203 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
204 |
+
)
|
205 |
+
|
206 |
+
# linear proj
|
207 |
+
hidden_states = attn.to_out[0](hidden_states)
|
208 |
+
# dropout
|
209 |
+
hidden_states = attn.to_out[1](hidden_states)
|
210 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
211 |
+
|
212 |
+
if input_ndim == 4:
|
213 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
214 |
+
if context_input_ndim == 4:
|
215 |
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
216 |
+
|
217 |
+
return hidden_states, encoder_hidden_states
|
218 |
+
|
219 |
+
class AdditFluxSingleAttnProcessor2_0:
|
220 |
+
r"""
|
221 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, layer_name: str, attention_store: AttentionStore,
|
225 |
+
extended_steps: Tuple[int, int] = (0, 30), **kwargs):
|
226 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
227 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
228 |
+
|
229 |
+
self.layer_name = layer_name
|
230 |
+
self.layer_idx = int(layer_name.split(".")[-1])
|
231 |
+
self.attention_store = attention_store
|
232 |
+
|
233 |
+
self.extended_steps = (0, extended_steps) if isinstance(extended_steps, int) else extended_steps
|
234 |
+
|
235 |
+
def __call__(
|
236 |
+
self,
|
237 |
+
attn: Attention,
|
238 |
+
hidden_states: torch.Tensor,
|
239 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
240 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
241 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
242 |
+
step_index: Optional[int] = None,
|
243 |
+
extended_scale: Optional[float] = 1.0,
|
244 |
+
) -> torch.Tensor:
|
245 |
+
input_ndim = hidden_states.ndim
|
246 |
+
|
247 |
+
if input_ndim == 4:
|
248 |
+
batch_size, channel, height, width = hidden_states.shape
|
249 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
250 |
+
|
251 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
252 |
+
|
253 |
+
query = attn.to_q(hidden_states)
|
254 |
+
if encoder_hidden_states is None:
|
255 |
+
encoder_hidden_states = hidden_states
|
256 |
+
|
257 |
+
key = attn.to_k(encoder_hidden_states)
|
258 |
+
value = attn.to_v(encoder_hidden_states)
|
259 |
+
|
260 |
+
inner_dim = key.shape[-1]
|
261 |
+
head_dim = inner_dim // attn.heads
|
262 |
+
|
263 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
264 |
+
|
265 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
266 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
267 |
+
|
268 |
+
if attn.norm_q is not None:
|
269 |
+
query = attn.norm_q(query)
|
270 |
+
if attn.norm_k is not None:
|
271 |
+
key = attn.norm_k(key)
|
272 |
+
|
273 |
+
# Apply RoPE if needed
|
274 |
+
if image_rotary_emb is not None:
|
275 |
+
# YiYi to-do: update uising apply_rotary_emb
|
276 |
+
# from ..embeddings import apply_rotary_emb
|
277 |
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
278 |
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
279 |
+
query, key = apply_rope(query, key, image_rotary_emb)
|
280 |
+
|
281 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
282 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
283 |
+
|
284 |
+
record_attention = self.attention_store.is_record_attention(self.layer_name, step_index)
|
285 |
+
extend_start, extend_end = self.extended_steps
|
286 |
+
extended_attention = extend_start <= step_index <= extend_end
|
287 |
+
|
288 |
+
hidden_states = apply_attention(query, key, value, self.attention_store, attn, self.layer_name, step_index,
|
289 |
+
record_attention, extended_attention, extended_scale)
|
290 |
+
|
291 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
292 |
+
hidden_states = hidden_states.to(query.dtype)
|
293 |
+
|
294 |
+
if input_ndim == 4:
|
295 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
296 |
+
|
297 |
+
return hidden_states
|
addit_attention_store.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the LICENSE file
|
4 |
+
# located at the root directory.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from skimage import filters
|
8 |
+
import cv2
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu
|
11 |
+
import numpy as np
|
12 |
+
from visualization_utils import show_tensors
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
def text_to_tokens(text, tokenizer):
|
16 |
+
return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]]
|
17 |
+
|
18 |
+
def flatten_list(l):
|
19 |
+
return [item for sublist in l for item in sublist]
|
20 |
+
|
21 |
+
def gaussian_blur(heatmap, kernel_size=7, sigma=0):
|
22 |
+
# Shape of heatmap: (H, W)
|
23 |
+
heatmap = heatmap.cpu().numpy()
|
24 |
+
heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma)
|
25 |
+
heatmap = torch.tensor(heatmap)
|
26 |
+
|
27 |
+
return heatmap
|
28 |
+
|
29 |
+
def min_max_norm(x):
|
30 |
+
return (x - x.min()) / (x.max() - x.min())
|
31 |
+
|
32 |
+
class AttentionStore:
|
33 |
+
def __init__(self, prompts, tokenizer,
|
34 |
+
subject_token=None, record_attention_steps=[],
|
35 |
+
is_cache_attn_ratio=False, attn_ratios_steps=[5]):
|
36 |
+
|
37 |
+
self.text2image_store = {}
|
38 |
+
self.image2text_store = {}
|
39 |
+
self.count_per_layer = {}
|
40 |
+
|
41 |
+
self.record_attention_steps = record_attention_steps
|
42 |
+
self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"]
|
43 |
+
|
44 |
+
self.attention_ratios = {}
|
45 |
+
self._is_cache_attn_ratio = is_cache_attn_ratio
|
46 |
+
self.attn_ratios_steps = attn_ratios_steps
|
47 |
+
self.ratio_source = 'text'
|
48 |
+
|
49 |
+
self.max_tokens_to_record = 10
|
50 |
+
|
51 |
+
if isinstance(prompts, str):
|
52 |
+
prompts = [prompts]
|
53 |
+
batch_size = 1
|
54 |
+
else:
|
55 |
+
batch_size = len(prompts)
|
56 |
+
|
57 |
+
tokens_per_prompt = []
|
58 |
+
|
59 |
+
for prompt in prompts:
|
60 |
+
tokens = text_to_tokens(prompt, tokenizer)
|
61 |
+
tokens_per_prompt.append(tokens)
|
62 |
+
|
63 |
+
self.tokens_to_record = []
|
64 |
+
self.token_idxs_to_record = []
|
65 |
+
|
66 |
+
if len(record_attention_steps) > 0:
|
67 |
+
self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]])
|
68 |
+
self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens]
|
69 |
+
self.add_token_idx = self.subject_tokens_idx[-1]
|
70 |
+
|
71 |
+
def is_record_attention(self, layer_name, step_index):
|
72 |
+
is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers)
|
73 |
+
|
74 |
+
record_attention = (step_index in self.record_attention_steps) and (is_correct_layer)
|
75 |
+
|
76 |
+
return record_attention
|
77 |
+
|
78 |
+
def store_attention(self, attention_probs, layer_name, batch_size, num_heads):
|
79 |
+
text_len = 512
|
80 |
+
timesteps = len(self.record_attention_steps)
|
81 |
+
|
82 |
+
# Split batch and heads
|
83 |
+
attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:])
|
84 |
+
|
85 |
+
# Mean over the heads
|
86 |
+
attention_probs = attention_probs.mean(dim=1)
|
87 |
+
|
88 |
+
# Attention: text -> image
|
89 |
+
attention_probs_text2image = attention_probs[:, :text_len, text_len:]
|
90 |
+
attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]]
|
91 |
+
|
92 |
+
# Attention: image -> text
|
93 |
+
attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2)
|
94 |
+
attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]]
|
95 |
+
|
96 |
+
if layer_name not in self.text2image_store:
|
97 |
+
self.text2image_store[layer_name] = [x for x in attention_probs_text2image]
|
98 |
+
self.image2text_store[layer_name] = [x for x in attention_probs_image2text]
|
99 |
+
else:
|
100 |
+
self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)]
|
101 |
+
self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)]
|
102 |
+
|
103 |
+
def is_cache_attn_ratio(self, step_index):
|
104 |
+
return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps)
|
105 |
+
|
106 |
+
def store_attention_ratios(self, attention_probs, step_index, layer_name):
|
107 |
+
layer_prefix = layer_name.split(".")[0]
|
108 |
+
|
109 |
+
if self.ratio_source == 'pixels':
|
110 |
+
extended_attention_probs = attention_probs.mean(dim=0)[512:, :]
|
111 |
+
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu()
|
112 |
+
extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu()
|
113 |
+
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu()
|
114 |
+
token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu()
|
115 |
+
|
116 |
+
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1)
|
117 |
+
elif self.ratio_source == 'text':
|
118 |
+
extended_attention_probs = attention_probs.mean(dim=0)[:512, :]
|
119 |
+
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu()
|
120 |
+
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu()
|
121 |
+
|
122 |
+
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1)
|
123 |
+
|
124 |
+
if step_index not in self.attention_ratios:
|
125 |
+
self.attention_ratios[step_index] = {}
|
126 |
+
|
127 |
+
if layer_prefix not in self.attention_ratios[step_index]:
|
128 |
+
self.attention_ratios[step_index][layer_prefix] = []
|
129 |
+
|
130 |
+
self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios)
|
131 |
+
|
132 |
+
def get_attention_ratios(self, step_indices=None, display_imgs=False):
|
133 |
+
ratios = []
|
134 |
+
|
135 |
+
if step_indices is None:
|
136 |
+
step_indices = list(self.attention_ratios.keys())
|
137 |
+
|
138 |
+
if len(step_indices) == 1:
|
139 |
+
steps = f"Step: {step_indices[0]}"
|
140 |
+
else:
|
141 |
+
steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]"
|
142 |
+
|
143 |
+
layer_prefixes = list(self.attention_ratios[step_indices[0]].keys())
|
144 |
+
scores_per_layer = {}
|
145 |
+
|
146 |
+
for layer_prefix in layer_prefixes:
|
147 |
+
ratios = []
|
148 |
+
|
149 |
+
for step_index in step_indices:
|
150 |
+
if layer_prefix in self.attention_ratios[step_index]:
|
151 |
+
step_ratios = self.attention_ratios[step_index][layer_prefix]
|
152 |
+
step_ratios = torch.stack(step_ratios).mean(dim=0)
|
153 |
+
ratios.append(step_ratios)
|
154 |
+
|
155 |
+
# Mean over the steps
|
156 |
+
ratios = torch.stack(ratios).mean(dim=0)
|
157 |
+
|
158 |
+
if self.ratio_source == 'pixels':
|
159 |
+
source, text, target, token = torch.split(ratios, 64, dim=1)
|
160 |
+
title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}"
|
161 |
+
ratios = min_max_norm(torch.cat([source, text, target], dim=1))
|
162 |
+
token = min_max_norm(token)
|
163 |
+
ratios = torch.cat([ratios, token], dim=1)
|
164 |
+
elif self.ratio_source == 'text':
|
165 |
+
source, target = torch.split(ratios, 64, dim=1)
|
166 |
+
source_sum = source.sum().item()
|
167 |
+
target_sum = target.sum().item()
|
168 |
+
text_sum = 512 - (source_sum + target_sum)
|
169 |
+
|
170 |
+
title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}"
|
171 |
+
ratios = min_max_norm(torch.cat([source, target], dim=1))
|
172 |
+
|
173 |
+
if display_imgs:
|
174 |
+
print(f"Layer: {layer_prefix}")
|
175 |
+
show_tensors([ratios], [title])
|
176 |
+
|
177 |
+
scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum)
|
178 |
+
|
179 |
+
return scores_per_layer
|
180 |
+
|
181 |
+
def plot_attention_ratios(self, step_indices=None):
|
182 |
+
steps = list(self.attention_ratios.keys())
|
183 |
+
score_per_layer = {
|
184 |
+
'transformer_blocks': {},
|
185 |
+
'single_transformer_blocks': {}
|
186 |
+
}
|
187 |
+
|
188 |
+
for i in steps:
|
189 |
+
scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False)
|
190 |
+
|
191 |
+
for layer in self.attention_ratios[i]:
|
192 |
+
source, text, target = scores_per_layer[layer]
|
193 |
+
score_per_layer[layer][i] = (source, text, target)
|
194 |
+
|
195 |
+
for layer_type in score_per_layer:
|
196 |
+
x = list(score_per_layer[layer_type].keys())
|
197 |
+
source_sums = [x[0] for x in score_per_layer[layer_type].values()]
|
198 |
+
text_sums = [x[1] for x in score_per_layer[layer_type].values()]
|
199 |
+
target_sums = [x[2] for x in score_per_layer[layer_type].values()]
|
200 |
+
|
201 |
+
# Calculate the total sums for each stack (source + text + target)
|
202 |
+
total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))]
|
203 |
+
|
204 |
+
# Create stacked bar plots
|
205 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
206 |
+
indices = np.arange(len(x))
|
207 |
+
|
208 |
+
# Plot source at the bottom
|
209 |
+
ax.bar(indices, source_sums, label='Source', color='#6A2C70')
|
210 |
+
|
211 |
+
# Plot text stacked on source
|
212 |
+
ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums)
|
213 |
+
|
214 |
+
# Plot target stacked on text + source
|
215 |
+
target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))]
|
216 |
+
ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom)
|
217 |
+
|
218 |
+
# Annotate bars with percentage values
|
219 |
+
for j, index in enumerate(indices):
|
220 |
+
|
221 |
+
font_size = 12
|
222 |
+
|
223 |
+
# Source percentage
|
224 |
+
source_percentage = 100 * source_sums[j] / total_sums[j]
|
225 |
+
ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%',
|
226 |
+
ha='center', va='center', rotation=90, color='white',
|
227 |
+
fontsize=font_size, fontweight='bold')
|
228 |
+
|
229 |
+
# Text percentage
|
230 |
+
text_percentage = 100 * text_sums[j] / total_sums[j]
|
231 |
+
ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%',
|
232 |
+
ha='center', va='center', rotation=90, color='white',
|
233 |
+
fontsize=font_size, fontweight='bold')
|
234 |
+
|
235 |
+
# Target percentage
|
236 |
+
target_percentage = 100 * target_sums[j] / total_sums[j]
|
237 |
+
ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%',
|
238 |
+
ha='center', va='center', rotation=90, color='white',
|
239 |
+
fontsize=font_size, fontweight='bold')
|
240 |
+
|
241 |
+
|
242 |
+
ax.set_xlabel('Step Index')
|
243 |
+
ax.set_ylabel('Attention Ratio')
|
244 |
+
ax.set_title(f'Attention Ratios for {layer_type}')
|
245 |
+
ax.set_xticks(indices)
|
246 |
+
ax.set_xticklabels(x)
|
247 |
+
|
248 |
+
plt.legend()
|
249 |
+
plt.show()
|
250 |
+
|
251 |
+
def aggregate_attention(self, store, target_layers=None, resolution=None,
|
252 |
+
gaussian_kernel=3, thr_type='otsu', thr_number=0.5):
|
253 |
+
if target_layers is None:
|
254 |
+
store_vals = list(store.values())
|
255 |
+
elif isinstance(target_layers, list):
|
256 |
+
store_vals = [store[x] for x in target_layers]
|
257 |
+
else:
|
258 |
+
raise ValueError("target_layers must be a list of layer names or None.")
|
259 |
+
|
260 |
+
# store vals = List[layers] of Tensor[batch_size, text_tokens, image_tokens]
|
261 |
+
batch_size = len(store_vals[0])
|
262 |
+
|
263 |
+
attention_maps = []
|
264 |
+
attention_masks = []
|
265 |
+
|
266 |
+
for i in range(batch_size):
|
267 |
+
# Average over the layers
|
268 |
+
agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0)
|
269 |
+
|
270 |
+
if resolution is None:
|
271 |
+
size = int(agg_vals.shape[-1] ** 0.5)
|
272 |
+
resolution = (size, size)
|
273 |
+
|
274 |
+
agg_vals = agg_vals.view(agg_vals.shape[0], *resolution)
|
275 |
+
|
276 |
+
if gaussian_kernel > 0:
|
277 |
+
agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype)
|
278 |
+
|
279 |
+
mask_vals = agg_vals.clone()
|
280 |
+
|
281 |
+
for j in range(mask_vals.shape[0]):
|
282 |
+
mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min())
|
283 |
+
np_vals = mask_vals[j].float().cpu().numpy()
|
284 |
+
|
285 |
+
otsu_thr = filters.threshold_otsu(np_vals)
|
286 |
+
li_thr = threshold_li(np_vals, initial_guess=otsu_thr)
|
287 |
+
yen_thr = threshold_yen(np_vals)
|
288 |
+
|
289 |
+
if thr_type == 'otsu':
|
290 |
+
thr = otsu_thr
|
291 |
+
elif thr_type == 'yen':
|
292 |
+
thr = yen_thr
|
293 |
+
elif thr_type == 'li':
|
294 |
+
thr = li_thr
|
295 |
+
elif thr_type == 'number':
|
296 |
+
thr = thr_number
|
297 |
+
elif thr_type == 'multiotsu':
|
298 |
+
thrs = threshold_multiotsu(np_vals, classes=3)
|
299 |
+
|
300 |
+
if thrs[1] > thrs[0] * 3.5:
|
301 |
+
thr = thrs[1]
|
302 |
+
else:
|
303 |
+
thr = thrs[0]
|
304 |
+
|
305 |
+
# Take the closest threshold to otsu_thr
|
306 |
+
# thr = thrs[np.argmin(np.abs(thrs - otsu_thr))]
|
307 |
+
|
308 |
+
# alpha = 0.8
|
309 |
+
# thr = (alpha * thr + (1-alpha) * mask_vals[j].max())
|
310 |
+
|
311 |
+
mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype)
|
312 |
+
|
313 |
+
attention_maps.append(agg_vals)
|
314 |
+
attention_masks.append(mask_vals)
|
315 |
+
|
316 |
+
return attention_maps, attention_masks, self.tokens_to_record
|
addit_blending_utils.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the LICENSE file
|
4 |
+
# located at the root directory.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from skimage import filters
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from scipy.ndimage import maximum_filter, label, find_objects
|
12 |
+
|
13 |
+
def dilate_mask(latents_mask, k, latents_dtype):
|
14 |
+
# Reshape the mask to 2D (64x64)
|
15 |
+
mask_2d = latents_mask.view(64, 64)
|
16 |
+
|
17 |
+
# Create a square kernel for dilation
|
18 |
+
kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype)
|
19 |
+
|
20 |
+
# Add two dimensions to make it compatible with conv2d
|
21 |
+
mask_4d = mask_2d.unsqueeze(0).unsqueeze(0)
|
22 |
+
|
23 |
+
# Perform dilation using conv2d
|
24 |
+
dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k)
|
25 |
+
|
26 |
+
# Threshold the result to get a binary mask
|
27 |
+
dilated_mask = (dilated_mask > 0).to(mask_2d.dtype)
|
28 |
+
|
29 |
+
# Reshape back to the original shape and convert to the desired dtype
|
30 |
+
dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype)
|
31 |
+
|
32 |
+
return dilated_mask
|
33 |
+
|
34 |
+
def clipseg_predict(model, processor, image, text, device):
|
35 |
+
inputs = processor(text=text, images=image, return_tensors="pt")
|
36 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
outputs = model(**inputs)
|
40 |
+
preds = outputs.logits.unsqueeze(1)
|
41 |
+
preds = torch.sigmoid(preds)
|
42 |
+
|
43 |
+
otsu_thr = filters.threshold_otsu(preds.cpu().numpy())
|
44 |
+
subject_mask = (preds > otsu_thr).float()
|
45 |
+
|
46 |
+
return subject_mask
|
47 |
+
|
48 |
+
def grounding_sam_predict(model, processor, sam_predictor, image, text, device):
|
49 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs = model(**inputs)
|
52 |
+
|
53 |
+
results = processor.post_process_grounded_object_detection(
|
54 |
+
outputs,
|
55 |
+
inputs.input_ids,
|
56 |
+
box_threshold=0.4,
|
57 |
+
text_threshold=0.3,
|
58 |
+
target_sizes=[image.size[::-1]]
|
59 |
+
)
|
60 |
+
|
61 |
+
input_boxes = results[0]["boxes"].cpu().numpy()
|
62 |
+
|
63 |
+
if input_boxes.shape[0] == 0:
|
64 |
+
return torch.ones((64, 64), device=device)
|
65 |
+
|
66 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
67 |
+
sam_predictor.set_image(image)
|
68 |
+
masks, scores, logits = sam_predictor.predict(
|
69 |
+
point_coords=None,
|
70 |
+
point_labels=None,
|
71 |
+
box=input_boxes,
|
72 |
+
multimask_output=False,
|
73 |
+
)
|
74 |
+
|
75 |
+
subject_mask = torch.tensor(masks[0], device=device)
|
76 |
+
|
77 |
+
return subject_mask
|
78 |
+
|
79 |
+
def mask_to_box_sam_predict(mask, sam_predictor, image, text, device):
|
80 |
+
H, W = image.size
|
81 |
+
|
82 |
+
# Resize clipseg mask to image size
|
83 |
+
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
|
84 |
+
mask_indices = torch.nonzero(mask)
|
85 |
+
top_left = mask_indices.min(dim=0)[0]
|
86 |
+
bottom_right = mask_indices.max(dim=0)[0]
|
87 |
+
|
88 |
+
# numpy shape [1,4]
|
89 |
+
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
|
90 |
+
|
91 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
92 |
+
sam_predictor.set_image(image)
|
93 |
+
masks, scores, logits = sam_predictor.predict(
|
94 |
+
point_coords=None,
|
95 |
+
point_labels=None,
|
96 |
+
box=input_boxes,
|
97 |
+
multimask_output=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
# subject_mask = torch.tensor(masks[0], device=device)
|
101 |
+
subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
|
102 |
+
|
103 |
+
return subject_mask, input_boxes[0]
|
104 |
+
|
105 |
+
def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device):
|
106 |
+
H, W = (256, 256)
|
107 |
+
|
108 |
+
# Resize clipseg mask to image size
|
109 |
+
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W)
|
110 |
+
mask_input = mask.float().cpu().numpy()
|
111 |
+
|
112 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
113 |
+
sam_predictor.set_image(image)
|
114 |
+
masks, scores, logits = sam_predictor.predict(
|
115 |
+
point_coords=None,
|
116 |
+
point_labels=None,
|
117 |
+
mask_input=mask_input,
|
118 |
+
multimask_output=False,
|
119 |
+
)
|
120 |
+
|
121 |
+
subject_mask = torch.tensor(masks[0], device=device)
|
122 |
+
|
123 |
+
return subject_mask
|
124 |
+
|
125 |
+
def mask_to_points_sam_predict(mask, sam_predictor, image, text, device):
|
126 |
+
H, W = image.size
|
127 |
+
|
128 |
+
# Resize clipseg mask to image size
|
129 |
+
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
|
130 |
+
mask_indices = torch.nonzero(mask)
|
131 |
+
|
132 |
+
# Randomly sample 10 points from the mask
|
133 |
+
n_points = 2
|
134 |
+
point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy()
|
135 |
+
point_labels = torch.ones((n_points,)).float().cpu().numpy()
|
136 |
+
|
137 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
138 |
+
sam_predictor.set_image(image)
|
139 |
+
masks, scores, logits = sam_predictor.predict(
|
140 |
+
point_coords=point_coords,
|
141 |
+
point_labels=point_labels,
|
142 |
+
multimask_output=False,
|
143 |
+
)
|
144 |
+
|
145 |
+
subject_mask = torch.tensor(masks[0], device=device)
|
146 |
+
|
147 |
+
return subject_mask
|
148 |
+
|
149 |
+
def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device):
|
150 |
+
H, W = image.size
|
151 |
+
|
152 |
+
# Resize clipseg mask to image size
|
153 |
+
subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
|
154 |
+
subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
|
155 |
+
|
156 |
+
# Get mask_bbox
|
157 |
+
subject_mask_indices = torch.nonzero(subject_mask)
|
158 |
+
top_left = subject_mask_indices.min(dim=0)[0]
|
159 |
+
bottom_right = subject_mask_indices.max(dim=0)[0]
|
160 |
+
box_width = bottom_right[1] - top_left[1]
|
161 |
+
box_height = bottom_right[0] - top_left[0]
|
162 |
+
|
163 |
+
# Define the number of points and minimum distance between points
|
164 |
+
n_points = 3
|
165 |
+
max_thr = 0.35
|
166 |
+
max_attention = torch.max(subject_attention)
|
167 |
+
min_distance = max(box_width, box_height) // (n_points + 1) # Adjust this value to control spread
|
168 |
+
# min_distance = max(min_distance, 75)
|
169 |
+
|
170 |
+
# Initialize list to store selected points
|
171 |
+
selected_points = []
|
172 |
+
|
173 |
+
# Create a copy of the attention map
|
174 |
+
remaining_attention = subject_attention.clone()
|
175 |
+
|
176 |
+
for _ in range(n_points):
|
177 |
+
if remaining_attention.max() < max_thr * max_attention:
|
178 |
+
break
|
179 |
+
|
180 |
+
# Find the highest attention point
|
181 |
+
point = torch.argmax(remaining_attention)
|
182 |
+
y, x = torch.unravel_index(point, remaining_attention.shape)
|
183 |
+
y, x = y.item(), x.item()
|
184 |
+
|
185 |
+
# Add the point to our list
|
186 |
+
selected_points.append((x, y))
|
187 |
+
|
188 |
+
# Zero out the area around the selected point
|
189 |
+
y_min = max(0, y - min_distance)
|
190 |
+
y_max = min(H, y + min_distance + 1)
|
191 |
+
x_min = max(0, x - min_distance)
|
192 |
+
x_max = min(W, x + min_distance + 1)
|
193 |
+
remaining_attention[y_min:y_max, x_min:x_max] = 0
|
194 |
+
|
195 |
+
# Convert selected points to numpy array
|
196 |
+
point_coords = np.array(selected_points)
|
197 |
+
point_labels = np.ones(point_coords.shape[0], dtype=int)
|
198 |
+
|
199 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
200 |
+
sam_predictor.set_image(image)
|
201 |
+
masks, scores, logits = sam_predictor.predict(
|
202 |
+
point_coords=point_coords,
|
203 |
+
point_labels=point_labels,
|
204 |
+
multimask_output=False,
|
205 |
+
)
|
206 |
+
|
207 |
+
subject_mask = torch.tensor(masks[0], device=device)
|
208 |
+
|
209 |
+
return subject_mask, point_coords
|
210 |
+
|
211 |
+
def sam_refine_step(mask, sam_predictor, image, device):
|
212 |
+
mask_indices = torch.nonzero(mask)
|
213 |
+
top_left = mask_indices.min(dim=0)[0]
|
214 |
+
bottom_right = mask_indices.max(dim=0)[0]
|
215 |
+
|
216 |
+
# numpy shape [1,4]
|
217 |
+
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
|
218 |
+
|
219 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
220 |
+
sam_predictor.set_image(image)
|
221 |
+
masks, scores, logits = sam_predictor.predict(
|
222 |
+
point_coords=None,
|
223 |
+
point_labels=None,
|
224 |
+
box=input_boxes,
|
225 |
+
multimask_output=True,
|
226 |
+
)
|
227 |
+
|
228 |
+
# subject_mask = torch.tensor(masks[0], device=device)
|
229 |
+
subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
|
230 |
+
|
231 |
+
return subject_mask, input_boxes[0]
|
232 |
+
|
addit_flux_pipeline.py
ADDED
@@ -0,0 +1,1389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
16 |
+
#
|
17 |
+
# This work is licensed under the LICENSE file
|
18 |
+
# located at the root directory.
|
19 |
+
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
from PIL import Image
|
24 |
+
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
|
25 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
26 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
|
30 |
+
import torch.fft
|
31 |
+
import torch.nn.functional as F
|
32 |
+
|
33 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
34 |
+
from addit_attention_processors import AdditFluxAttnProcessor2_0, AdditFluxSingleAttnProcessor2_0
|
35 |
+
from addit_attention_store import AttentionStore
|
36 |
+
|
37 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
38 |
+
from skimage import filters
|
39 |
+
from visualization_utils import show_image_and_heatmap, show_images, draw_points_on_pil_image, draw_bboxes_on_image
|
40 |
+
from addit_blending_utils import clipseg_predict, grounding_sam_predict, mask_to_box_sam_predict, \
|
41 |
+
mask_to_mask_sam_predict, attention_to_points_sam_predict
|
42 |
+
|
43 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
44 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
45 |
+
|
46 |
+
from scipy.optimize import brentq
|
47 |
+
from scipy.optimize import root_scalar
|
48 |
+
|
49 |
+
def register_my_attention_processors(transformer, attention_store, extended_steps_multi, extended_steps_single):
|
50 |
+
attn_procs = {}
|
51 |
+
|
52 |
+
for i, (name, processor) in enumerate(transformer.attn_processors.items()):
|
53 |
+
layer_name = ".".join(name.split(".")[:2])
|
54 |
+
|
55 |
+
if layer_name.startswith("transformer_blocks"):
|
56 |
+
attn_procs[name] = AdditFluxAttnProcessor2_0(layer_name=layer_name,
|
57 |
+
attention_store=attention_store,
|
58 |
+
extended_steps=extended_steps_multi)
|
59 |
+
elif layer_name.startswith("single_transformer_blocks"):
|
60 |
+
attn_procs[name] = AdditFluxSingleAttnProcessor2_0(layer_name=layer_name,
|
61 |
+
attention_store=attention_store,
|
62 |
+
extended_steps=extended_steps_single)
|
63 |
+
|
64 |
+
transformer.set_attn_processor(attn_procs)
|
65 |
+
|
66 |
+
def register_regular_attention_processors(transformer):
|
67 |
+
attn_procs = {}
|
68 |
+
|
69 |
+
for i, (name, processor) in enumerate(transformer.attn_processors.items()):
|
70 |
+
layer_name = ".".join(name.split(".")[:2])
|
71 |
+
|
72 |
+
if layer_name.startswith("transformer_blocks"):
|
73 |
+
attn_procs[name] = FluxAttnProcessor2_0()
|
74 |
+
elif layer_name.startswith("single_transformer_blocks"):
|
75 |
+
attn_procs[name] = FluxSingleAttnProcessor2_0()
|
76 |
+
|
77 |
+
transformer.set_attn_processor(attn_procs)
|
78 |
+
|
79 |
+
def img2img_retrieve_latents(
|
80 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
81 |
+
):
|
82 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
83 |
+
return encoder_output.latent_dist.sample(generator)
|
84 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
85 |
+
return encoder_output.latent_dist.mode()
|
86 |
+
elif hasattr(encoder_output, "latents"):
|
87 |
+
return encoder_output.latents
|
88 |
+
else:
|
89 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
90 |
+
|
91 |
+
class AdditFluxPipeline(FluxPipeline):
|
92 |
+
def prepare_latents(
|
93 |
+
self,
|
94 |
+
batch_size,
|
95 |
+
num_channels_latents,
|
96 |
+
height,
|
97 |
+
width,
|
98 |
+
dtype,
|
99 |
+
device,
|
100 |
+
generator,
|
101 |
+
latents=None,
|
102 |
+
):
|
103 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
104 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
105 |
+
|
106 |
+
shape = (batch_size, num_channels_latents, height, width)
|
107 |
+
|
108 |
+
if latents is not None:
|
109 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
110 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
111 |
+
|
112 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
113 |
+
raise ValueError(
|
114 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
115 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
116 |
+
)
|
117 |
+
|
118 |
+
if isinstance(generator, list):
|
119 |
+
latents = torch.empty(shape, device=device, dtype=dtype)
|
120 |
+
|
121 |
+
latents_list = [randn_tensor(shape, generator=g, device=device, dtype=dtype) for g in generator]
|
122 |
+
|
123 |
+
for i, l_i in enumerate(latents_list):
|
124 |
+
latents[i] = l_i[i]
|
125 |
+
else:
|
126 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
127 |
+
|
128 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
129 |
+
|
130 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
131 |
+
|
132 |
+
return latents, latent_image_ids
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def __call__(
|
136 |
+
self,
|
137 |
+
prompt: Union[str, List[str]] = None,
|
138 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
139 |
+
height: Optional[int] = None,
|
140 |
+
width: Optional[int] = None,
|
141 |
+
num_inference_steps: int = 28,
|
142 |
+
timesteps: List[int] = None,
|
143 |
+
guidance_scale: Union[float, List[float]] = 7.0,
|
144 |
+
num_images_per_prompt: Optional[int] = 1,
|
145 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
146 |
+
latents: Optional[torch.FloatTensor] = None,
|
147 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
148 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
149 |
+
output_type: Optional[str] = "pil",
|
150 |
+
return_dict: bool = True,
|
151 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
152 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
153 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
154 |
+
max_sequence_length: int = 512,
|
155 |
+
|
156 |
+
seed: Optional[Union[int, List[int]]] = None,
|
157 |
+
same_latent_for_all_prompts: bool = False,
|
158 |
+
|
159 |
+
# Extended Attention
|
160 |
+
extended_steps_multi: Optional[int] = -1,
|
161 |
+
extended_steps_single: Optional[int] = -1,
|
162 |
+
extended_scale: Optional[Union[float, str]] = 1.0,
|
163 |
+
|
164 |
+
# Structure Transfer
|
165 |
+
source_latents: Optional[torch.FloatTensor] = None,
|
166 |
+
structure_transfer_step: int = 5,
|
167 |
+
|
168 |
+
# Latent Blending
|
169 |
+
subject_token: Optional[str] = None,
|
170 |
+
localization_model: Optional[str] = "attention_points_sam",
|
171 |
+
blend_steps: List[int] = [],
|
172 |
+
show_attention: bool = False,
|
173 |
+
|
174 |
+
# Real Image Source
|
175 |
+
is_img_src: bool = False,
|
176 |
+
use_offset: bool = False,
|
177 |
+
img_src_latents: Optional[List[torch.FloatTensor]] = None,
|
178 |
+
):
|
179 |
+
r"""
|
180 |
+
Function invoked when calling the pipeline for generation.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
prompt (`str` or `List[str]`, *optional*):
|
184 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
185 |
+
instead.
|
186 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
187 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
188 |
+
will be used instead
|
189 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
190 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
191 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
192 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
193 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
194 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
195 |
+
expense of slower inference.
|
196 |
+
timesteps (`List[int]`, *optional*):
|
197 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
198 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
199 |
+
passed will be used. Must be in descending order.
|
200 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
201 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
202 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
203 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
204 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
205 |
+
usually at the expense of lower image quality.
|
206 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
207 |
+
The number of images to generate per prompt.
|
208 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
209 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
210 |
+
to make generation deterministic.
|
211 |
+
latents (`torch.FloatTensor`, *optional*):
|
212 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
213 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
214 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
215 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
216 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
217 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
218 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
219 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
220 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
221 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
222 |
+
The output format of the generate image. Choose between
|
223 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
224 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
225 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
226 |
+
joint_attention_kwargs (`dict`, *optional*):
|
227 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
228 |
+
`self.processor` in
|
229 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
230 |
+
callback_on_step_end (`Callable`, *optional*):
|
231 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
232 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
233 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
234 |
+
`callback_on_step_end_tensor_inputs`.
|
235 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
236 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
237 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
238 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
239 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
240 |
+
|
241 |
+
Examples:
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
245 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
246 |
+
images.
|
247 |
+
"""
|
248 |
+
|
249 |
+
device = self._execution_device
|
250 |
+
|
251 |
+
# Blend Steps
|
252 |
+
blend_models = {}
|
253 |
+
if len(blend_steps) > 0:
|
254 |
+
if localization_model == "clipseg":
|
255 |
+
blend_models["clipseg_processor"] = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
256 |
+
blend_models["clipseg_model"] = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
257 |
+
elif localization_model == "grounding_sam":
|
258 |
+
grounding_dino_model_id = "IDEA-Research/grounding-dino-base"
|
259 |
+
blend_models["grounding_processor"] = AutoProcessor.from_pretrained(grounding_dino_model_id)
|
260 |
+
blend_models["grounding_model"] = AutoModelForZeroShotObjectDetection.from_pretrained(grounding_dino_model_id).to(device)
|
261 |
+
blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
262 |
+
elif localization_model == "clipseg_sam":
|
263 |
+
blend_models["clipseg_processor"] = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
264 |
+
blend_models["clipseg_model"] = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
265 |
+
blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
266 |
+
elif localization_model == "attention":
|
267 |
+
pass
|
268 |
+
elif localization_model in ["attention_box_sam", "attention_mask_sam", "attention_points_sam"]:
|
269 |
+
blend_models["sam_predictor"] = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
270 |
+
|
271 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
272 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
273 |
+
|
274 |
+
# 1. Check inputs. Raise error if not correct
|
275 |
+
self.check_inputs(
|
276 |
+
prompt,
|
277 |
+
prompt_2,
|
278 |
+
height,
|
279 |
+
width,
|
280 |
+
prompt_embeds=prompt_embeds,
|
281 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
282 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
283 |
+
max_sequence_length=max_sequence_length,
|
284 |
+
)
|
285 |
+
|
286 |
+
self._guidance_scale = guidance_scale
|
287 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
288 |
+
self._interrupt = False
|
289 |
+
|
290 |
+
# 2. Define call parameters
|
291 |
+
if prompt is not None and isinstance(prompt, str):
|
292 |
+
batch_size = 1
|
293 |
+
elif prompt is not None and isinstance(prompt, list):
|
294 |
+
batch_size = len(prompt)
|
295 |
+
else:
|
296 |
+
batch_size = prompt_embeds.shape[0]
|
297 |
+
|
298 |
+
device = self._execution_device
|
299 |
+
|
300 |
+
lora_scale = (
|
301 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
302 |
+
)
|
303 |
+
(
|
304 |
+
prompt_embeds,
|
305 |
+
pooled_prompt_embeds,
|
306 |
+
text_ids,
|
307 |
+
) = self.encode_prompt(
|
308 |
+
prompt=prompt,
|
309 |
+
prompt_2=prompt_2,
|
310 |
+
prompt_embeds=prompt_embeds,
|
311 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
312 |
+
device=device,
|
313 |
+
num_images_per_prompt=num_images_per_prompt,
|
314 |
+
max_sequence_length=max_sequence_length,
|
315 |
+
lora_scale=lora_scale,
|
316 |
+
)
|
317 |
+
|
318 |
+
# 4. Prepare latent variables
|
319 |
+
if (generator is None) and seed is not None:
|
320 |
+
if isinstance(seed, int):
|
321 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
322 |
+
else:
|
323 |
+
assert len(seed) == batch_size, "The number of seeds must match the batch size"
|
324 |
+
generator = [torch.Generator(device=device).manual_seed(s) for s in seed]
|
325 |
+
|
326 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
327 |
+
latents, latent_image_ids = self.prepare_latents(
|
328 |
+
batch_size * num_images_per_prompt,
|
329 |
+
num_channels_latents,
|
330 |
+
height,
|
331 |
+
width,
|
332 |
+
prompt_embeds.dtype,
|
333 |
+
device,
|
334 |
+
generator,
|
335 |
+
latents,
|
336 |
+
)
|
337 |
+
|
338 |
+
if same_latent_for_all_prompts:
|
339 |
+
latents = latents[:1].repeat(batch_size * num_images_per_prompt, 1, 1)
|
340 |
+
|
341 |
+
noise = latents.clone()
|
342 |
+
|
343 |
+
attention_store_kwargs = {}
|
344 |
+
|
345 |
+
if extended_scale == "auto":
|
346 |
+
is_auto_extend_scale = True
|
347 |
+
extended_scale = 1.05
|
348 |
+
attention_store_kwargs["is_cache_attn_ratio"] = True
|
349 |
+
auto_extended_step = 5
|
350 |
+
target_auto_ratio = 1.05
|
351 |
+
else:
|
352 |
+
is_auto_extend_scale = False
|
353 |
+
|
354 |
+
if len(blend_steps) > 0:
|
355 |
+
attn_steps = range(blend_steps[0] - 2, blend_steps[0] + 1)
|
356 |
+
attention_store_kwargs["record_attention_steps"] = attn_steps
|
357 |
+
|
358 |
+
self.attention_store = AttentionStore(prompts=prompt, tokenizer=self.tokenizer_2, subject_token=subject_token, **attention_store_kwargs)
|
359 |
+
register_my_attention_processors(self.transformer, self.attention_store, extended_steps_multi, extended_steps_single)
|
360 |
+
|
361 |
+
# 5. Prepare timesteps
|
362 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
363 |
+
image_seq_len = latents.shape[1]
|
364 |
+
mu = calculate_shift(
|
365 |
+
image_seq_len,
|
366 |
+
self.scheduler.config.base_image_seq_len,
|
367 |
+
self.scheduler.config.max_image_seq_len,
|
368 |
+
self.scheduler.config.base_shift,
|
369 |
+
self.scheduler.config.max_shift,
|
370 |
+
)
|
371 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
372 |
+
self.scheduler,
|
373 |
+
num_inference_steps,
|
374 |
+
device,
|
375 |
+
timesteps,
|
376 |
+
sigmas,
|
377 |
+
mu=mu,
|
378 |
+
)
|
379 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
380 |
+
self._num_timesteps = len(timesteps)
|
381 |
+
|
382 |
+
# handle guidance
|
383 |
+
if self.transformer.config.guidance_embeds:
|
384 |
+
if isinstance(guidance_scale, float):
|
385 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
386 |
+
guidance = guidance.expand(latents.shape[0])
|
387 |
+
elif isinstance(guidance_scale, list):
|
388 |
+
assert len(guidance_scale) == latents.shape[0], "The number of guidance scales must match the batch size"
|
389 |
+
guidance = torch.tensor(guidance_scale, device=device, dtype=torch.float32)
|
390 |
+
else:
|
391 |
+
guidance = None
|
392 |
+
|
393 |
+
if is_img_src and img_src_latents is None:
|
394 |
+
assert source_latents is not None, "source_latents must be provided when is_img_src is True"
|
395 |
+
|
396 |
+
rand_noise = noise[0].clone()
|
397 |
+
img_src_latents = []
|
398 |
+
|
399 |
+
for i in range(timesteps.shape[0]):
|
400 |
+
sigma = self.scheduler.sigmas[i]
|
401 |
+
img_src_latents.append((1.0 - sigma) * source_latents[0] + sigma * rand_noise)
|
402 |
+
|
403 |
+
# 6. Denoising loop
|
404 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
405 |
+
for i, t in enumerate(timesteps):
|
406 |
+
if self.interrupt:
|
407 |
+
continue
|
408 |
+
|
409 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
410 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
411 |
+
|
412 |
+
# For denoising from source image
|
413 |
+
if is_img_src:
|
414 |
+
latents[0] = img_src_latents[i]
|
415 |
+
|
416 |
+
# For Structure Transfer
|
417 |
+
if (source_latents is not None) and i == structure_transfer_step:
|
418 |
+
sigma = self.scheduler.sigmas[i]
|
419 |
+
latents[1] = (1.0 - sigma) * source_latents[0] + sigma * noise[1]
|
420 |
+
|
421 |
+
if is_auto_extend_scale and i == auto_extended_step:
|
422 |
+
def f(gamma):
|
423 |
+
self.attention_store.attention_ratios[i] = {}
|
424 |
+
noise_pred = self.transformer(
|
425 |
+
hidden_states=latents,
|
426 |
+
timestep=timestep / 1000,
|
427 |
+
guidance=guidance,
|
428 |
+
pooled_projections=pooled_prompt_embeds,
|
429 |
+
encoder_hidden_states=prompt_embeds,
|
430 |
+
txt_ids=text_ids,
|
431 |
+
img_ids=latent_image_ids,
|
432 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
433 |
+
return_dict=False,
|
434 |
+
proccesor_kwargs={"step_index": i, "extended_scale": gamma},
|
435 |
+
)[0]
|
436 |
+
|
437 |
+
scores_per_layer = self.attention_store.get_attention_ratios(step_indices=[i], display_imgs=False)
|
438 |
+
source_sum, text_sum, target_sum = scores_per_layer['transformer_blocks']
|
439 |
+
|
440 |
+
# We want to find the gamma that makes the ratio equal to K
|
441 |
+
ratio = (target_sum / source_sum)
|
442 |
+
return (ratio - target_auto_ratio)
|
443 |
+
|
444 |
+
gamma_sol = brentq(f, 1.0, 1.2, xtol=0.01)
|
445 |
+
|
446 |
+
print('Chosen gamma:', gamma_sol)
|
447 |
+
extended_scale = gamma_sol
|
448 |
+
else:
|
449 |
+
noise_pred = self.transformer(
|
450 |
+
hidden_states=latents,
|
451 |
+
timestep=timestep / 1000,
|
452 |
+
guidance=guidance,
|
453 |
+
pooled_projections=pooled_prompt_embeds,
|
454 |
+
encoder_hidden_states=prompt_embeds,
|
455 |
+
txt_ids=text_ids,
|
456 |
+
img_ids=latent_image_ids,
|
457 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
458 |
+
return_dict=False,
|
459 |
+
proccesor_kwargs={"step_index": i, "extended_scale": extended_scale},
|
460 |
+
)[0]
|
461 |
+
|
462 |
+
# compute the previous noisy sample x_t -> x_t-1
|
463 |
+
latents_dtype = latents.dtype
|
464 |
+
latents, x0 = self.scheduler.step(noise_pred, t, latents, return_dict=False, step_index=i)
|
465 |
+
|
466 |
+
if use_offset and is_img_src and (i+1 < len(img_src_latents)):
|
467 |
+
next_latent = img_src_latents[i+1]
|
468 |
+
offset = (next_latent - latents[0])
|
469 |
+
latents[1] = latents[1] + offset
|
470 |
+
|
471 |
+
# blend latents
|
472 |
+
if i in blend_steps and (subject_token is not None) and (localization_model is not None):
|
473 |
+
x0 = self._unpack_latents(x0, height, width, self.vae_scale_factor)
|
474 |
+
x0 = (x0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
475 |
+
images = self.vae.decode(x0, return_dict=False)[0]
|
476 |
+
images = self.image_processor.postprocess(images, output_type="pil")
|
477 |
+
|
478 |
+
self.do_step_blend(images, latents, subject_token, localization_model, show_attention, i, blend_models)
|
479 |
+
|
480 |
+
if latents.dtype != latents_dtype:
|
481 |
+
if torch.backends.mps.is_available():
|
482 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
483 |
+
latents = latents.to(latents_dtype)
|
484 |
+
|
485 |
+
if callback_on_step_end is not None:
|
486 |
+
callback_kwargs = {}
|
487 |
+
for k in callback_on_step_end_tensor_inputs:
|
488 |
+
callback_kwargs[k] = locals()[k]
|
489 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
490 |
+
|
491 |
+
latents = callback_outputs.pop("latents", latents)
|
492 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
493 |
+
|
494 |
+
# call the callback, if provided
|
495 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
496 |
+
progress_bar.update()
|
497 |
+
|
498 |
+
# if XLA_AVAILABLE:
|
499 |
+
# xm.mark_step()
|
500 |
+
|
501 |
+
if output_type == "latent":
|
502 |
+
image = latents
|
503 |
+
elif output_type == "both":
|
504 |
+
return_latents = latents
|
505 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
506 |
+
|
507 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
508 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
509 |
+
image = self.image_processor.postprocess(image, output_type="pil")
|
510 |
+
|
511 |
+
return (image, return_latents)
|
512 |
+
else:
|
513 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
514 |
+
|
515 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
516 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
517 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
518 |
+
|
519 |
+
# Offload all models
|
520 |
+
self.maybe_free_model_hooks()
|
521 |
+
|
522 |
+
if not return_dict:
|
523 |
+
return (image,)
|
524 |
+
|
525 |
+
return FluxPipelineOutput(images=image)
|
526 |
+
|
527 |
+
def do_step_blend(self, images, latents, subject_token, localization_model,
|
528 |
+
show_attention, i, blend_models):
|
529 |
+
|
530 |
+
device = latents.device
|
531 |
+
latents_dtype = latents.dtype
|
532 |
+
|
533 |
+
clipseg_processor = blend_models.get("clipseg_processor", None)
|
534 |
+
clipseg_model = blend_models.get("clipseg_model", None)
|
535 |
+
grounding_processor = blend_models.get("grounding_processor", None)
|
536 |
+
grounding_model = blend_models.get("grounding_model", None)
|
537 |
+
sam_predictor = blend_models.get("sam_predictor", None)
|
538 |
+
|
539 |
+
image_to_display = []
|
540 |
+
titles_to_display = []
|
541 |
+
|
542 |
+
if show_attention:
|
543 |
+
image_to_display += [images[0], images[1]]
|
544 |
+
titles_to_display += ["Source X0", "Target X0"]
|
545 |
+
|
546 |
+
if localization_model == "clipseg":
|
547 |
+
subject_mask = clipseg_predict(clipseg_model, clipseg_processor, [images[-1]], f"A photo of {subject_token}", device)
|
548 |
+
elif localization_model == "grounding_sam":
|
549 |
+
subject_mask = grounding_sam_predict(grounding_model, grounding_processor, sam_predictor, images[-1], f"A {subject_token}.", device)
|
550 |
+
elif localization_model == "clipseg_sam":
|
551 |
+
subject_mask = clipseg_predict(clipseg_model, clipseg_processor, [images[-1]], f"A photo of {subject_token}", device)
|
552 |
+
subject_mask = mask_to_box_sam_predict(subject_mask, sam_predictor, images[-1], None, device)
|
553 |
+
elif localization_model == "attention":
|
554 |
+
store = self.attention_store.image2text_store
|
555 |
+
attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
|
556 |
+
|
557 |
+
subject_mask = attention_masks[0][-1].to(device)
|
558 |
+
subject_attention = attention_maps[0][-1].to(device)
|
559 |
+
|
560 |
+
if show_attention:
|
561 |
+
attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
|
562 |
+
attention_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=512)
|
563 |
+
|
564 |
+
image_to_display += [attentioned_image, attention_masked_image]
|
565 |
+
titles_to_display += ["Attention", "Attention Mask"]
|
566 |
+
|
567 |
+
elif localization_model == "attention_box_sam":
|
568 |
+
store = self.attention_store.image2text_store
|
569 |
+
attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
|
570 |
+
|
571 |
+
attention_mask = attention_masks[0][-1].to(device)
|
572 |
+
subject_attention = attention_maps[0][-1].to(device)
|
573 |
+
|
574 |
+
subject_mask, bbox = mask_to_box_sam_predict(attention_mask, sam_predictor, images[-1], None, device)
|
575 |
+
|
576 |
+
if show_attention:
|
577 |
+
attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
|
578 |
+
attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
|
579 |
+
|
580 |
+
sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
|
581 |
+
sam_masked_image = draw_bboxes_on_image(sam_masked_image, [bbox.tolist()], color="green", thickness=5)
|
582 |
+
|
583 |
+
image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
|
584 |
+
titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
|
585 |
+
|
586 |
+
elif localization_model == "attention_mask_sam":
|
587 |
+
store = self.attention_store.image2text_store
|
588 |
+
attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
|
589 |
+
|
590 |
+
attention_mask = attention_masks[0][-1].to(device)
|
591 |
+
subject_attention = attention_maps[0][-1].to(device)
|
592 |
+
|
593 |
+
subject_mask = mask_to_mask_sam_predict(attention_mask, sam_predictor, images[-1], None, device)
|
594 |
+
|
595 |
+
if show_attention:
|
596 |
+
print('Attention:')
|
597 |
+
attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
|
598 |
+
attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
|
599 |
+
sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
|
600 |
+
|
601 |
+
image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
|
602 |
+
titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
|
603 |
+
|
604 |
+
elif localization_model == "attention_points_sam":
|
605 |
+
store = self.attention_store.image2text_store
|
606 |
+
attention_maps, attention_masks, tokens = self.attention_store.aggregate_attention(store, target_layers=None, gaussian_kernel=3)
|
607 |
+
|
608 |
+
attention_mask = attention_masks[0][-1].to(device)
|
609 |
+
subject_attention = attention_maps[0][-1].to(device)
|
610 |
+
|
611 |
+
subject_mask, point_coords = attention_to_points_sam_predict(subject_attention, attention_mask, sam_predictor, images[1], None, device)
|
612 |
+
|
613 |
+
if show_attention:
|
614 |
+
print('Attention:')
|
615 |
+
attentioned_image = show_image_and_heatmap(subject_attention.float(), images[1], relevnace_res=512)
|
616 |
+
attention_masked_image = show_image_and_heatmap(attention_mask.float(), images[1], relevnace_res=512)
|
617 |
+
|
618 |
+
sam_masked_image = show_image_and_heatmap(subject_mask.float(), images[1], relevnace_res=1024)
|
619 |
+
sam_masked_image = draw_points_on_pil_image(sam_masked_image, point_coords, point_color="green", radius=10)
|
620 |
+
|
621 |
+
image_to_display += [attentioned_image, attention_masked_image, sam_masked_image]
|
622 |
+
titles_to_display += ["Attention", "Attention Mask", "SAM Mask"]
|
623 |
+
|
624 |
+
if show_attention:
|
625 |
+
show_images(image_to_display, titles_to_display, size=512, save_path="attn_vis.png")
|
626 |
+
|
627 |
+
# Resize the mask to latents size
|
628 |
+
latents_mask = torch.nn.functional.interpolate(subject_mask.view(1,1,subject_mask.shape[-2],subject_mask.shape[-1]), size=64, mode='bilinear').view(4096, 1).to(latents_dtype)
|
629 |
+
latents_mask[latents_mask > 0.01] = 1
|
630 |
+
|
631 |
+
latents[1] = latents[1] * latents_mask + latents[0] * (1 - latents_mask)
|
632 |
+
|
633 |
+
############# Image to Image Methods #############
|
634 |
+
def img2img_encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
635 |
+
if isinstance(generator, list):
|
636 |
+
image_latents = [
|
637 |
+
img2img_retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
638 |
+
for i in range(image.shape[0])
|
639 |
+
]
|
640 |
+
image_latents = torch.cat(image_latents, dim=0)
|
641 |
+
else:
|
642 |
+
image_latents = img2img_retrieve_latents(self.vae.encode(image), generator=generator)
|
643 |
+
|
644 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
645 |
+
|
646 |
+
return image_latents
|
647 |
+
|
648 |
+
def img2img_prepare_latents(
|
649 |
+
self,
|
650 |
+
image,
|
651 |
+
timestep,
|
652 |
+
batch_size,
|
653 |
+
num_channels_latents,
|
654 |
+
height,
|
655 |
+
width,
|
656 |
+
dtype,
|
657 |
+
device,
|
658 |
+
generator,
|
659 |
+
latents=None,
|
660 |
+
):
|
661 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
662 |
+
raise ValueError(
|
663 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
664 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
665 |
+
)
|
666 |
+
|
667 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
668 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
669 |
+
|
670 |
+
shape = (batch_size, num_channels_latents, height, width)
|
671 |
+
latent_image_ids = self.img2img_prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
672 |
+
|
673 |
+
if latents is not None:
|
674 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
675 |
+
|
676 |
+
image = image.to(device=device, dtype=dtype)
|
677 |
+
image_latents = self.img2img_encode_vae_image(image=image, generator=generator)
|
678 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
679 |
+
# expand init_latents for batch_size
|
680 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
681 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
682 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
683 |
+
raise ValueError(
|
684 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
685 |
+
)
|
686 |
+
else:
|
687 |
+
image_latents = torch.cat([image_latents], dim=0)
|
688 |
+
|
689 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
690 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
691 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
692 |
+
return latents, latent_image_ids
|
693 |
+
|
694 |
+
def img2img_check_inputs(
|
695 |
+
self,
|
696 |
+
prompt,
|
697 |
+
prompt_2,
|
698 |
+
strength,
|
699 |
+
height,
|
700 |
+
width,
|
701 |
+
prompt_embeds=None,
|
702 |
+
pooled_prompt_embeds=None,
|
703 |
+
callback_on_step_end_tensor_inputs=None,
|
704 |
+
max_sequence_length=None,
|
705 |
+
):
|
706 |
+
if strength < 0 or strength > 1:
|
707 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
708 |
+
|
709 |
+
if height % 8 != 0 or width % 8 != 0:
|
710 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
711 |
+
|
712 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
713 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
714 |
+
):
|
715 |
+
raise ValueError(
|
716 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
717 |
+
)
|
718 |
+
|
719 |
+
if prompt is not None and prompt_embeds is not None:
|
720 |
+
raise ValueError(
|
721 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
722 |
+
" only forward one of the two."
|
723 |
+
)
|
724 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
725 |
+
raise ValueError(
|
726 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
727 |
+
" only forward one of the two."
|
728 |
+
)
|
729 |
+
elif prompt is None and prompt_embeds is None:
|
730 |
+
raise ValueError(
|
731 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
732 |
+
)
|
733 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
734 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
735 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
736 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
737 |
+
|
738 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
739 |
+
raise ValueError(
|
740 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
741 |
+
)
|
742 |
+
|
743 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
744 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
745 |
+
|
746 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
747 |
+
def img2img_get_timesteps(self, num_inference_steps, strength, device):
|
748 |
+
# get the original timestep using init_timestep
|
749 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
750 |
+
|
751 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
752 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
753 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
754 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
755 |
+
|
756 |
+
return timesteps, num_inference_steps - t_start
|
757 |
+
|
758 |
+
@staticmethod
|
759 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
760 |
+
def img2img_prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
761 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
762 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
763 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
764 |
+
|
765 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
766 |
+
|
767 |
+
latent_image_ids = latent_image_ids.reshape(
|
768 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
769 |
+
)
|
770 |
+
|
771 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
772 |
+
|
773 |
+
@torch.no_grad()
|
774 |
+
def call_img2img(
|
775 |
+
self,
|
776 |
+
prompt: Union[str, List[str]] = None,
|
777 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
778 |
+
image: PipelineImageInput = None,
|
779 |
+
height: Optional[int] = None,
|
780 |
+
width: Optional[int] = None,
|
781 |
+
strength: float = 0.6,
|
782 |
+
num_inference_steps: int = 28,
|
783 |
+
timesteps: List[int] = None,
|
784 |
+
guidance_scale: float = 7.0,
|
785 |
+
num_images_per_prompt: Optional[int] = 1,
|
786 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
787 |
+
latents: Optional[torch.FloatTensor] = None,
|
788 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
789 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
790 |
+
output_type: Optional[str] = "pil",
|
791 |
+
return_dict: bool = True,
|
792 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
793 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
794 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
795 |
+
max_sequence_length: int = 512,
|
796 |
+
):
|
797 |
+
r"""
|
798 |
+
Function invoked when calling the pipeline for generation.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
prompt (`str` or `List[str]`, *optional*):
|
802 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
803 |
+
instead.
|
804 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
805 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
806 |
+
will be used instead
|
807 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
808 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
809 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
810 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
811 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
812 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
813 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
814 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
815 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
816 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
817 |
+
strength (`float`, *optional*, defaults to 1.0):
|
818 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
819 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
820 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
821 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
822 |
+
essentially ignores `image`.
|
823 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
824 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
825 |
+
expense of slower inference.
|
826 |
+
timesteps (`List[int]`, *optional*):
|
827 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
828 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
829 |
+
passed will be used. Must be in descending order.
|
830 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
831 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
832 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
833 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
834 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
835 |
+
usually at the expense of lower image quality.
|
836 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
837 |
+
The number of images to generate per prompt.
|
838 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
839 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
840 |
+
to make generation deterministic.
|
841 |
+
latents (`torch.FloatTensor`, *optional*):
|
842 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
843 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
844 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
845 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
846 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
847 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
848 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
849 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
850 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
851 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
852 |
+
The output format of the generate image. Choose between
|
853 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
854 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
855 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
856 |
+
joint_attention_kwargs (`dict`, *optional*):
|
857 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
858 |
+
`self.processor` in
|
859 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
860 |
+
callback_on_step_end (`Callable`, *optional*):
|
861 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
862 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
863 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
864 |
+
`callback_on_step_end_tensor_inputs`.
|
865 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
866 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
867 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
868 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
869 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
870 |
+
|
871 |
+
Examples:
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
875 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
876 |
+
images.
|
877 |
+
"""
|
878 |
+
|
879 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
880 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
881 |
+
|
882 |
+
# 1. Check inputs. Raise error if not correct
|
883 |
+
self.img2img_check_inputs(
|
884 |
+
prompt,
|
885 |
+
prompt_2,
|
886 |
+
strength,
|
887 |
+
height,
|
888 |
+
width,
|
889 |
+
prompt_embeds=prompt_embeds,
|
890 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
891 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
892 |
+
max_sequence_length=max_sequence_length,
|
893 |
+
)
|
894 |
+
|
895 |
+
self._guidance_scale = guidance_scale
|
896 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
897 |
+
self._interrupt = False
|
898 |
+
|
899 |
+
# 2. Preprocess image
|
900 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
901 |
+
init_image = init_image.to(dtype=torch.float32)
|
902 |
+
|
903 |
+
# 3. Define call parameters
|
904 |
+
if prompt is not None and isinstance(prompt, str):
|
905 |
+
batch_size = 1
|
906 |
+
elif prompt is not None and isinstance(prompt, list):
|
907 |
+
batch_size = len(prompt)
|
908 |
+
else:
|
909 |
+
batch_size = prompt_embeds.shape[0]
|
910 |
+
|
911 |
+
device = self._execution_device
|
912 |
+
|
913 |
+
lora_scale = (
|
914 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
915 |
+
)
|
916 |
+
(
|
917 |
+
prompt_embeds,
|
918 |
+
pooled_prompt_embeds,
|
919 |
+
text_ids,
|
920 |
+
) = self.encode_prompt(
|
921 |
+
prompt=prompt,
|
922 |
+
prompt_2=prompt_2,
|
923 |
+
prompt_embeds=prompt_embeds,
|
924 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
925 |
+
device=device,
|
926 |
+
num_images_per_prompt=num_images_per_prompt,
|
927 |
+
max_sequence_length=max_sequence_length,
|
928 |
+
lora_scale=lora_scale,
|
929 |
+
)
|
930 |
+
|
931 |
+
register_regular_attention_processors(self.transformer)
|
932 |
+
|
933 |
+
# 4.Prepare timesteps
|
934 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
935 |
+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
936 |
+
mu = calculate_shift(
|
937 |
+
image_seq_len,
|
938 |
+
self.scheduler.config.base_image_seq_len,
|
939 |
+
self.scheduler.config.max_image_seq_len,
|
940 |
+
self.scheduler.config.base_shift,
|
941 |
+
self.scheduler.config.max_shift,
|
942 |
+
)
|
943 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
944 |
+
self.scheduler,
|
945 |
+
num_inference_steps,
|
946 |
+
device,
|
947 |
+
timesteps,
|
948 |
+
sigmas,
|
949 |
+
mu=mu,
|
950 |
+
)
|
951 |
+
timesteps, num_inference_steps = self.img2img_get_timesteps(num_inference_steps, strength, device)
|
952 |
+
|
953 |
+
if num_inference_steps < 1:
|
954 |
+
raise ValueError(
|
955 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
956 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
957 |
+
)
|
958 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
959 |
+
|
960 |
+
# 5. Prepare latent variables
|
961 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
962 |
+
|
963 |
+
latents, latent_image_ids = self.img2img_prepare_latents(
|
964 |
+
init_image,
|
965 |
+
latent_timestep,
|
966 |
+
batch_size * num_images_per_prompt,
|
967 |
+
num_channels_latents,
|
968 |
+
height,
|
969 |
+
width,
|
970 |
+
prompt_embeds.dtype,
|
971 |
+
device,
|
972 |
+
generator,
|
973 |
+
latents,
|
974 |
+
)
|
975 |
+
|
976 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
977 |
+
self._num_timesteps = len(timesteps)
|
978 |
+
|
979 |
+
# handle guidance
|
980 |
+
if self.transformer.config.guidance_embeds:
|
981 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
982 |
+
guidance = guidance.expand(latents.shape[0])
|
983 |
+
else:
|
984 |
+
guidance = None
|
985 |
+
|
986 |
+
text_ids = text_ids.expand(latents.shape[0], -1, -1)
|
987 |
+
latent_image_ids = latent_image_ids.expand(latents.shape[0], -1, -1)
|
988 |
+
|
989 |
+
# 6. Denoising loop
|
990 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
991 |
+
for i, t in enumerate(timesteps):
|
992 |
+
if self.interrupt:
|
993 |
+
continue
|
994 |
+
|
995 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
996 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
997 |
+
noise_pred = self.transformer(
|
998 |
+
hidden_states=latents,
|
999 |
+
timestep=timestep / 1000,
|
1000 |
+
guidance=guidance,
|
1001 |
+
pooled_projections=pooled_prompt_embeds,
|
1002 |
+
encoder_hidden_states=prompt_embeds,
|
1003 |
+
txt_ids=text_ids,
|
1004 |
+
img_ids=latent_image_ids,
|
1005 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1006 |
+
return_dict=False,
|
1007 |
+
)[0]
|
1008 |
+
|
1009 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1010 |
+
latents_dtype = latents.dtype
|
1011 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
1012 |
+
|
1013 |
+
if latents.dtype != latents_dtype:
|
1014 |
+
if torch.backends.mps.is_available():
|
1015 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1016 |
+
latents = latents.to(latents_dtype)
|
1017 |
+
|
1018 |
+
if callback_on_step_end is not None:
|
1019 |
+
callback_kwargs = {}
|
1020 |
+
for k in callback_on_step_end_tensor_inputs:
|
1021 |
+
callback_kwargs[k] = locals()[k]
|
1022 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1023 |
+
|
1024 |
+
latents = callback_outputs.pop("latents", latents)
|
1025 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1026 |
+
|
1027 |
+
# call the callback, if provided
|
1028 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1029 |
+
progress_bar.update()
|
1030 |
+
|
1031 |
+
# if XLA_AVAILABLE:
|
1032 |
+
# xm.mark_step()
|
1033 |
+
|
1034 |
+
if output_type == "latent":
|
1035 |
+
image = latents
|
1036 |
+
|
1037 |
+
else:
|
1038 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
1039 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1040 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1041 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1042 |
+
|
1043 |
+
# Offload all models
|
1044 |
+
self.maybe_free_model_hooks()
|
1045 |
+
|
1046 |
+
if not return_dict:
|
1047 |
+
return (image,)
|
1048 |
+
|
1049 |
+
return FluxPipelineOutput(images=image)
|
1050 |
+
|
1051 |
+
############# Invert Methods #############
|
1052 |
+
def invert_prepare_latents(
|
1053 |
+
self,
|
1054 |
+
image,
|
1055 |
+
timestep,
|
1056 |
+
batch_size,
|
1057 |
+
num_channels_latents,
|
1058 |
+
height,
|
1059 |
+
width,
|
1060 |
+
dtype,
|
1061 |
+
device,
|
1062 |
+
generator,
|
1063 |
+
latents=None,
|
1064 |
+
add_noise=False,
|
1065 |
+
):
|
1066 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
1067 |
+
raise ValueError(
|
1068 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
1069 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
1073 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
1074 |
+
|
1075 |
+
shape = (batch_size, num_channels_latents, height, width)
|
1076 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
1077 |
+
|
1078 |
+
if latents is not None:
|
1079 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
1080 |
+
|
1081 |
+
image = image.to(device=device, dtype=dtype)
|
1082 |
+
image_latents = self.img2img_encode_vae_image(image=image, generator=generator)
|
1083 |
+
|
1084 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
1085 |
+
# expand init_latents for batch_size
|
1086 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
1087 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
1088 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
1089 |
+
raise ValueError(
|
1090 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
1091 |
+
)
|
1092 |
+
else:
|
1093 |
+
image_latents = torch.cat([image_latents], dim=0)
|
1094 |
+
|
1095 |
+
if add_noise:
|
1096 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
1097 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
1098 |
+
else:
|
1099 |
+
latents = image_latents
|
1100 |
+
|
1101 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
1102 |
+
|
1103 |
+
return latents, latent_image_ids
|
1104 |
+
|
1105 |
+
@torch.no_grad()
|
1106 |
+
def call_invert(
|
1107 |
+
self,
|
1108 |
+
prompt: Union[str, List[str]] = None,
|
1109 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
1110 |
+
image: PipelineImageInput = None,
|
1111 |
+
height: Optional[int] = None,
|
1112 |
+
width: Optional[int] = None,
|
1113 |
+
num_inference_steps: int = 28,
|
1114 |
+
timesteps: List[int] = None,
|
1115 |
+
guidance_scale: float = 7.0,
|
1116 |
+
num_images_per_prompt: Optional[int] = 1,
|
1117 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1118 |
+
latents: Optional[torch.FloatTensor] = None,
|
1119 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1120 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1121 |
+
output_type: Optional[str] = "pil",
|
1122 |
+
return_dict: bool = True,
|
1123 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1124 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1125 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1126 |
+
max_sequence_length: int = 512,
|
1127 |
+
|
1128 |
+
fixed_point_iterations: int = 1,
|
1129 |
+
):
|
1130 |
+
r"""
|
1131 |
+
Function invoked when calling the pipeline for generation.
|
1132 |
+
|
1133 |
+
Args:
|
1134 |
+
prompt (`str` or `List[str]`, *optional*):
|
1135 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1136 |
+
instead.
|
1137 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1138 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1139 |
+
will be used instead
|
1140 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1141 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
1142 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1143 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
1144 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1145 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1146 |
+
expense of slower inference.
|
1147 |
+
timesteps (`List[int]`, *optional*):
|
1148 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
1149 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
1150 |
+
passed will be used. Must be in descending order.
|
1151 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
1152 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1153 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1154 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1155 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1156 |
+
usually at the expense of lower image quality.
|
1157 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1158 |
+
The number of images to generate per prompt.
|
1159 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1160 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1161 |
+
to make generation deterministic.
|
1162 |
+
latents (`torch.FloatTensor`, *optional*):
|
1163 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1164 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1165 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1166 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1167 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1168 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1169 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1170 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1171 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1172 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1173 |
+
The output format of the generate image. Choose between
|
1174 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1175 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1176 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
1177 |
+
joint_attention_kwargs (`dict`, *optional*):
|
1178 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1179 |
+
`self.processor` in
|
1180 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1181 |
+
callback_on_step_end (`Callable`, *optional*):
|
1182 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1183 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1184 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1185 |
+
`callback_on_step_end_tensor_inputs`.
|
1186 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1187 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1188 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1189 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
1190 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
1191 |
+
|
1192 |
+
Examples:
|
1193 |
+
|
1194 |
+
Returns:
|
1195 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
1196 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
1197 |
+
images.
|
1198 |
+
"""
|
1199 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1200 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1201 |
+
|
1202 |
+
# 1. Check inputs. Raise error if not correct
|
1203 |
+
self.check_inputs(
|
1204 |
+
prompt,
|
1205 |
+
prompt_2,
|
1206 |
+
height,
|
1207 |
+
width,
|
1208 |
+
prompt_embeds=prompt_embeds,
|
1209 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1210 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
1211 |
+
max_sequence_length=max_sequence_length,
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
self._guidance_scale = guidance_scale
|
1215 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
1216 |
+
self._interrupt = False
|
1217 |
+
|
1218 |
+
# 1.5. Preprocess image
|
1219 |
+
if isinstance(image, Image.Image):
|
1220 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
1221 |
+
elif isinstance(image, torch.Tensor):
|
1222 |
+
init_image = image
|
1223 |
+
latents = image
|
1224 |
+
else:
|
1225 |
+
raise ValueError("Image must be of type `PIL.Image.Image` or `torch.Tensor`")
|
1226 |
+
|
1227 |
+
init_image = init_image.to(dtype=torch.float32)
|
1228 |
+
|
1229 |
+
# 2. Define call parameters
|
1230 |
+
if prompt is not None and isinstance(prompt, str):
|
1231 |
+
batch_size = 1
|
1232 |
+
elif prompt is not None and isinstance(prompt, list):
|
1233 |
+
batch_size = len(prompt)
|
1234 |
+
else:
|
1235 |
+
batch_size = prompt_embeds.shape[0]
|
1236 |
+
|
1237 |
+
device = self._execution_device
|
1238 |
+
|
1239 |
+
lora_scale = (
|
1240 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
1241 |
+
)
|
1242 |
+
(
|
1243 |
+
prompt_embeds,
|
1244 |
+
pooled_prompt_embeds,
|
1245 |
+
text_ids,
|
1246 |
+
) = self.encode_prompt(
|
1247 |
+
prompt=prompt,
|
1248 |
+
prompt_2=prompt_2,
|
1249 |
+
prompt_embeds=prompt_embeds,
|
1250 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1251 |
+
device=device,
|
1252 |
+
num_images_per_prompt=num_images_per_prompt,
|
1253 |
+
max_sequence_length=max_sequence_length,
|
1254 |
+
lora_scale=lora_scale,
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
# 4. Prepare latent variables
|
1258 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
1259 |
+
# latents, latent_image_ids = self.prepare_latents(
|
1260 |
+
# batch_size * num_images_per_prompt,
|
1261 |
+
# num_channels_latents,
|
1262 |
+
# height,
|
1263 |
+
# width,
|
1264 |
+
# prompt_embeds.dtype,
|
1265 |
+
# device,
|
1266 |
+
# generator,
|
1267 |
+
# latents,
|
1268 |
+
# )
|
1269 |
+
latents, latent_image_ids = self.invert_prepare_latents(
|
1270 |
+
init_image,
|
1271 |
+
None,
|
1272 |
+
batch_size * num_images_per_prompt,
|
1273 |
+
num_channels_latents,
|
1274 |
+
height,
|
1275 |
+
width,
|
1276 |
+
prompt_embeds.dtype,
|
1277 |
+
device,
|
1278 |
+
generator,
|
1279 |
+
latents,
|
1280 |
+
False
|
1281 |
+
)
|
1282 |
+
|
1283 |
+
register_regular_attention_processors(self.transformer)
|
1284 |
+
|
1285 |
+
# 5. Prepare timesteps
|
1286 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
1287 |
+
image_seq_len = latents.shape[1]
|
1288 |
+
mu = calculate_shift(
|
1289 |
+
image_seq_len,
|
1290 |
+
self.scheduler.config.base_image_seq_len,
|
1291 |
+
self.scheduler.config.max_image_seq_len,
|
1292 |
+
self.scheduler.config.base_shift,
|
1293 |
+
self.scheduler.config.max_shift,
|
1294 |
+
)
|
1295 |
+
|
1296 |
+
# For Inversion, reverse the sigmas
|
1297 |
+
# sigmas = sigmas[::-1]
|
1298 |
+
|
1299 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
1300 |
+
self.scheduler,
|
1301 |
+
num_inference_steps,
|
1302 |
+
device,
|
1303 |
+
timesteps,
|
1304 |
+
sigmas,
|
1305 |
+
mu=mu,
|
1306 |
+
)
|
1307 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1308 |
+
self._num_timesteps = len(timesteps)
|
1309 |
+
|
1310 |
+
# handle guidance
|
1311 |
+
if self.transformer.config.guidance_embeds:
|
1312 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
1313 |
+
guidance = guidance.expand(latents.shape[0])
|
1314 |
+
else:
|
1315 |
+
guidance = None
|
1316 |
+
|
1317 |
+
self.scheduler.sigmas = reversed(self.scheduler.sigmas)
|
1318 |
+
|
1319 |
+
timesteps_zero_start = reversed(torch.cat([self.scheduler.timesteps[1:], torch.tensor([0], device=device)]))
|
1320 |
+
timesteps_one_start = reversed(self.scheduler.timesteps)
|
1321 |
+
|
1322 |
+
self.scheduler.timesteps = timesteps_zero_start
|
1323 |
+
# self.scheduler.timesteps = timesteps_one_start
|
1324 |
+
|
1325 |
+
timesteps = self.scheduler.timesteps
|
1326 |
+
|
1327 |
+
latents_list = []
|
1328 |
+
latents_list.append(latents)
|
1329 |
+
|
1330 |
+
# 6. Denoising loop
|
1331 |
+
with self.progress_bar(total=num_inference_steps * fixed_point_iterations) as progress_bar:
|
1332 |
+
for i, t in enumerate(timesteps):
|
1333 |
+
original_latents = latents.clone()
|
1334 |
+
for j in range(fixed_point_iterations):
|
1335 |
+
if self.interrupt:
|
1336 |
+
continue
|
1337 |
+
|
1338 |
+
if j == 0:
|
1339 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1340 |
+
timestep = timesteps[i].expand(latents.shape[0]).to(latents.dtype)
|
1341 |
+
else:
|
1342 |
+
timestep = timesteps_one_start[i].expand(latents.shape[0]).to(latents.dtype)
|
1343 |
+
|
1344 |
+
noise_pred = self.transformer(
|
1345 |
+
hidden_states=latents,
|
1346 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
1347 |
+
timestep=timestep / 1000,
|
1348 |
+
guidance=guidance,
|
1349 |
+
pooled_projections=pooled_prompt_embeds,
|
1350 |
+
encoder_hidden_states=prompt_embeds,
|
1351 |
+
txt_ids=text_ids,
|
1352 |
+
img_ids=latent_image_ids,
|
1353 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1354 |
+
return_dict=False,
|
1355 |
+
)[0]
|
1356 |
+
|
1357 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1358 |
+
latents_dtype = latents.dtype
|
1359 |
+
|
1360 |
+
# noise_pred = -noise_pred
|
1361 |
+
latents = self.scheduler.step(noise_pred, t, original_latents, return_dict=False, step_index=i)[0]
|
1362 |
+
|
1363 |
+
if latents.dtype != latents_dtype:
|
1364 |
+
if torch.backends.mps.is_available():
|
1365 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1366 |
+
latents = latents.to(latents_dtype)
|
1367 |
+
|
1368 |
+
if callback_on_step_end is not None:
|
1369 |
+
callback_kwargs = {}
|
1370 |
+
for k in callback_on_step_end_tensor_inputs:
|
1371 |
+
callback_kwargs[k] = locals()[k]
|
1372 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1373 |
+
|
1374 |
+
latents = callback_outputs.pop("latents", latents)
|
1375 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1376 |
+
|
1377 |
+
# call the callback, if provided
|
1378 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1379 |
+
progress_bar.update()
|
1380 |
+
|
1381 |
+
# if XLA_AVAILABLE:
|
1382 |
+
# xm.mark_step()
|
1383 |
+
|
1384 |
+
latents_list.append(latents)
|
1385 |
+
|
1386 |
+
# Offload all models
|
1387 |
+
self.maybe_free_model_hooks()
|
1388 |
+
|
1389 |
+
return latents_list
|
addit_flux_transformer.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Dict, List, Optional, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23 |
+
from diffusers.models.attention import FeedForward
|
24 |
+
from diffusers.models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
27 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
28 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
29 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
31 |
+
|
32 |
+
from addit_attention_processors import AdditFluxAttnProcessor2_0, AdditFluxSingleAttnProcessor2_0
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
# YiYi to-do: refactor rope related functions/classes
|
38 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
39 |
+
assert dim % 2 == 0, "The dimension must be even."
|
40 |
+
|
41 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
42 |
+
omega = 1.0 / (theta**scale)
|
43 |
+
|
44 |
+
batch_size, seq_length = pos.shape
|
45 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
46 |
+
cos_out = torch.cos(out)
|
47 |
+
sin_out = torch.sin(out)
|
48 |
+
|
49 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
50 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
51 |
+
return out.float()
|
52 |
+
|
53 |
+
|
54 |
+
# YiYi to-do: refactor rope related functions/classes
|
55 |
+
class EmbedND(nn.Module):
|
56 |
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
57 |
+
super().__init__()
|
58 |
+
self.dim = dim
|
59 |
+
self.theta = theta
|
60 |
+
self.axes_dim = axes_dim
|
61 |
+
|
62 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
63 |
+
n_axes = ids.shape[-1]
|
64 |
+
emb = torch.cat(
|
65 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
66 |
+
dim=-3,
|
67 |
+
)
|
68 |
+
return emb.unsqueeze(1)
|
69 |
+
|
70 |
+
|
71 |
+
@maybe_allow_in_graph
|
72 |
+
class AdditFluxSingleTransformerBlock(nn.Module):
|
73 |
+
r"""
|
74 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
75 |
+
|
76 |
+
Reference: https://arxiv.org/abs/2403.03206
|
77 |
+
|
78 |
+
Parameters:
|
79 |
+
dim (`int`): The number of channels in the input and output.
|
80 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
81 |
+
attention_head_dim (`int`): The number of channels in each head.
|
82 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
83 |
+
processing of `context` conditions.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
87 |
+
super().__init__()
|
88 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
89 |
+
|
90 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
91 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
92 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
93 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
94 |
+
|
95 |
+
processor = FluxSingleAttnProcessor2_0()
|
96 |
+
self.attn = Attention(
|
97 |
+
query_dim=dim,
|
98 |
+
cross_attention_dim=None,
|
99 |
+
dim_head=attention_head_dim,
|
100 |
+
heads=num_attention_heads,
|
101 |
+
out_dim=dim,
|
102 |
+
bias=True,
|
103 |
+
processor=processor,
|
104 |
+
qk_norm="rms_norm",
|
105 |
+
eps=1e-6,
|
106 |
+
pre_only=True,
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(
|
110 |
+
self,
|
111 |
+
hidden_states: torch.FloatTensor,
|
112 |
+
temb: torch.FloatTensor,
|
113 |
+
image_rotary_emb=None,
|
114 |
+
proccesor_kwargs=None,
|
115 |
+
):
|
116 |
+
residual = hidden_states
|
117 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
118 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
119 |
+
|
120 |
+
attn_output = self.attn(
|
121 |
+
hidden_states=norm_hidden_states,
|
122 |
+
image_rotary_emb=image_rotary_emb,
|
123 |
+
**(proccesor_kwargs or {}),
|
124 |
+
)
|
125 |
+
|
126 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
127 |
+
gate = gate.unsqueeze(1)
|
128 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
129 |
+
hidden_states = residual + hidden_states
|
130 |
+
if hidden_states.dtype == torch.float16:
|
131 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
132 |
+
|
133 |
+
return hidden_states
|
134 |
+
|
135 |
+
|
136 |
+
@maybe_allow_in_graph
|
137 |
+
class AdditFluxTransformerBlock(nn.Module):
|
138 |
+
r"""
|
139 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
140 |
+
|
141 |
+
Reference: https://arxiv.org/abs/2403.03206
|
142 |
+
|
143 |
+
Parameters:
|
144 |
+
dim (`int`): The number of channels in the input and output.
|
145 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
146 |
+
attention_head_dim (`int`): The number of channels in each head.
|
147 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
148 |
+
processing of `context` conditions.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.norm1 = AdaLayerNormZero(dim)
|
155 |
+
|
156 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
157 |
+
|
158 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
159 |
+
processor = FluxAttnProcessor2_0()
|
160 |
+
else:
|
161 |
+
raise ValueError(
|
162 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
163 |
+
)
|
164 |
+
self.attn = Attention(
|
165 |
+
query_dim=dim,
|
166 |
+
cross_attention_dim=None,
|
167 |
+
added_kv_proj_dim=dim,
|
168 |
+
dim_head=attention_head_dim,
|
169 |
+
heads=num_attention_heads,
|
170 |
+
out_dim=dim,
|
171 |
+
context_pre_only=False,
|
172 |
+
bias=True,
|
173 |
+
processor=processor,
|
174 |
+
qk_norm=qk_norm,
|
175 |
+
eps=eps,
|
176 |
+
)
|
177 |
+
|
178 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
179 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
180 |
+
|
181 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
182 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
183 |
+
|
184 |
+
# let chunk size default to None
|
185 |
+
self._chunk_size = None
|
186 |
+
self._chunk_dim = 0
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
hidden_states: torch.FloatTensor,
|
191 |
+
encoder_hidden_states: torch.FloatTensor,
|
192 |
+
temb: torch.FloatTensor,
|
193 |
+
image_rotary_emb=None,
|
194 |
+
proccesor_kwargs=None,
|
195 |
+
):
|
196 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
197 |
+
|
198 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
199 |
+
encoder_hidden_states, emb=temb
|
200 |
+
)
|
201 |
+
|
202 |
+
# Attention.
|
203 |
+
attn_output, context_attn_output = self.attn(
|
204 |
+
hidden_states=norm_hidden_states,
|
205 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
206 |
+
image_rotary_emb=image_rotary_emb,
|
207 |
+
**(proccesor_kwargs or {}),
|
208 |
+
)
|
209 |
+
|
210 |
+
# Process attention outputs for the `hidden_states`.
|
211 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
212 |
+
hidden_states = hidden_states + attn_output
|
213 |
+
|
214 |
+
norm_hidden_states = self.norm2(hidden_states)
|
215 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
216 |
+
|
217 |
+
ff_output = self.ff(norm_hidden_states)
|
218 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
219 |
+
|
220 |
+
hidden_states = hidden_states + ff_output
|
221 |
+
|
222 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
223 |
+
|
224 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
225 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
226 |
+
|
227 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
228 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
229 |
+
|
230 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
231 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
232 |
+
if encoder_hidden_states.dtype == torch.float16:
|
233 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
234 |
+
|
235 |
+
return encoder_hidden_states, hidden_states
|
236 |
+
|
237 |
+
|
238 |
+
class AdditFluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
239 |
+
"""
|
240 |
+
The Transformer model introduced in Flux.
|
241 |
+
|
242 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
243 |
+
|
244 |
+
Parameters:
|
245 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
246 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
247 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
248 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
249 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
250 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
251 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
252 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
253 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
254 |
+
"""
|
255 |
+
|
256 |
+
_supports_gradient_checkpointing = True
|
257 |
+
|
258 |
+
@register_to_config
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
patch_size: int = 1,
|
262 |
+
in_channels: int = 64,
|
263 |
+
num_layers: int = 19,
|
264 |
+
num_single_layers: int = 38,
|
265 |
+
attention_head_dim: int = 128,
|
266 |
+
num_attention_heads: int = 24,
|
267 |
+
joint_attention_dim: int = 4096,
|
268 |
+
pooled_projection_dim: int = 768,
|
269 |
+
guidance_embeds: bool = False,
|
270 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
271 |
+
):
|
272 |
+
super().__init__()
|
273 |
+
self.out_channels = in_channels
|
274 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
275 |
+
|
276 |
+
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
|
277 |
+
text_time_guidance_cls = (
|
278 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
279 |
+
)
|
280 |
+
self.time_text_embed = text_time_guidance_cls(
|
281 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
282 |
+
)
|
283 |
+
|
284 |
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
285 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
286 |
+
|
287 |
+
self.transformer_blocks = nn.ModuleList(
|
288 |
+
[
|
289 |
+
AdditFluxTransformerBlock(
|
290 |
+
dim=self.inner_dim,
|
291 |
+
num_attention_heads=self.config.num_attention_heads,
|
292 |
+
attention_head_dim=self.config.attention_head_dim,
|
293 |
+
)
|
294 |
+
for i in range(self.config.num_layers)
|
295 |
+
]
|
296 |
+
)
|
297 |
+
|
298 |
+
self.single_transformer_blocks = nn.ModuleList(
|
299 |
+
[
|
300 |
+
AdditFluxSingleTransformerBlock(
|
301 |
+
dim=self.inner_dim,
|
302 |
+
num_attention_heads=self.config.num_attention_heads,
|
303 |
+
attention_head_dim=self.config.attention_head_dim,
|
304 |
+
)
|
305 |
+
for i in range(self.config.num_single_layers)
|
306 |
+
]
|
307 |
+
)
|
308 |
+
|
309 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
310 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
311 |
+
|
312 |
+
self.gradient_checkpointing = False
|
313 |
+
|
314 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
315 |
+
if hasattr(module, "gradient_checkpointing"):
|
316 |
+
module.gradient_checkpointing = value
|
317 |
+
|
318 |
+
@property
|
319 |
+
def attn_processors(self):
|
320 |
+
r"""
|
321 |
+
Returns:
|
322 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
323 |
+
indexed by its weight name.
|
324 |
+
"""
|
325 |
+
# set recursively
|
326 |
+
processors = {}
|
327 |
+
|
328 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
329 |
+
if hasattr(module, "get_processor"):
|
330 |
+
processors[f"{name}.processor"] = module.get_processor()
|
331 |
+
|
332 |
+
for sub_name, child in module.named_children():
|
333 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
334 |
+
|
335 |
+
return processors
|
336 |
+
|
337 |
+
for name, module in self.named_children():
|
338 |
+
fn_recursive_add_processors(name, module, processors)
|
339 |
+
|
340 |
+
return processors
|
341 |
+
|
342 |
+
def set_attn_processor(
|
343 |
+
self, processor
|
344 |
+
):
|
345 |
+
r"""
|
346 |
+
Sets the attention processor to use to compute attention.
|
347 |
+
|
348 |
+
Parameters:
|
349 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
350 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
351 |
+
for **all** `Attention` layers.
|
352 |
+
|
353 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
354 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
355 |
+
"""
|
356 |
+
count = len(self.attn_processors.keys())
|
357 |
+
|
358 |
+
if isinstance(processor, dict) and len(processor) != count:
|
359 |
+
raise ValueError(
|
360 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
361 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
362 |
+
)
|
363 |
+
|
364 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
365 |
+
if hasattr(module, "set_processor"):
|
366 |
+
if not isinstance(processor, dict):
|
367 |
+
module.set_processor(processor)
|
368 |
+
else:
|
369 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
370 |
+
|
371 |
+
for sub_name, child in module.named_children():
|
372 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
373 |
+
|
374 |
+
for name, module in self.named_children():
|
375 |
+
fn_recursive_attn_processor(name, module, processor)
|
376 |
+
|
377 |
+
def forward(
|
378 |
+
self,
|
379 |
+
hidden_states: torch.Tensor,
|
380 |
+
encoder_hidden_states: torch.Tensor = None,
|
381 |
+
pooled_projections: torch.Tensor = None,
|
382 |
+
timestep: torch.LongTensor = None,
|
383 |
+
img_ids: torch.Tensor = None,
|
384 |
+
txt_ids: torch.Tensor = None,
|
385 |
+
guidance: torch.Tensor = None,
|
386 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
387 |
+
return_dict: bool = True,
|
388 |
+
proccesor_kwargs: Optional[Dict[str, Any]] = None,
|
389 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
390 |
+
"""
|
391 |
+
The [`FluxTransformer2DModel`] forward method.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
395 |
+
Input `hidden_states`.
|
396 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
397 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
398 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
399 |
+
from the embeddings of input conditions.
|
400 |
+
timestep ( `torch.LongTensor`):
|
401 |
+
Used to indicate denoising step.
|
402 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
403 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
404 |
+
joint_attention_kwargs (`dict`, *optional*):
|
405 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
406 |
+
`self.processor` in
|
407 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
408 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
409 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
410 |
+
tuple.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
414 |
+
`tuple` where the first element is the sample tensor.
|
415 |
+
"""
|
416 |
+
if joint_attention_kwargs is not None:
|
417 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
418 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
419 |
+
else:
|
420 |
+
lora_scale = 1.0
|
421 |
+
|
422 |
+
if USE_PEFT_BACKEND:
|
423 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
424 |
+
scale_lora_layers(self, lora_scale)
|
425 |
+
else:
|
426 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
427 |
+
logger.warning(
|
428 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
429 |
+
)
|
430 |
+
hidden_states = self.x_embedder(hidden_states)
|
431 |
+
|
432 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
433 |
+
if guidance is not None:
|
434 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
435 |
+
else:
|
436 |
+
guidance = None
|
437 |
+
temb = (
|
438 |
+
self.time_text_embed(timestep, pooled_projections)
|
439 |
+
if guidance is None
|
440 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
441 |
+
)
|
442 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
443 |
+
|
444 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
445 |
+
image_rotary_emb = self.pos_embed(ids)
|
446 |
+
|
447 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
448 |
+
if self.training and self.gradient_checkpointing:
|
449 |
+
|
450 |
+
def create_custom_forward(module, return_dict=None):
|
451 |
+
def custom_forward(*inputs):
|
452 |
+
if return_dict is not None:
|
453 |
+
return module(*inputs, return_dict=return_dict)
|
454 |
+
else:
|
455 |
+
return module(*inputs)
|
456 |
+
|
457 |
+
return custom_forward
|
458 |
+
|
459 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
460 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
461 |
+
create_custom_forward(block),
|
462 |
+
hidden_states,
|
463 |
+
encoder_hidden_states,
|
464 |
+
temb,
|
465 |
+
image_rotary_emb,
|
466 |
+
**ckpt_kwargs,
|
467 |
+
)
|
468 |
+
|
469 |
+
else:
|
470 |
+
encoder_hidden_states, hidden_states = block(
|
471 |
+
hidden_states=hidden_states,
|
472 |
+
encoder_hidden_states=encoder_hidden_states,
|
473 |
+
temb=temb,
|
474 |
+
image_rotary_emb=image_rotary_emb,
|
475 |
+
proccesor_kwargs=proccesor_kwargs,
|
476 |
+
)
|
477 |
+
|
478 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
479 |
+
|
480 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
481 |
+
if self.training and self.gradient_checkpointing:
|
482 |
+
|
483 |
+
def create_custom_forward(module, return_dict=None):
|
484 |
+
def custom_forward(*inputs):
|
485 |
+
if return_dict is not None:
|
486 |
+
return module(*inputs, return_dict=return_dict)
|
487 |
+
else:
|
488 |
+
return module(*inputs)
|
489 |
+
|
490 |
+
return custom_forward
|
491 |
+
|
492 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
493 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
494 |
+
create_custom_forward(block),
|
495 |
+
hidden_states,
|
496 |
+
temb,
|
497 |
+
image_rotary_emb,
|
498 |
+
**ckpt_kwargs,
|
499 |
+
)
|
500 |
+
|
501 |
+
else:
|
502 |
+
hidden_states = block(
|
503 |
+
hidden_states=hidden_states,
|
504 |
+
temb=temb,
|
505 |
+
image_rotary_emb=image_rotary_emb,
|
506 |
+
proccesor_kwargs=proccesor_kwargs,
|
507 |
+
)
|
508 |
+
|
509 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
510 |
+
|
511 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
512 |
+
output = self.proj_out(hidden_states)
|
513 |
+
|
514 |
+
if USE_PEFT_BACKEND:
|
515 |
+
# remove `lora_scale` from each PEFT layer
|
516 |
+
unscale_lora_layers(self, lora_scale)
|
517 |
+
|
518 |
+
if not return_dict:
|
519 |
+
return (output,)
|
520 |
+
|
521 |
+
return Transformer2DModelOutput(sample=output)
|
addit_methods.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the LICENSE file
|
4 |
+
# located at the root directory.
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
from visualization_utils import show_images
|
9 |
+
|
10 |
+
def _add_object(
|
11 |
+
pipe,
|
12 |
+
prompts,
|
13 |
+
seed_src,
|
14 |
+
seed_obj,
|
15 |
+
extended_scale,
|
16 |
+
source_latents,
|
17 |
+
structure_transfer_step,
|
18 |
+
subject_token,
|
19 |
+
blend_steps,
|
20 |
+
show_attention=False,
|
21 |
+
localization_model="attention_points_sam",
|
22 |
+
is_img_src=False,
|
23 |
+
img_src_latents=None,
|
24 |
+
use_offset=False,
|
25 |
+
display_output=False,
|
26 |
+
):
|
27 |
+
gc.collect()
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
|
30 |
+
out = pipe(
|
31 |
+
prompt=prompts,
|
32 |
+
guidance_scale=3.5 if (not is_img_src) else [1,3.5],
|
33 |
+
height=1024,
|
34 |
+
width=1024,
|
35 |
+
max_sequence_length=512,
|
36 |
+
num_inference_steps=30,
|
37 |
+
seed=[seed_src, seed_obj],
|
38 |
+
|
39 |
+
# Extended Attention
|
40 |
+
extended_scale=extended_scale,
|
41 |
+
extended_steps_multi=10,
|
42 |
+
extended_steps_single=20,
|
43 |
+
|
44 |
+
# Structure Transfer
|
45 |
+
source_latents=source_latents,
|
46 |
+
structure_transfer_step=structure_transfer_step,
|
47 |
+
|
48 |
+
# Latent Blending
|
49 |
+
subject_token=subject_token,
|
50 |
+
localization_model=localization_model,
|
51 |
+
blend_steps=blend_steps,
|
52 |
+
show_attention=show_attention,
|
53 |
+
|
54 |
+
# Real Image Source
|
55 |
+
is_img_src=is_img_src,
|
56 |
+
img_src_latents=img_src_latents,
|
57 |
+
use_offset=use_offset,
|
58 |
+
)
|
59 |
+
|
60 |
+
if display_output:
|
61 |
+
show_images(out.images)
|
62 |
+
|
63 |
+
return out.images
|
64 |
+
|
65 |
+
def add_object_generated(
|
66 |
+
pipe,
|
67 |
+
prompt_source,
|
68 |
+
prompt_object,
|
69 |
+
subject_token,
|
70 |
+
seed_src,
|
71 |
+
seed_obj,
|
72 |
+
show_attention=False,
|
73 |
+
extended_scale=1.05,
|
74 |
+
structure_transfer_step=2,
|
75 |
+
blend_steps=[15],
|
76 |
+
localization_model="attention_points_sam",
|
77 |
+
display_output=False
|
78 |
+
):
|
79 |
+
gc.collect()
|
80 |
+
torch.cuda.empty_cache()
|
81 |
+
|
82 |
+
# Generate source image and latents for each seed1
|
83 |
+
print('Generating source image...')
|
84 |
+
source_image, source_latents = pipe(
|
85 |
+
prompt=[prompt_source],
|
86 |
+
guidance_scale=3.5,
|
87 |
+
height=1024,
|
88 |
+
width=1024,
|
89 |
+
max_sequence_length=512,
|
90 |
+
num_inference_steps=30,
|
91 |
+
seed=[seed_src],
|
92 |
+
output_type="both",
|
93 |
+
)
|
94 |
+
source_image = source_image[0]
|
95 |
+
|
96 |
+
# Run the core combination logic
|
97 |
+
print('Running Addit...')
|
98 |
+
src_image, edited_image = _add_object(
|
99 |
+
pipe=pipe,
|
100 |
+
prompts=[prompt_source, prompt_object],
|
101 |
+
subject_token=subject_token,
|
102 |
+
seed_src=seed_src,
|
103 |
+
seed_obj=seed_obj,
|
104 |
+
source_latents=source_latents,
|
105 |
+
structure_transfer_step=structure_transfer_step,
|
106 |
+
extended_scale=extended_scale,
|
107 |
+
blend_steps=blend_steps,
|
108 |
+
show_attention=show_attention,
|
109 |
+
localization_model=localization_model,
|
110 |
+
display_output=display_output
|
111 |
+
)
|
112 |
+
|
113 |
+
return src_image, edited_image
|
114 |
+
|
115 |
+
def add_object_real(
|
116 |
+
pipe,
|
117 |
+
source_image,
|
118 |
+
prompt_source,
|
119 |
+
prompt_object,
|
120 |
+
subject_token,
|
121 |
+
seed_src,
|
122 |
+
seed_obj,
|
123 |
+
localization_model="attention_points_sam",
|
124 |
+
extended_scale=1.05,
|
125 |
+
structure_transfer_step=4,
|
126 |
+
blend_steps=[20],
|
127 |
+
use_offset=False,
|
128 |
+
show_attention=False,
|
129 |
+
use_inversion=False,
|
130 |
+
display_output=False
|
131 |
+
):
|
132 |
+
print('Noising-Denoising Original Image')
|
133 |
+
gc.collect()
|
134 |
+
torch.cuda.empty_cache()
|
135 |
+
|
136 |
+
# Get initial latents
|
137 |
+
source_latents = pipe.call_img2img(
|
138 |
+
prompt=prompt_source,
|
139 |
+
image=source_image,
|
140 |
+
num_inference_steps=30,
|
141 |
+
strength=0.1,
|
142 |
+
guidance_scale=3.5,
|
143 |
+
output_type="latent",
|
144 |
+
generator=torch.Generator(device=pipe.device).manual_seed(0)
|
145 |
+
).images
|
146 |
+
|
147 |
+
# Optional inversion step
|
148 |
+
img_src_latents = None
|
149 |
+
if use_inversion:
|
150 |
+
print('Inverting Image')
|
151 |
+
gc.collect()
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
|
154 |
+
latents_list = pipe.call_invert(
|
155 |
+
prompt=prompt_source,
|
156 |
+
image=source_latents,
|
157 |
+
num_inference_steps=30,
|
158 |
+
guidance_scale=1,
|
159 |
+
fixed_point_iterations=2,
|
160 |
+
generator=torch.Generator(device=pipe.device).manual_seed(0)
|
161 |
+
)
|
162 |
+
img_src_latents = [x[0] for x in latents_list][::-1]
|
163 |
+
|
164 |
+
print('Running Addit')
|
165 |
+
gc.collect()
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
|
168 |
+
src_image, edited_image = _add_object(
|
169 |
+
pipe,
|
170 |
+
prompts=[prompt_source, prompt_object],
|
171 |
+
seed_src=seed_src,
|
172 |
+
seed_obj=seed_obj,
|
173 |
+
extended_scale=extended_scale,
|
174 |
+
source_latents=source_latents,
|
175 |
+
structure_transfer_step=structure_transfer_step,
|
176 |
+
subject_token=subject_token,
|
177 |
+
blend_steps=blend_steps,
|
178 |
+
show_attention=show_attention,
|
179 |
+
localization_model=localization_model,
|
180 |
+
is_img_src=True,
|
181 |
+
img_src_latents=img_src_latents,
|
182 |
+
use_offset=use_offset,
|
183 |
+
display_output=display_output,
|
184 |
+
)
|
185 |
+
|
186 |
+
return src_image, edited_image
|
addit_scheduler.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
|
16 |
+
from typing import Union, Optional, Tuple
|
17 |
+
import torch
|
18 |
+
|
19 |
+
class AdditFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
20 |
+
def step(
|
21 |
+
self,
|
22 |
+
model_output: torch.FloatTensor,
|
23 |
+
timestep: Union[float, torch.FloatTensor],
|
24 |
+
sample: torch.FloatTensor,
|
25 |
+
s_churn: float = 0.0,
|
26 |
+
s_tmin: float = 0.0,
|
27 |
+
s_tmax: float = float("inf"),
|
28 |
+
s_noise: float = 1.0,
|
29 |
+
generator: Optional[torch.Generator] = None,
|
30 |
+
return_dict: bool = True,
|
31 |
+
step_index: Optional[int] = None,
|
32 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
33 |
+
"""
|
34 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
35 |
+
process from the learned model outputs (most often the predicted noise).
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model_output (`torch.FloatTensor`):
|
39 |
+
The direct output from learned diffusion model.
|
40 |
+
timestep (`float`):
|
41 |
+
The current discrete timestep in the diffusion chain.
|
42 |
+
sample (`torch.FloatTensor`):
|
43 |
+
A current instance of a sample created by the diffusion process.
|
44 |
+
s_churn (`float`):
|
45 |
+
s_tmin (`float`):
|
46 |
+
s_tmax (`float`):
|
47 |
+
s_noise (`float`, defaults to 1.0):
|
48 |
+
Scaling factor for noise added to the sample.
|
49 |
+
generator (`torch.Generator`, *optional*):
|
50 |
+
A random number generator.
|
51 |
+
return_dict (`bool`):
|
52 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
53 |
+
tuple.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
57 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
58 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
59 |
+
"""
|
60 |
+
|
61 |
+
if (
|
62 |
+
isinstance(timestep, int)
|
63 |
+
or isinstance(timestep, torch.IntTensor)
|
64 |
+
or isinstance(timestep, torch.LongTensor)
|
65 |
+
):
|
66 |
+
raise ValueError(
|
67 |
+
(
|
68 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
69 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
70 |
+
" one of the `scheduler.timesteps` as a timestep."
|
71 |
+
),
|
72 |
+
)
|
73 |
+
|
74 |
+
if step_index is not None:
|
75 |
+
self._step_index = step_index
|
76 |
+
|
77 |
+
if self.step_index is None:
|
78 |
+
self._init_step_index(timestep)
|
79 |
+
|
80 |
+
# Upcast to avoid precision issues when computing prev_sample
|
81 |
+
sample = sample.to(torch.float32)
|
82 |
+
|
83 |
+
sigma = self.sigmas[self.step_index]
|
84 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
85 |
+
|
86 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
87 |
+
|
88 |
+
# Calculate X_0
|
89 |
+
x_0 = sample - sigma * model_output
|
90 |
+
|
91 |
+
# Cast sample back to model compatible dtype
|
92 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
93 |
+
x_0 = x_0.to(model_output.dtype)
|
94 |
+
|
95 |
+
# upon completion increase step index by one
|
96 |
+
self._step_index += 1
|
97 |
+
|
98 |
+
if not return_dict:
|
99 |
+
return (prev_sample, x_0)
|
100 |
+
|
101 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
app.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# This work is licensed under the LICENSE file
|
5 |
+
# located at the root directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import gradio as gr
|
9 |
+
import spaces
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import tempfile
|
14 |
+
import gc
|
15 |
+
|
16 |
+
from addit_flux_pipeline import AdditFluxPipeline
|
17 |
+
from addit_flux_transformer import AdditFluxTransformer2DModel
|
18 |
+
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
|
19 |
+
from addit_methods import add_object_generated, add_object_real
|
20 |
+
|
21 |
+
# Global variables for model
|
22 |
+
pipe = None
|
23 |
+
device = None
|
24 |
+
|
25 |
+
# Initialize model at startup
|
26 |
+
print("Initializing ADDIT model...")
|
27 |
+
try:
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
print(f"Using device: {device}")
|
30 |
+
|
31 |
+
# Load transformer
|
32 |
+
my_transformer = AdditFluxTransformer2DModel.from_pretrained(
|
33 |
+
"black-forest-labs/FLUX.1-dev",
|
34 |
+
subfolder="transformer",
|
35 |
+
torch_dtype=torch.bfloat16
|
36 |
+
)
|
37 |
+
|
38 |
+
# Load pipeline
|
39 |
+
pipe = AdditFluxPipeline.from_pretrained(
|
40 |
+
"black-forest-labs/FLUX.1-dev",
|
41 |
+
transformer=my_transformer,
|
42 |
+
torch_dtype=torch.bfloat16
|
43 |
+
).to(device)
|
44 |
+
|
45 |
+
# Set scheduler
|
46 |
+
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
47 |
+
|
48 |
+
print("Model initialized successfully!")
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
print(f"Error initializing model: {str(e)}")
|
52 |
+
print("The application will start but model functionality will be unavailable.")
|
53 |
+
|
54 |
+
def validate_inputs(prompt_source, prompt_target, subject_token):
|
55 |
+
"""Validate user inputs"""
|
56 |
+
if not prompt_source.strip():
|
57 |
+
return "Source prompt cannot be empty"
|
58 |
+
if not prompt_target.strip():
|
59 |
+
return "Target prompt cannot be empty"
|
60 |
+
if not subject_token.strip():
|
61 |
+
return "Subject token cannot be empty"
|
62 |
+
if subject_token not in prompt_target:
|
63 |
+
return f"Subject token '{subject_token}' must appear in the target prompt"
|
64 |
+
return None
|
65 |
+
|
66 |
+
@spaces.GPU
|
67 |
+
def process_generated_image(
|
68 |
+
prompt_source,
|
69 |
+
prompt_target,
|
70 |
+
subject_token,
|
71 |
+
seed_src,
|
72 |
+
seed_obj,
|
73 |
+
extended_scale,
|
74 |
+
structure_transfer_step,
|
75 |
+
blend_steps,
|
76 |
+
localization_model,
|
77 |
+
progress=gr.Progress(track_tqdm=True)
|
78 |
+
):
|
79 |
+
"""Process generated image with ADDIT"""
|
80 |
+
global pipe
|
81 |
+
|
82 |
+
if pipe is None:
|
83 |
+
return None, None, "Model not initialized. Please restart the application."
|
84 |
+
|
85 |
+
# Validate inputs
|
86 |
+
error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
|
87 |
+
if error_msg:
|
88 |
+
return None, None, error_msg
|
89 |
+
|
90 |
+
try:
|
91 |
+
# Parse blend steps
|
92 |
+
if blend_steps.strip():
|
93 |
+
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
|
94 |
+
else:
|
95 |
+
blend_steps_list = []
|
96 |
+
|
97 |
+
# Generate images
|
98 |
+
src_image, edited_image = add_object_generated(
|
99 |
+
pipe=pipe,
|
100 |
+
prompt_source=prompt_source,
|
101 |
+
prompt_object=prompt_target,
|
102 |
+
subject_token=subject_token,
|
103 |
+
seed_src=seed_src,
|
104 |
+
seed_obj=seed_obj,
|
105 |
+
show_attention=False,
|
106 |
+
extended_scale=extended_scale,
|
107 |
+
structure_transfer_step=structure_transfer_step,
|
108 |
+
blend_steps=blend_steps_list,
|
109 |
+
localization_model=localization_model,
|
110 |
+
display_output=False
|
111 |
+
)
|
112 |
+
|
113 |
+
return src_image, edited_image, "Images generated successfully!"
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
error_msg = f"Error generating images: {str(e)}"
|
117 |
+
print(error_msg)
|
118 |
+
return None, None, error_msg
|
119 |
+
|
120 |
+
@spaces.GPU
|
121 |
+
def process_real_image(
|
122 |
+
source_image,
|
123 |
+
prompt_source,
|
124 |
+
prompt_target,
|
125 |
+
subject_token,
|
126 |
+
seed_src,
|
127 |
+
seed_obj,
|
128 |
+
extended_scale,
|
129 |
+
structure_transfer_step,
|
130 |
+
blend_steps,
|
131 |
+
localization_model,
|
132 |
+
use_offset,
|
133 |
+
disable_inversion,
|
134 |
+
progress=gr.Progress(track_tqdm=True)
|
135 |
+
):
|
136 |
+
"""Process real image with ADDIT"""
|
137 |
+
global pipe
|
138 |
+
|
139 |
+
if pipe is None:
|
140 |
+
return None, None, "Model not initialized. Please restart the application."
|
141 |
+
|
142 |
+
if source_image is None:
|
143 |
+
return None, None, "Please upload a source image"
|
144 |
+
|
145 |
+
# Validate inputs
|
146 |
+
error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
|
147 |
+
if error_msg:
|
148 |
+
return None, None, error_msg
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Resize source image
|
152 |
+
source_image = source_image.resize((1024, 1024))
|
153 |
+
|
154 |
+
# Parse blend steps
|
155 |
+
if blend_steps.strip():
|
156 |
+
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
|
157 |
+
else:
|
158 |
+
blend_steps_list = []
|
159 |
+
|
160 |
+
# Process image
|
161 |
+
src_image, edited_image = add_object_real(
|
162 |
+
pipe=pipe,
|
163 |
+
source_image=source_image,
|
164 |
+
prompt_source=prompt_source,
|
165 |
+
prompt_object=prompt_target,
|
166 |
+
subject_token=subject_token,
|
167 |
+
seed_src=seed_src,
|
168 |
+
seed_obj=seed_obj,
|
169 |
+
extended_scale=extended_scale,
|
170 |
+
structure_transfer_step=structure_transfer_step,
|
171 |
+
blend_steps=blend_steps_list,
|
172 |
+
localization_model=localization_model,
|
173 |
+
use_offset=use_offset,
|
174 |
+
show_attention=False,
|
175 |
+
use_inversion=not disable_inversion,
|
176 |
+
display_output=False
|
177 |
+
)
|
178 |
+
|
179 |
+
return src_image, edited_image, "Image edited successfully!"
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
error_msg = f"Error processing image: {str(e)}"
|
183 |
+
print(error_msg)
|
184 |
+
return None, None, error_msg
|
185 |
+
|
186 |
+
def create_interface():
|
187 |
+
"""Create the Gradio interface"""
|
188 |
+
|
189 |
+
# Show model status in the interface
|
190 |
+
model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable"
|
191 |
+
|
192 |
+
with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo:
|
193 |
+
gr.HTML(f"""
|
194 |
+
<div style="text-align: center; margin-bottom: 20px;">
|
195 |
+
<h1>🎨 Add-it: Training-Free Object Insertion</h1>
|
196 |
+
<p>Add objects to images using pretrained diffusion models</p>
|
197 |
+
<p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> |
|
198 |
+
<a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> |
|
199 |
+
<a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p>
|
200 |
+
<p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p>
|
201 |
+
</div>
|
202 |
+
""")
|
203 |
+
|
204 |
+
# Main interface
|
205 |
+
with gr.Tabs():
|
206 |
+
# Generated Images Tab
|
207 |
+
with gr.TabItem("🎭 Generated Images"):
|
208 |
+
gr.Markdown("### Generate a base image and add objects to it")
|
209 |
+
|
210 |
+
with gr.Row():
|
211 |
+
with gr.Column(scale=1):
|
212 |
+
gen_prompt_source = gr.Textbox(
|
213 |
+
label="Source Prompt",
|
214 |
+
placeholder="A photo of a cat sitting on the couch",
|
215 |
+
value="A photo of a cat sitting on the couch"
|
216 |
+
)
|
217 |
+
gen_prompt_target = gr.Textbox(
|
218 |
+
label="Target Prompt",
|
219 |
+
placeholder="A photo of a cat wearing a red hat sitting on the couch",
|
220 |
+
value="A photo of a cat wearing a red hat sitting on the couch"
|
221 |
+
)
|
222 |
+
gen_subject_token = gr.Textbox(
|
223 |
+
label="Subject Token",
|
224 |
+
placeholder="hat",
|
225 |
+
value="hat",
|
226 |
+
info="Single token representing the object to add **(must appear in target prompt)**"
|
227 |
+
)
|
228 |
+
|
229 |
+
with gr.Accordion("Advanced Settings", open=False):
|
230 |
+
gen_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
|
231 |
+
gen_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
|
232 |
+
gen_extended_scale = gr.Slider(
|
233 |
+
label="Extended Scale",
|
234 |
+
minimum=1.0,
|
235 |
+
maximum=1.3,
|
236 |
+
value=1.05,
|
237 |
+
step=0.01
|
238 |
+
)
|
239 |
+
gen_structure_transfer_step = gr.Slider(
|
240 |
+
label="Structure Transfer Step",
|
241 |
+
minimum=0,
|
242 |
+
maximum=10,
|
243 |
+
value=2,
|
244 |
+
step=1
|
245 |
+
)
|
246 |
+
gen_blend_steps = gr.Textbox(
|
247 |
+
label="Blend Steps",
|
248 |
+
value="15",
|
249 |
+
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
|
250 |
+
)
|
251 |
+
gen_localization_model = gr.Dropdown(
|
252 |
+
label="Localization Model",
|
253 |
+
choices=[
|
254 |
+
"attention_points_sam",
|
255 |
+
"attention",
|
256 |
+
"attention_box_sam",
|
257 |
+
"attention_mask_sam",
|
258 |
+
"grounding_sam"
|
259 |
+
],
|
260 |
+
value="attention_points_sam"
|
261 |
+
)
|
262 |
+
|
263 |
+
gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary")
|
264 |
+
|
265 |
+
with gr.Column(scale=2):
|
266 |
+
with gr.Row():
|
267 |
+
gen_src_output = gr.Image(label="Generated Source Image", type="pil")
|
268 |
+
gen_edited_output = gr.Image(label="Edited Image", type="pil")
|
269 |
+
gen_status = gr.Textbox(label="Status", interactive=False)
|
270 |
+
|
271 |
+
gen_submit_btn.click(
|
272 |
+
fn=process_generated_image,
|
273 |
+
inputs=[
|
274 |
+
gen_prompt_source, gen_prompt_target, gen_subject_token,
|
275 |
+
gen_seed_src, gen_seed_obj, gen_extended_scale,
|
276 |
+
gen_structure_transfer_step, gen_blend_steps,
|
277 |
+
gen_localization_model
|
278 |
+
],
|
279 |
+
outputs=[gen_src_output, gen_edited_output, gen_status]
|
280 |
+
)
|
281 |
+
|
282 |
+
# Examples for generated images
|
283 |
+
gr.Examples(
|
284 |
+
examples=[
|
285 |
+
["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"],
|
286 |
+
["A photo of a cat sitting on the couch", "A photo of a cat wearing a red hat sitting on the couch", "hat"],
|
287 |
+
["A car driving through an empty street", "A pink car driving through an empty street", "car"]
|
288 |
+
],
|
289 |
+
inputs=[
|
290 |
+
gen_prompt_source, gen_prompt_target, gen_subject_token
|
291 |
+
],
|
292 |
+
label="Example Prompts"
|
293 |
+
)
|
294 |
+
|
295 |
+
# Real Images Tab
|
296 |
+
with gr.TabItem("📸 Real Images"):
|
297 |
+
gr.Markdown("### Upload an image and add objects to it")
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
with gr.Column(scale=1):
|
301 |
+
real_source_image = gr.Image(label="Source Image", type="pil")
|
302 |
+
real_prompt_source = gr.Textbox(
|
303 |
+
label="Source Prompt",
|
304 |
+
placeholder="A photo of a bed in a dark room",
|
305 |
+
value="A photo of a bed in a dark room"
|
306 |
+
)
|
307 |
+
real_prompt_target = gr.Textbox(
|
308 |
+
label="Target Prompt",
|
309 |
+
placeholder="A photo of a dog lying on a bed in a dark room",
|
310 |
+
value="A photo of a dog lying on a bed in a dark room"
|
311 |
+
)
|
312 |
+
real_subject_token = gr.Textbox(
|
313 |
+
label="Subject Token",
|
314 |
+
placeholder="dog",
|
315 |
+
value="dog",
|
316 |
+
info="Single token representing the object to add **(must appear in target prompt)**"
|
317 |
+
)
|
318 |
+
|
319 |
+
with gr.Accordion("Advanced Settings", open=False):
|
320 |
+
real_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
|
321 |
+
real_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
|
322 |
+
real_extended_scale = gr.Slider(
|
323 |
+
label="Extended Scale",
|
324 |
+
minimum=1.0,
|
325 |
+
maximum=1.3,
|
326 |
+
value=1.1,
|
327 |
+
step=0.01
|
328 |
+
)
|
329 |
+
real_structure_transfer_step = gr.Slider(
|
330 |
+
label="Structure Transfer Step",
|
331 |
+
minimum=0,
|
332 |
+
maximum=10,
|
333 |
+
value=4,
|
334 |
+
step=1
|
335 |
+
)
|
336 |
+
real_blend_steps = gr.Textbox(
|
337 |
+
label="Blend Steps",
|
338 |
+
value="18",
|
339 |
+
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
|
340 |
+
)
|
341 |
+
real_localization_model = gr.Dropdown(
|
342 |
+
label="Localization Model",
|
343 |
+
choices=[
|
344 |
+
"attention",
|
345 |
+
"attention_points_sam",
|
346 |
+
"attention_box_sam",
|
347 |
+
"attention_mask_sam",
|
348 |
+
"grounding_sam"
|
349 |
+
],
|
350 |
+
value="attention"
|
351 |
+
)
|
352 |
+
real_use_offset = gr.Checkbox(label="Use Offset", value=False)
|
353 |
+
real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False)
|
354 |
+
|
355 |
+
real_submit_btn = gr.Button("🎨 Edit Image", variant="primary")
|
356 |
+
|
357 |
+
with gr.Column(scale=2):
|
358 |
+
with gr.Row():
|
359 |
+
real_src_output = gr.Image(label="Source Image", type="pil")
|
360 |
+
real_edited_output = gr.Image(label="Edited Image", type="pil")
|
361 |
+
real_status = gr.Textbox(label="Status", interactive=False)
|
362 |
+
|
363 |
+
real_submit_btn.click(
|
364 |
+
fn=process_real_image,
|
365 |
+
inputs=[
|
366 |
+
real_source_image, real_prompt_source, real_prompt_target, real_subject_token,
|
367 |
+
real_seed_src, real_seed_obj, real_extended_scale,
|
368 |
+
real_structure_transfer_step, real_blend_steps,
|
369 |
+
real_localization_model, real_use_offset,
|
370 |
+
real_disable_inversion
|
371 |
+
],
|
372 |
+
outputs=[real_src_output, real_edited_output, real_status]
|
373 |
+
)
|
374 |
+
|
375 |
+
# Examples for real images
|
376 |
+
gr.Examples(
|
377 |
+
examples=[
|
378 |
+
[
|
379 |
+
"images/bed_dark_room.jpg",
|
380 |
+
"A photo of a bed in a dark room",
|
381 |
+
"A photo of a dog lying on a bed in a dark room",
|
382 |
+
"dog"
|
383 |
+
],
|
384 |
+
[
|
385 |
+
"images/flower.jpg",
|
386 |
+
"A photo of a flower",
|
387 |
+
"A bee standing on a flower",
|
388 |
+
"bee"
|
389 |
+
]
|
390 |
+
],
|
391 |
+
inputs=[
|
392 |
+
real_source_image, real_prompt_source, real_prompt_target, real_subject_token
|
393 |
+
],
|
394 |
+
label="Example Images & Prompts"
|
395 |
+
)
|
396 |
+
|
397 |
+
# Tips
|
398 |
+
with gr.Accordion("💡 Tips for Better Results", open=False):
|
399 |
+
gr.Markdown("""
|
400 |
+
- **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert
|
401 |
+
- **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results
|
402 |
+
- **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance
|
403 |
+
- **Object Placement Issues**: If the object is not added to the image:
|
404 |
+
- Try **decreasing** Structure Transfer Step
|
405 |
+
- Try **increasing** Extended Scale
|
406 |
+
- **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list
|
407 |
+
""")
|
408 |
+
|
409 |
+
return demo
|
410 |
+
|
411 |
+
demo = create_interface()
|
412 |
+
demo.launch(
|
413 |
+
server_name="0.0.0.0",
|
414 |
+
server_port=7860,
|
415 |
+
share=True
|
416 |
+
)
|
images/bed_dark_room.jpg
ADDED
![]() |
Git LFS Details
|
images/flower.jpg
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
numpy==1.26.4
|
4 |
+
scipy==1.14.1
|
5 |
+
scikit-image==0.24.0
|
6 |
+
pandas==2.2.2
|
7 |
+
matplotlib
|
8 |
+
transformers==4.44.0
|
9 |
+
accelerate==0.33.0
|
10 |
+
diffusers @ git+https://github.com/huggingface/diffusers.git@15eb77bc4cf2ccb40781cb630b9a734b43cffcb8
|
11 |
+
opencv-python
|
12 |
+
pyarrow
|
13 |
+
fastparquet
|
14 |
+
ipykernel
|
15 |
+
sentencepiece==0.2.0
|
16 |
+
protobuf==5.27.3
|
17 |
+
python-dotenv
|
18 |
+
git+https://github.com/facebookresearch/sam2.git
|
19 |
+
gradio
|
20 |
+
spaces
|
run_CLI_addit_generated.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# This work is licensed under the LICENSE file
|
5 |
+
# located at the root directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
|
12 |
+
from visualization_utils import show_images
|
13 |
+
from addit_flux_pipeline import AdditFluxPipeline
|
14 |
+
from addit_flux_transformer import AdditFluxTransformer2DModel
|
15 |
+
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
|
16 |
+
from addit_methods import add_object_generated
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser(description='Run ADDIT with generated images')
|
20 |
+
|
21 |
+
# Required arguments
|
22 |
+
parser.add_argument('--prompt_source', type=str, default="A photo of a cat sitting on the couch",
|
23 |
+
help='Source prompt for generating the base image')
|
24 |
+
parser.add_argument('--prompt_target', type=str, default="A photo of a cat wearing a red hat sitting on the couch",
|
25 |
+
help='Target prompt describing the desired edited image')
|
26 |
+
parser.add_argument('--subject_token', type=str, default="hat",
|
27 |
+
help='Single token representing the subject to add to the image, must appear in the prompt_target')
|
28 |
+
|
29 |
+
# Optional arguments
|
30 |
+
parser.add_argument('--output_dir', type=str, default='outputs',
|
31 |
+
help='Directory to save output images (default: outputs)')
|
32 |
+
parser.add_argument('--seed_src', type=int, default=6311,
|
33 |
+
help='Seed for source generation')
|
34 |
+
parser.add_argument('--seed_obj', type=int, default=1,
|
35 |
+
help='Seed for edited image generation')
|
36 |
+
parser.add_argument('--extended_scale', type=float, default=1.05,
|
37 |
+
help='Extended attention scale (default: 1.05)')
|
38 |
+
parser.add_argument('--structure_transfer_step', type=int, default=2,
|
39 |
+
help='Structure transfer step (default: 2)')
|
40 |
+
parser.add_argument('--blend_steps', type=int, nargs='*', default=[15],
|
41 |
+
help='Blend steps (default: [15])')
|
42 |
+
parser.add_argument('--localization_model', type=str, default="attention_points_sam",
|
43 |
+
help='Localization model (default: attention_points_sam, Options: [attention_points_sam, attention, attention_box_sam, attention_mask_sam, grounding_sam])')
|
44 |
+
parser.add_argument('--show_attention', action='store_true',
|
45 |
+
help='Show attention maps')
|
46 |
+
parser.add_argument('--display_output', action='store_true',
|
47 |
+
help='Display output images during processing')
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
|
51 |
+
assert args.subject_token in args.prompt_target, "Subject token must appear in the prompt_target"
|
52 |
+
|
53 |
+
# Set up device and model
|
54 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
+
print(f"Using device: {device}")
|
56 |
+
|
57 |
+
my_transformer = AdditFluxTransformer2DModel.from_pretrained(
|
58 |
+
"black-forest-labs/FLUX.1-dev",
|
59 |
+
subfolder="transformer",
|
60 |
+
torch_dtype=torch.bfloat16
|
61 |
+
)
|
62 |
+
|
63 |
+
pipe = AdditFluxPipeline.from_pretrained(
|
64 |
+
"black-forest-labs/FLUX.1-dev",
|
65 |
+
transformer=my_transformer,
|
66 |
+
torch_dtype=torch.bfloat16
|
67 |
+
).to(device)
|
68 |
+
|
69 |
+
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
70 |
+
|
71 |
+
# Create output directory
|
72 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
73 |
+
|
74 |
+
# Process the seeds
|
75 |
+
print(f"\nProcessing with source seed: {args.seed_src}, object seed: {args.seed_obj}")
|
76 |
+
|
77 |
+
src_image, edited_image = add_object_generated(
|
78 |
+
pipe,
|
79 |
+
args.prompt_source,
|
80 |
+
args.prompt_target,
|
81 |
+
args.subject_token,
|
82 |
+
args.seed_src,
|
83 |
+
args.seed_obj,
|
84 |
+
show_attention=args.show_attention,
|
85 |
+
extended_scale=args.extended_scale,
|
86 |
+
structure_transfer_step=args.structure_transfer_step,
|
87 |
+
blend_steps=args.blend_steps,
|
88 |
+
localization_model=args.localization_model,
|
89 |
+
display_output=args.display_output
|
90 |
+
)
|
91 |
+
|
92 |
+
# Save output images
|
93 |
+
src_filename = f"src_{args.prompt_source}_seed-src={args.seed_src}.png"
|
94 |
+
edited_filename = f"edited_{args.prompt_target}_seed-src={args.seed_src}_seed-obj={args.seed_obj}.png"
|
95 |
+
|
96 |
+
src_image.save(os.path.join(args.output_dir, src_filename))
|
97 |
+
edited_image.save(os.path.join(args.output_dir, edited_filename))
|
98 |
+
|
99 |
+
print(f"Saved images: {src_filename}, {edited_filename}")
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|
run_CLI_addit_real.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# This work is licensed under the LICENSE file
|
5 |
+
# located at the root directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from visualization_utils import show_images
|
14 |
+
from addit_flux_pipeline import AdditFluxPipeline
|
15 |
+
from addit_flux_transformer import AdditFluxTransformer2DModel
|
16 |
+
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
|
17 |
+
from addit_methods import add_object_real
|
18 |
+
|
19 |
+
def main():
|
20 |
+
parser = argparse.ArgumentParser(description='Run ADDIT with real images')
|
21 |
+
|
22 |
+
# Required arguments
|
23 |
+
parser.add_argument('--source_image', type=str, default="images/bed_dark_room.jpg",
|
24 |
+
help='Path to the source image')
|
25 |
+
parser.add_argument('--prompt_source', type=str, default="A photo of a bed in a dark room",
|
26 |
+
help='Source prompt describing the original image')
|
27 |
+
parser.add_argument('--prompt_target', type=str, default="A photo of a dog lying on a bed in a dark room",
|
28 |
+
help='Target prompt describing the desired edited image')
|
29 |
+
parser.add_argument('--subject_token', type=str, default="dog",
|
30 |
+
help='Subject token to add to the image')
|
31 |
+
|
32 |
+
# Optional arguments
|
33 |
+
parser.add_argument('--output_dir', type=str, default='outputs',
|
34 |
+
help='Directory to save output images (default: outputs)')
|
35 |
+
parser.add_argument('--seed_src', type=int, default=6311,
|
36 |
+
help='Seed for source generation')
|
37 |
+
parser.add_argument('--seed_obj', type=int, default=1,
|
38 |
+
help='Seed for edited image generation')
|
39 |
+
parser.add_argument('--extended_scale', type=float, default=1.1,
|
40 |
+
help='Extended attention scale (default: 1.1)')
|
41 |
+
parser.add_argument('--structure_transfer_step', type=int, default=4,
|
42 |
+
help='Structure transfer step (default: 4)')
|
43 |
+
parser.add_argument('--blend_steps', type=int, nargs='*', default=[18],
|
44 |
+
help='Blend steps (default: [18])')
|
45 |
+
parser.add_argument('--localization_model', type=str, default="attention",
|
46 |
+
help='Localization model (default: attention, Options: [attention_points_sam, attention, attention_box_sam, attention_mask_sam, grounding_sam])')
|
47 |
+
parser.add_argument('--use_offset', action='store_true',
|
48 |
+
help='Use offset in processing')
|
49 |
+
parser.add_argument('--show_attention', action='store_true',
|
50 |
+
help='Show attention maps')
|
51 |
+
parser.add_argument('--disable_inversion', action='store_true',
|
52 |
+
help='Disable source image inversion')
|
53 |
+
parser.add_argument('--display_output', action='store_true',
|
54 |
+
help='Display output images during processing')
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
assert args.subject_token in args.prompt_target, "Subject token must appear in the prompt_target"
|
59 |
+
|
60 |
+
# Set up device and model
|
61 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
+
print(f"Using device: {device}")
|
63 |
+
|
64 |
+
my_transformer = AdditFluxTransformer2DModel.from_pretrained(
|
65 |
+
"black-forest-labs/FLUX.1-dev",
|
66 |
+
subfolder="transformer",
|
67 |
+
torch_dtype=torch.bfloat16
|
68 |
+
)
|
69 |
+
|
70 |
+
pipe = AdditFluxPipeline.from_pretrained(
|
71 |
+
"black-forest-labs/FLUX.1-dev",
|
72 |
+
transformer=my_transformer,
|
73 |
+
torch_dtype=torch.bfloat16
|
74 |
+
).to(device)
|
75 |
+
|
76 |
+
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
77 |
+
|
78 |
+
# Load and resize source image
|
79 |
+
source_image = Image.open(args.source_image).resize((1024, 1024))
|
80 |
+
print(f"Loaded source image: {args.source_image}")
|
81 |
+
|
82 |
+
# Set random seed
|
83 |
+
if args.seed_src is None:
|
84 |
+
random.seed(0)
|
85 |
+
args.seed_src = random.randint(0, 10000)
|
86 |
+
|
87 |
+
# Create output directory
|
88 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
89 |
+
|
90 |
+
# Process the seeds
|
91 |
+
print(f"\nProcessing with source seed: {args.seed_src}, object seed: {args.seed_obj}")
|
92 |
+
|
93 |
+
src_image, edited_image = add_object_real(
|
94 |
+
pipe,
|
95 |
+
source_image=source_image,
|
96 |
+
prompt_source=args.prompt_source,
|
97 |
+
prompt_object=args.prompt_target,
|
98 |
+
subject_token=args.subject_token,
|
99 |
+
seed_src=args.seed_src,
|
100 |
+
seed_obj=args.seed_obj,
|
101 |
+
extended_scale=args.extended_scale,
|
102 |
+
structure_transfer_step=args.structure_transfer_step,
|
103 |
+
blend_steps=args.blend_steps,
|
104 |
+
localization_model=args.localization_model,
|
105 |
+
use_offset=args.use_offset,
|
106 |
+
show_attention=args.show_attention,
|
107 |
+
use_inversion=not args.disable_inversion,
|
108 |
+
display_output=args.display_output
|
109 |
+
)
|
110 |
+
|
111 |
+
# Save output images
|
112 |
+
src_filename = f"src_{args.prompt_source}_seed-src={args.seed_src}.png"
|
113 |
+
edited_filename = f"edited_{args.prompt_target}_seed-src={args.seed_src}_seed-obj={args.seed_obj}.png"
|
114 |
+
|
115 |
+
src_image.save(os.path.join(args.output_dir, src_filename))
|
116 |
+
edited_image.save(os.path.join(args.output_dir, edited_filename))
|
117 |
+
|
118 |
+
print(f"Saved images: {src_filename}, {edited_filename}")
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
main()
|
run_addit_generated.ipynb
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.\n",
|
10 |
+
"#\n",
|
11 |
+
"# This work is licensed under the LICENSE file\n",
|
12 |
+
"# located at the root directory.\n",
|
13 |
+
"import torch\n",
|
14 |
+
"import random\n",
|
15 |
+
"\n",
|
16 |
+
"from visualization_utils import show_images\n",
|
17 |
+
"from addit_flux_pipeline import AdditFluxPipeline\n",
|
18 |
+
"from addit_flux_transformer import AdditFluxTransformer2DModel\n",
|
19 |
+
"from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler\n",
|
20 |
+
"from addit_methods import add_object_generated\n",
|
21 |
+
"\n",
|
22 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
23 |
+
"my_transformer = AdditFluxTransformer2DModel.from_pretrained(\"black-forest-labs/FLUX.1-dev\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
|
24 |
+
"\n",
|
25 |
+
"pipe = AdditFluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", \n",
|
26 |
+
" transformer=my_transformer,\n",
|
27 |
+
" torch_dtype=torch.bfloat16).to(device)\n",
|
28 |
+
"pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": null,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"# Reset the GPU memory tracking\n",
|
38 |
+
"torch.cuda.reset_max_memory_allocated(0)\n",
|
39 |
+
"\n",
|
40 |
+
"(prompt1, prompt2), subject_token = [\"A photo of a man sitting on a bench\", \"A photo of a man sitting on a bench with a dog\"], \"dog\"\n",
|
41 |
+
"\n",
|
42 |
+
"\n",
|
43 |
+
"random.seed(0)\n",
|
44 |
+
"seeds_src = [663]\n",
|
45 |
+
"seeds_obj = [0,1,2]\n",
|
46 |
+
"\n",
|
47 |
+
"for seed_src in seeds_src:\n",
|
48 |
+
" for seed_obj in seeds_obj:\n",
|
49 |
+
" src_image, edited_image = add_object_generated(pipe, prompt1, prompt2, subject_token, seed_src, seed_obj, show_attention=True, \n",
|
50 |
+
" extended_scale=1.05, structure_transfer_step=2, blend_steps=[15], \n",
|
51 |
+
" localization_model=\"attention_points_sam\", display_output=True)\n",
|
52 |
+
"\n",
|
53 |
+
"# Report maximum GPU memory usage in GB\n",
|
54 |
+
"max_memory_used = torch.cuda.max_memory_allocated(0) / (1024**3) # Convert to GB\n",
|
55 |
+
"print(f\"Maximum GPU memory used: {max_memory_used:.2f} GB\")"
|
56 |
+
]
|
57 |
+
}
|
58 |
+
],
|
59 |
+
"metadata": {
|
60 |
+
"kernelspec": {
|
61 |
+
"display_name": "addit",
|
62 |
+
"language": "python",
|
63 |
+
"name": "python3"
|
64 |
+
},
|
65 |
+
"language_info": {
|
66 |
+
"codemirror_mode": {
|
67 |
+
"name": "ipython",
|
68 |
+
"version": 3
|
69 |
+
},
|
70 |
+
"file_extension": ".py",
|
71 |
+
"mimetype": "text/x-python",
|
72 |
+
"name": "python",
|
73 |
+
"nbconvert_exporter": "python",
|
74 |
+
"pygments_lexer": "ipython3",
|
75 |
+
"version": "3.11.9"
|
76 |
+
}
|
77 |
+
},
|
78 |
+
"nbformat": 4,
|
79 |
+
"nbformat_minor": 2
|
80 |
+
}
|
run_addit_real.ipynb
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.\n",
|
10 |
+
"#\n",
|
11 |
+
"# This work is licensed under the LICENSE file\n",
|
12 |
+
"# located at the root directory.\n",
|
13 |
+
"\n",
|
14 |
+
"import torch\n",
|
15 |
+
"import random\n",
|
16 |
+
"from PIL import Image\n",
|
17 |
+
"\n",
|
18 |
+
"from visualization_utils import show_images\n",
|
19 |
+
"from addit_flux_pipeline import AdditFluxPipeline\n",
|
20 |
+
"from addit_flux_transformer import AdditFluxTransformer2DModel\n",
|
21 |
+
"from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler\n",
|
22 |
+
"from addit_methods import add_object_real\n",
|
23 |
+
"\n",
|
24 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
25 |
+
"my_transformer = AdditFluxTransformer2DModel.from_pretrained(\"black-forest-labs/FLUX.1-dev\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
|
26 |
+
"\n",
|
27 |
+
"pipe = AdditFluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", \n",
|
28 |
+
" transformer=my_transformer,\n",
|
29 |
+
" torch_dtype=torch.bfloat16).to(device)\n",
|
30 |
+
"pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"# Reset the GPU memory tracking\n",
|
40 |
+
"torch.cuda.reset_max_memory_allocated(0)\n",
|
41 |
+
"\n",
|
42 |
+
"# source_image = Image.open(\"images/cat.jpg\").resize((1024, 1024))\n",
|
43 |
+
"# (prompt_src, prompt_tgt), subject_token = [\"A photo of a cat\", \"A photo of a cat wearing a scarf\"], \"scarf\"\n",
|
44 |
+
"\n",
|
45 |
+
"source_image = Image.open(\"images/bed_dark_room.jpg\").resize((1024, 1024))\n",
|
46 |
+
"(prompt_src, prompt_tgt), subject_token = [\"A photo of a bed in a dark room\", \"A photo of a dog lying on a bed in a dark room\"], \"dog\"\n",
|
47 |
+
"\n",
|
48 |
+
"random.seed(0)\n",
|
49 |
+
"seed_src = random.randint(0, 10000)\n",
|
50 |
+
"seeds_obj = [0,1,2]\n",
|
51 |
+
"\n",
|
52 |
+
"for seed_obj in seeds_obj:\n",
|
53 |
+
" images_list = add_object_real(pipe, source_image=source_image, prompt_source=prompt_src, prompt_object=prompt_tgt, \n",
|
54 |
+
" subject_token=subject_token, seed_src=seed_src, seed_obj=seed_obj, \n",
|
55 |
+
" extended_scale =1.1, structure_transfer_step=4, blend_steps = [18], #localization_model=\"attention\",\n",
|
56 |
+
" use_offset=False, show_attention=True, use_inversion=True, display_output=True)\n",
|
57 |
+
"\n",
|
58 |
+
"# Report maximum GPU memory usage in GB\n",
|
59 |
+
"max_memory_used = torch.cuda.max_memory_allocated(0) / (1024**3) # Convert to GB\n",
|
60 |
+
"print(f\"Maximum GPU memory used: {max_memory_used:.2f} GB\")"
|
61 |
+
]
|
62 |
+
}
|
63 |
+
],
|
64 |
+
"metadata": {
|
65 |
+
"kernelspec": {
|
66 |
+
"display_name": "addit",
|
67 |
+
"language": "python",
|
68 |
+
"name": "python3"
|
69 |
+
},
|
70 |
+
"language_info": {
|
71 |
+
"codemirror_mode": {
|
72 |
+
"name": "ipython",
|
73 |
+
"version": 3
|
74 |
+
},
|
75 |
+
"file_extension": ".py",
|
76 |
+
"mimetype": "text/x-python",
|
77 |
+
"name": "python",
|
78 |
+
"nbconvert_exporter": "python",
|
79 |
+
"pygments_lexer": "ipython3",
|
80 |
+
"version": "3.11.9"
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"nbformat": 4,
|
84 |
+
"nbformat_minor": 2
|
85 |
+
}
|
visualization_utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the LICENSE file
|
4 |
+
# located at the root directory.
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageDraw
|
9 |
+
import torch
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from skimage import filters
|
12 |
+
from IPython.display import display
|
13 |
+
|
14 |
+
def gaussian_blur(heatmap, kernel_size=7):
|
15 |
+
# Shape of heatmap: (H, W)
|
16 |
+
heatmap = heatmap.cpu().numpy()
|
17 |
+
heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), 0)
|
18 |
+
heatmap = torch.tensor(heatmap)
|
19 |
+
|
20 |
+
return heatmap
|
21 |
+
|
22 |
+
def show_cam_on_image(img, mask):
|
23 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
24 |
+
heatmap = np.float32(heatmap) / 255
|
25 |
+
cam = heatmap + np.float32(img)
|
26 |
+
cam = cam / np.max(cam)
|
27 |
+
return cam
|
28 |
+
|
29 |
+
def show_image_and_heatmap(heatmap: torch.Tensor, image: Image.Image, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3):
|
30 |
+
image = image.resize((relevnace_res, relevnace_res))
|
31 |
+
image = np.array(image)
|
32 |
+
image = (image - image.min()) / (image.max() - image.min())
|
33 |
+
|
34 |
+
# Apply gaussian blur to heatmap
|
35 |
+
# heatmap = gaussian_blur(heatmap, kernel_size=gassussian_kernel_size)
|
36 |
+
|
37 |
+
# heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
|
38 |
+
# otsu_thr = filters.threshold_otsu(heatmap.cpu().numpy())
|
39 |
+
# heatmap = (heatmap > otsu_thr).to(heatmap.dtype)
|
40 |
+
|
41 |
+
heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1])
|
42 |
+
heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation)
|
43 |
+
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
|
44 |
+
heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu()
|
45 |
+
|
46 |
+
vis = show_cam_on_image(image, heatmap)
|
47 |
+
vis = np.uint8(255 * vis)
|
48 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
49 |
+
|
50 |
+
vis = vis.astype(np.uint8)
|
51 |
+
vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res))
|
52 |
+
|
53 |
+
return vis
|
54 |
+
|
55 |
+
def show_only_heatmap(heatmap: torch.Tensor, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3):
|
56 |
+
# Apply gaussian blur to heatmap
|
57 |
+
# heatmap = gaussian_blur(heatmap, kernel_size=gassussian_kernel_size)
|
58 |
+
|
59 |
+
heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1])
|
60 |
+
heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation)
|
61 |
+
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
|
62 |
+
heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu()
|
63 |
+
|
64 |
+
vis = heatmap
|
65 |
+
vis = np.uint8(255 * vis)
|
66 |
+
|
67 |
+
# Show in black and white
|
68 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_GRAY2BGR)
|
69 |
+
|
70 |
+
vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res))
|
71 |
+
|
72 |
+
return vis
|
73 |
+
|
74 |
+
def visualize_tokens_attentions(attention, tokens, image, heatmap_interpolation="nearest", show_on_image=True):
|
75 |
+
# Tokens: list of strings
|
76 |
+
# attention: tensor of shape (batch_size, num_tokens, width, height)
|
77 |
+
token_vis = []
|
78 |
+
for j, token in enumerate(tokens):
|
79 |
+
if j >= attention.shape[0]:
|
80 |
+
break
|
81 |
+
|
82 |
+
if show_on_image:
|
83 |
+
vis = show_image_and_heatmap(attention[j], image, relevnace_res=512, interpolation=heatmap_interpolation)
|
84 |
+
else:
|
85 |
+
vis = show_only_heatmap(attention[j], relevnace_res=512, interpolation=heatmap_interpolation)
|
86 |
+
|
87 |
+
token_vis.append((token, vis))
|
88 |
+
|
89 |
+
# Display the token and the attention map in a grid, with K tokens per row
|
90 |
+
K = 4
|
91 |
+
n_rows = (len(token_vis) + K - 1) // K # Ceiling division
|
92 |
+
fig, axs = plt.subplots(n_rows, K, figsize=(K*5, n_rows*5))
|
93 |
+
|
94 |
+
for i, (token, vis) in enumerate(token_vis):
|
95 |
+
row, col = divmod(i, K)
|
96 |
+
if n_rows > 1:
|
97 |
+
ax = axs[row, col]
|
98 |
+
elif K > 1:
|
99 |
+
ax = axs[col]
|
100 |
+
else:
|
101 |
+
ax = axs
|
102 |
+
|
103 |
+
ax.imshow(vis)
|
104 |
+
ax.set_title(token)
|
105 |
+
ax.axis("off")
|
106 |
+
|
107 |
+
# Hide unused subplots
|
108 |
+
for j in range(i + 1, n_rows * K):
|
109 |
+
row, col = divmod(j, K)
|
110 |
+
if n_rows > 1:
|
111 |
+
axs[row, col].axis('off')
|
112 |
+
elif K > 1:
|
113 |
+
axs[col].axis('off')
|
114 |
+
|
115 |
+
plt.tight_layout()
|
116 |
+
|
117 |
+
# We want to return the figure so that we can save it to a file
|
118 |
+
return fig
|
119 |
+
|
120 |
+
def show_images(images, titles=None, size=1024, max_row_length=5, figsize=None, col_height=10, save_path=None):
|
121 |
+
if isinstance(images, Image.Image):
|
122 |
+
images = [images]
|
123 |
+
|
124 |
+
if len(images) == 1:
|
125 |
+
img = images[0]
|
126 |
+
img = img.resize((size, size))
|
127 |
+
plt.imshow(img)
|
128 |
+
plt.axis('off')
|
129 |
+
|
130 |
+
if titles is not None:
|
131 |
+
plt.title(titles[0])
|
132 |
+
|
133 |
+
if save_path:
|
134 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
135 |
+
|
136 |
+
plt.show()
|
137 |
+
else:
|
138 |
+
images = [img.resize((size, size)) for img in images]
|
139 |
+
|
140 |
+
# Check if the number of titles matches the number of images
|
141 |
+
if titles is not None:
|
142 |
+
assert len(images) == len(titles), "Number of titles should match the number of images"
|
143 |
+
|
144 |
+
n_images = len(images)
|
145 |
+
n_cols = min(n_images, max_row_length)
|
146 |
+
n_rows = (n_images + n_cols - 1) // n_cols # Calculate the number of rows needed
|
147 |
+
|
148 |
+
if figsize is None:
|
149 |
+
figsize=(n_cols * col_height, n_rows * col_height)
|
150 |
+
|
151 |
+
fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
|
152 |
+
axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs]
|
153 |
+
|
154 |
+
# Display images in the subplots
|
155 |
+
for i, img in enumerate(images):
|
156 |
+
axs[i].imshow(img)
|
157 |
+
if titles is not None:
|
158 |
+
axs[i].set_title(titles[i])
|
159 |
+
axs[i].axis("off")
|
160 |
+
|
161 |
+
# Turn off any unused subplots
|
162 |
+
for ax in axs[len(images):]:
|
163 |
+
ax.axis("off")
|
164 |
+
|
165 |
+
if save_path:
|
166 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
167 |
+
|
168 |
+
plt.show()
|
169 |
+
|
170 |
+
def show_tensors(tensors, titles=None, size=None, max_row_length=5):
|
171 |
+
# Shape of tensors: List[Tensor[H, W]]
|
172 |
+
if size is not None:
|
173 |
+
tensors = [torch.nn.functional.interpolate(t.unsqueeze(0).unsqueeze(0), size=(size, size), mode='bilinear').squeeze() for t in tensors]
|
174 |
+
|
175 |
+
if len(tensors) == 1:
|
176 |
+
plt.imshow(tensors[0].cpu().numpy())
|
177 |
+
plt.axis('off')
|
178 |
+
|
179 |
+
if titles is not None:
|
180 |
+
plt.title(titles[0])
|
181 |
+
|
182 |
+
plt.show()
|
183 |
+
else:
|
184 |
+
# Check if the number of titles matches the number of images
|
185 |
+
if titles is not None:
|
186 |
+
assert len(tensors) == len(titles), "Number of titles should match the number of images"
|
187 |
+
|
188 |
+
n_tensors = len(tensors)
|
189 |
+
n_cols = min(n_tensors, max_row_length)
|
190 |
+
n_rows = (n_tensors + n_cols - 1) // n_cols
|
191 |
+
|
192 |
+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10))
|
193 |
+
axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs]
|
194 |
+
|
195 |
+
for i, tensor in enumerate(tensors):
|
196 |
+
axs[i].imshow(tensor.cpu().numpy())
|
197 |
+
if titles is not None:
|
198 |
+
axs[i].set_title(titles[i])
|
199 |
+
axs[i].axis("off")
|
200 |
+
|
201 |
+
for ax in axs[len(tensors):]:
|
202 |
+
ax.axis("off")
|
203 |
+
|
204 |
+
plt.show()
|
205 |
+
|
206 |
+
def draw_bboxes_on_image(image, bboxes, color="red", thickness=2):
|
207 |
+
image = image.copy()
|
208 |
+
draw = ImageDraw.Draw(image)
|
209 |
+
for bbox in bboxes:
|
210 |
+
draw.rectangle(bbox, outline=color, width=thickness)
|
211 |
+
return image
|
212 |
+
|
213 |
+
def draw_points_on_pil_image(pil_image, point_coords, point_color="red", radius=5):
|
214 |
+
"""
|
215 |
+
Draw points (circles) on a PIL image and return the modified image.
|
216 |
+
|
217 |
+
:param pil_image: PIL Image (e.g., sam_masked_image)
|
218 |
+
:param point_coords: An array-like of shape (N, 2), with x,y coordinates
|
219 |
+
:param point_color: Color of the point (default 'red')
|
220 |
+
:param radius: Radius of the drawn circles
|
221 |
+
:return: PIL Image with points drawn
|
222 |
+
"""
|
223 |
+
# Copy so we don't modify the original
|
224 |
+
out_img = pil_image.copy()
|
225 |
+
draw = ImageDraw.Draw(out_img)
|
226 |
+
|
227 |
+
# Draw each point
|
228 |
+
for x, y in point_coords:
|
229 |
+
# Calculate bounding box of the circle
|
230 |
+
left_up_point = (x - radius, y - radius)
|
231 |
+
right_down_point = (x + radius, y + radius)
|
232 |
+
# Draw the circle
|
233 |
+
draw.ellipse([left_up_point, right_down_point], fill=point_color, outline=point_color)
|
234 |
+
|
235 |
+
return out_img
|