File size: 9,376 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Tuple

import torch

from cosmos_predict1.diffusion.functional.batch_ops import batch_mul


def phi1(t: torch.Tensor) -> torch.Tensor:
    """
    Compute the first order phi function: (exp(t) - 1) / t.

    Args:
        t: Input tensor.

    Returns:
        Tensor: Result of phi1 function.
    """
    input_dtype = t.dtype
    t = t.to(dtype=torch.float64)
    return (torch.expm1(t) / t).to(dtype=input_dtype)


def phi2(t: torch.Tensor) -> torch.Tensor:
    """
    Compute the second order phi function: (phi1(t) - 1) / t.

    Args:
        t: Input tensor.

    Returns:
        Tensor: Result of phi2 function.
    """
    input_dtype = t.dtype
    t = t.to(dtype=torch.float64)
    return ((phi1(t) - 1.0) / t).to(dtype=input_dtype)


def res_x0_rk2_step(
    x_s: torch.Tensor,
    t: torch.Tensor,
    s: torch.Tensor,
    x0_s: torch.Tensor,
    s1: torch.Tensor,
    x0_s1: torch.Tensor,
) -> torch.Tensor:
    """
    Perform a residual-based 2nd order Runge-Kutta step.

    Args:
        x_s: Current state tensor.
        t: Target time tensor.
        s: Current time tensor.
        x0_s: Prediction at current time.
        s1: Intermediate time tensor.
        x0_s1: Prediction at intermediate time.

    Returns:
        Tensor: Updated state tensor.

    Raises:
        AssertionError: If step size is too small.
    """
    s = -torch.log(s)
    t = -torch.log(t)
    m = -torch.log(s1)

    dt = t - s
    assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
    assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"

    c2 = (m - s) / dt
    phi1_val, phi2_val = phi1(-dt), phi2(-dt)

    # Handle edge case where t = s = m
    b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
    b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)

    return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1))


def reg_x0_euler_step(
    x_s: torch.Tensor,
    s: torch.Tensor,
    t: torch.Tensor,
    x0_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a regularized Euler step based on x0 prediction.

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_s: Prediction at current time.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and current prediction.
    """
    coef_x0 = (s - t) / s
    coef_xs = t / s
    return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s


def reg_eps_euler_step(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a regularized Euler step based on epsilon prediction.

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        eps_s: Epsilon prediction at current time.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction.
    """
    return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s)


def rk1_euler(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a first-order Runge-Kutta (Euler) step.

    Recommended for diffusion models with guidance or model undertrained
    Usually more stable at the cost of a bit slower convergence.

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
    """
    x0_s = x0_fn(x_s, s)
    return reg_x0_euler_step(x_s, s, t, x0_s)


def rk2_mid_stable(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a stable second-order Runge-Kutta (midpoint) step.

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
    """
    s1 = torch.sqrt(s * t)
    x_s1, _ = rk1_euler(x_s, s, s1, x0_fn)

    x0_s1 = x0_fn(x_s1, s1)
    return reg_x0_euler_step(x_s, s, t, x0_s1)


def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a second-order Runge-Kutta (midpoint) step.

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
    """
    s1 = torch.sqrt(s * t)
    x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn)

    x0_s1 = x0_fn(x_s1, s1)

    return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1


def rk_2heun_naive(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a naive second-order Runge-Kutta (Heun's method) step.
    Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis
    Recommended for diffusion models without guidance and relative large NFE

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and current state.
    """
    x_t, x0_s = rk1_euler(x_s, s, t, x0_fn)
    eps_s = batch_mul(1.0 / s, x_t - x0_s)
    x0_t = x0_fn(x_t, t)
    eps_t = batch_mul(1.0 / t, x_t - x0_t)

    avg_eps = (eps_s + eps_t) / 2

    return reg_eps_euler_step(x_s, s, t, avg_eps)


def rk_2heun_edm(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a naive second-order Runge-Kutta (Heun's method) step.
    Impl based no EDM second order Heun method

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and current state.
    """
    x_t, x0_s = rk1_euler(x_s, s, t, x0_fn)
    x0_t = x0_fn(x_t, t)

    avg_x0 = (x0_s + x0_t) / 2

    return reg_x0_euler_step(x_s, s, t, avg_x0)


def rk_3kutta_naive(
    x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a naive third-order Runge-Kutta step.
    Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis
    Recommended for diffusion models without guidance and relative large NFE

    Args:
        x_s: Current state tensor.
        s: Current time tensor.
        t: Target time tensor.
        x0_fn: Function to compute x0 prediction.

    Returns:
        Tuple[Tensor, Tensor]: Updated state tensor and current state.
    """
    c2, c3 = 0.5, 1.0
    a31, a32 = -1.0, 2.0
    b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6

    delta = t - s

    s1 = c2 * delta + s
    s2 = c3 * delta + s
    x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn)
    eps_s = batch_mul(1.0 / s, x_s - x0_s)
    x0_s1 = x0_fn(x_s1, s1)
    eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1)

    _eps = a31 * eps_s + a32 * eps_s1
    x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps)

    x0_s2 = x0_fn(x_s2, s2)
    eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2)

    avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2
    return reg_eps_euler_step(x_s, s, t, avg_eps)


# key : order + name
RK_FNs = {
    "1euler": rk1_euler,
    "2mid": rk2_mid,
    "2mid_stable": rk2_mid_stable,
    "2heun_edm": rk_2heun_edm,
    "2heun_naive": rk_2heun_naive,
    "3kutta_naive": rk_3kutta_naive,
}


def get_runge_kutta_fn(name: str) -> Callable:
    """
    Get the specified Runge-Kutta function.

    Args:
        name: Name of the Runge-Kutta method.

    Returns:
        Callable: The specified Runge-Kutta function.

    Raises:
        RuntimeError: If the specified method is not supported.
    """
    if name in RK_FNs:
        return RK_FNs[name]
    methods = "\n\t".join(RK_FNs.keys())
    raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}")


def is_runge_kutta_fn_supported(name: str) -> bool:
    """
    Check if the specified Runge-Kutta function is supported.

    Args:
        name: Name of the Runge-Kutta method.

    Returns:
        bool: True if the method is supported, False otherwise.
    """
    return name in RK_FNs