# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. """ Euler ODE solver. """ from typing import Callable import torch from einops import rearrange from torch.nn import functional as F from models.dit_v2 import na from ..types import PredictionType from ..utils import expand_dims from .base import Sampler, SamplerModelArgs class EulerSampler(Sampler): """ The Euler method is the simplest ODE solver. """ def sample( self, x: torch.Tensor, f: Callable[[SamplerModelArgs], torch.Tensor], ) -> torch.Tensor: timesteps = self.timesteps.timesteps progress = self.get_progress_bar() i = 0 for t, s in zip(timesteps[:-1], timesteps[1:]): pred = f(SamplerModelArgs(x, t, i)) x = self.step_to(pred, x, t, s) i += 1 progress.update() if self.return_endpoint: t = timesteps[-1] pred = f(SamplerModelArgs(x, t, i)) x = self.get_endpoint(pred, x, t) progress.update() return x def step( self, pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor, ) -> torch.Tensor: """ Step to the next timestep. """ return self.step_to(pred, x_t, t, self.get_next_timestep(t)) def step_to( self, pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor, s: torch.Tensor, ) -> torch.Tensor: """ Steps from x_t at timestep t to x_s at timestep s. Returns x_s. """ t = expand_dims(t, x_t.ndim) s = expand_dims(s, x_t.ndim) T = self.schedule.T # Step from x_t to x_s. pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) # Clamp x_s to x_0 and x_T if s is out of bound. pred_x_s = pred_x_s.where(s >= 0, pred_x_0) pred_x_s = pred_x_s.where(s <= T, pred_x_T) return pred_x_s