Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,712 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
"""
Euler ODE solver.
"""
from typing import Callable
import torch
from einops import rearrange
from torch.nn import functional as F
from models.dit_v2 import na
from ..types import PredictionType
from ..utils import expand_dims
from .base import Sampler, SamplerModelArgs
class EulerSampler(Sampler):
"""
The Euler method is the simplest ODE solver.
<https://en.wikipedia.org/wiki/Euler_method>
"""
def sample(
self,
x: torch.Tensor,
f: Callable[[SamplerModelArgs], torch.Tensor],
) -> torch.Tensor:
timesteps = self.timesteps.timesteps
progress = self.get_progress_bar()
i = 0
for t, s in zip(timesteps[:-1], timesteps[1:]):
pred = f(SamplerModelArgs(x, t, i))
x = self.step_to(pred, x, t, s)
i += 1
progress.update()
if self.return_endpoint:
t = timesteps[-1]
pred = f(SamplerModelArgs(x, t, i))
x = self.get_endpoint(pred, x, t)
progress.update()
return x
def step(
self,
pred: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Step to the next timestep.
"""
return self.step_to(pred, x_t, t, self.get_next_timestep(t))
def step_to(
self,
pred: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor,
s: torch.Tensor,
) -> torch.Tensor:
"""
Steps from x_t at timestep t to x_s at timestep s. Returns x_s.
"""
t = expand_dims(t, x_t.ndim)
s = expand_dims(s, x_t.ndim)
T = self.schedule.T
# Step from x_t to x_s.
pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T))
# Clamp x_s to x_0 and x_T if s is out of bound.
pred_x_s = pred_x_s.where(s >= 0, pred_x_0)
pred_x_s = pred_x_s.where(s <= T, pred_x_T)
return pred_x_s
|