SahilCarterr commited on
Commit
a49b34f
·
verified ·
1 Parent(s): da6a35b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +569 -0
app.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.util import strtobool
2
+ from typing import Optional
3
+ import os
4
+ import argparse
5
+ import gc
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from distutils.util import strtobool
11
+ import spaces
12
+
13
+ import pandas as pd
14
+
15
+ import gc
16
+
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import torch
20
+ import yaml
21
+ from diffusers import FlowMatchEulerDiscreteScheduler
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from PIL import Image
24
+
25
+ from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector
26
+ from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0
27
+ from src.attn_utils.seq_aligner import get_refinement_mapper
28
+ from src.callback.callback_fn import CallbackAll
29
+ from src.inversion.inverse import get_inversed_latent_list
30
+ from src.inversion.scheduling_flow_inverse import \
31
+ FlowMatchEulerDiscreteForwardScheduler
32
+ from src.pipeline.flux_pipeline import NewFluxPipeline
33
+ from src.transformer_utils.transformer_utils import (FeatureCollector,
34
+ FeatureReplace)
35
+ from src.utils import (find_token_id_differences, find_word_token_indices,
36
+ get_flux_pipeline, mask_decode, mask_interpolate)
37
+ from typing import Any, Callable, Dict, List, Optional, Union
38
+
39
+ pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline)
40
+ pipe = pipe.to("cuda")
41
+
42
+ def fix_seed(random_seed):
43
+ """
44
+ fix seed to control any randomness from a code
45
+ (enable stability of the experiments' results.)
46
+ """
47
+ torch.manual_seed(random_seed)
48
+ torch.cuda.manual_seed(random_seed)
49
+ torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
50
+ torch.backends.cudnn.deterministic = True
51
+ torch.backends.cudnn.benchmark = False
52
+ np.random.seed(random_seed)
53
+ random.seed(random_seed)
54
+
55
+ @spaces.GPU
56
+ def infer(
57
+ input_image: Union[str, Image.Image], # ⬅️ Main UI (uploaded image)
58
+ target_prompt: Union[str, List[str]] = '', # ⬅️ Main UI (text prompt)
59
+ source_prompt: Union[str, List[str]] = '', # ⬅️ Advanced accordion
60
+ seed: int = 0, # ⬅️ Advanced accordion
61
+ ca_steps: int = 10, # ⬅️ Advanced accordion
62
+ sa_steps: int = 7, # ⬅️ Advanced accordion
63
+ feature_steps: int = 5, # ⬅️ Advanced accordion
64
+ attn_topk: int = 20, # ⬅️ Advanced accordion
65
+ mask_image: Optional[Image.Image] = None, # ⬅️ Advanced (optional upload)
66
+
67
+ # Everything below is backend-related or defaults, not exposed in UI
68
+ blend_word: str = '',
69
+ results_dir: str = 'results',
70
+ model: str = 'flux',
71
+
72
+ ca_attn_layer_from: int = 13,
73
+ ca_attn_layer_to: int = 45,
74
+ sa_attn_layer_from: int = 20,
75
+ sa_attn_layer_to: int = 45,
76
+ feature_layer_from: int = 13,
77
+ feature_layer_to: int = 20,
78
+ flow_steps: int = 7,
79
+ step_start: int = 0,
80
+ num_inference_steps: int = 28,
81
+ guidance_scale: float = 3.5,
82
+ text_scale: float = 4.0,
83
+ mid_step_index: int = 14,
84
+ use_mask: bool = True,
85
+ use_ca_mask: bool = True,
86
+ mask_steps: int = 18,
87
+ mask_dilation: int = 3,
88
+ mask_nbins: int = 128
89
+ ):
90
+ if isinstance(mask_image, Image.Image):
91
+ # Ensure mask is single channel
92
+ if mask_image.mode != "L":
93
+ mask_image = mask_image.convert("L")
94
+
95
+
96
+ fix_seed(seed)
97
+ device = torch.device('cuda')
98
+
99
+
100
+ attn_proc = NewFluxAttnProcessor2_0
101
+
102
+ layer_order = range(57)
103
+
104
+ ca_layer_list = layer_order[ca_attn_layer_from:ca_attn_layer_to]
105
+ sa_layer_list = layer_order[feature_layer_to:sa_attn_layer_to]
106
+ feature_layer_list = layer_order[feature_layer_from:feature_layer_to]
107
+
108
+
109
+
110
+ source_img = input_image.resize((1024, 1024)).convert("RGB")
111
+ #img_base_name = os.path.splitext(img_path)[0].split('/')[-1]
112
+ result_img_dir = f"{results_dir}/seed_{seed}/{target_prompt}"
113
+
114
+ source_prompt = source_prompt
115
+ target_prompt = target_prompt
116
+ prompts = [source_prompt, target_prompt]
117
+ mask_path=mask_image
118
+ print(prompts)
119
+ mask = None
120
+
121
+ if use_mask:
122
+ use_mask = True
123
+
124
+ if mask_path is not None:
125
+ mask = mask_path
126
+ mask = torch.tensor(np.array(mask)).bool()
127
+ mask = mask.to(device)
128
+
129
+ # Increase the latent blending steps if the ground truth mask is used.
130
+ mask_steps = int(num_inference_steps * 0.9)
131
+
132
+ source_ca_index = None
133
+ target_ca_index = None
134
+ use_ca_mask = False
135
+
136
+ elif use_ca_mask and source_prompt:
137
+ mask = None
138
+ if blend_word and blend_word in source_prompt:
139
+ editing_source_token_index = find_word_token_indices(source_prompt, blend_word, pipe.tokenizer_2)
140
+ editing_target_token_index = None
141
+ else:
142
+ editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2)
143
+ editing_source_token_index = editing_tokens_info['prompt_1']['index']
144
+ editing_target_token_index = editing_tokens_info['prompt_2']['index']
145
+
146
+ use_ca_mask = True
147
+ if editing_source_token_index:
148
+ source_ca_index = editing_source_token_index
149
+ target_ca_index = None
150
+ elif editing_target_token_index:
151
+ source_ca_index = None
152
+ target_ca_index = editing_target_token_index
153
+ else:
154
+ source_ca_index = None
155
+ target_ca_index = None
156
+ use_ca_mask = False
157
+
158
+ else:
159
+ source_ca_index = None
160
+ target_ca_index = None
161
+ use_ca_mask = False
162
+
163
+ else:
164
+ use_mask = False
165
+ use_ca_mask = False
166
+ source_ca_index = None
167
+ target_ca_index = None
168
+
169
+ if source_prompt:
170
+ # Use I2T-CA injection
171
+ mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512)
172
+ mappers = mappers.to(device=device)
173
+ alphas = alphas.to(device=device, dtype=pipe.dtype)
174
+ alphas = alphas[:, None, None, :]
175
+
176
+ attn_adj_from = 1
177
+
178
+ else:
179
+ # Not use I2T-CA injection
180
+ mappers = None
181
+ alphas = None
182
+
183
+ ca_steps = 0
184
+ attn_adj_from=3
185
+
186
+ feature_steps = feature_steps
187
+
188
+ attn_controller = AttentionAdapter(
189
+ ca_layer_list=ca_layer_list,
190
+ sa_layer_list=sa_layer_list,
191
+ ca_steps=ca_steps,
192
+ sa_steps=sa_steps,
193
+ method='replace_topk',
194
+ topk=attn_topk,
195
+ text_scale=text_scale,
196
+ mappers=mappers,
197
+ alphas=alphas,
198
+ attn_adj_from=attn_adj_from,
199
+ save_source_ca=source_ca_index is not None,
200
+ save_target_ca=target_ca_index is not None,
201
+ )
202
+
203
+ attn_collector = AttnCollector(
204
+ transformer=pipe.transformer,
205
+ controller=attn_controller,
206
+ attn_processor_class=NewFluxAttnProcessor2_0,
207
+ )
208
+
209
+ feature_controller = FeatureReplace(
210
+ layer_list=feature_layer_list,
211
+ feature_steps=feature_steps,
212
+ )
213
+
214
+ feature_collector = FeatureCollector(
215
+ transformer=pipe.transformer,
216
+ controller=feature_controller,
217
+ )
218
+
219
+ num_prompts=len(prompts)
220
+
221
+ shape = (1, 16, 128, 128)
222
+ generator = torch.Generator(device=device).manual_seed(seed)
223
+ latents = randn_tensor(shape, device=device, generator=generator)
224
+ latents = pipe._pack_latents(latents, *latents.shape)
225
+
226
+ attn_collector.restore_orig_attention()
227
+ feature_collector.restore_orig_transformer()
228
+
229
+ t0 = time.perf_counter()
230
+
231
+ inv_latents = get_inversed_latent_list(
232
+ pipe,
233
+ source_img,
234
+ random_noise=latents,
235
+ num_inference_steps=num_inference_steps,
236
+ backward_method="ode",
237
+ use_prompt_for_inversion=False,
238
+ guidance_scale_for_inversion=0,
239
+ prompt_for_inversion='',
240
+ flow_steps=flow_steps,
241
+ )
242
+
243
+ source_latents = inv_latents[::-1]
244
+ target_latents = inv_latents[::-1]
245
+
246
+ attn_collector.register_attention_control()
247
+ feature_collector.register_transformer_control()
248
+
249
+ callback_fn = CallbackAll(
250
+ latents=source_latents,
251
+ attn_collector=attn_collector,
252
+ feature_collector=feature_collector,
253
+ feature_inject_steps=feature_steps,
254
+ mid_step_index=mid_step_index,
255
+ step_start=step_start,
256
+ use_mask=use_mask,
257
+ use_ca_mask=use_ca_mask,
258
+ source_ca_index=source_ca_index,
259
+ target_ca_index=target_ca_index,
260
+ mask_kwargs={'dilation': mask_dilation},
261
+ mask_steps=mask_steps,
262
+ mask=mask,
263
+ )
264
+
265
+ init_latent = target_latents[step_start]
266
+ init_latent = init_latent.repeat(num_prompts, 1, 1)
267
+ init_latent[0] = source_latents[mid_step_index]
268
+
269
+ os.makedirs(result_img_dir, exist_ok=True)
270
+ pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config(
271
+ pipe.scheduler.config,
272
+ step_start=step_start,
273
+ margin_index_from_image=0
274
+ )
275
+
276
+ attn_controller.reset()
277
+ feature_controller.reset()
278
+ attn_controller.text_scale = text_scale
279
+ attn_controller.cur_step = step_start
280
+ feature_controller.cur_step = step_start
281
+
282
+ with torch.no_grad():
283
+ images = pipe(
284
+ prompts,
285
+ latents=init_latent,
286
+ num_images_per_prompt=1,
287
+ guidance_scale=guidance_scale,
288
+ num_inference_steps=num_inference_steps,
289
+ generator=generator,
290
+ callback_on_step_end=callback_fn,
291
+ mid_step_index=mid_step_index,
292
+ step_start=step_start,
293
+ callback_on_step_end_tensor_inputs=['latents'],
294
+ ).images
295
+
296
+ t1 = time.perf_counter()
297
+ print(f"Done in {t1 - t0:.1f}s.")
298
+
299
+ source_img_path = os.path.join(result_img_dir, f"source.png")
300
+ source_img.save(source_img_path)
301
+ final_image=input_image
302
+ for i, img in enumerate(images[1:]):
303
+ target_img_path = os.path.join(result_img_dir, f"target_{i}.png")
304
+ img.save(target_img_path)
305
+ final_image=img
306
+
307
+
308
+ target_text_path = os.path.join(result_img_dir, f"target_prompts.txt")
309
+ with open(target_text_path, 'w') as file:
310
+ file.write(target_prompt + '\n')
311
+
312
+ source_text_path = os.path.join(result_img_dir, f"source_prompt.txt")
313
+ with open(source_text_path, 'w') as file:
314
+ file.write(source_prompt + '\n')
315
+
316
+ images = [source_img] + images
317
+
318
+ fs=3
319
+ n = len(images)
320
+ fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs))
321
+
322
+ for i, img in enumerate(images):
323
+ ax[i].imshow(img)
324
+
325
+ ax[0].set_title('source')
326
+ ax[1].set_title(source_prompt, fontsize=7)
327
+ ax[2].set_title(target_prompt, fontsize=7)
328
+
329
+ overall_img_path = os.path.join(result_img_dir, f"overall.png")
330
+ plt.savefig(overall_img_path, bbox_inches='tight')
331
+ plt.close()
332
+
333
+ mask_save_dir = os.path.join(result_img_dir, f"mask")
334
+ os.makedirs(mask_save_dir, exist_ok=True)
335
+
336
+ if use_ca_mask:
337
+ ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png")
338
+ mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L')
339
+ mask_img.save(ca_mask_path)
340
+
341
+ del inv_latents
342
+ del init_latent
343
+ gc.collect()
344
+ torch.cuda.empty_cache()
345
+ import shutil
346
+ shutil.rmtree(result_img_dir)
347
+ shutil.rmtree(results_dir)
348
+
349
+ return final_image, seed, gr.Button(visible=True)
350
+
351
+ import gradio as gr
352
+ from PIL import Image
353
+ import numpy as np
354
+
355
+ MAX_SEED = np.iinfo(np.int32).max
356
+
357
+ @spaces.GPU
358
+ def infer_example(input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image=None):
359
+ img, seed, _ = infer(
360
+ input_image=input_image,
361
+ target_prompt=target_prompt,
362
+ source_prompt=source_prompt,
363
+ seed=seed,
364
+ ca_steps=ca_steps,
365
+ sa_steps=sa_steps,
366
+ feature_steps=feature_steps,
367
+ attn_topk=attn_topk,
368
+ mask_image=mask_image
369
+ )
370
+ return img, seed
371
+
372
+
373
+ with gr.Blocks() as demo:
374
+ with gr.Column(elem_id="col-container"):
375
+ gr.Markdown("""# ReFlex
376
+ Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation
377
+ [[blog]](https://wlaud1001.github.io/ReFlex/) | [[Github]](https://github.com/wlaud1001/ReFlex)
378
+ """)
379
+ with gr.Row():
380
+ with gr.Column():
381
+ input_image = gr.Image(label="Upload the image for editing", type="pil")
382
+ mask_image = gr.Image(label="Upload optional mask", type="pil")
383
+
384
+ with gr.Row():
385
+ target_prompt = gr.Text(
386
+ label="Target Prompt",
387
+ show_label=False,
388
+ max_lines=1,
389
+ placeholder="Describe the Edited Image",
390
+ container=False,
391
+ )
392
+
393
+ with gr.Column():
394
+ source_prompt = gr.Text(
395
+ label="Source Prompt",
396
+ show_label=False,
397
+ max_lines=1,
398
+ placeholder="Enter source prompt (optional) : Describe the Input Image",
399
+ container=False,
400
+ )
401
+ run_button = gr.Button("Run", scale=10)
402
+
403
+ with gr.Accordion("Advanced Settings", open=False):
404
+
405
+ seed = gr.Slider(
406
+ label="Seed",
407
+ minimum=0,
408
+ maximum=MAX_SEED,
409
+ step=1,
410
+ value=0,
411
+ )
412
+ ca_steps = gr.Slider(
413
+ label="Cross-Attn (CA) Steps",
414
+ minimum=0,
415
+ maximum=20,
416
+ step=1,
417
+ value=10
418
+ )
419
+ sa_steps = gr.Slider(
420
+ label="Self-Attn (SA) Steps",
421
+ minimum=0,
422
+ maximum=20,
423
+ step=1,
424
+ value=7
425
+ )
426
+ feature_steps = gr.Slider(
427
+ label="Feature Injection Steps",
428
+ minimum=0,
429
+ maximum=20,
430
+ step=1,
431
+ value=5
432
+ )
433
+ attn_topk = gr.Slider(
434
+ label="Attention Top-K",
435
+ minimum=1,
436
+ maximum=64,
437
+ step=1,
438
+ value=20
439
+ )
440
+
441
+ with gr.Column():
442
+ result = gr.Image(label="Result", show_label=False, interactive=False)
443
+ reuse_button = gr.Button("Reuse this image", visible=False)
444
+
445
+ examples = gr.Examples(
446
+ examples=[
447
+
448
+ # 2. Without mask
449
+ [
450
+ "data/images/bear.jpeg",
451
+ "an image of Paddington the bear",
452
+ "",
453
+ 0, 0, 12, 7, 20,
454
+ None
455
+ ],
456
+ # 3. Without mask
457
+ [
458
+ "data/images/bird_painting.jpg",
459
+ "a photo of an eagle in the sky",
460
+ "",
461
+ 0, 0, 12, 7, 20,
462
+ None
463
+ ],
464
+ [
465
+ "data/images/dancing.jpeg",
466
+ "a couple of silver robots dancing in the garden",
467
+ "",
468
+ 0, 0, 12, 7, 20,
469
+ None
470
+ ],
471
+
472
+ [
473
+ "data/images/real_karate.jpeg",
474
+ "a silver robot in the snow",
475
+ "",
476
+ 0, 0, 12, 7, 20,
477
+ None
478
+ ],
479
+ [
480
+ "data/images/woman_book.jpg",
481
+ "a woman sitting in the grass with a laptop",
482
+ "a woman sitting in the grass with a book",
483
+ 0, 10, 7, 5, 20,
484
+ None
485
+ ],
486
+ [
487
+ "data/images/statue.jpg",
488
+ "photo of a statue in side view",
489
+ "photo of a statue in front view",
490
+ 0, 10, 7, 5, 60,
491
+ None
492
+ ],
493
+ [
494
+ "data/images/tennis.jpg",
495
+ "a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball",
496
+ "a woman in a black tank top and pink shorts is about to hit a tennis ball",
497
+ 0, 10, 7, 5, 20,
498
+ None
499
+ ],
500
+ [
501
+ "data/images/owl_heart.jpg",
502
+ "a cartoon painting of a cute owl with a circle on its body",
503
+ "a cartoon painting of a cute owl with a heart on its body",
504
+ 0, 10, 7, 5, 20,
505
+ None
506
+ ],
507
+
508
+ [
509
+ "data/images/girl_mountain.jpg",
510
+ "a woman with her arms outstretched in front of the NewYork",
511
+ "a woman with her arms outstretched on top of a mountain",
512
+ 0, 10, 7, 5, 20,
513
+ "data/masks/girl_mountain.jpg"
514
+ ],
515
+ [
516
+ "data/images/santa.jpg",
517
+ "the christmas illustration of a santa's angry face",
518
+ "the christmas illustration of a santa's laughing face",
519
+ 0, 10, 7, 5, 20,
520
+ "data/masks/santa.jpg"
521
+ ],
522
+ [
523
+ "data/images/cat_mirror.jpg",
524
+ "a tiger sitting next to a mirror",
525
+ "a cat sitting next to a mirror",
526
+ 0, 10, 7, 5, 20,
527
+ "data/masks/cat_mirror.jpg"
528
+ ],
529
+ ],
530
+ inputs=[
531
+ input_image,
532
+ target_prompt,
533
+ source_prompt,
534
+ seed,
535
+ ca_steps,
536
+ sa_steps,
537
+ feature_steps,
538
+ attn_topk,
539
+ mask_image
540
+ ],
541
+ outputs=[result, seed],
542
+ fn=infer_example,
543
+ cache_examples="lazy"
544
+ )
545
+
546
+ gr.on(
547
+ triggers=[run_button.click, target_prompt.submit],
548
+ fn=infer,
549
+ inputs=[
550
+ input_image,
551
+ target_prompt,
552
+ source_prompt,
553
+ seed,
554
+ ca_steps,
555
+ sa_steps,
556
+ feature_steps,
557
+ attn_topk,
558
+ mask_image
559
+ ],
560
+ outputs=[result, seed, reuse_button]
561
+ )
562
+
563
+ reuse_button.click(
564
+ fn=lambda image: image,
565
+ inputs=[result],
566
+ outputs=[input_image]
567
+ )
568
+
569
+ demo.launch(share=True, debug=True)