Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Impl of multistep methods to solve the ODE in the diffusion model. | |
""" | |
from typing import Callable, List, Tuple | |
import torch | |
from cosmos_predict1.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step | |
def order2_fn( | |
x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
""" | |
impl the second order multistep method in https://arxiv.org/pdf/2308.02157 | |
Adams Bashforth approach! | |
""" | |
if x0_preds: | |
x0_s1, s1 = x0_preds[0] | |
x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) | |
else: | |
x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] | |
return x_t, [(x0_s, s)] | |
# key: method name, value: method function | |
# key: order + algorithm name | |
MULTISTEP_FNs = { | |
"2ab": order2_fn, | |
} | |
def get_multi_step_fn(name: str) -> Callable: | |
if name in MULTISTEP_FNs: | |
return MULTISTEP_FNs[name] | |
methods = "\n\t".join(MULTISTEP_FNs.keys()) | |
raise RuntimeError("Only support multistep method\n" + methods) | |
def is_multi_step_fn_supported(name: str) -> bool: | |
""" | |
Check if the multistep method is supported. | |
""" | |
return name in MULTISTEP_FNs | |