SahilCarterr commited on
Commit
f056744
·
verified ·
1 Parent(s): e7f743d

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. README.md +68 -14
  3. data/images/bear.jpeg +0 -0
  4. data/images/bird.jpg +0 -0
  5. data/images/bird_painting.jpg +0 -0
  6. data/images/cabin.jpg +3 -0
  7. data/images/car.jpg +0 -0
  8. data/images/cat_hat.jpg +0 -0
  9. data/images/cat_mirror.jpg +0 -0
  10. data/images/cat_poly.jpg +0 -0
  11. data/images/dancing.jpeg +0 -0
  12. data/images/flower.jpg +0 -0
  13. data/images/fruit.jpg +3 -0
  14. data/images/girl_mountain.jpg +0 -0
  15. data/images/koala.jpg +3 -0
  16. data/images/man_tree.jpg +3 -0
  17. data/images/meditation.png +3 -0
  18. data/images/old_couple.jpg +3 -0
  19. data/images/owl_heart.jpg +0 -0
  20. data/images/raven.jpg +0 -0
  21. data/images/real_karate.jpeg +0 -0
  22. data/images/santa.jpg +0 -0
  23. data/images/squirrel.jpg +0 -0
  24. data/images/statue.jpg +3 -0
  25. data/images/steak.jpg +3 -0
  26. data/images/tennis.jpg +0 -0
  27. data/images/woman_book.jpg +3 -0
  28. data/masks/cat_hat.jpg +0 -0
  29. data/masks/cat_mirror.jpg +0 -0
  30. data/masks/girl_mountain.jpg +0 -0
  31. data/masks/man_tree.jpg +0 -0
  32. data/masks/old_couple.jpg +0 -0
  33. data/masks/raven.jpg +0 -0
  34. data/masks/santa.jpg +0 -0
  35. images/main_figure.png +3 -0
  36. img_edit.py +492 -0
  37. requirements.txt +12 -0
  38. scripts/w_ca/run_bird.sh +20 -0
  39. scripts/w_ca/run_cabin.sh +20 -0
  40. scripts/w_ca/run_car.sh +21 -0
  41. scripts/w_ca/run_cat_poly.sh +21 -0
  42. scripts/w_ca/run_flower.sh +21 -0
  43. scripts/w_ca/run_fruit.sh +20 -0
  44. scripts/w_ca/run_koala.sh +20 -0
  45. scripts/w_ca/run_owl_heart.sh +20 -0
  46. scripts/w_ca/run_statue.sh +21 -0
  47. scripts/w_ca/run_steak.sh +20 -0
  48. scripts/w_ca/run_tennis.sh +21 -0
  49. scripts/w_ca/run_woman_book.sh +20 -0
  50. scripts/w_mask/run_cat_hat.sh +21 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/images/cabin.jpg filter=lfs diff=lfs merge=lfs -text
