ReFlex / img_edit.py
SahilCarterr's picture
Upload 77 files
f056744 verified
raw
history blame
13 kB
import argparse
import gc
import os
import random
import re
import time
from distutils.util import strtobool
import pandas as pd
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_path",
type=str,
)
parser.add_argument(
"--target_prompt",
type=str,
)
parser.add_argument(
"--source_prompt",
type=str,
default=''
)
parser.add_argument(
"--blend_word",
type=str,
default=''
)
parser.add_argument(
"--mask_path",
type=str,
default=None
)
parser.add_argument(
"--gpu",
type=str,
default="0",
)
parser.add_argument(
"--seed",
type=int,
default=0
)
parser.add_argument(
"--results_dir",
type=str,
default='results'
)
parser.add_argument(
"--model",
type=str,
default='flux',
choices=['flux']
)
parser.add_argument(
"--ca_steps",
type=int,
default=10,
help="Number of steps to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--sa_steps",
type=int,
default=7
help="Number of steps to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--feature_steps",
type=int,
default=5
help="Number of steps to inject residual features.",
)
parser.add_argument(
"--ca_attn_layer_from",
type=int,
default=13,
help="Layers to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--ca_attn_layer_to",
type=int,
default=45,
help="Layers to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--sa_attn_layer_from",
type=int,
default=20,
help="Layers to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--sa_attn_layer_to",
type=int,
default=45,
help="Layers to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--feature_layer_from",
type=int,
default=13,
help="Layers to inject residual features.",
)
parser.add_argument(
"--feature_layer_to",
type=int,
default=20,
help="Layers to inject residual features.",
)
parser.add_argument(
"--flow_steps",
type=int,
default=7,
help="Steps to apply forward step before inversion",
)
parser.add_argument(
"--step_start",
type=int,
default=0
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=28
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
)
parser.add_argument(
"--attn_topk",
type=int,
default=20,
help="Hyperparameter for I2I-SA adaptaion."
)
parser.add_argument(
"--text_scale",
type=float,
default=4,
help="Hyperparameter for I2T-CA adaptaion."
)
parser.add_argument(
"--mid_step_index",
type=int,
default=14,
help="Hyperparameter for mid-step feature extraction."
)
parser.add_argument(
"--use_mask",
type=strtobool,
default=True
)
parser.add_argument(
"--use_ca_mask",
type=strtobool,
default=True
)
parser.add_argument(
"--mask_steps",
type=int,
default=18,
help="Steps to apply latent blending"
)
parser.add_argument(
"--mask_dilation",
type=int,
default=3
)
parser.add_argument(
"--mask_nbins",
type=int,
default=128
)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"
import gc
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector
from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0
from src.attn_utils.seq_aligner import get_refinement_mapper
from src.callback.callback_fn import CallbackAll
from src.inversion.inverse import get_inversed_latent_list
from src.inversion.scheduling_flow_inverse import \
FlowMatchEulerDiscreteForwardScheduler
from src.pipeline.flux_pipeline import NewFluxPipeline
from src.transformer_utils.transformer_utils import (FeatureCollector,
FeatureReplace)
from src.utils import (find_token_id_differences, find_word_token_indices,
get_flux_pipeline, mask_decode, mask_interpolate)
def fix_seed(random_seed):
"""
fix seed to control any randomness from a code
(enable stability of the experiments' results.)
"""
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
def main(args):
fix_seed(args.seed)
device = torch.device('cuda')
pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline)
attn_proc = NewFluxAttnProcessor2_0
pipe = pipe.to(device)
layer_order = range(57)
ca_layer_list = layer_order[args.ca_attn_layer_from:args.ca_attn_layer_to]
sa_layer_list = layer_order[args.feature_layer_to:args.sa_attn_layer_to]
feature_layer_list = layer_order[args.feature_layer_from:args.feature_layer_to]
img_path = args.img_path
source_img = Image.open(img_path).resize((1024, 1024)).convert("RGB")
img_base_name = os.path.splitext(img_path)[0].split('/')[-1]
result_img_dir = f"{args.results_dir}/seed_{args.seed}/{args.target_prompt}"
source_prompt = args.source_prompt
target_prompt = args.target_prompt
prompts = [source_prompt, target_prompt]
print(prompts)
mask = None
if args.use_mask:
use_mask = True
if args.mask_path is not None:
mask = Image.open(args.mask_path)
mask = torch.tensor(np.array(mask)).bool()
mask = mask.to(device)
# Increase the latent blending steps if the ground truth mask is used.
args.mask_steps = int(args.num_inference_steps * 0.9)
source_ca_index = None
target_ca_index = None
use_ca_mask = False
elif args.use_ca_mask and source_prompt:
mask = None
if args.blend_word and args.blend_word in source_prompt:
editing_source_token_index = find_word_token_indices(source_prompt, args.blend_word, pipe.tokenizer_2)
editing_target_token_index = None
else:
editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2)
editing_source_token_index = editing_tokens_info['prompt_1']['index']
editing_target_token_index = editing_tokens_info['prompt_2']['index']
use_ca_mask = True
if editing_source_token_index:
source_ca_index = editing_source_token_index
target_ca_index = None
elif editing_target_token_index:
source_ca_index = None
target_ca_index = editing_target_token_index
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
use_mask = False
use_ca_mask = False
source_ca_index = None
target_ca_index = None
if source_prompt:
# Use I2T-CA injection
mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512)
mappers = mappers.to(device=device)
alphas = alphas.to(device=device, dtype=pipe.dtype)
alphas = alphas[:, None, None, :]
ca_steps = args.ca_steps
attn_adj_from = 1
else:
# Not use I2T-CA injection
mappers = None
alphas = None
ca_steps = 0
attn_adj_from=3
sa_steps = args.sa_steps
feature_steps = args.feature_steps
attn_controller = AttentionAdapter(
ca_layer_list=ca_layer_list,
sa_layer_list=sa_layer_list,
ca_steps=ca_steps,
sa_steps=sa_steps,
method='replace_topk',
topk=args.attn_topk,
text_scale=args.text_scale,
mappers=mappers,
alphas=alphas,
attn_adj_from=attn_adj_from,
save_source_ca=source_ca_index is not None,
save_target_ca=target_ca_index is not None,
)
attn_collector = AttnCollector(
transformer=pipe.transformer,
controller=attn_controller,
attn_processor_class=NewFluxAttnProcessor2_0,
)
feature_controller = FeatureReplace(
layer_list=feature_layer_list,
feature_steps=feature_steps,
)
feature_collector = FeatureCollector(
transformer=pipe.transformer,
controller=feature_controller,
)
num_prompts=len(prompts)
shape = (1, 16, 128, 128)
generator = torch.Generator(device=device).manual_seed(args.seed)
latents = randn_tensor(shape, device=device, generator=generator)
latents = pipe._pack_latents(latents, *latents.shape)
attn_collector.restore_orig_attention()
feature_collector.restore_orig_transformer()
t0 = time.perf_counter()
inv_latents = get_inversed_latent_list(
pipe,
source_img,
random_noise=latents,
num_inference_steps=args.num_inference_steps,
backward_method="ode",
use_prompt_for_inversion=False,
guidance_scale_for_inversion=0,
prompt_for_inversion='',
flow_steps=args.flow_steps,
)
source_latents = inv_latents[::-1]
target_latents = inv_latents[::-1]
attn_collector.register_attention_control()
feature_collector.register_transformer_control()
callback_fn = CallbackAll(
latents=source_latents,
attn_collector=attn_collector,
feature_collector=feature_collector,
feature_inject_steps=feature_steps,
mid_step_index=args.mid_step_index,
step_start=args.step_start,
use_mask=use_mask,
use_ca_mask=use_ca_mask,
source_ca_index=source_ca_index,
target_ca_index=target_ca_index,
mask_kwargs={'dilation': args.mask_dilation},
mask_steps=args.mask_steps,
mask=mask,
)
init_latent = target_latents[args.step_start]
init_latent = init_latent.repeat(num_prompts, 1, 1)
init_latent[0] = source_latents[args.mid_step_index]
os.makedirs(result_img_dir, exist_ok=True)
pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config(
pipe.scheduler.config,
step_start=args.step_start,
margin_index_from_image=0
)
attn_controller.reset()
feature_controller.reset()
attn_controller.text_scale = args.text_scale
attn_controller.cur_step = args.step_start
feature_controller.cur_step = args.step_start
with torch.no_grad():
images = pipe(
prompts,
latents=init_latent,
num_images_per_prompt=1,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
generator=generator,
callback_on_step_end=callback_fn,
mid_step_index=args.mid_step_index,
step_start=args.step_start,
callback_on_step_end_tensor_inputs=['latents'],
).images
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
source_img_path = os.path.join(result_img_dir, f"source.png")
source_img.save(source_img_path)
for i, img in enumerate(images[1:]):
target_img_path = os.path.join(result_img_dir, f"target_{i}.png")
img.save(target_img_path)
target_text_path = os.path.join(result_img_dir, f"target_prompts.txt")
with open(target_text_path, 'w') as file:
file.write(target_prompt + '\n')
source_text_path = os.path.join(result_img_dir, f"source_prompt.txt")
with open(source_text_path, 'w') as file:
file.write(source_prompt + '\n')
images = [source_img] + images
fs=3
n = len(images)
fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs))
for i, img in enumerate(images):
ax[i].imshow(img)
ax[0].set_title('source')
ax[1].set_title(source_prompt, fontsize=7)
ax[2].set_title(target_prompt, fontsize=7)
overall_img_path = os.path.join(result_img_dir, f"overall.png")
plt.savefig(overall_img_path, bbox_inches='tight')
plt.close()
mask_save_dir = os.path.join(result_img_dir, f"mask")
os.makedirs(mask_save_dir, exist_ok=True)
if use_ca_mask:
ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png")
mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L')
mask_img.save(ca_mask_path)
del inv_latents
del init_latent
gc.collect()
torch.cuda.empty_cache()
if __name__ == '__main__':
main(args)