click2mask / scripts /text_editing_click2mask.py
omeregev's picture
Initial commit
6df18f5
import random
import os
import gc
import numpy as np
from PIL import Image
import cv2
from diffusers import DDIMScheduler, StableDiffusionPipeline
from pytorch_lightning import seed_everything
import torch
from scipy.ndimage import gaussian_filter
import sys
sys.path.append("./scripts")
from dyn_mask import DynMask, get_surround
from arguments import parse_args
from clicker import ClickCreate, ClickDraw
from augmentations import ImageAugmentations
from constants import Const, N
def read_image(image: Image.Image, device, dest_size):
image = image.convert("RGB")
image = image.resize(dest_size, Image.LANCZOS) if dest_size != image.size else image
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(device)
image = image * 2.0 - 1.0
return image
class Click2Mask:
def __init__(self):
self.args = parse_args()
self.device = torch.device(f"cuda:{self.args.gpu_id}")
self.load_models()
def load_models(self):
pipe = StableDiffusionPipeline.from_pretrained(
self.args.model_path, torch_dtype=torch.float16
)
self.vae = pipe.vae.to(self.device)
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder.to(self.device)
self.unet = pipe.unet.to(self.device)
self.scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
@torch.enable_grad()
def blended_latent_diffusion(
self,
dyn_mask,
create_dyn_mask,
seed,
original_rand_latents,
scheduler,
blending_percentage,
total_steps,
source_latents,
text_embeddings,
guidance_scale,
dyn_start_step_i=None,
dyn_cond_stop_step_i=None,
dyn_final_stop_step_i=None,
max_area_ratio_for_dilation=None,
last_step_threshed_latent_mask=None,
rerun_return_during_step_i=None,
):
seed_everything(seed)
use_plain_dilation_from_latent_mask = not create_dyn_mask
blending_steps_t = scheduler.timesteps[
int(len(scheduler.timesteps) * blending_percentage) :
]
latents = original_rand_latents
if create_dyn_mask:
update_steps = list(range(dyn_start_step_i, dyn_cond_stop_step_i + 1))
update_steps = [u for u in update_steps if 0 != u < len(blending_steps_t)]
first_update_step, orig_last_update_step = update_steps[0], update_steps[-1]
best_step_i = orig_last_update_step
if last_step_threshed_latent_mask is not None:
latent_mask = last_step_threshed_latent_mask
for step_i, t in enumerate(blending_steps_t):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(
latent_model_input, timestep=t
)
# predict the noise residual
with torch.no_grad():
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latent_pred_z0 = scheduler.step(noise_pred, t, latents).pred_original_sample
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
if rerun_return_during_step_i == step_i:
return latents, latent_mask
# dilation for rerun + final runs
elif use_plain_dilation_from_latent_mask:
latent_mask = dyn_mask.get_plain_dilated_latent_mask(
last_step_latent_mask=last_step_threshed_latent_mask,
step_i=step_i,
total_steps=total_steps,
max_area_ratio_for_dilation=max_area_ratio_for_dilation,
rerun_dyn_start_step_i=None
if not rerun_return_during_step_i
else dyn_start_step_i,
)
# mask evolution
elif create_dyn_mask:
if step_i in update_steps:
latent_mask = dyn_mask.evolve_mask(
step_i=step_i,
decoder=self.vae.decode,
latent_pred_z0=latent_pred_z0,
source_latents=source_latents,
return_only=N.LATENT_MASK,
)
# Rerun
latents, _ = self.blended_latent_diffusion(
dyn_mask,
create_dyn_mask=False,
seed=seed,
original_rand_latents=original_rand_latents,
scheduler=scheduler,
blending_percentage=blending_percentage,
total_steps=total_steps,
source_latents=source_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
dyn_start_step_i=dyn_start_step_i,
max_area_ratio_for_dilation=Const.RERUN_MAX_AREA_RATIO_FOR_DILATION,
last_step_threshed_latent_mask=latent_mask,
rerun_return_during_step_i=step_i,
)
elif step_i < first_update_step: # initial dilation
latent_mask = dyn_mask.set_cur_masks(
step_i=step_i, return_only=N.LATENT_MASK
)
# Blending
noise_source_latents = scheduler.add_noise(
source_latents, torch.randn_like(latents), t
)
latents = latents * latent_mask + noise_source_latents * (1 - latent_mask)
if create_dyn_mask:
if step_i >= orig_last_update_step:
dyn_mask.make_cached_masks_clones(name=step_i)
dyn_mask.latents_hist[step_i] = latents
dyn_mask.latent_masks_hist[step_i] = latent_mask
if step_i >= orig_last_update_step + 2:
step_prev1_better = (
dyn_mask.closs_hist[step_i - 1]
< dyn_mask.closs_hist[step_i - 2]
)
if step_prev1_better:
best_step_i = step_i - 1
if (not step_prev1_better) or (step_i > dyn_final_stop_step_i):
# we need an extra step to calculate clip loss for last evolved mask
latents = dyn_mask.latents_hist[best_step_i]
latent_mask = dyn_mask.latent_masks_hist[best_step_i]
dyn_mask.set_masks_from_cached_masks_clones(
name=best_step_i
)
break
update_steps.append(step_i + 1)
return latents, latent_mask
@torch.no_grad()
def edit_image(
self,
image_pil,
click_pil,
prompts,
height,
width,
num_inference_steps,
num_static_inference_steps,
guidance_scale,
seed,
blending_percentage,
):
generator = torch.manual_seed(seed)
batch_size = len(prompts)
self.scheduler.set_timesteps(num_inference_steps)
image_pil = image_pil.resize((height, width), Image.LANCZOS)
image_np = np.array(image_pil)[:, :, :3]
source_latents = self._image2latent(image_np)
init_image_tensor = read_image(
image=image_pil, device=self.device, dest_size=(height, width)
)
total_steps = num_inference_steps - int(
len(self.scheduler.timesteps) * blending_percentage
)
dyn_mask = DynMask(
click_pil, self.args, init_image_tensor, self.device, total_steps
)
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, self.unet.config.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(self.device).half()
original_rand_latents = latents
dyn_start_step_i = (
Const.DYN_START
if Const.DYN_START > 1
else round(Const.DYN_START * total_steps)
)
dyn_cond_stop_step_i = (
Const.DYN_COND_STOP
if Const.DYN_COND_STOP > 1
else round(Const.DYN_COND_STOP * total_steps)
)
dyn_final_stop_step_i = (
Const.DYN_FINAL_STOP
if Const.DYN_FINAL_STOP > 1
else round(Const.DYN_FINAL_STOP * total_steps)
)
# Evolve mask
self.blended_latent_diffusion(
dyn_mask=dyn_mask,
create_dyn_mask=True,
seed=seed,
original_rand_latents=original_rand_latents,
scheduler=self.scheduler,
blending_percentage=blending_percentage,
total_steps=total_steps,
source_latents=source_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
dyn_start_step_i=dyn_start_step_i,
dyn_cond_stop_step_i=dyn_cond_stop_step_i,
dyn_final_stop_step_i=dyn_final_stop_step_i,
)
# Final run
self.static_scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
self.static_scheduler.set_timesteps(num_static_inference_steps)
total_static_steps = num_static_inference_steps - int(
len(self.static_scheduler.timesteps) * blending_percentage
)
latents_list = []
latent_masks_list = []
seeds_list = []
seeds_to_run = random.sample(range(1, Const.MAX_SEED), Const.N_OUTS_FOR_DYN_MASK - 1)
print(f"running output (from {Const.N_OUTS_FOR_DYN_MASK}): ", end="")
for out_i in range(Const.N_OUTS_FOR_DYN_MASK):
print(f"{out_i + 1}", end="... ")
orig_l = original_rand_latents
seed_i = seed
if out_i > 0:
seed_i = seeds_to_run[out_i - 1]
orig_l = torch.randn(
(batch_size, self.unet.config.in_channels, height // 8, width // 8),
generator=torch.manual_seed(seed_i),
)
orig_l = orig_l.to(self.device).half()
latents, latent_mask = self.blended_latent_diffusion(
dyn_mask=dyn_mask,
create_dyn_mask=False,
seed=seed_i,
original_rand_latents=orig_l,
scheduler=self.static_scheduler if out_i > 0 else self.scheduler,
blending_percentage=blending_percentage,
total_steps=total_static_steps if out_i > 0 else total_steps,
source_latents=source_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
max_area_ratio_for_dilation=Const.MAX_AREA_RATIO_FOR_DILATION,
last_step_threshed_latent_mask=dyn_mask.get_curr_masks(
return_only=N.LATENT_MASK
),
)
latents_list.append(latents)
latent_masks_list.append(latent_mask)
seeds_list.append(seed_i)
print("scoring...")
results = self.score_and_arrange_results(
dyn_mask=dyn_mask,
latents_list=latents_list,
latent_masks_list=latent_masks_list,
n_runs=Const.N_RUNS_ON_SCORES,
aug_num=Const.N_AUGS_ON_SCORES,
alpha_mask_dilation_on_512=Const.ALPHA_MASK_DILATION_ON_512,
)
return results
@torch.no_grad()
def _image2latent(self, image):
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)
image = image.half()
latents = self.vae.encode(image)["latent_dist"].mean
latents = latents * 0.18215
return latents
def back_preserve_with_gauss(self, decoded_img, latent_mask, dyn_mask):
upsampled_mask = latent_mask.cpu().numpy().squeeze()
upsampled_mask = cv2.resize(
upsampled_mask.squeeze().astype(np.float32),
dyn_mask.decoded_size,
Image.LANCZOS,
)
upsampled_mask = upsampled_mask > 0.5
g_mask = gaussian_filter(
upsampled_mask.astype(float), sigma=Const.BACK_PRES_SIGMA
)
g_mask = torch.from_numpy(g_mask).half().to(self.device)
g_mask = (g_mask * Const.BACK_PRES_SCALE).clip(0, 1)
g_mask[upsampled_mask > 0.5] = 1
blended = decoded_img * g_mask + dyn_mask.init_image * (1 - g_mask)
return blended
def score_and_arrange_results(
self,
dyn_mask,
latents_list,
latent_masks_list,
n_runs,
aug_num,
alpha_mask_dilation_on_512,
):
results = []
raw_d_prompt = np.zeros((n_runs, len(latents_list)))
for i, (latents, latent_mask) in enumerate(
zip(latents_list, latent_masks_list)
):
latents = 1 / 0.18215 * latents
with torch.no_grad():
img = self.vae.decode(latents).sample
img = self.back_preserve_with_gauss(img, latent_mask, dyn_mask)
results.append({"im": img, "latent_mask": latent_mask})
alpha_mask = get_surround(
latent_mask,
alpha_mask_dilation_on_512 * (latent_mask.shape[-1] / 512.0),
self.device,
)
if aug_num is not None:
image_augmentations = ImageAugmentations(
self.args.alpha_clip_scale, aug_num
)
else:
image_augmentations = None
for run_i in range(n_runs):
raw_d_prompt[run_i][i] = dyn_mask.alpha_clip_loss(
img,
alpha_mask,
dyn_mask.text_features,
image_augmentations=image_augmentations,
augs_with_orig=(run_i == 0),
return_as_similarity=True,
)
raw_d_prompt = raw_d_prompt.mean(axis=0)
for i, res in enumerate(results):
res["dist"] = float(raw_d_prompt[i])
return results
def click2mask_app(prompt: str, image_pil: Image.Image, point512: np.ndarray):
c2m = Click2Mask()
c2m.args.prompt = prompt
results = []
for mask_i in range(c2m.args.n_masks):
print(f"\nEvolving mask {mask_i + 1}...")
seed = (
c2m.args.seed
if (c2m.args.seed and mask_i == 0)
else random.sample(range(1, Const.MAX_SEED), 1)[0]
)
seed_everything(seed)
click_draw = ClickDraw()
click_pil, _ = click_draw(image_pil, point512=point512)
mask_i_results = c2m.edit_image(
image_pil=image_pil,
click_pil=click_pil,
prompts=[c2m.args.prompt] * Const.BATCH_SIZE,
height=Const.H,
width=Const.W,
num_inference_steps=Const.NUM_INFERENCE_STEPS,
num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS,
guidance_scale=Const.GUIDANCE_SCALE,
seed=seed,
blending_percentage=Const.BLENDING_START_PERCENTAGE,
)
results += mask_i_results
sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True)
out_img = sorted_results[0]["im"]
out_img = (out_img / 2 + 0.5).clamp(0, 1)
out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze()
out_img = (out_img * 255).round().astype(np.uint8)
torch.cuda.empty_cache()
gc.collect()
print(f"\nCompleted.")
return out_img
if __name__ == "__main__":
c2m = Click2Mask()
img_dir = os.path.dirname(c2m.args.image_path)
img_name = os.path.basename(os.path.normpath(c2m.args.image_path))
img_base_name = os.path.splitext(img_name)[0]
results = []
for mask_i in range(c2m.args.n_masks):
print(f"\nEvolving mask {mask_i + 1}...")
seed = (
c2m.args.seed
if (c2m.args.seed and mask_i == 0)
else random.sample(range(1, Const.MAX_SEED), 1)[0]
)
seed_everything(seed)
click_ext = [
ext
for ext in ("jpg", "JPG", "JPEG", "jpeg", "png", "PNG")
if os.path.exists(os.path.join(img_dir, f"{img_base_name}_click.{ext}"))
]
if (not click_ext) or (mask_i == 0 and c2m.args.refresh_click):
click_create = ClickCreate()
c2m.args.click_path = click_create(
c2m.args.image_path, os.path.join(img_dir, f"{img_base_name}_click.jpg")
)
else:
c2m.args.click_path = os.path.join(
img_dir, f"{img_base_name}_click.{click_ext[0]}"
)
mask_i_results = c2m.edit_image(
image_pil=Image.open(c2m.args.image_path),
click_pil=Image.open(c2m.args.click_path),
prompts=[c2m.args.prompt] * Const.BATCH_SIZE,
height=Const.H,
width=Const.W,
num_inference_steps=Const.NUM_INFERENCE_STEPS,
num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS,
guidance_scale=Const.GUIDANCE_SCALE,
seed=seed,
blending_percentage=Const.BLENDING_START_PERCENTAGE,
)
results += mask_i_results
os.makedirs(c2m.args.output_dir, exist_ok=True)
sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True)
out_img = sorted_results[0]["im"]
out_img = (out_img / 2 + 0.5).clamp(0, 1)
out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze()
out_img = (out_img * 255).round().astype(np.uint8)
out_path = os.path.join(c2m.args.output_dir, f"{img_base_name}_out.jpg")
Image.fromarray(out_img).save(out_path, quality=95)
print(f"\nCompleted.\nOutput image path:\n{os.path.abspath(out_path)}")