File size: 7,297 Bytes
a176955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from diffusers import DDPMScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from pipeline import Zero123PlusPipeline
from utils import add_white_bg, load_z123_pipe
from typing import Optional

class VAEProcessor:
    """A helper class to handle encoding and decoding images with the VAE."""
    def __init__(self, pipeline: Zero123PlusPipeline):
        self.pipe = pipeline
        self.image_processor = pipeline.image_processor
        self.vae = pipeline.vae
        
        self.latent_shift_factor = 0.22
        self.latent_scale_factor = 0.75
        self.image_scale_factor = 0.5 / 0.8

    def encode(self, image: Image.Image) -> torch.Tensor:
        """Encodes a PIL image into the latent space."""
        image_tensor = self.image_processor.preprocess(image).to(self.vae.device).half()
        with torch.autocast("cuda"), torch.inference_mode():
            image_tensor *= self.image_scale_factor
            denorm = self.vae.encode(image_tensor).latent_dist.mode()
            denorm *= self.vae.config.scaling_factor
        return (denorm - self.latent_shift_factor) * self.latent_scale_factor

    def decode(self, latents: torch.Tensor) -> Image.Image:
        """Decodes latents back into a post-processed image."""
        with torch.autocast("cuda"), torch.inference_mode():
            denorm = latents / self.latent_scale_factor + self.latent_shift_factor
            image = self.vae.decode(denorm / self.vae.config.scaling_factor, return_dict=False)[0]
            image /= self.image_scale_factor
        return self.image_processor.postprocess(image)


class EditAwareDenoiser:
    """Encapsulates the entire Edit-Aware Denoising process."""
    def __init__(self, pipe: Zero123PlusPipeline, scheduler: DDPMScheduler, T_steps: int, src_gs: float, tar_gs: float, n_max: int):
        """Initializes the denoiser with the pipeline and configuration."""
        self.pipe = pipe
        self.scheduler = scheduler
        self.T_steps = T_steps
        self.src_guidance_scale = src_gs
        self.tar_guidance_scale = tar_gs
        self.n_max = n_max

    @staticmethod
    def _mix_cfg(cond: torch.Tensor, uncond: torch.Tensor, cfg: float) -> torch.Tensor:
        """Mixes conditional and unconditional predictions."""
        return uncond + cfg * (cond - uncond)

    def _get_differential_edit_direction(self, t: torch.Tensor, zt_src: torch.Tensor, zt_tar: torch.Tensor) -> torch.Tensor:
        """Computes the differential edit direction (delta v) for a timestep."""
        condition_noise = torch.randn_like(self.src_cond_lat)

        noisy_src_cond_lat = self.pipe.scheduler.scale_model_input(
            self.pipe.scheduler.add_noise(self.src_cond_lat, condition_noise, t), t
        )
        vt_src_uncond, vt_src_cond = self._calc_v_zero(self.src_cond_img, zt_src, t, noisy_src_cond_lat)
        vt_src = self._mix_cfg(vt_src_cond, vt_src_uncond, self.src_guidance_scale)

        noisy_tar_cond_lat = self.pipe.scheduler.scale_model_input(
            self.pipe.scheduler.add_noise(self.tar_cond_lat, condition_noise, t), t
        )
        vt_tar_uncond, vt_tar_cond = self._calc_v_zero(self.tar_cond_img, zt_tar, t, noisy_tar_cond_lat)
        vt_tar = self._mix_cfg(vt_tar_cond, vt_tar_uncond, self.tar_guidance_scale)

        return vt_tar - vt_src

    def _propagate_for_timestep(self, zt_edit: torch.Tensor, t: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
        """Performs a single propagation step for the edit."""
        fwd_noise = torch.randn_like(self.x_src)
        zt_src = self.scheduler.scale_model_input(self.scheduler.add_noise(self.x_src, fwd_noise, t), t)
        zt_tar = self.scheduler.scale_model_input(self.scheduler.add_noise(zt_edit, fwd_noise, t), t)

        diff_v = self._get_differential_edit_direction(t, zt_src, zt_tar)
        
        zt_edit_change = dt * diff_v
        zt_edit = zt_edit.to(torch.float32) + zt_edit_change
        return zt_edit.to(diff_v.dtype)

    def _calc_v_zero(self, condition_image: Image.Image, noisy_latent: torch.Tensor, t: torch.Tensor, noised_condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Calculates the unconditional and conditional v-prediction from the UNet."""
        DUMMY_GUIDANCE_SCALE = 2
        model_output = {}

        def hook_fn(module, args, output):
            model_output['v_pred'] = output[0]

        hook_handle = self.pipe.unet.register_forward_hook(hook_fn)

        try:
            self.pipe(
                condition_image,
                latents=noisy_latent,
                num_inference_steps=1,
                guidance_scale=DUMMY_GUIDANCE_SCALE,
                timesteps=[t.item()],
                output_type="latent",
                noisy_cond_lat=noised_condition,
            )
        finally:
            hook_handle.remove()

        return model_output['v_pred'].chunk(2)

    @torch.no_grad()
    def denoise(self, x_src: torch.Tensor, src_cond_img: Image.Image, tar_cond_img: Image.Image) -> torch.Tensor:
        """Public method to run the entire denoising process."""
        self.x_src = x_src
        self.src_cond_img = src_cond_img
        self.tar_cond_img = tar_cond_img
        
        timesteps, _ = retrieve_timesteps(self.scheduler, self.T_steps, self.x_src.device)
        zt_edit = self.x_src.clone()

        self.src_cond_lat = self.pipe.make_condition_lat(self.src_cond_img, guidance_scale=2.0)
        self.tar_cond_lat = self.pipe.make_condition_lat(self.tar_cond_img, guidance_scale=2.0)
        
        start_index = max(0, len(timesteps) - self.n_max)

        for i in tqdm(range(start_index, len(timesteps))):
            t = timesteps[i]
            t_i = t / 1000.0
            t_im1 = timesteps[i + 1] / 1000.0 if i + 1 < len(timesteps) else torch.zeros_like(t_i)
            dt = t_im1 - t_i
            
            zt_edit = self._propagate_for_timestep(zt_edit, t, dt)

        return zt_edit


def run_editp23(
    src_condition_path: str,
    tgt_condition_path: str,
    original_mv: str,
    save_path: str,
    device_number: int = 0,
    T_steps: int = 50,
    n_max: int = 31,
    src_guidance_scale: float = 3.5,
    tar_guidance_scale: float = 5.0,
    seed: int = 18,
    pipeline: Optional[Zero123PlusPipeline] = None,
) -> None:
    """Main execution function to run the complete editing pipeline."""
    if pipeline is None:
        pipeline = load_z123_pipe(device_number)
    
    torch.manual_seed(seed)
    np.random.seed(seed)

    vae_processor = VAEProcessor(pipeline)

    src_cond_img = add_white_bg(Image.open(src_condition_path))
    tgt_cond_img = add_white_bg(Image.open(tgt_condition_path))
    mv_src = add_white_bg(Image.open(original_mv))
    x0_src = vae_processor.encode(mv_src)

    denoiser = EditAwareDenoiser(
        pipe=pipeline,
        scheduler=pipeline.scheduler,
        T_steps=T_steps,
        src_gs=src_guidance_scale,
        tar_gs=tar_guidance_scale,
        n_max=n_max
    )
    x0_tar = denoiser.denoise(x0_src, src_cond_img, tgt_cond_img)

    image_tar = vae_processor.decode(x0_tar)
    image_tar[0].save(save_path)
    print(f"Successfully saved result to {save_path}")