|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
from diffusers import DDPMScheduler |
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps |
|
from pipeline import Zero123PlusPipeline |
|
from utils import add_white_bg, load_z123_pipe |
|
from typing import Optional |
|
|
|
class VAEProcessor: |
|
"""A helper class to handle encoding and decoding images with the VAE.""" |
|
def __init__(self, pipeline: Zero123PlusPipeline): |
|
self.pipe = pipeline |
|
self.image_processor = pipeline.image_processor |
|
self.vae = pipeline.vae |
|
|
|
self.latent_shift_factor = 0.22 |
|
self.latent_scale_factor = 0.75 |
|
self.image_scale_factor = 0.5 / 0.8 |
|
|
|
def encode(self, image: Image.Image) -> torch.Tensor: |
|
"""Encodes a PIL image into the latent space.""" |
|
image_tensor = self.image_processor.preprocess(image).to(self.vae.device).half() |
|
with torch.autocast("cuda"), torch.inference_mode(): |
|
image_tensor *= self.image_scale_factor |
|
denorm = self.vae.encode(image_tensor).latent_dist.mode() |
|
denorm *= self.vae.config.scaling_factor |
|
return (denorm - self.latent_shift_factor) * self.latent_scale_factor |
|
|
|
def decode(self, latents: torch.Tensor) -> Image.Image: |
|
"""Decodes latents back into a post-processed image.""" |
|
with torch.autocast("cuda"), torch.inference_mode(): |
|
denorm = latents / self.latent_scale_factor + self.latent_shift_factor |
|
image = self.vae.decode(denorm / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image /= self.image_scale_factor |
|
return self.image_processor.postprocess(image) |
|
|
|
|
|
class EditAwareDenoiser: |
|
"""Encapsulates the entire Edit-Aware Denoising process.""" |
|
def __init__(self, pipe: Zero123PlusPipeline, scheduler: DDPMScheduler, T_steps: int, src_gs: float, tar_gs: float, n_max: int): |
|
"""Initializes the denoiser with the pipeline and configuration.""" |
|
self.pipe = pipe |
|
self.scheduler = scheduler |
|
self.T_steps = T_steps |
|
self.src_guidance_scale = src_gs |
|
self.tar_guidance_scale = tar_gs |
|
self.n_max = n_max |
|
|
|
@staticmethod |
|
def _mix_cfg(cond: torch.Tensor, uncond: torch.Tensor, cfg: float) -> torch.Tensor: |
|
"""Mixes conditional and unconditional predictions.""" |
|
return uncond + cfg * (cond - uncond) |
|
|
|
def _get_differential_edit_direction(self, t: torch.Tensor, zt_src: torch.Tensor, zt_tar: torch.Tensor) -> torch.Tensor: |
|
"""Computes the differential edit direction (delta v) for a timestep.""" |
|
condition_noise = torch.randn_like(self.src_cond_lat) |
|
|
|
noisy_src_cond_lat = self.pipe.scheduler.scale_model_input( |
|
self.pipe.scheduler.add_noise(self.src_cond_lat, condition_noise, t), t |
|
) |
|
vt_src_uncond, vt_src_cond = self._calc_v_zero(self.src_cond_img, zt_src, t, noisy_src_cond_lat) |
|
vt_src = self._mix_cfg(vt_src_cond, vt_src_uncond, self.src_guidance_scale) |
|
|
|
noisy_tar_cond_lat = self.pipe.scheduler.scale_model_input( |
|
self.pipe.scheduler.add_noise(self.tar_cond_lat, condition_noise, t), t |
|
) |
|
vt_tar_uncond, vt_tar_cond = self._calc_v_zero(self.tar_cond_img, zt_tar, t, noisy_tar_cond_lat) |
|
vt_tar = self._mix_cfg(vt_tar_cond, vt_tar_uncond, self.tar_guidance_scale) |
|
|
|
return vt_tar - vt_src |
|
|
|
def _propagate_for_timestep(self, zt_edit: torch.Tensor, t: torch.Tensor, dt: torch.Tensor) -> torch.Tensor: |
|
"""Performs a single propagation step for the edit.""" |
|
fwd_noise = torch.randn_like(self.x_src) |
|
zt_src = self.scheduler.scale_model_input(self.scheduler.add_noise(self.x_src, fwd_noise, t), t) |
|
zt_tar = self.scheduler.scale_model_input(self.scheduler.add_noise(zt_edit, fwd_noise, t), t) |
|
|
|
diff_v = self._get_differential_edit_direction(t, zt_src, zt_tar) |
|
|
|
zt_edit_change = dt * diff_v |
|
zt_edit = zt_edit.to(torch.float32) + zt_edit_change |
|
return zt_edit.to(diff_v.dtype) |
|
|
|
def _calc_v_zero(self, condition_image: Image.Image, noisy_latent: torch.Tensor, t: torch.Tensor, noised_condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Calculates the unconditional and conditional v-prediction from the UNet.""" |
|
DUMMY_GUIDANCE_SCALE = 2 |
|
model_output = {} |
|
|
|
def hook_fn(module, args, output): |
|
model_output['v_pred'] = output[0] |
|
|
|
hook_handle = self.pipe.unet.register_forward_hook(hook_fn) |
|
|
|
try: |
|
self.pipe( |
|
condition_image, |
|
latents=noisy_latent, |
|
num_inference_steps=1, |
|
guidance_scale=DUMMY_GUIDANCE_SCALE, |
|
timesteps=[t.item()], |
|
output_type="latent", |
|
noisy_cond_lat=noised_condition, |
|
) |
|
finally: |
|
hook_handle.remove() |
|
|
|
return model_output['v_pred'].chunk(2) |
|
|
|
@torch.no_grad() |
|
def denoise(self, x_src: torch.Tensor, src_cond_img: Image.Image, tar_cond_img: Image.Image) -> torch.Tensor: |
|
"""Public method to run the entire denoising process.""" |
|
self.x_src = x_src |
|
self.src_cond_img = src_cond_img |
|
self.tar_cond_img = tar_cond_img |
|
|
|
timesteps, _ = retrieve_timesteps(self.scheduler, self.T_steps, self.x_src.device) |
|
zt_edit = self.x_src.clone() |
|
|
|
self.src_cond_lat = self.pipe.make_condition_lat(self.src_cond_img, guidance_scale=2.0) |
|
self.tar_cond_lat = self.pipe.make_condition_lat(self.tar_cond_img, guidance_scale=2.0) |
|
|
|
start_index = max(0, len(timesteps) - self.n_max) |
|
|
|
for i in tqdm(range(start_index, len(timesteps))): |
|
t = timesteps[i] |
|
t_i = t / 1000.0 |
|
t_im1 = timesteps[i + 1] / 1000.0 if i + 1 < len(timesteps) else torch.zeros_like(t_i) |
|
dt = t_im1 - t_i |
|
|
|
zt_edit = self._propagate_for_timestep(zt_edit, t, dt) |
|
|
|
return zt_edit |
|
|
|
|
|
def run_editp23( |
|
src_condition_path: str, |
|
tgt_condition_path: str, |
|
original_mv: str, |
|
save_path: str, |
|
device_number: int = 0, |
|
T_steps: int = 50, |
|
n_max: int = 31, |
|
src_guidance_scale: float = 3.5, |
|
tar_guidance_scale: float = 5.0, |
|
seed: int = 18, |
|
pipeline: Optional[Zero123PlusPipeline] = None, |
|
) -> None: |
|
"""Main execution function to run the complete editing pipeline.""" |
|
if pipeline is None: |
|
pipeline = load_z123_pipe(device_number) |
|
|
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
vae_processor = VAEProcessor(pipeline) |
|
|
|
src_cond_img = add_white_bg(Image.open(src_condition_path)) |
|
tgt_cond_img = add_white_bg(Image.open(tgt_condition_path)) |
|
mv_src = add_white_bg(Image.open(original_mv)) |
|
x0_src = vae_processor.encode(mv_src) |
|
|
|
denoiser = EditAwareDenoiser( |
|
pipe=pipeline, |
|
scheduler=pipeline.scheduler, |
|
T_steps=T_steps, |
|
src_gs=src_guidance_scale, |
|
tar_gs=tar_guidance_scale, |
|
n_max=n_max |
|
) |
|
x0_tar = denoiser.denoise(x0_src, src_cond_img, tgt_cond_img) |
|
|
|
image_tar = vae_processor.decode(x0_tar) |
|
image_tar[0].save(save_path) |
|
print(f"Successfully saved result to {save_path}") |