import os import torch from PIL import Image import jsonc as json from model import ( DirectionalAttentionControl, StableDiffusionXLImg2ImgPipeline, register_attention_editor_diffusers, ) from utils.pipeline_utils import * from utils import get_args, extract_mask from src import get_ddpm_inversion_scheduler from visualization import save_results def run( image_path, src_prompt, tgt_prompt, masks, pipeline: StableDiffusionXLImg2ImgPipeline, args, ): seed = args.seed num_timesteps = args.timesteps torch.manual_seed(seed) generator = torch.Generator(device=SAMPLING_DEVICE).manual_seed(seed) timesteps, config = set_pipeline(pipeline, num_timesteps, generator, args) x_0_image = Image.open(image_path).convert("RGB").resize((512, 512), RESIZE_TYPE) x_0 = encode_image(x_0_image, pipeline, generator) x_ts = create_xts( config.noise_shift_delta, config.noise_timesteps, generator, pipeline.scheduler, timesteps, x_0, ) x_ts = [xt.to(dtype=x_0.dtype) for xt in x_ts] latents = [x_ts[0]] if not isinstance(masks, torch.Tensor): mask = extract_mask(masks, 512, 512) else: mask = masks pipeline.scheduler = get_ddpm_inversion_scheduler( pipeline.scheduler, config, timesteps, latents, x_ts, w1=args.w1, dift_timestep=args.dift_timestep, movement_intensifier=args.movement_intensifier, apply_dift_correction=args.apply_dift_correction, mask=mask, ) step, layer = 0, 44 editor = DirectionalAttentionControl( step, layer, total_steps=11, model_type="SDXL", alpha=args.alpha, mode=args.mode, beta=1-args.beta, structural_alignment=args.structural_alignment, support_new_object=args.support_new_object ) register_attention_editor_diffusers(pipeline, editor) latent = latents[0].expand(3, -1, -1, -1) prompt = [src_prompt, src_prompt, tgt_prompt] pipeline.unet.latent_store.reset() image = pipeline.__call__(image=latent, prompt=prompt).images return [x_0_image, image[0], image[2]] if __name__ == "__main__": args = get_args() img_paths_to_prompts = json.load(open(args.prompts_file, "r")) eval_dataset_folder = args.eval_dataset_folder img_paths = [ f"{eval_dataset_folder}/{img_name}" for img_name in img_paths_to_prompts.keys() ] pipeline = load_pipeline(args.fp16, args.cache_dir) sim_scores_total = 0 os.makedirs(args.output_dir, exist_ok=True) images_to_plot = [] output_dir = args.output_dir for i, img_path in enumerate(img_paths): args.img_path = img_path img_name = img_path.split("/")[-1] prompt = img_paths_to_prompts[img_name]["src_prompt"] edit_prompts = img_paths_to_prompts[img_name]["tgt_prompt"] args.alpha = img_paths_to_prompts[img_name].get("alpha", 0.7) args.beta = img_paths_to_prompts[img_name].get("beta", 1) masks = img_paths_to_prompts[img_name].get("masks", None) args.mask = masks args.source_prompt = prompt args.target_prompt = edit_prompts[0] res = run( img_path, prompt, edit_prompts[0], masks, pipeline=pipeline, args=args, ) torch.cuda.empty_cache() save_results( args=args, source_prompt=prompt, target_prompt=edit_prompts[0], images=res )