diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f859a2b22e3caa16e4387be2c3c89f051d8d7f5e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +data/images/cabin.jpg filter=lfs diff=lfs merge=lfs -text +data/images/fruit.jpg filter=lfs diff=lfs merge=lfs -text +data/images/koala.jpg filter=lfs diff=lfs merge=lfs -text +data/images/man_tree.jpg filter=lfs diff=lfs merge=lfs -text +data/images/meditation.png filter=lfs diff=lfs merge=lfs -text +data/images/old_couple.jpg filter=lfs diff=lfs merge=lfs -text +data/images/statue.jpg filter=lfs diff=lfs merge=lfs -text +data/images/steak.jpg filter=lfs diff=lfs merge=lfs -text +data/images/woman_book.jpg filter=lfs diff=lfs merge=lfs -text +images/main_figure.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 77564a5add06c0a3569e5fc9f1c4ac76a1584932..73d10bc1adf43916d74498aaed12e32e1a01f62b 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,68 @@ ---- -title: ReFlex -emoji: 📚 -colorFrom: red -colorTo: yellow -sdk: gradio -sdk_version: 5.38.0 -app_file: app.py -pinned: false -license: mit -short_description: Text-Guided Editing of Real Images ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# ReFlex: Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation + +### [ICCV 2025] Official Pytorch implementation of the paper: "ReFlex: Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation" +by Jimyeon Kim, Jungwon Park, Yeji Song, Nojun Kwak, Wonjong Rhee†. + +Seoul National University + +[Arxiv](https://arxiv.org/abs/2507.01496) +  +[Project Page](https://wlaud1001.github.io/ReFlex/) + + + +![main](./images/main_figure.png) + +## Setup +``` +git clone https://github.com/wlaud1001/ReFlex.git +cd ReFlex + +conda create -n reflex python=3.10 +conda activate reflex +pip install -r requirements.txt +``` + +## Run + +### Run exmaple +``` +python img_edit.py \ + --gpu {gpu} \ + --seed {seed} \ + --img_path {source_img_path} \ + --source_prompt {source_prompt} \ + --target_prompt {target_prompt} \ + --results_dir {results_dir} \ + --feature_steps {feature_steps} \ + --attn_topk {attn_topk} +``` +### Arguments +- --gpu: Index of the GPU to use. +- --seed: Random seed. +- --img_path: Path to the input real image to be edited. +- --mask_path (optional): Path to a ground-truth mask for local editing. + - If provided, this mask is used directly. + - If omitted, the editing mask is automatically generated from attention maps. +- --source_prompt (optional): Text prompt describing the content of the input image. + - If provided, mask generation and latent blending will be applied. + - If omitted, editing proceeds without latent blending. +- --target_prompt: Text prompt describing the desired edited image. +- --blend_word (optional): Word in --source_prompt to guide mask generation via its I2T-CA map. + - If omitted, the blend word is automatically inferred by comparing source_prompt and target_prompt. +- --results_dir: Directory to save the output images +### + +### Scripts +We also provide several example scripts in the (./scripts) directory for some use cases and reproducible experiments. +#### Script Categories +- scripts/wo_ca/: Cases where the source prompt is not given. I2T-CA adaptation and latent blending are not applied. +- scripts/w_ca/: Cases where the source prompt is given, and the editing mask for latent blending is automatically generated from the attention map. +- scripts/w_mask/: Cases where a ground-truth mask for local editing is provided and directly used for latent blending. + +You can run a script as follows: +``` +./scripts/wo_ca/run_bear.sh +./scripts/w_ca/run_bird.sh +./scripts/w_mask/run_cat_hat.sh +``` \ No newline at end of file diff --git a/data/images/bear.jpeg b/data/images/bear.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..9ce8073fb8a086892479fcc44820dc3b030288ce Binary files /dev/null and b/data/images/bear.jpeg differ diff --git a/data/images/bird.jpg b/data/images/bird.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9d2da7cf7c5307c73398829439ce126c4779b49d Binary files /dev/null and b/data/images/bird.jpg differ diff --git a/data/images/bird_painting.jpg b/data/images/bird_painting.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbb57a478340e6404f9237dedc0febedff8771e7 Binary files /dev/null and b/data/images/bird_painting.jpg differ diff --git a/data/images/cabin.jpg b/data/images/cabin.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c11cb5660537b448192f2e9e61fa72ca8e41e885 --- /dev/null +++ b/data/images/cabin.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57c526d303939ec8fa1e6fe6780ba1d8be5aacfe0ce6c4eeaf1b2771e29a534f +size 123301 diff --git a/data/images/car.jpg b/data/images/car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8dfcc69a579255635e3a7404bfdd36ef9477450 Binary files /dev/null and b/data/images/car.jpg differ diff --git a/data/images/cat_hat.jpg b/data/images/cat_hat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed414590e530cd73b2d3fd00aa58cda2048706af Binary files /dev/null and b/data/images/cat_hat.jpg differ diff --git a/data/images/cat_mirror.jpg b/data/images/cat_mirror.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dae2241dd3a83a77005a21603bd676288d611675 Binary files /dev/null and b/data/images/cat_mirror.jpg differ diff --git a/data/images/cat_poly.jpg b/data/images/cat_poly.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b955ac71b7daf040b211376020d65737b6a65939 Binary files /dev/null and b/data/images/cat_poly.jpg differ diff --git a/data/images/dancing.jpeg b/data/images/dancing.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..f8e13b8aa42278bd4e75fe3f89645ffc884872dc Binary files /dev/null and b/data/images/dancing.jpeg differ diff --git a/data/images/flower.jpg b/data/images/flower.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec9f8c860317fa4c104e5d72ba50fe04c6c062a1 Binary files /dev/null and b/data/images/flower.jpg differ diff --git a/data/images/fruit.jpg b/data/images/fruit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38b264df988139326eb9ce8525df784d0d9a7890 --- /dev/null +++ b/data/images/fruit.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2dfeda0bba2b887ac5b082771b74bbe990110a712e0ebaed2c3c6abca2d8630 +size 139142 diff --git a/data/images/girl_mountain.jpg b/data/images/girl_mountain.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9e17124f329a455cb60730a96858e216457b35a Binary files /dev/null and b/data/images/girl_mountain.jpg differ diff --git a/data/images/koala.jpg b/data/images/koala.jpg new file mode 100644 index 0000000000000000000000000000000000000000..57fca8f0753751ab4421e4c5e1d2711a0a10485f --- /dev/null +++ b/data/images/koala.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be9ab5f91b329a5cc53e55bac9eba350aaf80b39a04e8e6a03d147713a5eb283 +size 149969 diff --git a/data/images/man_tree.jpg b/data/images/man_tree.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e5be6234b6c5cc089aa47c8e33929136671b8d0 --- /dev/null +++ b/data/images/man_tree.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d53f9d74aeb377b65ca9fac3684dd5495451cb09cc4aeacb614d912ec89f462 +size 101634 diff --git a/data/images/meditation.png b/data/images/meditation.png new file mode 100644 index 0000000000000000000000000000000000000000..39719def6abd12bec210f7f79b682854b1db30cb --- /dev/null +++ b/data/images/meditation.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c1ebb8230cee73caa80b9a9b5ec1ae0c89d12742f06be789f60a53f9177f9c1 +size 288440 diff --git a/data/images/old_couple.jpg b/data/images/old_couple.jpg new file mode 100644 index 0000000000000000000000000000000000000000..06f61e0a7b92ea4ce020fef3ef4ba9487fd00716 --- /dev/null +++ b/data/images/old_couple.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:405cc22840c86e79aeef24f36ce0a6a1e90491bf3badabfd1c16d0cc300c17f2 +size 151138 diff --git a/data/images/owl_heart.jpg b/data/images/owl_heart.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e562026497d52db2c5049369e0188b8332a5d626 Binary files /dev/null and b/data/images/owl_heart.jpg differ diff --git a/data/images/raven.jpg b/data/images/raven.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c95a7aa9df26ba6a58f5914138f6ef678646b120 Binary files /dev/null and b/data/images/raven.jpg differ diff --git a/data/images/real_karate.jpeg b/data/images/real_karate.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..1dcb0225247e52558b33ded15557236f69a32d19 Binary files /dev/null and b/data/images/real_karate.jpeg differ diff --git a/data/images/santa.jpg b/data/images/santa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbc619df33762eeca98d9a94fc93edb325a1e025 Binary files /dev/null and b/data/images/santa.jpg differ diff --git a/data/images/squirrel.jpg b/data/images/squirrel.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c4d4bd5743705cceef20021e64c586e6af5c760e Binary files /dev/null and b/data/images/squirrel.jpg differ diff --git a/data/images/statue.jpg b/data/images/statue.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25737e1a61104d971d51172fc6ac91957a62d856 --- /dev/null +++ b/data/images/statue.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7a02cb1cfb21a69bfb3bed2d56c74799385860c625a78f2f9c9527d0b96d123 +size 214234 diff --git a/data/images/steak.jpg b/data/images/steak.jpg new file mode 100644 index 0000000000000000000000000000000000000000..278f3240f6367715fc604e48334b41c7826517c9 --- /dev/null +++ b/data/images/steak.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60a98952c0d657c652d7c686d6eb93419cb3dff1495aca93a4ddcbcd2c30af32 +size 159623 diff --git a/data/images/tennis.jpg b/data/images/tennis.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f6aed9fe15d9f2ae8689186bd71cea8337e66de1 Binary files /dev/null and b/data/images/tennis.jpg differ diff --git a/data/images/woman_book.jpg b/data/images/woman_book.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a4f1d9fac43f49e67256d28027bd460567fd14ac --- /dev/null +++ b/data/images/woman_book.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aaa44eba168cbbec858b846ba3f801fd67e5e4d4a7d8f76d28b56661ceaac992 +size 112555 diff --git a/data/masks/cat_hat.jpg b/data/masks/cat_hat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b11b4d68ab11db414ae24fcd349741cccdd70bfe Binary files /dev/null and b/data/masks/cat_hat.jpg differ diff --git a/data/masks/cat_mirror.jpg b/data/masks/cat_mirror.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3bd2b48daaeb2c14c7959f14816bfec1abbe01f3 Binary files /dev/null and b/data/masks/cat_mirror.jpg differ diff --git a/data/masks/girl_mountain.jpg b/data/masks/girl_mountain.jpg new file mode 100644 index 0000000000000000000000000000000000000000..43c1ab2fb0b586cf9fc9d0a97d52ef815e360469 Binary files /dev/null and b/data/masks/girl_mountain.jpg differ diff --git a/data/masks/man_tree.jpg b/data/masks/man_tree.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e7a265b34bf3af1e5141ca4f2097f2e2e9758a1 Binary files /dev/null and b/data/masks/man_tree.jpg differ diff --git a/data/masks/old_couple.jpg b/data/masks/old_couple.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d459d96c54c0bc2c7247f9e2f66cf8672e6ad6fb Binary files /dev/null and b/data/masks/old_couple.jpg differ diff --git a/data/masks/raven.jpg b/data/masks/raven.jpg new file mode 100644 index 0000000000000000000000000000000000000000..27b1d3ebb578251e79fe01b029ef72153e22b4a8 Binary files /dev/null and b/data/masks/raven.jpg differ diff --git a/data/masks/santa.jpg b/data/masks/santa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ece225885064abd5514cf7ffe23813ac2b0b128c Binary files /dev/null and b/data/masks/santa.jpg differ diff --git a/images/main_figure.png b/images/main_figure.png new file mode 100644 index 0000000000000000000000000000000000000000..0e3ad4e91751ddef7ef7f0cd56b9be5c7c3c2279 --- /dev/null +++ b/images/main_figure.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15cdc45b0a49a939fa22c167d9392cdd147d451f519ab616bd065c018860722e +size 15438913 diff --git a/img_edit.py b/img_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..15801fe3ca6eb15bbd76c001eb73b85dccf05581 --- /dev/null +++ b/img_edit.py @@ -0,0 +1,492 @@ +import argparse +import gc +import os +import random +import re +import time +from distutils.util import strtobool + +import pandas as pd + +parser = argparse.ArgumentParser() +parser.add_argument( + "--img_path", + type=str, +) +parser.add_argument( + "--target_prompt", + type=str, +) +parser.add_argument( + "--source_prompt", + type=str, + default='' +) +parser.add_argument( + "--blend_word", + type=str, + default='' +) +parser.add_argument( + "--mask_path", + type=str, + default=None +) + + +parser.add_argument( + "--gpu", + type=str, + default="0", +) +parser.add_argument( + "--seed", + type=int, + default=0 +) +parser.add_argument( + "--results_dir", + type=str, + default='results' +) + + +parser.add_argument( + "--model", + type=str, + default='flux', + choices=['flux'] +) + +parser.add_argument( + "--ca_steps", + type=int, + default=10, + help="Number of steps to apply I2T-CA adaptation and injection.", +) + +parser.add_argument( + "--sa_steps", + type=int, + default=7 + help="Number of steps to apply I2I-SA adaptation and injection.", +) + +parser.add_argument( + "--feature_steps", + type=int, + default=5 + help="Number of steps to inject residual features.", +) + + +parser.add_argument( + "--ca_attn_layer_from", + type=int, + default=13, + help="Layers to apply I2T-CA adaptation and injection.", +) +parser.add_argument( + "--ca_attn_layer_to", + type=int, + default=45, + help="Layers to apply I2T-CA adaptation and injection.", +) + +parser.add_argument( + "--sa_attn_layer_from", + type=int, + default=20, + help="Layers to apply I2I-SA adaptation and injection.", +) +parser.add_argument( + "--sa_attn_layer_to", + type=int, + default=45, + help="Layers to apply I2I-SA adaptation and injection.", +) + +parser.add_argument( + "--feature_layer_from", + type=int, + default=13, + help="Layers to inject residual features.", +) +parser.add_argument( + "--feature_layer_to", + type=int, + default=20, + help="Layers to inject residual features.", +) + +parser.add_argument( + "--flow_steps", + type=int, + default=7, + help="Steps to apply forward step before inversion", +) +parser.add_argument( + "--step_start", + type=int, + default=0 +) + + +parser.add_argument( + "--num_inference_steps", + type=int, + default=28 +) +parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, +) + +parser.add_argument( + "--attn_topk", + type=int, + default=20, + help="Hyperparameter for I2I-SA adaptaion." +) + +parser.add_argument( + "--text_scale", + type=float, + default=4, + help="Hyperparameter for I2T-CA adaptaion." +) + +parser.add_argument( + "--mid_step_index", + type=int, + default=14, + help="Hyperparameter for mid-step feature extraction." +) + + +parser.add_argument( + "--use_mask", + type=strtobool, + default=True +) + +parser.add_argument( + "--use_ca_mask", + type=strtobool, + default=True +) + +parser.add_argument( + "--mask_steps", + type=int, + default=18, + help="Steps to apply latent blending" +) + +parser.add_argument( + "--mask_dilation", + type=int, + default=3 +) +parser.add_argument( + "--mask_nbins", + type=int, + default=128 +) + +args = parser.parse_args() + +os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}" + +import gc + +import matplotlib.pyplot as plt +import numpy as np +import torch +import yaml +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image + +from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector +from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0 +from src.attn_utils.seq_aligner import get_refinement_mapper +from src.callback.callback_fn import CallbackAll +from src.inversion.inverse import get_inversed_latent_list +from src.inversion.scheduling_flow_inverse import \ + FlowMatchEulerDiscreteForwardScheduler +from src.pipeline.flux_pipeline import NewFluxPipeline +from src.transformer_utils.transformer_utils import (FeatureCollector, + FeatureReplace) +from src.utils import (find_token_id_differences, find_word_token_indices, + get_flux_pipeline, mask_decode, mask_interpolate) + + +def fix_seed(random_seed): + """ + fix seed to control any randomness from a code + (enable stability of the experiments' results.) + """ + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) # if use multi-GPU + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(random_seed) + random.seed(random_seed) + +def main(args): + fix_seed(args.seed) + device = torch.device('cuda') + + pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline) + attn_proc = NewFluxAttnProcessor2_0 + pipe = pipe.to(device) + + layer_order = range(57) + + ca_layer_list = layer_order[args.ca_attn_layer_from:args.ca_attn_layer_to] + sa_layer_list = layer_order[args.feature_layer_to:args.sa_attn_layer_to] + feature_layer_list = layer_order[args.feature_layer_from:args.feature_layer_to] + + + img_path = args.img_path + source_img = Image.open(img_path).resize((1024, 1024)).convert("RGB") + img_base_name = os.path.splitext(img_path)[0].split('/')[-1] + result_img_dir = f"{args.results_dir}/seed_{args.seed}/{args.target_prompt}" + + source_prompt = args.source_prompt + target_prompt = args.target_prompt + prompts = [source_prompt, target_prompt] + + print(prompts) + mask = None + + if args.use_mask: + use_mask = True + + if args.mask_path is not None: + mask = Image.open(args.mask_path) + mask = torch.tensor(np.array(mask)).bool() + mask = mask.to(device) + + # Increase the latent blending steps if the ground truth mask is used. + args.mask_steps = int(args.num_inference_steps * 0.9) + + source_ca_index = None + target_ca_index = None + use_ca_mask = False + + elif args.use_ca_mask and source_prompt: + mask = None + if args.blend_word and args.blend_word in source_prompt: + editing_source_token_index = find_word_token_indices(source_prompt, args.blend_word, pipe.tokenizer_2) + editing_target_token_index = None + else: + editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2) + editing_source_token_index = editing_tokens_info['prompt_1']['index'] + editing_target_token_index = editing_tokens_info['prompt_2']['index'] + + use_ca_mask = True + if editing_source_token_index: + source_ca_index = editing_source_token_index + target_ca_index = None + elif editing_target_token_index: + source_ca_index = None + target_ca_index = editing_target_token_index + else: + source_ca_index = None + target_ca_index = None + use_ca_mask = False + + else: + source_ca_index = None + target_ca_index = None + use_ca_mask = False + + else: + use_mask = False + use_ca_mask = False + source_ca_index = None + target_ca_index = None + + if source_prompt: + # Use I2T-CA injection + mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512) + mappers = mappers.to(device=device) + alphas = alphas.to(device=device, dtype=pipe.dtype) + alphas = alphas[:, None, None, :] + + ca_steps = args.ca_steps + attn_adj_from = 1 + + else: + # Not use I2T-CA injection + mappers = None + alphas = None + + ca_steps = 0 + attn_adj_from=3 + + sa_steps = args.sa_steps + feature_steps = args.feature_steps + + attn_controller = AttentionAdapter( + ca_layer_list=ca_layer_list, + sa_layer_list=sa_layer_list, + ca_steps=ca_steps, + sa_steps=sa_steps, + method='replace_topk', + topk=args.attn_topk, + text_scale=args.text_scale, + mappers=mappers, + alphas=alphas, + attn_adj_from=attn_adj_from, + save_source_ca=source_ca_index is not None, + save_target_ca=target_ca_index is not None, + ) + + attn_collector = AttnCollector( + transformer=pipe.transformer, + controller=attn_controller, + attn_processor_class=NewFluxAttnProcessor2_0, + ) + + feature_controller = FeatureReplace( + layer_list=feature_layer_list, + feature_steps=feature_steps, + ) + + feature_collector = FeatureCollector( + transformer=pipe.transformer, + controller=feature_controller, + ) + + num_prompts=len(prompts) + + shape = (1, 16, 128, 128) + generator = torch.Generator(device=device).manual_seed(args.seed) + latents = randn_tensor(shape, device=device, generator=generator) + latents = pipe._pack_latents(latents, *latents.shape) + + attn_collector.restore_orig_attention() + feature_collector.restore_orig_transformer() + + t0 = time.perf_counter() + + inv_latents = get_inversed_latent_list( + pipe, + source_img, + random_noise=latents, + num_inference_steps=args.num_inference_steps, + backward_method="ode", + use_prompt_for_inversion=False, + guidance_scale_for_inversion=0, + prompt_for_inversion='', + flow_steps=args.flow_steps, + ) + + source_latents = inv_latents[::-1] + target_latents = inv_latents[::-1] + + attn_collector.register_attention_control() + feature_collector.register_transformer_control() + + callback_fn = CallbackAll( + latents=source_latents, + attn_collector=attn_collector, + feature_collector=feature_collector, + feature_inject_steps=feature_steps, + mid_step_index=args.mid_step_index, + step_start=args.step_start, + use_mask=use_mask, + use_ca_mask=use_ca_mask, + source_ca_index=source_ca_index, + target_ca_index=target_ca_index, + mask_kwargs={'dilation': args.mask_dilation}, + mask_steps=args.mask_steps, + mask=mask, + ) + + init_latent = target_latents[args.step_start] + init_latent = init_latent.repeat(num_prompts, 1, 1) + init_latent[0] = source_latents[args.mid_step_index] + + os.makedirs(result_img_dir, exist_ok=True) + pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config( + pipe.scheduler.config, + step_start=args.step_start, + margin_index_from_image=0 + ) + + attn_controller.reset() + feature_controller.reset() + attn_controller.text_scale = args.text_scale + attn_controller.cur_step = args.step_start + feature_controller.cur_step = args.step_start + + with torch.no_grad(): + images = pipe( + prompts, + latents=init_latent, + num_images_per_prompt=1, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + generator=generator, + callback_on_step_end=callback_fn, + mid_step_index=args.mid_step_index, + step_start=args.step_start, + callback_on_step_end_tensor_inputs=['latents'], + ).images + + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s.") + + source_img_path = os.path.join(result_img_dir, f"source.png") + source_img.save(source_img_path) + + for i, img in enumerate(images[1:]): + target_img_path = os.path.join(result_img_dir, f"target_{i}.png") + img.save(target_img_path) + + target_text_path = os.path.join(result_img_dir, f"target_prompts.txt") + with open(target_text_path, 'w') as file: + file.write(target_prompt + '\n') + + source_text_path = os.path.join(result_img_dir, f"source_prompt.txt") + with open(source_text_path, 'w') as file: + file.write(source_prompt + '\n') + + images = [source_img] + images + + fs=3 + n = len(images) + fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs)) + + for i, img in enumerate(images): + ax[i].imshow(img) + + ax[0].set_title('source') + ax[1].set_title(source_prompt, fontsize=7) + ax[2].set_title(target_prompt, fontsize=7) + + overall_img_path = os.path.join(result_img_dir, f"overall.png") + plt.savefig(overall_img_path, bbox_inches='tight') + plt.close() + + mask_save_dir = os.path.join(result_img_dir, f"mask") + os.makedirs(mask_save_dir, exist_ok=True) + + if use_ca_mask: + ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png") + mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L') + mask_img.save(ca_mask_path) + + del inv_latents + del init_latent + gc.collect() + torch.cuda.empty_cache() + +if __name__ == '__main__': + main(args) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bccd636b997fd1eb97ad781b5731fe4bb442d7eb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +diffusers==0.31.0 +torch==2.4.1 +pandas +matplotlib +transformers==4.44.2 +torchao +torchvision +opencv-python +scikit-image +accelerate +sentencepiece +protobuf \ No newline at end of file diff --git a/scripts/w_ca/run_bird.sh b/scripts/w_ca/run_bird.sh new file mode 100644 index 0000000000000000000000000000000000000000..f58c58e0dc9dbaafa2765376dd0a58c7faeedb3c --- /dev/null +++ b/scripts/w_ca/run_bird.sh @@ -0,0 +1,20 @@ +source_prompt='a blue and white bird sits on a branch' +target_prompt='a blue and white butterfly sits on a branch' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/bird.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/bird' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_cabin.sh b/scripts/w_ca/run_cabin.sh new file mode 100644 index 0000000000000000000000000000000000000000..25b55ae6ce4e6f2f5ac588b22890b7a366a7a7d6 --- /dev/null +++ b/scripts/w_ca/run_cabin.sh @@ -0,0 +1,20 @@ +source_prompt='a painting of a cabin in the snow with mountains in the background' +target_prompt='a painting of a car in the snow with mountains in the background' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=40 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/cabin.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/cabin' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_car.sh b/scripts/w_ca/run_car.sh new file mode 100644 index 0000000000000000000000000000000000000000..d56d72f75c61681a86307476d9e4b39c44ff0ba3 --- /dev/null +++ b/scripts/w_ca/run_car.sh @@ -0,0 +1,21 @@ +source_prompt='a sports car driving down the street' +target_prompt='stained glass window of a sports car driving down the street' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=10 + +python img_edit.py \ + --gpu 1 \ + --seed 0 \ + --img_path 'data/images/car.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/car' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --use_mask 0 \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_cat_poly.sh b/scripts/w_ca/run_cat_poly.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c743653a23c1068c4349aaa00b8b16a3a76b2d3 --- /dev/null +++ b/scripts/w_ca/run_cat_poly.sh @@ -0,0 +1,21 @@ +source_prompt='a cat is shown in a low polygonal style' +target_prompt='a fox is shown in a low polygonal style' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 1 \ + --seed 0 \ + --img_path 'data/images/cat_poly.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/cat_poly' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk + \ No newline at end of file diff --git a/scripts/w_ca/run_flower.sh b/scripts/w_ca/run_flower.sh new file mode 100644 index 0000000000000000000000000000000000000000..39aeb7d6f358233ae64f06eff6e00ecbe4d92637 --- /dev/null +++ b/scripts/w_ca/run_flower.sh @@ -0,0 +1,21 @@ +source_prompt='a pink flower with yellow center in the middle' +target_prompt='a blue flower with red center in the middle' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 1 \ + --seed 0 \ + --img_path 'data/images/flower.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/flower' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ + --blend_word 'flower' diff --git a/scripts/w_ca/run_fruit.sh b/scripts/w_ca/run_fruit.sh new file mode 100644 index 0000000000000000000000000000000000000000..a56c02c04ceeb5e60adac9b4eb444d10eef5b64f --- /dev/null +++ b/scripts/w_ca/run_fruit.sh @@ -0,0 +1,20 @@ +source_prompt='white plate with fruits on it' +target_prompt='white plate with pizza on it' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=40 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/fruit.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/fruit' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ No newline at end of file diff --git a/scripts/w_ca/run_koala.sh b/scripts/w_ca/run_koala.sh new file mode 100644 index 0000000000000000000000000000000000000000..38c6e5c7d58f3f8e5db15453c7cd4e7ea3a7e61c --- /dev/null +++ b/scripts/w_ca/run_koala.sh @@ -0,0 +1,20 @@ +source_prompt='a koala is sitting on a tree' +target_prompt='a koala and a bird is sitting on a tree' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=40 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/koala.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/koala' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_owl_heart.sh b/scripts/w_ca/run_owl_heart.sh new file mode 100644 index 0000000000000000000000000000000000000000..879ffac9dad185105f0840d2a97da1760c0b8f45 --- /dev/null +++ b/scripts/w_ca/run_owl_heart.sh @@ -0,0 +1,20 @@ +source_prompt='a cartoon painting of a cute owl with a heart on its body' +target_prompt='a cartoon painting of a cute owl with a circle on its body' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 1 \ + --seed 0 \ + --img_path 'data/images/owl_heart.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/owl_heart' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_statue.sh b/scripts/w_ca/run_statue.sh new file mode 100644 index 0000000000000000000000000000000000000000..20870119ca3ff8f92bbbff1b52f978ad7b4875d9 --- /dev/null +++ b/scripts/w_ca/run_statue.sh @@ -0,0 +1,21 @@ +source_prompt='photo of a statue in front view' +target_prompt='photo of a statue in side view' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=60 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/statue.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/statue' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ + --blend_word 'statue' diff --git a/scripts/w_ca/run_steak.sh b/scripts/w_ca/run_steak.sh new file mode 100644 index 0000000000000000000000000000000000000000..a70118ad2f548e6645fd1a3dff6d289643533866 --- /dev/null +++ b/scripts/w_ca/run_steak.sh @@ -0,0 +1,20 @@ +source_prompt='a plate with steak on it' +target_prompt='a plate with salmon on it' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=40 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/steak.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/steak' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_ca/run_tennis.sh b/scripts/w_ca/run_tennis.sh new file mode 100644 index 0000000000000000000000000000000000000000..5bb3f0ddaf3e86dbc6bcb5f79052c97f4104f189 --- /dev/null +++ b/scripts/w_ca/run_tennis.sh @@ -0,0 +1,21 @@ +source_prompt='a woman in a black tank top and pink shorts is about to hit a tennis ball' +target_prompt='a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/tennis.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/tennis' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ + --blend_word 'woman' diff --git a/scripts/w_ca/run_woman_book.sh b/scripts/w_ca/run_woman_book.sh new file mode 100644 index 0000000000000000000000000000000000000000..d92d031548f8218f8545fe83ee9d8de756da1036 --- /dev/null +++ b/scripts/w_ca/run_woman_book.sh @@ -0,0 +1,20 @@ +source_prompt='a woman sitting in the grass with a book' +target_prompt='a woman sitting in the grass with a laptop' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 1 \ + --seed 0 \ + --img_path 'data/images/woman_book.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/woman_book' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_cat_hat.sh b/scripts/w_mask/run_cat_hat.sh new file mode 100644 index 0000000000000000000000000000000000000000..e8f5495ad23487f391074b52b900f3da8f8a3882 --- /dev/null +++ b/scripts/w_mask/run_cat_hat.sh @@ -0,0 +1,21 @@ +source_prompt='a cat wearing a pink hat' +target_prompt='a tiger wearing a pink hat' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/cat_hat.jpg' \ + --mask_path 'data/masks/cat_hat.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/cat_hat' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_cat_mirror.sh b/scripts/w_mask/run_cat_mirror.sh new file mode 100644 index 0000000000000000000000000000000000000000..3713ef8e05b3abed49740d58221fde9425c5416b --- /dev/null +++ b/scripts/w_mask/run_cat_mirror.sh @@ -0,0 +1,21 @@ +source_prompt='a cat sitting next to a mirror' +target_prompt='a tiger sitting next to a mirror' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/cat_mirror.jpg' \ + --mask_path 'data/masks/cat_mirror.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/cat_mirror' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_dancing.sh b/scripts/w_mask/run_dancing.sh new file mode 100644 index 0000000000000000000000000000000000000000..71d7886a47341b30cc6bcb929b6ca330ceb212aa --- /dev/null +++ b/scripts/w_mask/run_dancing.sh @@ -0,0 +1,21 @@ +source_prompt='a photo of couples dancing' +target_prompt='a photo of silver robots dancing' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 2 \ + --seed 0 \ + --img_path 'data/images/dancing.jpeg' \ + --mask_path 'data/masks/dancing.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/dancing' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_girl_mountain.sh b/scripts/w_mask/run_girl_mountain.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b73ac62d895dd6534b5faf5d58edfc0b6519684 --- /dev/null +++ b/scripts/w_mask/run_girl_mountain.sh @@ -0,0 +1,21 @@ +source_prompt='a woman with her arms outstretched on top of a mountain' +target_prompt='a woman with her arms outstretched in front of the NewYork' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 2 \ + --seed 0 \ + --img_path 'data/images/girl_mountain.jpg' \ + --mask_path 'data/masks/girl_mountain.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/girl_mountain' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_man_tree.sh b/scripts/w_mask/run_man_tree.sh new file mode 100644 index 0000000000000000000000000000000000000000..a56461c0f224aa9b3873acd785fd34568d6722de --- /dev/null +++ b/scripts/w_mask/run_man_tree.sh @@ -0,0 +1,21 @@ +source_prompt='a man sitting on a rock with trees in the background' +target_prompt='a man sitting on a rock with a city in the background' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 2 \ + --seed 0 \ + --img_path 'data/images/man_tree.jpg' \ + --mask_path 'data/masks/man_tree.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/man_tree' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_old_couple.sh b/scripts/w_mask/run_old_couple.sh new file mode 100644 index 0000000000000000000000000000000000000000..db8a6a79cb4118a14d0320f5aa42cfb7456a459a --- /dev/null +++ b/scripts/w_mask/run_old_couple.sh @@ -0,0 +1,21 @@ +source_prompt='an older couple walking down a narrow dirt road' +target_prompt='an older couple walking down a snow coverd road' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 2 \ + --seed 0 \ + --img_path 'data/images/old_couple.jpg' \ + --mask_path 'data/masks/old_couple.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/old_couple' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_raven.sh b/scripts/w_mask/run_raven.sh new file mode 100644 index 0000000000000000000000000000000000000000..e5e51e5b3726a74a4efe1404e791cfb8578c416c --- /dev/null +++ b/scripts/w_mask/run_raven.sh @@ -0,0 +1,21 @@ +source_prompt='a black raven sits on a tree stump in the rain' +target_prompt='a white raven sits on a tree stump in the rain' + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 2 \ + --seed 0 \ + --img_path 'data/images/raven.jpg' \ + --mask_path 'data/masks/raven.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/raven' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/w_mask/run_santa.sh b/scripts/w_mask/run_santa.sh new file mode 100644 index 0000000000000000000000000000000000000000..73cb74859b6d6561eef0422d988f3604bf30ba55 --- /dev/null +++ b/scripts/w_mask/run_santa.sh @@ -0,0 +1,21 @@ +source_prompt="the christmas illustration of a santa's laughing face" +target_prompt="the christmas illustration of a santa's angry face" + +ca_steps=10 +sa_steps=7 +feature_steps=5 + +attn_topk=20 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/santa.jpg' \ + --mask_path 'data/masks/santa.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/santa' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/wo_ca/run_bear.sh b/scripts/wo_ca/run_bear.sh new file mode 100644 index 0000000000000000000000000000000000000000..33c2ad0a4a8ee11642e4c2d900ce21c8ba5d2b28 --- /dev/null +++ b/scripts/wo_ca/run_bear.sh @@ -0,0 +1,21 @@ +source_prompt='' +target_prompt='an image of Paddington the bear' + +ca_steps=0 +sa_steps=12 +feature_steps=7 + +attn_topk=20 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/bear.jpeg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/bear' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk + \ No newline at end of file diff --git a/scripts/wo_ca/run_bird_painting.sh b/scripts/wo_ca/run_bird_painting.sh new file mode 100644 index 0000000000000000000000000000000000000000..02aaf7c03f50284495ddd931c51bfaaf150c1743 --- /dev/null +++ b/scripts/wo_ca/run_bird_painting.sh @@ -0,0 +1,20 @@ +source_prompt='' +target_prompt='a photo of an eagle in the sky' + +ca_steps=0 +sa_steps=12 +feature_steps=7 + +attn_topk=20 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/bird_painting.jpg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/bird_painting' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/scripts/wo_ca/run_dancing.sh b/scripts/wo_ca/run_dancing.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0dc873094d7b74199c14376b18892607352352b --- /dev/null +++ b/scripts/wo_ca/run_dancing.sh @@ -0,0 +1,20 @@ +source_prompt='' +target_prompt='a couple of silver robots dancing in the garden' + +ca_steps=0 +sa_steps=12 +feature_steps=7 + +attn_topk=20 + +python img_edit.py \ + --gpu 3 \ + --seed 0 \ + --img_path 'data/images/dancing.jpeg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/dancing' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ No newline at end of file diff --git a/scripts/wo_ca/run_karate.sh b/scripts/wo_ca/run_karate.sh new file mode 100644 index 0000000000000000000000000000000000000000..382c9d083de5d0a014a569f7332dc8db47091523 --- /dev/null +++ b/scripts/wo_ca/run_karate.sh @@ -0,0 +1,22 @@ +source_prompt='' +target_prompt='a silver robot in the snow' + +ca_steps=0 +sa_steps=12 +feature_steps=7 + +attn_topk=20 + +python img_edit.py \ + --gpu 0 \ + --seed 0 \ + --img_path 'data/images/real_karate.jpeg' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/karate' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk \ + + \ No newline at end of file diff --git a/scripts/wo_ca/run_meditation.sh b/scripts/wo_ca/run_meditation.sh new file mode 100644 index 0000000000000000000000000000000000000000..34a584786b2582c432978ecaa0ca2ea6a07ad31b --- /dev/null +++ b/scripts/wo_ca/run_meditation.sh @@ -0,0 +1,20 @@ +source_prompt='' +target_prompt='a photo of a golden statue in a temple' + +ca_steps=0 +sa_steps=12 +feature_steps=7 + +attn_topk=20 + +python img_edit.py \ + --gpu 1 \ + --seed 10 \ + --img_path 'data/images/meditation.png' \ + --source_prompt "$source_prompt" \ + --target_prompt "$target_prompt" \ + --results_dir 'results/meditation' \ + --ca_steps $ca_steps \ + --sa_steps $sa_steps \ + --feature_steps $feature_steps \ + --attn_topk $attn_topk diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/attn_utils/__init__.py b/src/attn_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/attn_utils/attn_utils.py b/src/attn_utils/attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..47d5d3f15b0d8a100aea3a58e996ea3f836ba9d5 --- /dev/null +++ b/src/attn_utils/attn_utils.py @@ -0,0 +1,300 @@ +import abc +import gc +import math +import numbers +from collections import defaultdict +from difflib import SequenceMatcher +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention_processor import FluxAttnProcessor2_0 +from PIL import Image +from scipy.ndimage import binary_dilation +from skimage.filters import threshold_otsu + + +class AttentionControl(abc.ABC): + def __init__(self,): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + self.get_model_info() + + def get_model_info(self): + t5_dim = 512 + latent_dim = 4096 + attn_dim = t5_dim + latent_dim + index_all = torch.arange(attn_dim) + t5_index, latent_index = index_all.split([t5_dim, latent_dim]) + patch_order = ['t5', 'latent'] + + self.model_info = { + 't5_dim': t5_dim, + 'latent_dim': latent_dim, + 'attn_dim': attn_dim, + 't5_index': t5_index, + 'latent_index': latent_index, + 'patch_order': patch_order + } + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @abc.abstractmethod + def forward(self, q, k, v, place_in_transformer: str): + raise NotImplementedError + + @torch.no_grad() + def __call__(self, q, k, v, place_in_transformer: str): + hs = self.forward(q, k, v, place_in_transformer) + + self.cur_att_layer += 1 + + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + + self.between_steps() + return hs + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight + + def split_attn(self, attn, q='latent', k='latent'): + patch_order = self.model_info['patch_order'] + t5_dim = self.model_info['t5_dim'] + latent_dim = self.model_info['latent_dim'] + clip_dim = self.model_info.get('clip_dim', None) + + idx_q = patch_order.index(q) + idx_k = patch_order.index(k) + split = [t5_dim, latent_dim] + + return attn.split(split, dim=-2)[idx_q].split(split, dim=-1)[idx_k].clone() + + +class AttentionAdapter(AttentionControl): + def __init__( + self, + ca_layer_list=list(range(13,45)), + sa_layer_list=list(range(22,45)), + method='replace_topk', + topk=1, + text_scale=1, + mappers=None, + alphas=None, + ca_steps=10, + sa_steps=7, + save_source_ca=False, + save_target_ca=False, + use_sa_replace=True, + attn_adj_from=0, + ): + super(AttentionAdapter, self).__init__() + self.ca_layer_list = ca_layer_list + self.sa_layer_list = sa_layer_list + self.method = method + self.topk = topk + self.text_scale = text_scale + self.use_sa_replace = use_sa_replace + + self.ca_steps = ca_steps + self.sa_steps = sa_steps + + self.mappers = mappers + self.alphas = alphas + + self.save_source_ca = save_source_ca + self.save_target_ca = save_target_ca + + self.attn_adj_from = attn_adj_from + + self.source_ca = None + + self.source_attn = {} + + @staticmethod + def get_empty_store(): + return defaultdict(list) + + def refine_ca(self, source_ca, target_ca): + source_ca_replace = source_ca[:, :, self.mappers].permute(2, 0, 1, 3) + new_ca = source_ca_replace * self.alphas + target_ca * (1 - self.alphas) * self.text_scale + return new_ca + + def replace_ca(self, source_ca, target_ca): + new_ca = torch.einsum('hpw,bwn->bhpn', source_ca, self.mappers) + return new_ca + + def get_index_from_source(self, attn, topk): + if self.method == 'replace_topk': + sa_max = torch.topk(attn, k=topk, dim=-1)[0][..., [-1]] + idx_from_source = (attn > sa_max) + elif self.method == 'replace_z': + log_attn = torch.log(attn) + idx_from_source = log_attn > (log_attn.mean(-1, keepdim=True) + self.z_value * log_attn.std(-1, keepdim=True)) + else: + print("No method") + return idx_from_source + + def forward(self, q, k, v, place_in_transformer: str): + layer_index = int(place_in_transformer.split('_')[-1]) + + use_ca_replace = False + use_sa_replace = False + if (layer_index in self.ca_layer_list) and (self.cur_step in range(0, self.ca_steps)): + if self.mappers is not None: + use_ca_replace = True + if (layer_index in self.sa_layer_list) and (self.cur_step in range(0, self.sa_steps)): + use_sa_replace = True + + if not (use_sa_replace or use_ca_replace): + return F.scaled_dot_product_attention(q, k, v) + + latent_index = self.model_info['latent_index'] + t5_index = self.model_info['t5_index'] + clip_index = self.model_info.get('clip_index', None) + + # Get attention map first + attn = self.scaled_dot_product_attention(q, k, v) + source_attn = attn[:1] + target_attn = attn[1:] + source_hs = source_attn @ v[:1] + + source_ca = self.split_attn(source_attn, 'latent', 't5') + target_ca = self.split_attn(target_attn, 'latent', 't5') + + if use_ca_replace: + if self.save_source_ca: + if layer_index == self.ca_layer_list[0]: + self.source_ca = source_ca / source_ca.sum(dim=-1, keepdim=True) + else: + self.source_ca += source_ca / source_ca.sum(dim=-1, keepdim=True) + + if self.save_target_ca: + if layer_index == self.ca_layer_list[0]: + self.target_ca = target_ca / target_ca.sum(dim=-1, keepdim=True) + else: + self.target_ca += target_ca / target_ca.sum(dim=-1, keepdim=True) + + if self.alphas is None: + target_ca = self.replace_ca(source_ca[0], target_ca) + else: + target_ca = self.refine_ca(source_ca[0], target_ca) + + target_sa = self.split_attn(target_attn, 'latent', 'latent') + if use_sa_replace: + if self.cur_step < self.attn_adj_from: + topk = 1 + else: + topk = self.topk + + if self.method == 'base': + new_sa = self.split_attn(target_attn, 'latent', 'latent') + else: + source_sa = self.split_attn(source_attn, 'latent', 'latent') + if topk <= 1: + new_sa = source_sa.clone().repeat(len(target_attn), 1, 1, 1) + else: + idx_from_source = self.get_index_from_source(source_sa, topk) + # Get top-k attention values from target SA + new_sa = target_sa.clone() + new_sa.mul_(idx_from_source) + # Normalize + new_sa.div_(new_sa.sum(-1,keepdim=True)) + new_sa.nan_to_num_(0) + new_sa.mul_((source_sa * idx_from_source).sum(-1, keepdim=True)) + # Get rest attention vlaues from source SA + new_sa.add_(source_sa * idx_from_source.logical_not()) + # Additional normalize (Optional) + # new_sa.mul_((target_sa.sum(dim=(-1), keepdim=True) / new_sa.sum(dim=(-1), keepdim=True))) + target_sa = new_sa + + target_l_to_l = target_sa @ v[1:, :, latent_index] + target_l_to_t = target_ca @ v[1:, :, t5_index] + + if self.alphas is None: + target_latent_hs = target_l_to_l + target_l_to_t * self.text_scale + else: + # text scaling is already performed in self.refine_ca() + target_latent_hs = target_l_to_l + target_l_to_t + + target_text_hs = target_attn[:,:, t5_index,:] @ v[1:] + target_hs = torch.cat([target_text_hs, target_latent_hs], dim=-2) + hs = torch.cat([source_hs, target_hs]) + return hs + + def reset(self): + super(AttentionAdapter, self).reset() + del self.source_attn + gc.collect() + torch.cuda.empty_cache() + + self.source_attn = {} + +class AttnCollector: + def __init__(self, transformer, controller, attn_processor_class, layer_list=[]): + self.transformer = transformer + self.controller = controller + self.attn_processor_class = attn_processor_class + + def restore_orig_attention(self): + attn_procs = {} + place='' + for i, (name, attn_proc) in enumerate(self.transformer.attn_processors.items()): + attn_proc = self.attn_processor_class( + controller=None, place_in_transformer=place, + ) + attn_procs[name] = attn_proc + self.transformer.set_attn_processor(attn_procs) + self.controller.num_att_layers = 0 + + def register_attention_control(self): + attn_procs = {} + count = 0 + for i, (name, attn_proc) in enumerate(self.transformer.attn_processors.items()): + if 'single' in name: + place = f'single_{i}' + else: + place = f'joint_{i}' + count += 1 + + attn_proc = self.attn_processor_class( + controller=self.controller, place_in_transformer=place, + ) + attn_procs[name] = attn_proc + + self.transformer.set_attn_processor(attn_procs) + self.controller.num_att_layers = count diff --git a/src/attn_utils/flux_attn_processor.py b/src/attn_utils/flux_attn_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..4124da720affe49934cbe5f51d3edb736ff6a3fa --- /dev/null +++ b/src/attn_utils/flux_attn_processor.py @@ -0,0 +1,100 @@ +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + + +class NewFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, controller, place_in_transformer): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.controller = controller + self.place_in_transformer = place_in_transformer + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if self.controller is None: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + else: + hidden_states = self.controller(query, key, value, self.place_in_transformer) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + diff --git a/src/attn_utils/mask_utils.py b/src/attn_utils/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5325a91c0f57638d4fce51871ec2e6d68b90150f --- /dev/null +++ b/src/attn_utils/mask_utils.py @@ -0,0 +1,64 @@ +import torch +import torch.nn.functional as F +from scipy.ndimage import binary_dilation +from skimage.filters import threshold_otsu + + +def gaussian_blur(image, kernel_size=7, sigma=2): + """ + Apply Gaussian blur to a binary mask image. + + Args: + image (torch.Tensor): Input binary mask (1x1xHxW or HxW) as a PyTorch tensor. + kernel_size (int): Size of the Gaussian kernel. Should be odd. + sigma (float): Standard deviation of the Gaussian kernel. + + Returns: + torch.Tensor: Blurred mask image. + """ + # Ensure kernel size is odd + if kernel_size % 2 == 0: + kernel_size += 1 + + # Generate Gaussian kernel + x = torch.arange(kernel_size, device=image.device, dtype=image.dtype) - kernel_size // 2 + gaussian_1d = torch.exp(-(x**2) / (2 * sigma**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_kernel = gaussian_1d[:, None] * gaussian_1d[None, :] + gaussian_kernel = gaussian_kernel / gaussian_kernel.sum() # Normalize + + # Reshape to fit convolution: (out_channels, in_channels, kH, kW) + gaussian_kernel = gaussian_kernel.unsqueeze(0).unsqueeze(0) + + # Ensure image is 4D (BxCxHxW) + if image.ndim == 2: # HxW + image = image.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + elif image.ndim == 3: # CxHxW + image = image.unsqueeze(0) # Add batch dimension + + # Convolve image with Gaussian kernel + blurred_image = F.conv2d(image, gaussian_kernel, padding=kernel_size // 2) + + return blurred_image.squeeze() # Remove extra dimensions + +def mask_interpolate(mask, size=128): + mask = torch.tensor(mask) + mask = F.interpolate(mask[None, None, ...], size, mode='bicubic') + mask = mask.squeeze() + return mask + +def get_mask(ca, ca_index, gb_kernel=11, gb_sigma=2, dilation=1, nbins=64): + if ca is None: + return None + else: + ca = ca[0].mean(0) + token_ca = ca[..., ca_index].mean(dim=-1).reshape(64, 64) + token_ca = gaussian_blur(token_ca, kernel_size=gb_kernel, sigma=gb_sigma) + token_ca = mask_interpolate(token_ca, size=1024) + thres = threshold_otsu(token_ca.float().cpu().numpy(), nbins=nbins) + mask = token_ca > thres + mask = mask_interpolate(mask.to(ca.dtype), 128) + if dilation: + mask = binary_dilation(mask.float().cpu().numpy(), iterations=dilation) + mask = torch.tensor(mask, device=ca.device, dtype=ca.dtype) + return mask \ No newline at end of file diff --git a/src/attn_utils/seq_aligner.py b/src/attn_utils/seq_aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..92888dd4b1db77d0e89c392820bc82b83a1ed410 --- /dev/null +++ b/src/attn_utils/seq_aligner.py @@ -0,0 +1,195 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + + +class ScoreParams: + + def __init__(self, gap, match, mismatch): + self.gap = gap + self.match = match + self.mismatch = mismatch + + def mis_match_char(self, x, y): + if x != y: + return self.mismatch + else: + return self.match + + +def get_matrix(size_x, size_y, gap): + matrix = [] + for i in range(len(size_x) + 1): + sub_matrix = [] + for j in range(len(size_y) + 1): + sub_matrix.append(0) + matrix.append(sub_matrix) + for j in range(1, len(size_y) + 1): + matrix[0][j] = j*gap + for i in range(1, len(size_x) + 1): + matrix[i][0] = i*gap + return matrix + + +def get_matrix(size_x, size_y, gap): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = (np.arange(size_y) + 1) * gap + matrix[1:, 0] = (np.arange(size_x) + 1) * gap + return matrix + + +def get_traceback_matrix(size_x, size_y): + matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) + matrix[0, 1:] = 1 + matrix[1:, 0] = 2 + matrix[0, 0] = 4 + return matrix + + +def global_align(x, y, score): + matrix = get_matrix(len(x), len(y), score.gap) + trace_back = get_traceback_matrix(len(x), len(y)) + for i in range(1, len(x) + 1): + for j in range(1, len(y) + 1): + left = matrix[i, j - 1] + score.gap + up = matrix[i - 1, j] + score.gap + diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) + matrix[i, j] = max(left, up, diag) + if matrix[i, j] == left: + trace_back[i, j] = 1 + elif matrix[i, j] == up: + trace_back[i, j] = 2 + else: + trace_back[i, j] = 3 + return matrix, trace_back + + +def get_aligned_sequences(x, y, trace_back): + x_seq = [] + y_seq = [] + i = len(x) + j = len(y) + mapper_y_to_x = [] + while i > 0 or j > 0: + if trace_back[i, j] == 3: + x_seq.append(x[i-1]) + y_seq.append(y[j-1]) + i = i-1 + j = j-1 + mapper_y_to_x.append((j, i)) + elif trace_back[i][j] == 1: + x_seq.append('-') + y_seq.append(y[j-1]) + j = j-1 + mapper_y_to_x.append((j, -1)) + elif trace_back[i][j] == 2: + x_seq.append(x[i-1]) + y_seq.append('-') + i = i-1 + elif trace_back[i][j] == 4: + break + mapper_y_to_x.reverse() + return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) + + +def get_mapper(x: str, y: str, tokenizer, max_len=77): + x_seq = tokenizer.encode(x) + y_seq = tokenizer.encode(y) + score = ScoreParams(0, 1, -1) + matrix, trace_back = global_align(x_seq, y_seq, score) + mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] + alphas = torch.ones(max_len) + alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() + mapper = torch.zeros(max_len, dtype=torch.int64) + mapper[:mapper_base.shape[0]] = mapper_base[:, 1] + mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) + return mapper, alphas + + +def get_refinement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers, alphas = [], [] + for i in range(1, len(prompts)): + mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + alphas.append(alpha) + return torch.stack(mappers), torch.stack(alphas) + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(' ') + words_y = y.split(' ') + if len(words_x) != len(words_y): + raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + return torch.from_numpy(mapper).float() + + + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) \ No newline at end of file diff --git a/src/callback/__init__.py b/src/callback/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/callback/callback_fn.py b/src/callback/callback_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..13c9b4226f66ef71cf5c4291db9715e654a0e3b8 --- /dev/null +++ b/src/callback/callback_fn.py @@ -0,0 +1,102 @@ +import torch +import torch.nn.functional as F +from diffusers.callbacks import PipelineCallback +from scipy.ndimage import binary_dilation +from skimage.filters import threshold_otsu + +from ..attn_utils.mask_utils import get_mask + + +class CallbackLatentStore(PipelineCallback): + tensor_inputs = ['latents'] + + def __init__(self): + self.latents = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs): + self.latents.append(callback_kwargs['latents']) + return callback_kwargs + +class CallbackAll(PipelineCallback): + tensor_inputs = ['latents'] + def __init__( + self, + latents, + attn_collector, + feature_collector, + feature_inject_steps, + mid_step_index=0, + step_start=0, + use_mask=False, + use_ca_mask=False, + source_ca_index=None, + target_ca_index=None, + mask_steps=18, + mask_kwargs={}, + mask=None, + ): + self.latents = latents + + self.attn_collector = attn_collector + self.feature_collector = feature_collector + self.feature_inject_steps = feature_inject_steps + + self.mid_step_index = mid_step_index + self.step_start = step_start + + self.mask = mask + self.mask_steps = mask_steps + + self.use_mask = use_mask + self.use_ca_mask = use_ca_mask + self.source_ca_index = source_ca_index + self.target_ca_index = target_ca_index + self.mask_kwargs = mask_kwargs + + def latent_blend(self, s, t, mask): + return s * (1-mask) + t * mask + # return s * mask.logical_not() + t * mask + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs): + cur_step = step_index + self.step_start + + if self.latents is None: + pass + elif cur_step < self.mid_step_index: + inject_latent = self.latents[self.mid_step_index] + callback_kwargs['latents'][:1] = inject_latent + + if self.use_mask: + if self.use_ca_mask: + if self.source_ca_index is not None: + source_ca = self.attn_collector.controller.source_ca + mask = get_mask(source_ca, self.source_ca_index, **self.mask_kwargs) + elif self.target_ca_index is not None: + if cur_step < 5: + return callback_kwargs + target_ca = self.attn_collector.controller.target_ca + mask = get_mask(target_ca, self.target_ca_index, **self.mask_kwargs) + self.mask = mask + elif self.mask is not None: + mask = self.mask + else: + return callback_kwargs + + if (cur_step < self.mask_steps): + mask = mask.to(pipeline.dtype) + target_latent = callback_kwargs['latents'][1:] + blend_latent = self.latents[cur_step+1] + # if cur_step + 1 < self.mid_step_index: + # blend_latent = self.latents[cur_step+1] + # else: + # blend_latent = callback_kwargs['latents'][:1] + + new_latent = self.latent_blend( + pipeline._unpack_latents(blend_latent, 1024, 1024, pipeline.vae_scale_factor), + pipeline._unpack_latents(target_latent, 1024, 1024, pipeline.vae_scale_factor), + mask + ) + new_latent = pipeline._pack_latents(new_latent, *new_latent.shape) + callback_kwargs['latents'][1:] = new_latent + + return callback_kwargs \ No newline at end of file diff --git a/src/inversion/__init__.py b/src/inversion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/inversion/inverse.py b/src/inversion/inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d082dcff1d0f4044b5ab761c56b659eb1baa7 --- /dev/null +++ b/src/inversion/inverse.py @@ -0,0 +1,115 @@ +import gc + +import numpy as np +import torch +from diffusers.pipelines.flux.pipeline_flux import calculate_shift +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image +from torchvision import transforms + +from ..callback.callback_fn import CallbackLatentStore +from .scheduling_flow_inverse import (FlowMatchEulerDiscreteBackwardScheduler, + FlowMatchEulerDiscreteForwardScheduler) + + +@torch.no_grad() +def img_to_latent(img, vae): + normalize = transforms.Normalize(mean=[0.5],std=[0.5]) + trans = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + + tensor_img = trans(img)[None, ...] + tensor_img = tensor_img.to(dtype=vae.dtype, device=vae.device) + posterior = vae.encode(tensor_img).latent_dist + latents = (posterior.mean - vae.config.shift_factor) * vae.config.scaling_factor + # latents = posterior.mean + return latents + + +@torch.no_grad() +def get_inversed_latent_list( + pipe, + image: Image, + random_noise=None, + num_inference_steps: int = 28, + backward_method: str = 'ode', + model_name: str = 'flux', + res=(1024, 1024), + use_prompt_for_inversion=False, + guidance_scale_for_inversion=0, + prompt_for_inversion=None, + seed=0, + flow_steps=1, + ode_steps=1, + intermediate_steps=None +): + img = image.resize(res) + img_latent = img_to_latent(image, pipe.vae) + device = img_latent.device + + generator = torch.Generator(device=device).manual_seed(seed) + + + if random_noise is None: + random_noise = randn_tensor(img_latent.shape, device=device, generator=generator) + if model_name == 'flux': + random_noise = pipe._pack_latents(random_noise, *random_noise.shape) + if model_name == 'flux': + img_latent = pipe._pack_latents(img_latent, *img_latent.shape) + + pipe.scheduler = FlowMatchEulerDiscreteBackwardScheduler.from_config( + pipe.scheduler.config, + margin_index_from_noise=0, + margin_index_from_image=0, + intermediate_steps=intermediate_steps + ) + if model_name == 'flux': + image_seq_len = img_latent.shape[1] + mu = calculate_shift( + image_seq_len, + pipe.scheduler.config.base_image_seq_len, + pipe.scheduler.config.max_image_seq_len, + pipe.scheduler.config.base_shift, + pipe.scheduler.config.max_shift, + ) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + else: + mu = None + sigmas = None + pipe.scheduler.set_timesteps(num_inference_steps=num_inference_steps, mu=mu, sigmas=sigmas) + + sigmas = pipe.scheduler.sigmas + timesteps = pipe.scheduler.timesteps + + if backward_method == 'flow': + inv_latents = [img_latent] + for sigma in sigmas: + inv_latent = (1 - sigma) * img_latent + sigma * random_noise + inv_latents.append(inv_latent) + + elif backward_method == 'ode': + inv_latents = [img_latent] + img_latent_new = img_latent.to(pipe.dtype) + random_noise = random_noise.to(pipe.dtype) + + callback_fn = CallbackLatentStore() + inv_latent = pipe.inversion( + latents=img_latent_new, + rand_latents=random_noise, + flow_steps=flow_steps, + prompt=prompt_for_inversion if use_prompt_for_inversion else '', + num_images_per_prompt=1, + output_type='latent', + width=res[0], height=res[1], + guidance_scale=guidance_scale_for_inversion, + num_inference_steps=num_inference_steps, + callback_on_step_end=callback_fn + ).images + inv_latents = inv_latents + callback_fn.latents + del img_latent + gc.collect() + torch.cuda.empty_cache() + + return inv_latents \ No newline at end of file diff --git a/src/inversion/scheduling_flow_inverse.py b/src/inversion/scheduling_flow_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..ed566b65275c646510212ef727f2704243169cc3 --- /dev/null +++ b/src/inversion/scheduling_flow_inverse.py @@ -0,0 +1,613 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import (FlowMatchEulerDiscreteScheduler, + FlowMatchHeunDiscreteScheduler) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor + + +@dataclass +class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteBackwardScheduler(FlowMatchEulerDiscreteScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + margin_index_from_noise: int = 3, + margin_index_from_image: int = 1, + intermediate_steps=None + ): + super().__init__( + num_train_timesteps=num_train_timesteps, + shift=shift, + use_dynamic_shifting=use_dynamic_shifting, + base_shift=base_shift, + max_shift=max_shift, + base_image_seq_len=base_image_seq_len, + max_image_seq_len=max_image_seq_len, + ) + self.margin_index_from_noise = margin_index_from_noise + self.margin_index_from_image = margin_index_from_image + self.intermediate_steps = intermediate_steps + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if num_inference_steps is None: + num_inference_steps = len(sigmas) + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + self.timesteps = torch.cat([timesteps, torch.zeros(1, device=timesteps.device)]) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = self.timesteps.flip(0) + self.sigmas = self.sigmas.flip(0) + + self.timesteps = self.timesteps[ + self.config.margin_index_from_image : num_inference_steps - self.config.margin_index_from_noise + ] + self.sigmas = self.sigmas[ + self.config.margin_index_from_image : num_inference_steps - self.config.margin_index_from_noise + 1 + ] + + if self.config.intermediate_steps is not None: + # self.timesteps = torch.linspace(self.timesteps[0], self.timesteps[-1], self.config.intermediate_steps).to(self.timesteps.device) + self.sigmas = torch.linspace(self.sigmas[0], self.sigmas[-1], self.config.intermediate_steps + 1).to(self.timesteps.device) + self.timesteps = self.sigmas[:-1] * 1000 + + + self._step_index = None + self._begin_index = None + + +class FlowMatchEulerDiscreteForwardScheduler(FlowMatchEulerDiscreteScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + margin_index_from_noise: int = 3, + margin_index_from_image: int = 0, + ): + super().__init__( + num_train_timesteps=num_train_timesteps, + shift=shift, + use_dynamic_shifting=use_dynamic_shifting, + base_shift=base_shift, + max_shift=max_shift, + base_image_seq_len=base_image_seq_len, + max_image_seq_len=max_image_seq_len, + ) + self.margin_index_from_noise = margin_index_from_noise + self.margin_index_from_image = margin_index_from_image + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if num_inference_steps is None: + num_inference_steps = len(sigmas) + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = self.timesteps[ + self.config.margin_index_from_noise : num_inference_steps - self.config.margin_index_from_image + ] + self.sigmas = self.sigmas[ + self.config.margin_index_from_noise : num_inference_steps - self.config.margin_index_from_image + 1 + ] + + self._step_index = None + self._begin_index = None + + + +class FlowMatchHeunDiscreteForwardScheduler(FlowMatchHeunDiscreteScheduler): + _compatibles = [] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + margin_index: int = 0, + use_dynamic_shifting = False + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.use_dynamic_shifting = use_dynamic_shifting + self.margin_index = margin_index + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + timesteps = timesteps[self.config.margin_index:] + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + self.timesteps = timesteps.to(device=device) + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + sigmas = sigmas[self.config.margin_index:] + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + self._step_index = None + self._begin_index = None + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.state_in_first_order: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma + # 2. convert to an ODE derivative for 1st order + derivative = (sample - denoised) / sigma_hat + # 3. Delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma_next + # 2. 2nd order / Heun's method + derivative = (sample - denoised) / sigma_next + derivative = 0.5 * (self.prev_derivative + derivative) + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return prev_sample + + +class FlowMatchHeunDiscreteBackwardScheduler(FlowMatchHeunDiscreteScheduler): + _compatibles = [] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + margin_index: int = 0, + use_dynamic_shifting = False + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.use_dynamic_shifting = use_dynamic_shifting + self.margin_index = margin_index + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + timesteps = timesteps[self.config.margin_index:].flip(0) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + self.timesteps = timesteps.to(device=device) + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + sigmas = sigmas[self.config.margin_index:].flip(0) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + self._step_index = None + self._begin_index = None + + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + if sigma == 0: + prev_sample = sample + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 2 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.state_in_first_order: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma + # 2. convert to an ODE derivative for 1st order + derivative = (sample - denoised) / sigma_hat + # 3. Delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma_next + # 2. 2nd order / Heun's method + derivative = (sample - denoised) / sigma_next + derivative = 0.5 * (self.prev_derivative + derivative) + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) \ No newline at end of file diff --git a/src/pipeline/flux_pipeline.py b/src/pipeline/flux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..dc63df8c69916f655aa46c771c399565266f0731 --- /dev/null +++ b/src/pipeline/flux_pipeline.py @@ -0,0 +1,451 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_flux import (calculate_shift, + retrieve_timesteps) +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_xla_available, logging, + replace_example_docstring, scale_lora_layers, + unscale_lora_layers) +from diffusers.utils.torch_utils import randn_tensor +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5TokenizerFast) + + +class NewFluxPipeline(FluxPipeline): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + mid_step_index=0, + step_start=0 + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # prompt, + # prompt_2, + # height, + # width, + # prompt_embeds=prompt_embeds, + # pooled_prompt_embeds=pooled_prompt_embeds, + # callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + # max_sequence_length=max_sequence_length, + # ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.repeat(latents.shape[0]).to(latents.dtype) + if i + step_start < mid_step_index: + timestep[0] = timesteps[mid_step_index] + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + @torch.no_grad() + def inversion( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + rand_latents: Optional[torch.FloatTensor] = None, + flow_steps: int = 0, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, _ = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = 0 + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + img_latents = latents.clone() + flowed_latents = img_latents + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps[:num_inference_steps]): + if self.interrupt: + continue + + if i < flow_steps: + t_next = timesteps[i+1] + sigma_next = t_next / 1000 + latents = img_latents * (1 - sigma_next) + rand_latents * sigma_next + else: + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/src/transformer_utils/__init__.py b/src/transformer_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/transformer_utils/flux_transformer_forward.py b/src/transformer_utils/flux_transformer_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..071ff8aec706906fbfff8411a26c2935abf68a3b --- /dev/null +++ b/src/transformer_utils/flux_transformer_forward.py @@ -0,0 +1,96 @@ +import abc +import types + +import torch +from diffusers.models.transformers.transformer_flux import ( + FluxSingleTransformerBlock, FluxTransformerBlock) + + +def joint_transformer_forward(self, controller, place_in_transformer): + def forward( + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None + ): + + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + if controller is not None: + ff_output = controller(ff_output, place_in_transformer) + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output + + encoder_hidden_states = encoder_hidden_states + context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + return forward + + +def single_transformer_forward(self, controller, place_in_transformer): + def forward( + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None + ): + residual = hidden_states + + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_input = norm_hidden_states + + mlp_hidden_states = self.act_mlp(self.proj_mlp(mlp_input)) + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + + # Change here + if controller is not None: + hidden_states = controller(hidden_states, place_in_transformer) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + return forward \ No newline at end of file diff --git a/src/transformer_utils/transformer_utils.py b/src/transformer_utils/transformer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c56194012077fff26b6ae8eb83d157cfafc8f24 --- /dev/null +++ b/src/transformer_utils/transformer_utils.py @@ -0,0 +1,112 @@ +import abc +import types + +import torch +from diffusers.models.transformers.transformer_flux import ( + FluxSingleTransformerBlock, FluxTransformerBlock) + +from .flux_transformer_forward import (joint_transformer_forward, + single_transformer_forward) + + +class FeatureCollector: + def __init__(self, transformer, controller, layer_list=[]): + self.transformer = transformer + self.controller = controller + self.layer_list = layer_list + + def register_transformer_control(self): + index = 0 + for joint_transformer in self.transformer.transformer_blocks: + place_in_transformer = f'joint_{index}' + joint_transformer.forward = joint_transformer_forward(joint_transformer, self.controller, place_in_transformer) + index +=1 + + for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): + place_in_transformer = f'single_{index}' + single_transformer.forward = single_transformer_forward(single_transformer, self.controller, place_in_transformer) + index +=1 + + self.controller.num_layers = index + + def restore_orig_transformer(self): + place_in_transformer='' + + for joint_transformer in self.transformer.transformer_blocks: + joint_transformer.forward = joint_transformer_forward(joint_transformer, None, place_in_transformer) + + for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): + single_transformer.forward = single_transformer_forward(single_transformer, None, place_in_transformer) + + +class FeatureControl(abc.ABC): + def __init__(self): + self.cur_step = 0 + self.num_layers = -1 + self.cur_layer = 0 + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @abc.abstractmethod + def forward(self, attn, place_in_transformer: str): + raise NotImplementedError + + @torch.no_grad() + def __call__(self, hidden_state, place_in_transformer: str): + hidden_state = self.forward(hidden_state, place_in_transformer) + self.cur_layer = self.cur_layer + 1 + + if self.cur_layer == self.num_layers: + self.cur_layer = 0 + self.cur_step = self.cur_step + 1 + self.between_steps() + + return hidden_state + + def reset(self): + self.cur_step = 0 + self.cur_layer = 0 + + +class FeatureReplace(FeatureControl): + def __init__( + self, + layer_list=[], + feature_steps=7 + ): + super(FeatureReplace, self).__init__() + self.layer_list = layer_list + self.feature_steps = feature_steps + + + def forward(self, hidden_states, place_in_transformer): + layer_index = int(place_in_transformer.split('_')[-1]) + if (layer_index not in self.layer_list) or (self.cur_step not in range(0, self.feature_steps)): + return hidden_states + + hs_dim = hidden_states.shape[1] + + t5_dim = 512 + latent_dim = 4096 + attn_dim = t5_dim + latent_dim + index_all = torch.arange(attn_dim) + t5_index, latent_index = index_all.split([t5_dim, latent_dim]) + + if 'single' in place_in_transformer: + mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) + mask[t5_index] = 0 # Only use image latent + else: + mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) + + mask = mask[None, :, None] + + source_hs = hidden_states[:1] + target_hs = hidden_states[1:] + + target_hs = source_hs * mask + target_hs * (1 - mask) + hidden_states[1:] = target_hs + return hidden_states diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8843f0fe69af605c3b821b90b07e362d5f278022 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,144 @@ +import base64 +import difflib +import json +import os + +import diffusers +import numpy as np +import requests +import torch +import torch.nn.functional as F +import transformers +from diffusers import (AutoencoderKL, DiffusionPipeline, + FlowMatchEulerDiscreteScheduler, FluxPipeline, + FluxTransformer2DModel, SD3Transformer2DModel, + StableDiffusion3Pipeline) +from diffusers.callbacks import PipelineCallback +from torchao.quantization import int8_weight_only, quantize_ +from torchvision import transforms +from transformers import (AutoModelForCausalLM, AutoProcessor, CLIPTextModel, + CLIPTextModelWithProjection, T5EncoderModel) + + +def get_flux_pipeline( + model_id="black-forest-labs/FLUX.1-dev", + pipeline_class=FluxPipeline, + torch_dtype=torch.bfloat16, + quantize=False +): + ############ Diffusion Transformer ############ + transformer = FluxTransformer2DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=torch_dtype + ) + + ############ Text Encoder ############ + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch_dtype + ) + + ############ Text Encoder 2 ############ + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype + ) + + ############ VAE ############ + vae = AutoencoderKL.from_pretrained( + model_id, subfolder="vae", torch_dtype=torch_dtype + ) + + if quantize: + quantize_(transformer, int8_weight_only()) + quantize_(text_encoder, int8_weight_only()) + quantize_(text_encoder_2, int8_weight_only()) + quantize_(vae, int8_weight_only()) + + # Initialize the pipeline now. + pipe = pipeline_class.from_pretrained( + model_id, + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + torch_dtype=torch_dtype + ) + return pipe + +def mask_decode(encoded_mask,image_shape=[512,512]): + length=image_shape[0]*image_shape[1] + mask_array=np.zeros((length,)) + + for i in range(0,len(encoded_mask),2): + splice_len=min(encoded_mask[i+1],length-encoded_mask[i]) + for j in range(splice_len): + mask_array[encoded_mask[i]+j]=1 + + mask_array=mask_array.reshape(image_shape[0], image_shape[1]) + # to avoid annotation errors in boundary + mask_array[0,:]=1 + mask_array[-1,:]=1 + mask_array[:,0]=1 + mask_array[:,-1]=1 + + return mask_array + +def mask_interpolate(mask, size=128): + mask = torch.tensor(mask) + mask = F.interpolate(mask[None, None, ...], size, mode='bicubic') + mask = mask.squeeze() + return mask + +def get_blend_word_index(prompt, word, tokenizer): + input_ids = tokenizer(prompt).input_ids + blend_ids = tokenizer(word, add_special_tokens=False).input_ids + + index = [] + for i, id in enumerate(input_ids): + # Ignore common token + if id < 100: + continue + if id in blend_ids: + index.append(i) + + return index + +def find_token_id_differences(prompt1, prompt2, tokenizer): + # Tokenize inputs and get input IDs + tokens1 = tokenizer.encode(prompt1, add_special_tokens=False) + tokens2 = tokenizer.encode(prompt2, add_special_tokens=False) + + # Get sequence matcher output + seq_matcher = difflib.SequenceMatcher(None, tokens1, tokens2) + + diff1_indices, diff1_ids = [], [] + diff2_indices, diff2_ids = [], [] + + for opcode, a_start, a_end, b_start, b_end in seq_matcher.get_opcodes(): + if opcode in ['replace', 'delete']: + diff1_indices.extend(range(a_start, a_end)) + diff1_ids.extend(tokens1[a_start:a_end]) + if opcode in ['replace', 'insert']: + diff2_indices.extend(range(b_start, b_end)) + diff2_ids.extend(tokens2[b_start:b_end]) + + return { + 'prompt_1': {'index': diff1_indices, 'id': diff1_ids}, + 'prompt_2': {'index': diff2_indices, 'id': diff2_ids} + } + +def find_word_token_indices(prompt, word, tokenizer): + # Tokenize with offsets to track word positions + encoding = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False) + tokens = encoding.tokens() + offsets = encoding.offset_mapping # Start and end positions of tokens in the original text + + word_indices = [] + + # Normalize the word for comparison + word_tokens = tokenizer(word, add_special_tokens=False).tokens() + + # Find matching token sequences + for i in range(len(tokens) - len(word_tokens) + 1): + if tokens[i : i + len(word_tokens)] == word_tokens: + word_indices.extend(range(i, i + len(word_tokens))) + + return word_indices \ No newline at end of file