File size: 5,359 Bytes
7758cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# modified from https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L23

import torch
import torch.nn as nn
import numpy as np


class ModelSamplingDiscreteFlow(nn.Module):
    """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""

    def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs):
        super().__init__()
        self.num_train_timesteps = num_train_timesteps
        self.shift = shift
        ts = self.to_sigma(torch.arange(1, num_train_timesteps + 1, 1))  # [1/1000, 1]
        self.register_buffer("sigmas", ts)
    
    @property
    def sigma_min(self):
        return self.sigmas[0]
    
    @property
    def sigma_max(self):
        return self.sigmas[-1]

    def to_timestep(self, sigma):
        return sigma * self.num_train_timesteps
    
    def to_sigma(self, timestep: torch.Tensor):
        timestep = timestep / self.num_train_timesteps
        if self.shift == 1.0:
            return timestep
        return self.shift * timestep / (1 + (self.shift - 1) * timestep)
    
    def uniform_sample_t(self, batch_size, device):
        ts = (self.sigma_max - self.sigma_min) * torch.rand(batch_size, device=device) + self.sigma_min
        return ts

    def calculate_denoised(self, sigma, model_output, model_input):
        # model ouput, vector field, v = dx = (x_1 - x_0)
        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
        return model_input - model_output * sigma

    def noise_scaling(self, sigma, noise, latent_image):
        return sigma * noise + (1.0 - sigma) * latent_image
        
    def add_noise(self, sample, noise=None, timesteps=None):
        # sample, B, L, D
        if timesteps is None:
            # Sample time step
            batch_size = sample.shape[0]
            sigmas = self.uniform_sample_t(batch_size, device=sample.device).to(dtype=sample.dtype)  # (B,)
            timesteps = self.to_timestep(sigmas)
        else:
            timesteps = timesteps.to(device=sample.device, dtype=sample.dtype)
            sigmas = self.to_sigma(timesteps)
        
        sigmas = sigmas.view(-1, 1, 1)            # (B, 1, 1)
        noise = torch.randn_like(sample)
        noisy_samples = sigmas * noise + (1.0 - sigmas) * sample
        return noisy_samples, noise, noise - sample, timesteps

    def set_timesteps(self, num_inference_steps, device=None):
        if num_inference_steps > self.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
                f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.num_train_timesteps} timesteps."
            )

        self.num_inference_steps = num_inference_steps

        start = self.to_timestep(self.sigma_max)
        end = self.to_timestep(self.sigma_min)
        timesteps = torch.linspace(start, end, num_inference_steps)

        self.timesteps = torch.from_numpy(np.array(timesteps)).to(device)

    def append_dims(self, x, target_dims):
        """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
        dims_to_append = target_dims - x.ndim
        return x[(...,) + (None,) * dims_to_append]

    def to_d(self, x, sigma, denoised):
        """Converts a denoiser output to a Karras ODE derivative."""
        return (x - denoised) / self.append_dims(sigma, x.ndim)

    @torch.no_grad()
    def step(self, model_output, timestep, sample, method="euler", **kwargs):
        """
        Args:
            model_output (`torch.Tensor`):
                The direct output from learned diffusion model, direction (noise - x_0).
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process, x_t.
            method (`str`):
                ODE solver, `euler` or `dpmpp_2m`

        Returns:
            `tuple`:
                the sample tensor.
        """

        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )
        sigma = self.to_sigma(timestep)
        prev_sigma = sigma - (self.sigma_max - self.sigma_min) / (self.num_inference_steps - 1)
        prev_sigma = 0.0 if prev_sigma < 0.0 else prev_sigma

        if method == "euler":
            """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
            dt = prev_sigma - sigma
            prev_sample = sample + model_output * dt
        elif method == "dpmpp_2m":
            """DPM-Solver++(2M)."""
            raise NotImplementedError
        else:
            raise ValueError(f"Unsupported ode solver: {method}, only supports `euler` or `dpmpp_2m`")

        pred_original_sample = sample - model_output * sigma

        return (
            prev_sample,
            pred_original_sample
        )

    def get_pred_original_sample(self, model_output, timestep, sample):
        sigma = self.to_sigma(timestep).view(-1, 1, 1)
        pred_original_sample = sample - model_output * sigma

        return pred_original_sample