click2mask / ldm /image_reconstruction.py
omeregev's picture
Initial commit
6df18f5
import argparse
import copy
import os
from pathlib import Path
from PIL import Image
import lpips
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from general_utils.seamless_cloning import poisson_seamless_clone
from omegaconf import OmegaConf
from torch import optim
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.utils import make_grid
from ldm.image_editor import load_model_from_config, read_image, read_mask
from ldm.models.diffusion.ddpm import LatentDiffusion
class ImagesDataset(Dataset):
def __init__(self, source_path, transform, indices=None):
self.source_path = Path(source_path)
self.img_names = os.listdir(source_path)
self.img_names.sort()
if indices is not None:
self.img_names = [self.img_names[i] for i in indices]
self.transform = transform
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
image = Image.open(self.source_path / self.img_names[idx]).convert("RGB")
tensor_image = self.transform(image)
tensor_image = tensor_image * 2.0 - 1.0
return tensor_image
class ImageReconstruction:
def __init__(
self,
verbose: bool = False,
):
self.opt = self.get_arguments()
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml")
self.device = (
torch.device(f"cuda:{self.opt.gpu_id}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model = load_model_from_config(
config=config, ckpt="models/ldm/text2img-large/model.ckpt", device=self.device
)
self.model = self.model.to(self.device)
img_size = (self.opt.W, self.opt.H)
mask_size = (self.opt.W // 8, self.opt.H // 8)
self.init_image = read_image(
img_path=self.opt.init_image, device=self.device, dest_size=img_size
)
self.mask, self.org_mask = read_mask(
mask_path=self.opt.mask, device=self.device, dest_size=mask_size, img_size=img_size
)
if self.opt.invert_mask:
self.mask = 1 - self.mask
self.org_mask = 1 - self.org_mask
self.verbose = verbose
# self.lpips_model = lpips.LPIPS(net="vgg").to(model.device)
samples_dataset = ImagesDataset(
source_path=os.path.join(self.opt.images_path, "images"),
transform=ToTensor(),
indices=self.opt.selected_indices,
)
reconstructed_samples = self._reconstruct_background(samples_dataset)
self._save_visualization(reconstructed_samples)
def get_arguments(self):
parser = argparse.ArgumentParser()
parser.add_argument("--init_image", type=str, default="", help="a source image to edit")
parser.add_argument("--mask", type=str, default="", help="a mask to edit the image")
parser.add_argument(
"--invert_mask",
help="Indicator enabling inverting the input mask",
action="store_true",
dest="invert_mask",
)
parser.add_argument(
"--images_path",
type=str,
default="outputs/edit_results/samples/",
help="The path for the images to reconstruct",
)
parser.add_argument(
"--H",
type=int,
default=256,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=256,
help="image width, in pixel space",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="The number of optimization steps in case of optimization",
)
parser.add_argument(
"--optimization_steps",
type=int,
default=75,
help="The number of optimization steps in case of optimization",
)
parser.add_argument(
"--reconstruction_type",
type=str,
help="The background reconstruction type",
default="optimization",
choices=["optimization", "pixel", "poisson"],
)
parser.add_argument(
"--optimization_mode",
type=str,
help="The optimization mode in case of optimization reconstruction type",
default="weights",
choices=["weights", "latents"],
)
parser.add_argument(
"--selected_indices",
type=int,
nargs="+",
default=None,
help="The indices to reconstruct, if not given - will reconstruct all the images",
)
# Misc
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="The GPU specific id",
)
opt = parser.parse_args()
return opt
def _reconstruct_background(self, samples):
reconstructed_samples = []
if self.opt.reconstruction_type == "pixel":
for sample in samples:
sample = sample.to(self.device) * self.org_mask[0] + self.init_image * (
1 - self.org_mask[0]
)
sample = torch.clamp((sample + 1.0) / 2.0, min=0.0, max=1.0)
reconstructed_samples.append(sample)
elif self.opt.reconstruction_type == "poisson":
mask_numpy = self.org_mask.squeeze().cpu().numpy()
init_image_numpy = rearrange(
((self.init_image + 1) / 2).squeeze().cpu().numpy(), "c h w -> h w c"
)
for sample in samples:
sample = torch.clamp((sample + 1.0) / 2.0, min=0.0, max=1.0)
curr_sample = rearrange(sample.cpu().numpy(), "c h w -> h w c")
cloned_sample = poisson_seamless_clone(
source_image=curr_sample,
destination_image=init_image_numpy,
mask=mask_numpy,
)
cloned_sample = torch.from_numpy(
cloned_sample[np.newaxis, ...].transpose(0, 3, 1, 2)
).to(self.device)
reconstructed_samples.append(cloned_sample)
elif self.opt.reconstruction_type == "optimization":
for sample in samples:
optimized_sample = self.reconstruct_image_by_optimization(
fg_image=sample.to(self.device).unsqueeze(0),
bg_image=self.init_image,
mask=self.org_mask,
)
optimized_sample = torch.clamp(optimized_sample, min=0.0, max=1.0)
reconstructed_samples.append(optimized_sample)
else:
raise ValueError("Missing reconstruction type")
reconstructed_samples = torch.cat(reconstructed_samples)
return reconstructed_samples
def loss(
self,
fg_image: torch.Tensor,
bg_image: torch.Tensor,
curr_latent: torch.Tensor,
mask: torch.Tensor,
preservation_ratio: float = 100,
):
curr_reconstruction = self.model.decode_first_stage(curr_latent)
loss = (
F.mse_loss(fg_image * mask, curr_reconstruction * mask)
+ F.mse_loss(bg_image * (1 - mask), curr_reconstruction * (1 - mask))
* preservation_ratio
)
# loss = self.lpips_model(fg_image * mask, curr_reconstruction * mask).sum() + \
# self.lpips_model(bg_image * (1 - mask), curr_reconstruction * (1 - mask)).sum()
return loss
@torch.no_grad()
def get_curr_reconstruction(self, curr_latent):
curr_reconstruction = self.model.decode_first_stage(curr_latent)
curr_reconstruction = torch.clamp((curr_reconstruction + 1.0) / 2.0, min=0.0, max=1.0)
return curr_reconstruction
@torch.no_grad()
def plot_reconstructed_image(self, curr_latent, fg_image, bg_image, mask):
curr_reconstruction = self.get_curr_reconstruction(curr_latent=curr_latent)
curr_reconstruction = curr_reconstruction[0].cpu().numpy().transpose(1, 2, 0)
fg_image = torch.clamp((fg_image + 1.0) / 2.0, min=0.0, max=1.0)
fg_image = fg_image[0].cpu().numpy().transpose(1, 2, 0)
bg_image = torch.clamp((bg_image + 1.0) / 2.0, min=0.0, max=1.0)
bg_image = bg_image[0].cpu().numpy().transpose(1, 2, 0)
mask = mask[0].detach().cpu().numpy().transpose(1, 2, 0)
composed = fg_image * mask + bg_image * (1 - mask)
plt.imshow(np.hstack([bg_image, fg_image, composed, curr_reconstruction]))
plt.axis("off")
plt.tight_layout()
plt.show()
def _save_visualization(self, samples, images_per_row: int = 6):
self._save_images(samples)
# Add source image and mask to visualization
if self.init_image is not None:
blank_image = torch.ones_like(self.init_image)
if self.mask is None:
self.org_mask = blank_image
resized_mask = blank_image
else:
self.org_mask = self.org_mask.repeat((1, 3, 1, 1))
resized_mask = F.interpolate(self.mask, size=(self.opt.H, self.opt.W))
resized_mask = resized_mask.repeat((1, 3, 1, 1))
encoder_posterior = self.model.encode_first_stage(self.init_image)
encoder_result = self.model.get_first_stage_encoding(encoder_posterior)
reconstructed_image = self.model.decode_first_stage(encoder_result)
reconstructed_image = torch.clamp((reconstructed_image + 1.0) / 2.0, min=0.0, max=1.0)
inputs_row = [
(self.init_image + 1) / 2,
reconstructed_image,
self.org_mask,
resized_mask,
]
pad_row = [blank_image for _ in range(images_per_row - len(inputs_row))]
inputs_row = inputs_row + pad_row
samples = torch.cat([torch.cat(inputs_row), samples])
grid = make_grid(samples, nrow=images_per_row)
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(
os.path.join(self.opt.images_path, f"reconstructed_{self.opt.reconstruction_type}.png")
)
def _save_images(self, samples):
samples_dir = os.path.join(
self.opt.images_path,
f"reconstructed_{self.opt.reconstruction_type}",
)
os.makedirs(samples_dir, exist_ok=True)
for i, sample in enumerate(samples):
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(os.path.join(samples_dir, f"{i:04}.png"))
def reconstruct_image_by_optimization(
self, fg_image: torch.Tensor, bg_image: torch.Tensor, mask: torch.Tensor
):
encoder_posterior = self.model.encode_first_stage(fg_image)
initial_latent = self.model.get_first_stage_encoding(encoder_posterior)
if self.opt.optimization_mode == "weights":
curr_latent = initial_latent.clone().detach()
decoder_copy = copy.deepcopy(self.model.first_stage_model.decoder)
self.model.first_stage_model.decoder.requires_grad_(True)
optimizer = optim.Adam(self.model.first_stage_model.decoder.parameters(), lr=0.0001)
else:
curr_latent = initial_latent.clone().detach().requires_grad_(True)
optimizer = optim.Adam([curr_latent], lr=0.1)
for i in tqdm(range(self.opt.optimization_steps), desc="Reconstruction optimization"):
if self.verbose and i % 25 == 0:
self.plot_reconstructed_image(
curr_latent=curr_latent,
fg_image=fg_image,
bg_image=bg_image,
mask=mask,
)
optimizer.zero_grad()
loss = self.loss(
fg_image=fg_image, bg_image=bg_image, curr_latent=curr_latent, mask=mask
)
if self.verbose:
print(f"Iteration {i}: Curr loss is {loss}")
loss.backward()
optimizer.step()
reconstructed_result = self.get_curr_reconstruction(curr_latent=curr_latent)
if self.opt.optimization_mode == "weights":
self.model.first_stage_model.decoder = None
self.model.first_stage_model.decoder = decoder_copy
return reconstructed_result