click2mask / scripts /dyn_mask.py
omeregev's picture
Initial commit
6df18f5
import torch
from torchvision import transforms
import numpy as np
import skfmm
from PIL import Image
import torch.nn as nn
import cv2
import scipy
from scipy.ndimage.filters import gaussian_filter
import kornia
import warnings
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended")
import alpha_clip
from augmentations import ImageAugmentations
from constants import Const, N
@torch.no_grad()
def get_dist_field(dist_from, device, as_squeezed_np=False):
if not isinstance(dist_from, np.ndarray):
dist_from = dist_from.cpu().numpy()
assert np.max(dist_from) <= 1
dist_from = -(np.where(dist_from, 0, -1) + 0.5)
dist_field = skfmm.distance(dist_from, dx=1)
if as_squeezed_np:
return dist_field
return torch.tensor(dist_field).to(device)
def get_surround(surround_from, surround_width, device, as_squeezed_np=False):
dists = get_dist_field(surround_from, device)
surround = (dists <= surround_width).to(surround_from.dtype)
if as_squeezed_np:
return surround.cpu().numpy()
return surround
class DynMask:
def __init__(self, click_pil, args, init_image_tensor, device, total_steps):
self.args = args
self.device = device
self.init_image = init_image_tensor
self.total_steps = total_steps
self.ac_size = (self.args.alpha_clip_scale, self.args.alpha_clip_scale)
if self.args.alpha_clip_scale == 336:
self.ac_model, self.ac_preprocess = alpha_clip.load(
"ViT-L/14@336px",
alpha_vision_ckpt_pth="./checkpoints/clip_l14_336_grit1m_fultune_8xe.pth",
device=self.device,
)
else:
self.ac_model, self.ac_preprocess = alpha_clip.load(
"ViT-L/14",
alpha_vision_ckpt_pth="./checkpoints/clip_l14_grit20m_fultune_2xe.pth",
device=self.device,
)
self.image_augmentations = ImageAugmentations(
self.args.alpha_clip_scale, Const.AUG_NUM
)
self.text_features = self.get_text_features([self.args.prompt])
self.latent_size = Const.LATENT_SIZE
self.decoded_size = (Const.H, Const.W)
self.thresh_val = Const.THRESH_VAL
self.base_potential = None
self.potential = None
self.latent_mask = None
self.set_init_masks(click_pil)
self.cached_masks_clones = {}
self.closs_hist = {}
self.latents_hist = {}
self.latent_masks_hist = {}
@torch.no_grad()
def normalize_point_size(self, click, radius_for64=1.367):
threshed = (click > 0.5).astype(float)
x, y = np.where(threshed)
center = int(x.mean().round()), int(y.mean().round())
norm_threshed = np.zeros_like(threshed)
norm_threshed[center[0], center[1]] = 1
norm_threshed = get_surround(
torch.tensor(norm_threshed).to(self.device),
click.shape[0] / 64 * radius_for64 - 0.3,
self.device,
as_squeezed_np=True,
)
return norm_threshed
@torch.no_grad()
def calc_potential(self, click_pil, sigma_for_shape64):
dest_size = self.latent_size
click = click_pil.convert("L").resize(dest_size, Image.NEAREST)
click = (np.array(click) > 125).astype(float)
click = self.normalize_point_size(
click, radius_for64=Const.POINT_ON_LATENT_RADIUS
)
potential = gaussian_filter(
click, sigma=sigma_for_shape64 * (click.shape[0]) / 64
)
potential = (potential - np.min(potential)) / max(
np.max(potential) - np.min(potential), 1e-8
)
potential = potential[np.newaxis, np.newaxis, ...]
potential = torch.from_numpy(potential).half().to(self.device)
return potential
@torch.no_grad()
def set_init_masks(self, click_pil, stretch_factor=1.0):
potential = self.calc_potential(
click_pil, sigma_for_shape64=Const.SIGMA_FOR_SHAPE64
)
self.base_potential = potential.detach().to(torch.float64)
if self.base_potential.ndim == 2:
self.base_potential = self.base_potential.unsqueeze(0).unsqueeze(0)
self.base_potential = self.base_potential * (Const.POTENTIAL_PEAK - (-1)) - 1
self.base_potential = stretch_factor * self.base_potential
self.set_cur_masks(step_i=0)
@torch.no_grad()
def set_cur_masks(
self, step_i, grads_to_update=None, surround_ring=None, return_only=None
):
potential = self.base_potential + self.get_bias(step_i)
if grads_to_update is not None:
potential = potential + (surround_ring * Const.MASK_LR * grads_to_update)
potential = transforms.GaussianBlur(
Const.GAUSS_K_MASK, sigma=Const.GAUSS_SIGMA_MASK
)(potential)
if torch.all(potential <= 0):
potential += Const.ADDITION_IN_COLLAPSE
print(
f"{'*' * 10} Mask shrunk entirely, added {Const.ADDITION_IN_COLLAPSE}"
)
elif torch.all(potential >= 0):
potential -= Const.ADDITION_IN_COLLAPSE
print(
f"{'*' * 10} Mask expanded entirely, reduced {Const.ADDITION_IN_COLLAPSE}"
)
self.potential = potential.half()
self.latent_mask = self.get_threshed_mask(self.potential)
return self.get_curr_masks(return_only=return_only)
@torch.no_grad()
def get_curr_masks(self, return_only=None):
if return_only is not None:
if return_only == N.POTENTIAL:
return self.potential
elif return_only == N.LATENT_MASK:
return self.latent_mask
else:
raise ValueError(f"return_only should be in ('{N.POTENTIAL}', '{N.LATENT_MASK}')")
return self.potential, self.latent_mask
@torch.no_grad()
def make_cached_masks_clones(self, name):
self.cached_masks_clones[name] = {
N.POTENTIAL: self.potential.detach().clone(),
N.LATENT_MASK: self.latent_mask.detach().clone(),
}
@torch.no_grad()
def set_masks_from_cached_masks_clones(self, name):
self.potential = self.cached_masks_clones[name][N.POTENTIAL]
self.latent_mask = self.cached_masks_clones[name][N.LATENT_MASK]
@torch.no_grad()
def evolve_mask(
self, step_i, decoder, latent_pred_z0, source_latents, return_only=None
):
potential, latent_mask = self.get_curr_masks()
surround_ring = self.get_ring(latent_mask)
grads_latent = self.calc_grads(
latent_pred_z0=latent_pred_z0,
source_latents=source_latents,
potential=potential,
step_i=step_i,
decoder=decoder,
)
grads_latent = torch.abs(grads_latent)
grads_latent = transforms.GaussianBlur(
Const.GAUSS_K_GRADS, sigma=Const.GAUSS_SIGMA_GRADS
)(grads_latent)
grads_latent = (grads_latent - grads_latent.mean()) / max(
grads_latent.std(), 1e-6
)
grads_latent = torch.maximum(grads_latent, torch.tensor(0.0).to(self.device))
self.set_cur_masks(
step_i=step_i, grads_to_update=grads_latent, surround_ring=surround_ring
)
return self.get_curr_masks(return_only=return_only)
def calc_grads(self, latent_pred_z0, source_latents, potential, step_i, decoder):
with torch.enable_grad():
latent_mask = self.get_threshed_mask(potential)
latent_mask = latent_mask.detach().requires_grad_()
blend_predz0_origz0 = latent_pred_z0 * latent_mask + (
source_latents * (1 - latent_mask)
)
scaled_blend_pred_z0_origz0 = 1 / 0.18215 * blend_predz0_origz0
decoded_blend_predz0_origz0 = decoder(
scaled_blend_pred_z0_origz0
).sample.to(torch.float32)
alpha_mask = transforms.Resize(self.decoded_size, interpolation=0)(
latent_mask
)
alpha_mask = (alpha_mask > 0.5).half().clone().detach()
alpha_mask = get_surround(
alpha_mask,
Const.ALPHA_MASK_DILATION_ON_512 * (Const.HW / 512.0),
self.device,
)
alpha_loss = self.alpha_clip_loss(
decoded_blend_predz0_origz0,
alpha_mask,
self.text_features,
self.image_augmentations,
augs_with_orig=True,
)
self.closs_hist[
step_i - 1
] = alpha_loss.detach() # The mask used for the loss is prev step mask
grads_latent = torch.autograd.grad(alpha_loss, latent_mask)[0].to(
torch.float64
)
return grads_latent.detach()
def alpha_clip_loss(
self,
image,
mask,
text_features,
image_augmentations,
augs_with_orig=True,
return_as_similarity=False,
):
"""
image and mask in range 0.0 to 1.0
"""
assert mask.min() >= 0 and mask.max() <= 1
mask_transform = transforms.Compose(
[nn.AdaptiveAvgPool2d(self.ac_size), transforms.Normalize(0.5, 0.26)]
)
mask_normalize = transforms.Normalize(0.5, 0.26)
image_transform = transforms.Compose(
[
transforms.Resize(self.ac_size, interpolation=Image.BICUBIC),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
image_normalize = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)
image = image.add(1).div(2)
if image.ndim == 3:
image = image.unsqueeze(0)
alpha = mask
if alpha.ndim == 3:
alpha = alpha.unsqueeze(dim=0)
if image_augmentations is not None:
image, alpha = image_augmentations(image, alpha, with_orig=augs_with_orig)
image = image_normalize(image).half()
alpha = mask_normalize(alpha).half()
else:
image = image_transform(image).half()
alpha = mask_transform(alpha).half()
image_features = self.ac_model.visual(image, alpha)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
if return_as_similarity:
alpha_loss = image_features @ text_features.T
else:
alpha_loss = 1 - image_features @ text_features.T
alpha_loss = alpha_loss.mean(dim=0)
return alpha_loss
def get_text_features(self, prompt):
assert type(prompt) in (list, tuple)
text = alpha_clip.tokenize(prompt).to(self.device)
text_features = self.ac_model.encode_text(text)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features
@torch.no_grad()
def get_bias(self, step_i):
bias = Const.BIAS_DILATION_VAL * (Const.BIAS_DILATION_DEC_FACTOR**step_i)
while torch.all(self.base_potential + bias > 0) and bias > 1e-8:
bias *= 0.9
return bias
def get_threshed_mask(self, potential):
thresh_val = self.thresh_val
t_m = (potential > thresh_val).half()
t_m = t_m.cpu().numpy().squeeze().astype(np.uint8)
t_m = scipy.ndimage.binary_fill_holes(t_m)
t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half()
t_m = self.close_gaps_with_connection(
t_m, thickness=Const.CLOSE_GAPS_WITH_CONNECTION_THICKNESS
)
t_m = kornia.morphology.closing(
t_m, torch.ones(Const.CLOSING_K, Const.CLOSING_K).to(self.device)
)
t_m = t_m.cpu().numpy().squeeze().astype(np.uint8)
t_m = scipy.ndimage.binary_fill_holes(t_m)
t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half()
t_m = transforms.GaussianBlur(
Const.GAUSS_K_THRESHED, sigma=Const.GAUSS_SIGMA_THRESHED
)(t_m)
t_m = (t_m > Const.THRESH_POST_GAUSS).half()
return t_m
@torch.no_grad()
def close_gaps_with_connection(self, threshed_mask, thickness):
# also cleans small contours
given_threshed_mask = threshed_mask
threshed_mask = threshed_mask.cpu().numpy().squeeze().astype(np.uint8)
connected_mask = threshed_mask * 0
contours, hierarchy = cv2.findContours(
threshed_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE
)
if len(contours) == 1:
return given_threshed_mask
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
contours = [
cnt
for cnt in contours
if cv2.contourArea(cnt)
> threshed_mask.shape[-1] * threshed_mask.shape[-2] * 0.001
]
cv2.drawContours(connected_mask, contours, 0, 255, -1)
for i in range(1, len(contours)):
cv2.drawContours(connected_mask, contours, i, 255, -1)
hull = cv2.convexHull(contours[i]) # Convex hull of contour
hull = cv2.approxPolyDP(hull, 0.1 * cv2.arcLength(hull, True), True)
connect = hull.copy()
for hp in hull:
dists = np.linalg.norm(contours[0] - hp, axis=2).squeeze()
min_points = np.where(dists == dists.min())[0]
for mp in min_points:
connect = np.append(
connect, np.expand_dims(contours[0][mp], axis=0), axis=0
)
connected_mask = cv2.drawContours(
connected_mask, [connect], -1, color=255, thickness=thickness
)
connected_mask = cv2.drawContours(
connected_mask, [connect], -1, color=255, thickness=-1
)
connected_mask = (
((torch.tensor(connected_mask).to(self.device)) > 125)
.unsqueeze(0)
.unsqueeze(0)
.half()
)
return connected_mask
@torch.no_grad()
def get_plain_dilated_latent_mask(
self,
last_step_latent_mask,
step_i,
total_steps,
max_area_ratio_for_dilation=None,
rerun_dyn_start_step_i=None,
):
max_area_ratio_for_dilation = (
Const.MAX_AREA_RATIO_FOR_DILATION
if max_area_ratio_for_dilation is None
else max_area_ratio_for_dilation
)
if (
last_step_latent_mask.sum()
> max_area_ratio_for_dilation * last_step_latent_mask.nelement()
):
return last_step_latent_mask
first_k = self.latent_size[-1] // 2
while (
get_surround(last_step_latent_mask, first_k, self.device).sum()
> 0.75 * self.latent_size[-1] ** 2
):
first_k -= 1
if rerun_dyn_start_step_i:
plain_dilation_ws = np.linspace(
first_k, 0, rerun_dyn_start_step_i + 2 - Const.RERUN_STOP_DILATION
).round()
plain_dilation_ws = np.pad(
plain_dilation_ws, (0, total_steps - len(plain_dilation_ws))
)
else:
plain_dilation_ws = np.array(
[first_k / max(1, (i / 3)) for i in range(0, total_steps)]
).round()
plain_dilation_ws[-10:] = 0
return get_surround(
last_step_latent_mask, plain_dilation_ws[step_i], self.device
).half()
@torch.no_grad()
def get_ring(self, latent_mask):
assert (latent_mask.min() >= 0) and (latent_mask.max() <= 1)
out_ring_width = Const.OUT_RING_WIDTH
in_on_ring_width = Const.IN_ON_RING_WIDTH
latent_mask = (latent_mask.cpu().numpy() >= 0.5).astype(np.float16)
dists = get_dist_field(latent_mask, self.device, as_squeezed_np=True)
in_ring_width = in_on_ring_width - 1
in_ring = dists.copy()
in_ring[in_ring > -1] = 0
in_ring[in_ring <= -in_ring_width - 1] = 0
in_ring[in_ring != 0] = 1
on_ring = latent_mask.copy()
on_ring[dists < -1] = 0
in_on_ring = in_ring.astype(bool) | on_ring.astype(bool)
out_ring = dists.copy()
out_ring[out_ring <= 0] = 0
out_ring[out_ring > out_ring_width] = 0
out_ring[out_ring != 0] = 1
surround_ring = in_on_ring.astype(np.uint8) | out_ring.astype(np.uint8)
surround_ring = torch.tensor(surround_ring).to(self.device)
return surround_ring