37
+ data/images/fruit.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/images/koala.jpg filter=lfs diff=lfs merge=lfs -text
39
+ data/images/man_tree.jpg filter=lfs diff=lfs merge=lfs -text
40
+ data/images/meditation.png filter=lfs diff=lfs merge=lfs -text
41
+ data/images/old_couple.jpg filter=lfs diff=lfs merge=lfs -text
42
+ data/images/statue.jpg filter=lfs diff=lfs merge=lfs -text
43
+ data/images/steak.jpg filter=lfs diff=lfs merge=lfs -text
44
+ data/images/woman_book.jpg filter=lfs diff=lfs merge=lfs -text
45
+ images/main_figure.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,68 @@
1
- ---
2
- title: ReFlex
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Text-Guided Editing of Real Images
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReFlex: Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation
2
+
3
+ ### [ICCV 2025] Official Pytorch implementation of the paper: "ReFlex: Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation"
4
+ by Jimyeon Kim, Jungwon Park, Yeji Song, Nojun Kwak, Wonjong Rhee†.
5
+
6
+ Seoul National University
7
+
8
+ [Arxiv](https://arxiv.org/abs/2507.01496)
9
+  
10
+ [Project Page](https://wlaud1001.github.io/ReFlex/)
11
+
12
+
13
+
14
+ ![main](./images/main_figure.png)
15
+
16
+ ## Setup
17
+ ```
18
+ git clone https://github.com/wlaud1001/ReFlex.git
19
+ cd ReFlex
20
+
21
+ conda create -n reflex python=3.10
22
+ conda activate reflex
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ ## Run
27
+
28
+ ### Run exmaple
29
+ ```
30
+ python img_edit.py \
31
+ --gpu {gpu} \
32
+ --seed {seed} \
33
+ --img_path {source_img_path} \
34
+ --source_prompt {source_prompt} \
35
+ --target_prompt {target_prompt} \
36
+ --results_dir {results_dir} \
37
+ --feature_steps {feature_steps} \
38
+ --attn_topk {attn_topk}
39
+ ```
40
+ ### Arguments
41
+ - --gpu: Index of the GPU to use.
42
+ - --seed: Random seed.
43
+ - --img_path: Path to the input real image to be edited.
44
+ - --mask_path (optional): Path to a ground-truth mask for local editing.
45
+ - If provided, this mask is used directly.
46
+ - If omitted, the editing mask is automatically generated from attention maps.
47
+ - --source_prompt (optional): Text prompt describing the content of the input image.
48
+ - If provided, mask generation and latent blending will be applied.
49
+ - If omitted, editing proceeds without latent blending.
50
+ - --target_prompt: Text prompt describing the desired edited image.
51
+ - --blend_word (optional): Word in --source_prompt to guide mask generation via its I2T-CA map.
52
+ - If omitted, the blend word is automatically inferred by comparing source_prompt and target_prompt.
53
+ - --results_dir: Directory to save the output images
54
+ ###
55
+
56
+ ### Scripts
57
+ We also provide several example scripts in the (./scripts) directory for some use cases and reproducible experiments.
58
+ #### Script Categories
59
+ - scripts/wo_ca/: Cases where the source prompt is not given. I2T-CA adaptation and latent blending are not applied.
60
+ - scripts/w_ca/: Cases where the source prompt is given, and the editing mask for latent blending is automatically generated from the attention map.
61
+ - scripts/w_mask/: Cases where a ground-truth mask for local editing is provided and directly used for latent blending.
62
+
63
+ You can run a script as follows:
64
+ ```
65
+ ./scripts/wo_ca/run_bear.sh
66
+ ./scripts/w_ca/run_bird.sh
67
+ ./scripts/w_mask/run_cat_hat.sh
68
+ ```
data/images/bear.jpeg ADDED
data/images/bird.jpg ADDED
data/images/bird_painting.jpg ADDED
data/images/cabin.jpg ADDED

Git LFS Details

  • SHA256: 57c526d303939ec8fa1e6fe6780ba1d8be5aacfe0ce6c4eeaf1b2771e29a534f
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
data/images/car.jpg ADDED
data/images/cat_hat.jpg ADDED
data/images/cat_mirror.jpg ADDED
data/images/cat_poly.jpg ADDED
data/images/dancing.jpeg ADDED
data/images/flower.jpg ADDED
data/images/fruit.jpg ADDED

Git LFS Details

  • SHA256: e2dfeda0bba2b887ac5b082771b74bbe990110a712e0ebaed2c3c6abca2d8630
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
data/images/girl_mountain.jpg ADDED
data/images/koala.jpg ADDED

Git LFS Details

  • SHA256: be9ab5f91b329a5cc53e55bac9eba350aaf80b39a04e8e6a03d147713a5eb283
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
data/images/man_tree.jpg ADDED

Git LFS Details

  • SHA256: 6d53f9d74aeb377b65ca9fac3684dd5495451cb09cc4aeacb614d912ec89f462
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
data/images/meditation.png ADDED

Git LFS Details

  • SHA256: 7c1ebb8230cee73caa80b9a9b5ec1ae0c89d12742f06be789f60a53f9177f9c1
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
data/images/old_couple.jpg ADDED

Git LFS Details

  • SHA256: 405cc22840c86e79aeef24f36ce0a6a1e90491bf3badabfd1c16d0cc300c17f2
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
data/images/owl_heart.jpg ADDED
data/images/raven.jpg ADDED
data/images/real_karate.jpeg ADDED
data/images/santa.jpg ADDED
data/images/squirrel.jpg ADDED
data/images/statue.jpg ADDED

Git LFS Details

  • SHA256: d7a02cb1cfb21a69bfb3bed2d56c74799385860c625a78f2f9c9527d0b96d123
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
data/images/steak.jpg ADDED

Git LFS Details

  • SHA256: 60a98952c0d657c652d7c686d6eb93419cb3dff1495aca93a4ddcbcd2c30af32
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
data/images/tennis.jpg ADDED
data/images/woman_book.jpg ADDED

Git LFS Details

  • SHA256: aaa44eba168cbbec858b846ba3f801fd67e5e4d4a7d8f76d28b56661ceaac992
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
data/masks/cat_hat.jpg ADDED
data/masks/cat_mirror.jpg ADDED
data/masks/girl_mountain.jpg ADDED
data/masks/man_tree.jpg ADDED
data/masks/old_couple.jpg ADDED
data/masks/raven.jpg ADDED
data/masks/santa.jpg ADDED
images/main_figure.png ADDED

Git LFS Details

  • SHA256: 15cdc45b0a49a939fa22c167d9392cdd147d451f519ab616bd065c018860722e
  • Pointer size: 133 Bytes
  • Size of remote file: 15.4 MB
img_edit.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import random
5
+ import re
6
+ import time
7
+ from distutils.util import strtobool
8
+
9
+ import pandas as pd
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--img_path",
14
+ type=str,
15
+ )
16
+ parser.add_argument(
17
+ "--target_prompt",
18
+ type=str,
19
+ )
20
+ parser.add_argument(
21
+ "--source_prompt",
22
+ type=str,
23
+ default=''
24
+ )
25
+ parser.add_argument(
26
+ "--blend_word",
27
+ type=str,
28
+ default=''
29
+ )
30
+ parser.add_argument(
31
+ "--mask_path",
32
+ type=str,
33
+ default=None
34
+ )
35
+
36
+
37
+ parser.add_argument(
38
+ "--gpu",
39
+ type=str,
40
+ default="0",
41
+ )
42
+ parser.add_argument(
43
+ "--seed",
44
+ type=int,
45
+ default=0
46
+ )
47
+ parser.add_argument(
48
+ "--results_dir",
49
+ type=str,
50
+ default='results'
51
+ )
52
+
53
+
54
+ parser.add_argument(
55
+ "--model",
56
+ type=str,
57
+ default='flux',
58
+ choices=['flux']
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--ca_steps",
63
+ type=int,
64
+ default=10,
65
+ help="Number of steps to apply I2T-CA adaptation and injection.",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--sa_steps",
70
+ type=int,
71
+ default=7
72
+ help="Number of steps to apply I2I-SA adaptation and injection.",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--feature_steps",
77
+ type=int,
78
+ default=5
79
+ help="Number of steps to inject residual features.",
80
+ )
81
+
82
+
83
+ parser.add_argument(
84
+ "--ca_attn_layer_from",
85
+ type=int,
86
+ default=13,
87
+ help="Layers to apply I2T-CA adaptation and injection.",
88
+ )
89
+ parser.add_argument(
90
+ "--ca_attn_layer_to",
91
+ type=int,
92
+ default=45,
93
+ help="Layers to apply I2T-CA adaptation and injection.",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--sa_attn_layer_from",
98
+ type=int,
99
+ default=20,
100
+ help="Layers to apply I2I-SA adaptation and injection.",
101
+ )
102
+ parser.add_argument(
103
+ "--sa_attn_layer_to",
104
+ type=int,
105
+ default=45,
106
+ help="Layers to apply I2I-SA adaptation and injection.",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "--feature_layer_from",
111
+ type=int,
112
+ default=13,
113
+ help="Layers to inject residual features.",
114
+ )
115
+ parser.add_argument(
116
+ "--feature_layer_to",
117
+ type=int,
118
+ default=20,
119
+ help="Layers to inject residual features.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--flow_steps",
124
+ type=int,
125
+ default=7,
126
+ help="Steps to apply forward step before inversion",
127
+ )
128
+ parser.add_argument(
129
+ "--step_start",
130
+ type=int,
131
+ default=0
132
+ )
133
+
134
+
135
+ parser.add_argument(
136
+ "--num_inference_steps",
137
+ type=int,
138
+ default=28
139
+ )
140
+ parser.add_argument(
141
+ "--guidance_scale",
142
+ type=float,
143
+ default=3.5,
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--attn_topk",
148
+ type=int,
149
+ default=20,
150
+ help="Hyperparameter for I2I-SA adaptaion."
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--text_scale",
155
+ type=float,
156
+ default=4,
157
+ help="Hyperparameter for I2T-CA adaptaion."
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--mid_step_index",
162
+ type=int,
163
+ default=14,
164
+ help="Hyperparameter for mid-step feature extraction."
165
+ )
166
+
167
+
168
+ parser.add_argument(
169
+ "--use_mask",
170
+ type=strtobool,
171
+ default=True
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--use_ca_mask",
176
+ type=strtobool,
177
+ default=True
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--mask_steps",
182
+ type=int,
183
+ default=18,
184
+ help="Steps to apply latent blending"
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--mask_dilation",
189
+ type=int,
190
+ default=3
191
+ )
192
+ parser.add_argument(
193
+ "--mask_nbins",
194
+ type=int,
195
+ default=128
196
+ )
197
+
198
+ args = parser.parse_args()
199
+
200
+ os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"
201
+
202
+ import gc
203
+
204
+ import matplotlib.pyplot as plt
205
+ import numpy as np
206
+ import torch
207
+ import yaml
208
+ from diffusers import FlowMatchEulerDiscreteScheduler
209
+ from diffusers.utils.torch_utils import randn_tensor
210
+ from PIL import Image
211
+
212
+ from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector
213
+ from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0
214
+ from src.attn_utils.seq_aligner import get_refinement_mapper
215
+ from src.callback.callback_fn import CallbackAll
216
+ from src.inversion.inverse import get_inversed_latent_list
217
+ from src.inversion.scheduling_flow_inverse import \
218
+ FlowMatchEulerDiscreteForwardScheduler
219
+ from src.pipeline.flux_pipeline import NewFluxPipeline
220
+ from src.transformer_utils.transformer_utils import (FeatureCollector,
221
+ FeatureReplace)
222
+ from src.utils import (find_token_id_differences, find_word_token_indices,
223
+ get_flux_pipeline, mask_decode, mask_interpolate)
224
+
225
+
226
+ def fix_seed(random_seed):
227
+ """
228
+ fix seed to control any randomness from a code
229
+ (enable stability of the experiments' results.)
230
+ """
231
+ torch.manual_seed(random_seed)
232
+ torch.cuda.manual_seed(random_seed)
233
+ torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
234
+ torch.backends.cudnn.deterministic = True
235
+ torch.backends.cudnn.benchmark = False
236
+ np.random.seed(random_seed)
237
+ random.seed(random_seed)
238
+
239
+ def main(args):
240
+ fix_seed(args.seed)
241
+ device = torch.device('cuda')
242
+
243
+ pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline)
244
+ attn_proc = NewFluxAttnProcessor2_0
245
+ pipe = pipe.to(device)
246
+
247
+ layer_order = range(57)
248
+
249
+ ca_layer_list = layer_order[args.ca_attn_layer_from:args.ca_attn_layer_to]
250
+ sa_layer_list = layer_order[args.feature_layer_to:args.sa_attn_layer_to]
251
+ feature_layer_list = layer_order[args.feature_layer_from:args.feature_layer_to]
252
+
253
+
254
+ img_path = args.img_path
255
+ source_img = Image.open(img_path).resize((1024, 1024)).convert("RGB")
256
+ img_base_name = os.path.splitext(img_path)[0].split('/')[-1]
257
+ result_img_dir = f"{args.results_dir}/seed_{args.seed}/{args.target_prompt}"
258
+
259
+ source_prompt = args.source_prompt
260
+ target_prompt = args.target_prompt
261
+ prompts = [source_prompt, target_prompt]
262
+
263
+ print(prompts)
264
+ mask = None
265
+
266
+ if args.use_mask:
267
+ use_mask = True
268
+
269
+ if args.mask_path is not None:
270
+ mask = Image.open(args.mask_path)
271
+ mask = torch.tensor(np.array(mask)).bool()
272
+ mask = mask.to(device)
273
+
274
+ # Increase the latent blending steps if the ground truth mask is used.
275
+ args.mask_steps = int(args.num_inference_steps * 0.9)
276
+
277
+ source_ca_index = None
278
+ target_ca_index = None
279
+ use_ca_mask = False
280
+
281
+ elif args.use_ca_mask and source_prompt:
282
+ mask = None
283
+ if args.blend_word and args.blend_word in source_prompt:
284
+ editing_source_token_index = find_word_token_indices(source_prompt, args.blend_word, pipe.tokenizer_2)
285
+ editing_target_token_index = None
286
+ else:
287
+ editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2)
288
+ editing_source_token_index = editing_tokens_info['prompt_1']['index']
289
+ editing_target_token_index = editing_tokens_info['prompt_2']['index']
290
+
291
+ use_ca_mask = True
292
+ if editing_source_token_index:
293
+ source_ca_index = editing_source_token_index
294
+ target_ca_index = None
295
+ elif editing_target_token_index:
296
+ source_ca_index = None
297
+ target_ca_index = editing_target_token_index
298
+ else:
299
+ source_ca_index = None
300
+ target_ca_index = None
301
+ use_ca_mask = False
302
+
303
+ else:
304
+ source_ca_index = None
305
+ target_ca_index = None
306
+ use_ca_mask = False
307
+
308
+ else:
309
+ use_mask = False
310
+ use_ca_mask = False
311
+ source_ca_index = None
312
+ target_ca_index = None
313
+
314
+ if source_prompt:
315
+ # Use I2T-CA injection
316
+ mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512)
317
+ mappers = mappers.to(device=device)
318
+ alphas = alphas.to(device=device, dtype=pipe.dtype)
319
+ alphas = alphas[:, None, None, :]
320
+
321
+ ca_steps = args.ca_steps
322
+ attn_adj_from = 1
323
+
324
+ else:
325
+ # Not use I2T-CA injection
326
+ mappers = None
327
+ alphas = None
328
+
329
+ ca_steps = 0
330
+ attn_adj_from=3
331
+
332
+ sa_steps = args.sa_steps
333
+ feature_steps = args.feature_steps
334
+
335
+ attn_controller = AttentionAdapter(
336
+ ca_layer_list=ca_layer_list,
337
+ sa_layer_list=sa_layer_list,
338
+ ca_steps=ca_steps,
339
+ sa_steps=sa_steps,
340
+ method='replace_topk',
341
+ topk=args.attn_topk,
342
+ text_scale=args.text_scale,
343
+ mappers=mappers,
344
+ alphas=alphas,
345
+ attn_adj_from=attn_adj_from,
346
+ save_source_ca=source_ca_index is not None,
347
+ save_target_ca=target_ca_index is not None,
348
+ )
349
+
350
+ attn_collector = AttnCollector(
351
+ transformer=pipe.transformer,
352
+ controller=attn_controller,
353
+ attn_processor_class=NewFluxAttnProcessor2_0,
354
+ )
355
+
356
+ feature_controller = FeatureReplace(
357
+ layer_list=feature_layer_list,
358
+ feature_steps=feature_steps,
359
+ )
360
+
361
+ feature_collector = FeatureCollector(
362
+ transformer=pipe.transformer,
363
+ controller=feature_controller,
364
+ )
365
+
366
+ num_prompts=len(prompts)
367
+
368
+ shape = (1, 16, 128, 128)
369
+ generator = torch.Generator(device=device).manual_seed(args.seed)
370
+ latents = randn_tensor(shape, device=device, generator=generator)
371
+ latents = pipe._pack_latents(latents, *latents.shape)
372
+
373
+ attn_collector.restore_orig_attention()
374
+ feature_collector.restore_orig_transformer()
375
+
376
+ t0 = time.perf_counter()
377
+
378
+ inv_latents = get_inversed_latent_list(
379
+ pipe,
380
+ source_img,
381
+ random_noise=latents,
382
+ num_inference_steps=args.num_inference_steps,
383
+ backward_method="ode",
384
+ use_prompt_for_inversion=False,
385
+ guidance_scale_for_inversion=0,
386
+ prompt_for_inversion='',
387
+ flow_steps=args.flow_steps,
388
+ )
389
+
390
+ source_latents = inv_latents[::-1]
391
+ target_latents = inv_latents[::-1]
392
+
393
+ attn_collector.register_attention_control()
394
+ feature_collector.register_transformer_control()
395
+
396
+ callback_fn = CallbackAll(
397
+ latents=source_latents,
398
+ attn_collector=attn_collector,
399
+ feature_collector=feature_collector,
400
+ feature_inject_steps=feature_steps,
401
+ mid_step_index=args.mid_step_index,
402
+ step_start=args.step_start,
403
+ use_mask=use_mask,
404
+ use_ca_mask=use_ca_mask,
405
+ source_ca_index=source_ca_index,
406
+ target_ca_index=target_ca_index,
407
+ mask_kwargs={'dilation': args.mask_dilation},
408
+ mask_steps=args.mask_steps,
409
+ mask=mask,
410
+ )
411
+
412
+ init_latent = target_latents[args.step_start]
413
+ init_latent = init_latent.repeat(num_prompts, 1, 1)
414
+ init_latent[0] = source_latents[args.mid_step_index]
415
+
416
+ os.makedirs(result_img_dir, exist_ok=True)
417
+ pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config(
418
+ pipe.scheduler.config,
419
+ step_start=args.step_start,
420
+ margin_index_from_image=0
421
+ )
422
+
423
+ attn_controller.reset()
424
+ feature_controller.reset()
425
+ attn_controller.text_scale = args.text_scale
426
+ attn_controller.cur_step = args.step_start
427
+ feature_controller.cur_step = args.step_start
428
+
429
+ with torch.no_grad():
430
+ images = pipe(
431
+ prompts,
432
+ latents=init_latent,
433
+ num_images_per_prompt=1,
434
+ guidance_scale=args.guidance_scale,
435
+ num_inference_steps=args.num_inference_steps,
436
+ generator=generator,
437
+ callback_on_step_end=callback_fn,
438
+ mid_step_index=args.mid_step_index,
439
+ step_start=args.step_start,
440
+ callback_on_step_end_tensor_inputs=['latents'],
441
+ ).images
442
+
443
+ t1 = time.perf_counter()
444
+ print(f"Done in {t1 - t0:.1f}s.")
445
+
446
+ source_img_path = os.path.join(result_img_dir, f"source.png")
447
+ source_img.save(source_img_path)
448
+
449
+ for i, img in enumerate(images[1:]):
450
+ target_img_path = os.path.join(result_img_dir, f"target_{i}.png")
451
+ img.save(target_img_path)
452
+
453
+ target_text_path = os.path.join(result_img_dir, f"target_prompts.txt")
454
+ with open(target_text_path, 'w') as file:
455
+ file.write(target_prompt + '\n')
456
+
457
+ source_text_path = os.path.join(result_img_dir, f"source_prompt.txt")
458
+ with open(source_text_path, 'w') as file:
459
+ file.write(source_prompt + '\n')
460
+
461
+ images = [source_img] + images
462
+
463
+ fs=3
464
+ n = len(images)
465
+ fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs))
466
+
467
+ for i, img in enumerate(images):
468
+ ax[i].imshow(img)
469
+
470
+ ax[0].set_title('source')
471
+ ax[1].set_title(source_prompt, fontsize=7)
472
+ ax[2].set_title(target_prompt, fontsize=7)
473
+
474
+ overall_img_path = os.path.join(result_img_dir, f"overall.png")
475
+ plt.savefig(overall_img_path, bbox_inches='tight')
476
+ plt.close()
477
+
478
+ mask_save_dir = os.path.join(result_img_dir, f"mask")
479
+ os.makedirs(mask_save_dir, exist_ok=True)
480
+
481
+ if use_ca_mask:
482
+ ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png")
483
+ mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L')
484
+ mask_img.save(ca_mask_path)
485
+
486
+ del inv_latents
487
+ del init_latent
488
+ gc.collect()
489
+ torch.cuda.empty_cache()
490
+
491
+ if __name__ == '__main__':
492
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ torch==2.4.1
3
+ pandas
4
+ matplotlib
5
+ transformers==4.44.2
6
+ torchao
7
+ torchvision
8
+ opencv-python
9
+ scikit-image
10
+ accelerate
11
+ sentencepiece
12
+ protobuf
scripts/w_ca/run_bird.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a blue and white bird sits on a branch'
2
+ target_prompt='a blue and white butterfly sits on a branch'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 3 \
12
+ --seed 0 \
13
+ --img_path 'data/images/bird.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/bird' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_cabin.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a painting of a cabin in the snow with mountains in the background'
2
+ target_prompt='a painting of a car in the snow with mountains in the background'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=40
9
+
10
+ python img_edit.py \
11
+ --gpu 3 \
12
+ --seed 0 \
13
+ --img_path 'data/images/cabin.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/cabin' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_car.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a sports car driving down the street'
2
+ target_prompt='stained glass window of a sports car driving down the street'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=10
9
+
10
+ python img_edit.py \
11
+ --gpu 1 \
12
+ --seed 0 \
13
+ --img_path 'data/images/car.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/car' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --use_mask 0 \
21
+ --attn_topk $attn_topk
scripts/w_ca/run_cat_poly.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a cat is shown in a low polygonal style'
2
+ target_prompt='a fox is shown in a low polygonal style'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 1 \
12
+ --seed 0 \
13
+ --img_path 'data/images/cat_poly.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/cat_poly' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
21
+
scripts/w_ca/run_flower.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a pink flower with yellow center in the middle'
2
+ target_prompt='a blue flower with red center in the middle'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 1 \
12
+ --seed 0 \
13
+ --img_path 'data/images/flower.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/flower' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk \
21
+ --blend_word 'flower'
scripts/w_ca/run_fruit.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='white plate with fruits on it'
2
+ target_prompt='white plate with pizza on it'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=40
9
+
10
+ python img_edit.py \
11
+ --gpu 0 \
12
+ --seed 0 \
13
+ --img_path 'data/images/fruit.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/fruit' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_koala.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a koala is sitting on a tree'
2
+ target_prompt='a koala and a bird is sitting on a tree'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=40
9
+
10
+ python img_edit.py \
11
+ --gpu 3 \
12
+ --seed 0 \
13
+ --img_path 'data/images/koala.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/koala' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_owl_heart.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a cartoon painting of a cute owl with a heart on its body'
2
+ target_prompt='a cartoon painting of a cute owl with a circle on its body'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 1 \
12
+ --seed 0 \
13
+ --img_path 'data/images/owl_heart.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/owl_heart' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_statue.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='photo of a statue in front view'
2
+ target_prompt='photo of a statue in side view'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=60
9
+
10
+ python img_edit.py \
11
+ --gpu 0 \
12
+ --seed 0 \
13
+ --img_path 'data/images/statue.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/statue' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk \
21
+ --blend_word 'statue'
scripts/w_ca/run_steak.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a plate with steak on it'
2
+ target_prompt='a plate with salmon on it'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=40
9
+
10
+ python img_edit.py \
11
+ --gpu 0 \
12
+ --seed 0 \
13
+ --img_path 'data/images/steak.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/steak' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_ca/run_tennis.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a woman in a black tank top and pink shorts is about to hit a tennis ball'
2
+ target_prompt='a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 0 \
12
+ --seed 0 \
13
+ --img_path 'data/images/tennis.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/tennis' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk \
21
+ --blend_word 'woman'
scripts/w_ca/run_woman_book.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a woman sitting in the grass with a book'
2
+ target_prompt='a woman sitting in the grass with a laptop'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 1 \
12
+ --seed 0 \
13
+ --img_path 'data/images/woman_book.jpg' \
14
+ --source_prompt "$source_prompt" \
15
+ --target_prompt "$target_prompt" \
16
+ --results_dir 'results/woman_book' \
17
+ --ca_steps $ca_steps \
18
+ --sa_steps $sa_steps \
19
+ --feature_steps $feature_steps \
20
+ --attn_topk $attn_topk
scripts/w_mask/run_cat_hat.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_prompt='a cat wearing a pink hat'
2
+ target_prompt='a tiger wearing a pink hat'
3
+
4
+ ca_steps=10
5
+ sa_steps=7
6
+ feature_steps=5
7
+
8
+ attn_topk=20
9
+
10
+ python img_edit.py \
11
+ --gpu 3 \
12
+ --seed 0 \
13
+ --img_path 'data/images/cat_hat.jpg' \
14
+ --mask_path 'data/masks/cat_hat.jpg' \
15
+ --source_prompt "$source_prompt" \
16
+ --target_prompt "$target_prompt" \
17
+ --results_dir 'results/cat_hat' \
18
+ --ca_steps $ca_steps \
19
+ --sa_steps $sa_steps \
20
+ --feature_steps $feature_steps \
21
+ --attn_topk $attn_topk