Commit
·
e47d403
1
Parent(s):
239ab44
Upload k_diffusion_dpmpp.diff
Browse files- k_diffusion_dpmpp.diff +145 -0
k_diffusion_dpmpp.diff
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/README.md b/README.md
|
2 |
+
index 4f7c92f..e386624 100644
|
3 |
+
--- a/README.md
|
4 |
+
+++ b/README.md
|
5 |
+
@@ -1,3 +1,12 @@
|
6 |
+
+# THIS IS A FORK
|
7 |
+
+
|
8 |
+
+Forked from https://github.com/crowsonkb/k-diffusion
|
9 |
+
+
|
10 |
+
+Changes:
|
11 |
+
+
|
12 |
+
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
|
13 |
+
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
|
14 |
+
+
|
15 |
+
# k-diffusion
|
16 |
+
|
17 |
+
An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
|
18 |
+
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
|
19 |
+
index 79b51ce..b41d0eb 100644
|
20 |
+
--- a/k_diffusion/external.py
|
21 |
+
+++ b/k_diffusion/external.py
|
22 |
+
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
|
23 |
+
|
24 |
+
def t_to_sigma(self, t):
|
25 |
+
t = t.float()
|
26 |
+
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
27 |
+
+ low_idx = t.floor().long()
|
28 |
+
+ high_idx = t.ceil().long()
|
29 |
+
+ w = t - low_idx if t.device.type == 'mps' else t.frac()
|
30 |
+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
31 |
+
return log_sigma.exp()
|
32 |
+
|
33 |
+
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
|
34 |
+
index f050f88..9f859d4 100644
|
35 |
+
--- a/k_diffusion/sampling.py
|
36 |
+
+++ b/k_diffusion/sampling.py
|
37 |
+
@@ -16,7 +16,7 @@ def append_zero(x):
|
38 |
+
|
39 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
40 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
41 |
+
- ramp = torch.linspace(0, 1, n)
|
42 |
+
+ ramp = torch.linspace(0, 1, n, device=device)
|
43 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
44 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
45 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
46 |
+
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
|
47 |
+
|
48 |
+
for i in range(len(orders)):
|
49 |
+
eps_cache = {}
|
50 |
+
- t, t_next = ts[i], ts[i + 1]
|
51 |
+
+
|
52 |
+
+ # MacOS fix
|
53 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
54 |
+
+ t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
|
55 |
+
+ else:
|
56 |
+
+ t, t_next = ts[i], ts[i + 1]
|
57 |
+
+
|
58 |
+
if eta:
|
59 |
+
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
60 |
+
t_next_ = torch.minimum(t_end, self.t(sd))
|
61 |
+
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
62 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
63 |
+
s_in = x.new_ones([x.shape[0]])
|
64 |
+
sigma_fn = lambda t: t.neg().exp()
|
65 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
66 |
+
+
|
67 |
+
+ # MacOS fix
|
68 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
69 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
70 |
+
+ else:
|
71 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
72 |
+
|
73 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
74 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
75 |
+
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
76 |
+
extra_args = {} if extra_args is None else extra_args
|
77 |
+
s_in = x.new_ones([x.shape[0]])
|
78 |
+
sigma_fn = lambda t: t.neg().exp()
|
79 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
80 |
+
+
|
81 |
+
+ # MacOS fix
|
82 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
83 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
84 |
+
+ else:
|
85 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
86 |
+
|
87 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
88 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
89 |
+
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
90 |
+
extra_args = {} if extra_args is None else extra_args
|
91 |
+
s_in = x.new_ones([x.shape[0]])
|
92 |
+
sigma_fn = lambda t: t.neg().exp()
|
93 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
94 |
+
+
|
95 |
+
+ # MacOS fix
|
96 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
97 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
98 |
+
+ else:
|
99 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
100 |
+
+
|
101 |
+
old_denoised = None
|
102 |
+
|
103 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
104 |
+
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
105 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
106 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
107 |
+
h = t_next - t
|
108 |
+
+
|
109 |
+
+ t_min = min(sigma_fn(t_next), sigma_fn(t))
|
110 |
+
+ t_max = max(sigma_fn(t_next), sigma_fn(t))
|
111 |
+
+
|
112 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
113 |
+
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
114 |
+
+ x = (t_min / t_max) * x - (-h).expm1() * denoised
|
115 |
+
else:
|
116 |
+
h_last = t - t_fn(sigmas[i - 1])
|
117 |
+
- r = h_last / h
|
118 |
+
+
|
119 |
+
+ h_min = min(h_last, h)
|
120 |
+
+ h_max = max(h_last, h)
|
121 |
+
+ r = h_max / h_min
|
122 |
+
+
|
123 |
+
+ h_d = (h_max + h_min) / 2
|
124 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
125 |
+
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
126 |
+
+ x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
|
127 |
+
+
|
128 |
+
old_denoised = denoised
|
129 |
+
return x
|
130 |
+
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
|
131 |
+
index 9afedb9..ce6014b 100644
|
132 |
+
--- a/k_diffusion/utils.py
|
133 |
+
+++ b/k_diffusion/utils.py
|
134 |
+
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
|
135 |
+
dims_to_append = target_dims - x.ndim
|
136 |
+
if dims_to_append < 0:
|
137 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
138 |
+
- return x[(...,) + (None,) * dims_to_append]
|
139 |
+
+ expanded = x[(...,) + (None,) * dims_to_append]
|
140 |
+
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
141 |
+
+ # https://github.com/pytorch/pytorch/issues/84364
|
142 |
+
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
143 |
+
|
144 |
+
|
145 |
+
def n_params(module):
|