Spaces:
Sleeping
Sleeping
import random | |
import os | |
import gc | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from diffusers import DDIMScheduler, StableDiffusionPipeline | |
from pytorch_lightning import seed_everything | |
import torch | |
from scipy.ndimage import gaussian_filter | |
import sys | |
sys.path.append("./scripts") | |
from dyn_mask import DynMask, get_surround | |
from arguments import parse_args | |
from clicker import ClickCreate, ClickDraw | |
from augmentations import ImageAugmentations | |
from constants import Const, N | |
def read_image(image: Image.Image, device, dest_size): | |
image = image.convert("RGB") | |
image = image.resize(dest_size, Image.LANCZOS) if dest_size != image.size else image | |
image = np.array(image) | |
image = image.astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image).to(device) | |
image = image * 2.0 - 1.0 | |
return image | |
class Click2Mask: | |
def __init__(self): | |
self.args = parse_args() | |
self.device = torch.device(f"cuda:{self.args.gpu_id}") | |
self.load_models() | |
def load_models(self): | |
pipe = StableDiffusionPipeline.from_pretrained( | |
self.args.model_path, torch_dtype=torch.float16 | |
) | |
self.vae = pipe.vae.to(self.device) | |
self.tokenizer = pipe.tokenizer | |
self.text_encoder = pipe.text_encoder.to(self.device) | |
self.unet = pipe.unet.to(self.device) | |
self.scheduler = DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
) | |
def blended_latent_diffusion( | |
self, | |
dyn_mask, | |
create_dyn_mask, | |
seed, | |
original_rand_latents, | |
scheduler, | |
blending_percentage, | |
total_steps, | |
source_latents, | |
text_embeddings, | |
guidance_scale, | |
dyn_start_step_i=None, | |
dyn_cond_stop_step_i=None, | |
dyn_final_stop_step_i=None, | |
max_area_ratio_for_dilation=None, | |
last_step_threshed_latent_mask=None, | |
rerun_return_during_step_i=None, | |
): | |
seed_everything(seed) | |
use_plain_dilation_from_latent_mask = not create_dyn_mask | |
blending_steps_t = scheduler.timesteps[ | |
int(len(scheduler.timesteps) * blending_percentage) : | |
] | |
latents = original_rand_latents | |
if create_dyn_mask: | |
update_steps = list(range(dyn_start_step_i, dyn_cond_stop_step_i + 1)) | |
update_steps = [u for u in update_steps if 0 != u < len(blending_steps_t)] | |
first_update_step, orig_last_update_step = update_steps[0], update_steps[-1] | |
best_step_i = orig_last_update_step | |
if last_step_threshed_latent_mask is not None: | |
latent_mask = last_step_threshed_latent_mask | |
for step_i, t in enumerate(blending_steps_t): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input( | |
latent_model_input, timestep=t | |
) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = self.unet( | |
latent_model_input, t, encoder_hidden_states=text_embeddings | |
).sample | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
latent_pred_z0 = scheduler.step(noise_pred, t, latents).pred_original_sample | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
if rerun_return_during_step_i == step_i: | |
return latents, latent_mask | |
# dilation for rerun + final runs | |
elif use_plain_dilation_from_latent_mask: | |
latent_mask = dyn_mask.get_plain_dilated_latent_mask( | |
last_step_latent_mask=last_step_threshed_latent_mask, | |
step_i=step_i, | |
total_steps=total_steps, | |
max_area_ratio_for_dilation=max_area_ratio_for_dilation, | |
rerun_dyn_start_step_i=None | |
if not rerun_return_during_step_i | |
else dyn_start_step_i, | |
) | |
# mask evolution | |
elif create_dyn_mask: | |
if step_i in update_steps: | |
latent_mask = dyn_mask.evolve_mask( | |
step_i=step_i, | |
decoder=self.vae.decode, | |
latent_pred_z0=latent_pred_z0, | |
source_latents=source_latents, | |
return_only=N.LATENT_MASK, | |
) | |
# Rerun | |
latents, _ = self.blended_latent_diffusion( | |
dyn_mask, | |
create_dyn_mask=False, | |
seed=seed, | |
original_rand_latents=original_rand_latents, | |
scheduler=scheduler, | |
blending_percentage=blending_percentage, | |
total_steps=total_steps, | |
source_latents=source_latents, | |
text_embeddings=text_embeddings, | |
guidance_scale=guidance_scale, | |
dyn_start_step_i=dyn_start_step_i, | |
max_area_ratio_for_dilation=Const.RERUN_MAX_AREA_RATIO_FOR_DILATION, | |
last_step_threshed_latent_mask=latent_mask, | |
rerun_return_during_step_i=step_i, | |
) | |
elif step_i < first_update_step: # initial dilation | |
latent_mask = dyn_mask.set_cur_masks( | |
step_i=step_i, return_only=N.LATENT_MASK | |
) | |
# Blending | |
noise_source_latents = scheduler.add_noise( | |
source_latents, torch.randn_like(latents), t | |
) | |
latents = latents * latent_mask + noise_source_latents * (1 - latent_mask) | |
if create_dyn_mask: | |
if step_i >= orig_last_update_step: | |
dyn_mask.make_cached_masks_clones(name=step_i) | |
dyn_mask.latents_hist[step_i] = latents | |
dyn_mask.latent_masks_hist[step_i] = latent_mask | |
if step_i >= orig_last_update_step + 2: | |
step_prev1_better = ( | |
dyn_mask.closs_hist[step_i - 1] | |
< dyn_mask.closs_hist[step_i - 2] | |
) | |
if step_prev1_better: | |
best_step_i = step_i - 1 | |
if (not step_prev1_better) or (step_i > dyn_final_stop_step_i): | |
# we need an extra step to calculate clip loss for last evolved mask | |
latents = dyn_mask.latents_hist[best_step_i] | |
latent_mask = dyn_mask.latent_masks_hist[best_step_i] | |
dyn_mask.set_masks_from_cached_masks_clones( | |
name=best_step_i | |
) | |
break | |
update_steps.append(step_i + 1) | |
return latents, latent_mask | |
def edit_image( | |
self, | |
image_pil, | |
click_pil, | |
prompts, | |
height, | |
width, | |
num_inference_steps, | |
num_static_inference_steps, | |
guidance_scale, | |
seed, | |
blending_percentage, | |
): | |
generator = torch.manual_seed(seed) | |
batch_size = len(prompts) | |
self.scheduler.set_timesteps(num_inference_steps) | |
image_pil = image_pil.resize((height, width), Image.LANCZOS) | |
image_np = np.array(image_pil)[:, :, :3] | |
source_latents = self._image2latent(image_np) | |
init_image_tensor = read_image( | |
image=image_pil, device=self.device, dest_size=(height, width) | |
) | |
total_steps = num_inference_steps - int( | |
len(self.scheduler.timesteps) * blending_percentage | |
) | |
dyn_mask = DynMask( | |
click_pil, self.args, init_image_tensor, self.device, total_steps | |
) | |
text_input = self.tokenizer( | |
prompts, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = self.tokenizer( | |
[""] * batch_size, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
) | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
latents = torch.randn( | |
(batch_size, self.unet.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latents.to(self.device).half() | |
original_rand_latents = latents | |
dyn_start_step_i = ( | |
Const.DYN_START | |
if Const.DYN_START > 1 | |
else round(Const.DYN_START * total_steps) | |
) | |
dyn_cond_stop_step_i = ( | |
Const.DYN_COND_STOP | |
if Const.DYN_COND_STOP > 1 | |
else round(Const.DYN_COND_STOP * total_steps) | |
) | |
dyn_final_stop_step_i = ( | |
Const.DYN_FINAL_STOP | |
if Const.DYN_FINAL_STOP > 1 | |
else round(Const.DYN_FINAL_STOP * total_steps) | |
) | |
# Evolve mask | |
self.blended_latent_diffusion( | |
dyn_mask=dyn_mask, | |
create_dyn_mask=True, | |
seed=seed, | |
original_rand_latents=original_rand_latents, | |
scheduler=self.scheduler, | |
blending_percentage=blending_percentage, | |
total_steps=total_steps, | |
source_latents=source_latents, | |
text_embeddings=text_embeddings, | |
guidance_scale=guidance_scale, | |
dyn_start_step_i=dyn_start_step_i, | |
dyn_cond_stop_step_i=dyn_cond_stop_step_i, | |
dyn_final_stop_step_i=dyn_final_stop_step_i, | |
) | |
# Final run | |
self.static_scheduler = DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
) | |
self.static_scheduler.set_timesteps(num_static_inference_steps) | |
total_static_steps = num_static_inference_steps - int( | |
len(self.static_scheduler.timesteps) * blending_percentage | |
) | |
latents_list = [] | |
latent_masks_list = [] | |
seeds_list = [] | |
seeds_to_run = random.sample(range(1, Const.MAX_SEED), Const.N_OUTS_FOR_DYN_MASK - 1) | |
print(f"running output (from {Const.N_OUTS_FOR_DYN_MASK}): ", end="") | |
for out_i in range(Const.N_OUTS_FOR_DYN_MASK): | |
print(f"{out_i + 1}", end="... ") | |
orig_l = original_rand_latents | |
seed_i = seed | |
if out_i > 0: | |
seed_i = seeds_to_run[out_i - 1] | |
orig_l = torch.randn( | |
(batch_size, self.unet.config.in_channels, height // 8, width // 8), | |
generator=torch.manual_seed(seed_i), | |
) | |
orig_l = orig_l.to(self.device).half() | |
latents, latent_mask = self.blended_latent_diffusion( | |
dyn_mask=dyn_mask, | |
create_dyn_mask=False, | |
seed=seed_i, | |
original_rand_latents=orig_l, | |
scheduler=self.static_scheduler if out_i > 0 else self.scheduler, | |
blending_percentage=blending_percentage, | |
total_steps=total_static_steps if out_i > 0 else total_steps, | |
source_latents=source_latents, | |
text_embeddings=text_embeddings, | |
guidance_scale=guidance_scale, | |
max_area_ratio_for_dilation=Const.MAX_AREA_RATIO_FOR_DILATION, | |
last_step_threshed_latent_mask=dyn_mask.get_curr_masks( | |
return_only=N.LATENT_MASK | |
), | |
) | |
latents_list.append(latents) | |
latent_masks_list.append(latent_mask) | |
seeds_list.append(seed_i) | |
print("scoring...") | |
results = self.score_and_arrange_results( | |
dyn_mask=dyn_mask, | |
latents_list=latents_list, | |
latent_masks_list=latent_masks_list, | |
n_runs=Const.N_RUNS_ON_SCORES, | |
aug_num=Const.N_AUGS_ON_SCORES, | |
alpha_mask_dilation_on_512=Const.ALPHA_MASK_DILATION_ON_512, | |
) | |
return results | |
def _image2latent(self, image): | |
image = torch.from_numpy(image).float() / 127.5 - 1 | |
image = image.permute(2, 0, 1).unsqueeze(0).to(self.device) | |
image = image.half() | |
latents = self.vae.encode(image)["latent_dist"].mean | |
latents = latents * 0.18215 | |
return latents | |
def back_preserve_with_gauss(self, decoded_img, latent_mask, dyn_mask): | |
upsampled_mask = latent_mask.cpu().numpy().squeeze() | |
upsampled_mask = cv2.resize( | |
upsampled_mask.squeeze().astype(np.float32), | |
dyn_mask.decoded_size, | |
Image.LANCZOS, | |
) | |
upsampled_mask = upsampled_mask > 0.5 | |
g_mask = gaussian_filter( | |
upsampled_mask.astype(float), sigma=Const.BACK_PRES_SIGMA | |
) | |
g_mask = torch.from_numpy(g_mask).half().to(self.device) | |
g_mask = (g_mask * Const.BACK_PRES_SCALE).clip(0, 1) | |
g_mask[upsampled_mask > 0.5] = 1 | |
blended = decoded_img * g_mask + dyn_mask.init_image * (1 - g_mask) | |
return blended | |
def score_and_arrange_results( | |
self, | |
dyn_mask, | |
latents_list, | |
latent_masks_list, | |
n_runs, | |
aug_num, | |
alpha_mask_dilation_on_512, | |
): | |
results = [] | |
raw_d_prompt = np.zeros((n_runs, len(latents_list))) | |
for i, (latents, latent_mask) in enumerate( | |
zip(latents_list, latent_masks_list) | |
): | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
img = self.vae.decode(latents).sample | |
img = self.back_preserve_with_gauss(img, latent_mask, dyn_mask) | |
results.append({"im": img, "latent_mask": latent_mask}) | |
alpha_mask = get_surround( | |
latent_mask, | |
alpha_mask_dilation_on_512 * (latent_mask.shape[-1] / 512.0), | |
self.device, | |
) | |
if aug_num is not None: | |
image_augmentations = ImageAugmentations( | |
self.args.alpha_clip_scale, aug_num | |
) | |
else: | |
image_augmentations = None | |
for run_i in range(n_runs): | |
raw_d_prompt[run_i][i] = dyn_mask.alpha_clip_loss( | |
img, | |
alpha_mask, | |
dyn_mask.text_features, | |
image_augmentations=image_augmentations, | |
augs_with_orig=(run_i == 0), | |
return_as_similarity=True, | |
) | |
raw_d_prompt = raw_d_prompt.mean(axis=0) | |
for i, res in enumerate(results): | |
res["dist"] = float(raw_d_prompt[i]) | |
return results | |
def click2mask_app(prompt: str, image_pil: Image.Image, point512: np.ndarray): | |
c2m = Click2Mask() | |
c2m.args.prompt = prompt | |
results = [] | |
for mask_i in range(c2m.args.n_masks): | |
print(f"\nEvolving mask {mask_i + 1}...") | |
seed = ( | |
c2m.args.seed | |
if (c2m.args.seed and mask_i == 0) | |
else random.sample(range(1, Const.MAX_SEED), 1)[0] | |
) | |
seed_everything(seed) | |
click_draw = ClickDraw() | |
click_pil, _ = click_draw(image_pil, point512=point512) | |
mask_i_results = c2m.edit_image( | |
image_pil=image_pil, | |
click_pil=click_pil, | |
prompts=[c2m.args.prompt] * Const.BATCH_SIZE, | |
height=Const.H, | |
width=Const.W, | |
num_inference_steps=Const.NUM_INFERENCE_STEPS, | |
num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS, | |
guidance_scale=Const.GUIDANCE_SCALE, | |
seed=seed, | |
blending_percentage=Const.BLENDING_START_PERCENTAGE, | |
) | |
results += mask_i_results | |
sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True) | |
out_img = sorted_results[0]["im"] | |
out_img = (out_img / 2 + 0.5).clamp(0, 1) | |
out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze() | |
out_img = (out_img * 255).round().astype(np.uint8) | |
torch.cuda.empty_cache() | |
gc.collect() | |
print(f"\nCompleted.") | |
return out_img | |
if __name__ == "__main__": | |
c2m = Click2Mask() | |
img_dir = os.path.dirname(c2m.args.image_path) | |
img_name = os.path.basename(os.path.normpath(c2m.args.image_path)) | |
img_base_name = os.path.splitext(img_name)[0] | |
results = [] | |
for mask_i in range(c2m.args.n_masks): | |
print(f"\nEvolving mask {mask_i + 1}...") | |
seed = ( | |
c2m.args.seed | |
if (c2m.args.seed and mask_i == 0) | |
else random.sample(range(1, Const.MAX_SEED), 1)[0] | |
) | |
seed_everything(seed) | |
click_ext = [ | |
ext | |
for ext in ("jpg", "JPG", "JPEG", "jpeg", "png", "PNG") | |
if os.path.exists(os.path.join(img_dir, f"{img_base_name}_click.{ext}")) | |
] | |
if (not click_ext) or (mask_i == 0 and c2m.args.refresh_click): | |
click_create = ClickCreate() | |
c2m.args.click_path = click_create( | |
c2m.args.image_path, os.path.join(img_dir, f"{img_base_name}_click.jpg") | |
) | |
else: | |
c2m.args.click_path = os.path.join( | |
img_dir, f"{img_base_name}_click.{click_ext[0]}" | |
) | |
mask_i_results = c2m.edit_image( | |
image_pil=Image.open(c2m.args.image_path), | |
click_pil=Image.open(c2m.args.click_path), | |
prompts=[c2m.args.prompt] * Const.BATCH_SIZE, | |
height=Const.H, | |
width=Const.W, | |
num_inference_steps=Const.NUM_INFERENCE_STEPS, | |
num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS, | |
guidance_scale=Const.GUIDANCE_SCALE, | |
seed=seed, | |
blending_percentage=Const.BLENDING_START_PERCENTAGE, | |
) | |
results += mask_i_results | |
os.makedirs(c2m.args.output_dir, exist_ok=True) | |
sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True) | |
out_img = sorted_results[0]["im"] | |
out_img = (out_img / 2 + 0.5).clamp(0, 1) | |
out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze() | |
out_img = (out_img * 255).round().astype(np.uint8) | |
out_path = os.path.join(c2m.args.output_dir, f"{img_base_name}_out.jpg") | |
Image.fromarray(out_img).save(out_path, quality=95) | |
print(f"\nCompleted.\nOutput image path:\n{os.path.abspath(out_path)}") | |