netrunner-exe commited on
Commit
e47d403
·
1 Parent(s): 239ab44

Upload k_diffusion_dpmpp.diff

Browse files
Files changed (1) hide show
  1. k_diffusion_dpmpp.diff +145 -0
k_diffusion_dpmpp.diff ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/README.md b/README.md
2
+ index 4f7c92f..e386624 100644
3
+ --- a/README.md
4
+ +++ b/README.md
5
+ @@ -1,3 +1,12 @@
6
+ +# THIS IS A FORK
7
+ +
8
+ +Forked from https://github.com/crowsonkb/k-diffusion
9
+ +
10
+ +Changes:
11
+ +
12
+ +1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
13
+ +2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
14
+ +
15
+ # k-diffusion
16
+
17
+ An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
18
+ diff --git a/k_diffusion/external.py b/k_diffusion/external.py
19
+ index 79b51ce..b41d0eb 100644
20
+ --- a/k_diffusion/external.py
21
+ +++ b/k_diffusion/external.py
22
+ @@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
23
+
24
+ def t_to_sigma(self, t):
25
+ t = t.float()
26
+ - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
27
+ + low_idx = t.floor().long()
28
+ + high_idx = t.ceil().long()
29
+ + w = t - low_idx if t.device.type == 'mps' else t.frac()
30
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
31
+ return log_sigma.exp()
32
+
33
+ diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
34
+ index f050f88..9f859d4 100644
35
+ --- a/k_diffusion/sampling.py
36
+ +++ b/k_diffusion/sampling.py
37
+ @@ -16,7 +16,7 @@ def append_zero(x):
38
+
39
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
40
+ """Constructs the noise schedule of Karras et al. (2022)."""
41
+ - ramp = torch.linspace(0, 1, n)
42
+ + ramp = torch.linspace(0, 1, n, device=device)
43
+ min_inv_rho = sigma_min ** (1 / rho)
44
+ max_inv_rho = sigma_max ** (1 / rho)
45
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
46
+ @@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
47
+
48
+ for i in range(len(orders)):
49
+ eps_cache = {}
50
+ - t, t_next = ts[i], ts[i + 1]
51
+ +
52
+ + # MacOS fix
53
+ + if torch.backends.mps.is_available() and torch.backends.mps.is_built():
54
+ + t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
55
+ + else:
56
+ + t, t_next = ts[i], ts[i + 1]
57
+ +
58
+ if eta:
59
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
60
+ t_next_ = torch.minimum(t_end, self.t(sd))
61
+ @@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
62
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
63
+ s_in = x.new_ones([x.shape[0]])
64
+ sigma_fn = lambda t: t.neg().exp()
65
+ - t_fn = lambda sigma: sigma.log().neg()
66
+ +
67
+ + # MacOS fix
68
+ + if torch.backends.mps.is_available() and torch.backends.mps.is_built():
69
+ + t_fn = lambda sigma: sigma.detach().clone().log().neg()
70
+ + else:
71
+ + t_fn = lambda sigma: sigma.log().neg()
72
+
73
+ for i in trange(len(sigmas) - 1, disable=disable):
74
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
75
+ @@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
76
+ extra_args = {} if extra_args is None else extra_args
77
+ s_in = x.new_ones([x.shape[0]])
78
+ sigma_fn = lambda t: t.neg().exp()
79
+ - t_fn = lambda sigma: sigma.log().neg()
80
+ +
81
+ + # MacOS fix
82
+ + if torch.backends.mps.is_available() and torch.backends.mps.is_built():
83
+ + t_fn = lambda sigma: sigma.detach().clone().log().neg()
84
+ + else:
85
+ + t_fn = lambda sigma: sigma.log().neg()
86
+
87
+ for i in trange(len(sigmas) - 1, disable=disable):
88
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
89
+ @@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
90
+ extra_args = {} if extra_args is None else extra_args
91
+ s_in = x.new_ones([x.shape[0]])
92
+ sigma_fn = lambda t: t.neg().exp()
93
+ - t_fn = lambda sigma: sigma.log().neg()
94
+ +
95
+ + # MacOS fix
96
+ + if torch.backends.mps.is_available() and torch.backends.mps.is_built():
97
+ + t_fn = lambda sigma: sigma.detach().clone().log().neg()
98
+ + else:
99
+ + t_fn = lambda sigma: sigma.log().neg()
100
+ +
101
+ old_denoised = None
102
+
103
+ for i in trange(len(sigmas) - 1, disable=disable):
104
+ @@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
105
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
106
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
107
+ h = t_next - t
108
+ +
109
+ + t_min = min(sigma_fn(t_next), sigma_fn(t))
110
+ + t_max = max(sigma_fn(t_next), sigma_fn(t))
111
+ +
112
+ if old_denoised is None or sigmas[i + 1] == 0:
113
+ - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
114
+ + x = (t_min / t_max) * x - (-h).expm1() * denoised
115
+ else:
116
+ h_last = t - t_fn(sigmas[i - 1])
117
+ - r = h_last / h
118
+ +
119
+ + h_min = min(h_last, h)
120
+ + h_max = max(h_last, h)
121
+ + r = h_max / h_min
122
+ +
123
+ + h_d = (h_max + h_min) / 2
124
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
125
+ - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
126
+ + x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
127
+ +
128
+ old_denoised = denoised
129
+ return x
130
+ diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
131
+ index 9afedb9..ce6014b 100644
132
+ --- a/k_diffusion/utils.py
133
+ +++ b/k_diffusion/utils.py
134
+ @@ -42,7 +42,10 @@ def append_dims(x, target_dims):
135
+ dims_to_append = target_dims - x.ndim
136
+ if dims_to_append < 0:
137
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
138
+ - return x[(...,) + (None,) * dims_to_append]
139
+ + expanded = x[(...,) + (None,) * dims_to_append]
140
+ + # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
141
+ + # https://github.com/pytorch/pytorch/issues/84364
142
+ + return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
143
+
144
+
145
+ def n_params(module):