File size: 7,297 Bytes
a176955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from diffusers import DDPMScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from pipeline import Zero123PlusPipeline
from utils import add_white_bg, load_z123_pipe
from typing import Optional
class VAEProcessor:
"""A helper class to handle encoding and decoding images with the VAE."""
def __init__(self, pipeline: Zero123PlusPipeline):
self.pipe = pipeline
self.image_processor = pipeline.image_processor
self.vae = pipeline.vae
self.latent_shift_factor = 0.22
self.latent_scale_factor = 0.75
self.image_scale_factor = 0.5 / 0.8
def encode(self, image: Image.Image) -> torch.Tensor:
"""Encodes a PIL image into the latent space."""
image_tensor = self.image_processor.preprocess(image).to(self.vae.device).half()
with torch.autocast("cuda"), torch.inference_mode():
image_tensor *= self.image_scale_factor
denorm = self.vae.encode(image_tensor).latent_dist.mode()
denorm *= self.vae.config.scaling_factor
return (denorm - self.latent_shift_factor) * self.latent_scale_factor
def decode(self, latents: torch.Tensor) -> Image.Image:
"""Decodes latents back into a post-processed image."""
with torch.autocast("cuda"), torch.inference_mode():
denorm = latents / self.latent_scale_factor + self.latent_shift_factor
image = self.vae.decode(denorm / self.vae.config.scaling_factor, return_dict=False)[0]
image /= self.image_scale_factor
return self.image_processor.postprocess(image)
class EditAwareDenoiser:
"""Encapsulates the entire Edit-Aware Denoising process."""
def __init__(self, pipe: Zero123PlusPipeline, scheduler: DDPMScheduler, T_steps: int, src_gs: float, tar_gs: float, n_max: int):
"""Initializes the denoiser with the pipeline and configuration."""
self.pipe = pipe
self.scheduler = scheduler
self.T_steps = T_steps
self.src_guidance_scale = src_gs
self.tar_guidance_scale = tar_gs
self.n_max = n_max
@staticmethod
def _mix_cfg(cond: torch.Tensor, uncond: torch.Tensor, cfg: float) -> torch.Tensor:
"""Mixes conditional and unconditional predictions."""
return uncond + cfg * (cond - uncond)
def _get_differential_edit_direction(self, t: torch.Tensor, zt_src: torch.Tensor, zt_tar: torch.Tensor) -> torch.Tensor:
"""Computes the differential edit direction (delta v) for a timestep."""
condition_noise = torch.randn_like(self.src_cond_lat)
noisy_src_cond_lat = self.pipe.scheduler.scale_model_input(
self.pipe.scheduler.add_noise(self.src_cond_lat, condition_noise, t), t
)
vt_src_uncond, vt_src_cond = self._calc_v_zero(self.src_cond_img, zt_src, t, noisy_src_cond_lat)
vt_src = self._mix_cfg(vt_src_cond, vt_src_uncond, self.src_guidance_scale)
noisy_tar_cond_lat = self.pipe.scheduler.scale_model_input(
self.pipe.scheduler.add_noise(self.tar_cond_lat, condition_noise, t), t
)
vt_tar_uncond, vt_tar_cond = self._calc_v_zero(self.tar_cond_img, zt_tar, t, noisy_tar_cond_lat)
vt_tar = self._mix_cfg(vt_tar_cond, vt_tar_uncond, self.tar_guidance_scale)
return vt_tar - vt_src
def _propagate_for_timestep(self, zt_edit: torch.Tensor, t: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
"""Performs a single propagation step for the edit."""
fwd_noise = torch.randn_like(self.x_src)
zt_src = self.scheduler.scale_model_input(self.scheduler.add_noise(self.x_src, fwd_noise, t), t)
zt_tar = self.scheduler.scale_model_input(self.scheduler.add_noise(zt_edit, fwd_noise, t), t)
diff_v = self._get_differential_edit_direction(t, zt_src, zt_tar)
zt_edit_change = dt * diff_v
zt_edit = zt_edit.to(torch.float32) + zt_edit_change
return zt_edit.to(diff_v.dtype)
def _calc_v_zero(self, condition_image: Image.Image, noisy_latent: torch.Tensor, t: torch.Tensor, noised_condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates the unconditional and conditional v-prediction from the UNet."""
DUMMY_GUIDANCE_SCALE = 2
model_output = {}
def hook_fn(module, args, output):
model_output['v_pred'] = output[0]
hook_handle = self.pipe.unet.register_forward_hook(hook_fn)
try:
self.pipe(
condition_image,
latents=noisy_latent,
num_inference_steps=1,
guidance_scale=DUMMY_GUIDANCE_SCALE,
timesteps=[t.item()],
output_type="latent",
noisy_cond_lat=noised_condition,
)
finally:
hook_handle.remove()
return model_output['v_pred'].chunk(2)
@torch.no_grad()
def denoise(self, x_src: torch.Tensor, src_cond_img: Image.Image, tar_cond_img: Image.Image) -> torch.Tensor:
"""Public method to run the entire denoising process."""
self.x_src = x_src
self.src_cond_img = src_cond_img
self.tar_cond_img = tar_cond_img
timesteps, _ = retrieve_timesteps(self.scheduler, self.T_steps, self.x_src.device)
zt_edit = self.x_src.clone()
self.src_cond_lat = self.pipe.make_condition_lat(self.src_cond_img, guidance_scale=2.0)
self.tar_cond_lat = self.pipe.make_condition_lat(self.tar_cond_img, guidance_scale=2.0)
start_index = max(0, len(timesteps) - self.n_max)
for i in tqdm(range(start_index, len(timesteps))):
t = timesteps[i]
t_i = t / 1000.0
t_im1 = timesteps[i + 1] / 1000.0 if i + 1 < len(timesteps) else torch.zeros_like(t_i)
dt = t_im1 - t_i
zt_edit = self._propagate_for_timestep(zt_edit, t, dt)
return zt_edit
def run_editp23(
src_condition_path: str,
tgt_condition_path: str,
original_mv: str,
save_path: str,
device_number: int = 0,
T_steps: int = 50,
n_max: int = 31,
src_guidance_scale: float = 3.5,
tar_guidance_scale: float = 5.0,
seed: int = 18,
pipeline: Optional[Zero123PlusPipeline] = None,
) -> None:
"""Main execution function to run the complete editing pipeline."""
if pipeline is None:
pipeline = load_z123_pipe(device_number)
torch.manual_seed(seed)
np.random.seed(seed)
vae_processor = VAEProcessor(pipeline)
src_cond_img = add_white_bg(Image.open(src_condition_path))
tgt_cond_img = add_white_bg(Image.open(tgt_condition_path))
mv_src = add_white_bg(Image.open(original_mv))
x0_src = vae_processor.encode(mv_src)
denoiser = EditAwareDenoiser(
pipe=pipeline,
scheduler=pipeline.scheduler,
T_steps=T_steps,
src_gs=src_guidance_scale,
tar_gs=tar_guidance_scale,
n_max=n_max
)
x0_tar = denoiser.denoise(x0_src, src_cond_img, tgt_cond_img)
image_tar = vae_processor.decode(x0_tar)
image_tar[0].save(save_path)
print(f"Successfully saved result to {save_path}") |