EditP23 / src /edit_mv.py
roi's picture
Initial commit: EditP23 project with LFS tracking for binary files
a176955
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from diffusers import DDPMScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from pipeline import Zero123PlusPipeline
from utils import add_white_bg, load_z123_pipe
from typing import Optional
class VAEProcessor:
"""A helper class to handle encoding and decoding images with the VAE."""
def __init__(self, pipeline: Zero123PlusPipeline):
self.pipe = pipeline
self.image_processor = pipeline.image_processor
self.vae = pipeline.vae
self.latent_shift_factor = 0.22
self.latent_scale_factor = 0.75
self.image_scale_factor = 0.5 / 0.8
def encode(self, image: Image.Image) -> torch.Tensor:
"""Encodes a PIL image into the latent space."""
image_tensor = self.image_processor.preprocess(image).to(self.vae.device).half()
with torch.autocast("cuda"), torch.inference_mode():
image_tensor *= self.image_scale_factor
denorm = self.vae.encode(image_tensor).latent_dist.mode()
denorm *= self.vae.config.scaling_factor
return (denorm - self.latent_shift_factor) * self.latent_scale_factor
def decode(self, latents: torch.Tensor) -> Image.Image:
"""Decodes latents back into a post-processed image."""
with torch.autocast("cuda"), torch.inference_mode():
denorm = latents / self.latent_scale_factor + self.latent_shift_factor
image = self.vae.decode(denorm / self.vae.config.scaling_factor, return_dict=False)[0]
image /= self.image_scale_factor
return self.image_processor.postprocess(image)
class EditAwareDenoiser:
"""Encapsulates the entire Edit-Aware Denoising process."""
def __init__(self, pipe: Zero123PlusPipeline, scheduler: DDPMScheduler, T_steps: int, src_gs: float, tar_gs: float, n_max: int):
"""Initializes the denoiser with the pipeline and configuration."""
self.pipe = pipe
self.scheduler = scheduler
self.T_steps = T_steps
self.src_guidance_scale = src_gs
self.tar_guidance_scale = tar_gs
self.n_max = n_max
@staticmethod
def _mix_cfg(cond: torch.Tensor, uncond: torch.Tensor, cfg: float) -> torch.Tensor:
"""Mixes conditional and unconditional predictions."""
return uncond + cfg * (cond - uncond)
def _get_differential_edit_direction(self, t: torch.Tensor, zt_src: torch.Tensor, zt_tar: torch.Tensor) -> torch.Tensor:
"""Computes the differential edit direction (delta v) for a timestep."""
condition_noise = torch.randn_like(self.src_cond_lat)
noisy_src_cond_lat = self.pipe.scheduler.scale_model_input(
self.pipe.scheduler.add_noise(self.src_cond_lat, condition_noise, t), t
)
vt_src_uncond, vt_src_cond = self._calc_v_zero(self.src_cond_img, zt_src, t, noisy_src_cond_lat)
vt_src = self._mix_cfg(vt_src_cond, vt_src_uncond, self.src_guidance_scale)
noisy_tar_cond_lat = self.pipe.scheduler.scale_model_input(
self.pipe.scheduler.add_noise(self.tar_cond_lat, condition_noise, t), t
)
vt_tar_uncond, vt_tar_cond = self._calc_v_zero(self.tar_cond_img, zt_tar, t, noisy_tar_cond_lat)
vt_tar = self._mix_cfg(vt_tar_cond, vt_tar_uncond, self.tar_guidance_scale)
return vt_tar - vt_src
def _propagate_for_timestep(self, zt_edit: torch.Tensor, t: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
"""Performs a single propagation step for the edit."""
fwd_noise = torch.randn_like(self.x_src)
zt_src = self.scheduler.scale_model_input(self.scheduler.add_noise(self.x_src, fwd_noise, t), t)
zt_tar = self.scheduler.scale_model_input(self.scheduler.add_noise(zt_edit, fwd_noise, t), t)
diff_v = self._get_differential_edit_direction(t, zt_src, zt_tar)
zt_edit_change = dt * diff_v
zt_edit = zt_edit.to(torch.float32) + zt_edit_change
return zt_edit.to(diff_v.dtype)
def _calc_v_zero(self, condition_image: Image.Image, noisy_latent: torch.Tensor, t: torch.Tensor, noised_condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates the unconditional and conditional v-prediction from the UNet."""
DUMMY_GUIDANCE_SCALE = 2
model_output = {}
def hook_fn(module, args, output):
model_output['v_pred'] = output[0]
hook_handle = self.pipe.unet.register_forward_hook(hook_fn)
try:
self.pipe(
condition_image,
latents=noisy_latent,
num_inference_steps=1,
guidance_scale=DUMMY_GUIDANCE_SCALE,
timesteps=[t.item()],
output_type="latent",
noisy_cond_lat=noised_condition,
)
finally:
hook_handle.remove()
return model_output['v_pred'].chunk(2)
@torch.no_grad()
def denoise(self, x_src: torch.Tensor, src_cond_img: Image.Image, tar_cond_img: Image.Image) -> torch.Tensor:
"""Public method to run the entire denoising process."""
self.x_src = x_src
self.src_cond_img = src_cond_img
self.tar_cond_img = tar_cond_img
timesteps, _ = retrieve_timesteps(self.scheduler, self.T_steps, self.x_src.device)
zt_edit = self.x_src.clone()
self.src_cond_lat = self.pipe.make_condition_lat(self.src_cond_img, guidance_scale=2.0)
self.tar_cond_lat = self.pipe.make_condition_lat(self.tar_cond_img, guidance_scale=2.0)
start_index = max(0, len(timesteps) - self.n_max)
for i in tqdm(range(start_index, len(timesteps))):
t = timesteps[i]
t_i = t / 1000.0
t_im1 = timesteps[i + 1] / 1000.0 if i + 1 < len(timesteps) else torch.zeros_like(t_i)
dt = t_im1 - t_i
zt_edit = self._propagate_for_timestep(zt_edit, t, dt)
return zt_edit
def run_editp23(
src_condition_path: str,
tgt_condition_path: str,
original_mv: str,
save_path: str,
device_number: int = 0,
T_steps: int = 50,
n_max: int = 31,
src_guidance_scale: float = 3.5,
tar_guidance_scale: float = 5.0,
seed: int = 18,
pipeline: Optional[Zero123PlusPipeline] = None,
) -> None:
"""Main execution function to run the complete editing pipeline."""
if pipeline is None:
pipeline = load_z123_pipe(device_number)
torch.manual_seed(seed)
np.random.seed(seed)
vae_processor = VAEProcessor(pipeline)
src_cond_img = add_white_bg(Image.open(src_condition_path))
tgt_cond_img = add_white_bg(Image.open(tgt_condition_path))
mv_src = add_white_bg(Image.open(original_mv))
x0_src = vae_processor.encode(mv_src)
denoiser = EditAwareDenoiser(
pipe=pipeline,
scheduler=pipeline.scheduler,
T_steps=T_steps,
src_gs=src_guidance_scale,
tar_gs=tar_guidance_scale,
n_max=n_max
)
x0_tar = denoiser.denoise(x0_src, src_cond_img, tgt_cond_img)
image_tar = vae_processor.decode(x0_tar)
image_tar[0].save(save_path)
print(f"Successfully saved result to {save_path}")