Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
""" | |
Schedule base class. | |
""" | |
from abc import ABC, abstractmethod, abstractproperty | |
from typing import Tuple, Union | |
import torch | |
from ..types import PredictionType | |
from ..utils import expand_dims | |
class Schedule(ABC): | |
""" | |
Diffusion schedules are uniquely defined by T, A, B: | |
x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T] | |
Schedules can be continuous or discrete. | |
""" | |
def T(self) -> Union[int, float]: | |
""" | |
Maximum timestep inclusive. | |
Schedule is continuous if float, discrete if int. | |
""" | |
def A(self, t: torch.Tensor) -> torch.Tensor: | |
""" | |
Interpolation coefficient A. | |
Returns tensor with the same shape as t. | |
""" | |
def B(self, t: torch.Tensor) -> torch.Tensor: | |
""" | |
Interpolation coefficient B. | |
Returns tensor with the same shape as t. | |
""" | |
# ---------------------------------------------------- | |
def snr(self, t: torch.Tensor) -> torch.Tensor: | |
""" | |
Signal to noise ratio. | |
Returns tensor with the same shape as t. | |
""" | |
return (self.A(t) ** 2) / (self.B(t) ** 2) | |
def isnr(self, snr: torch.Tensor) -> torch.Tensor: | |
""" | |
Inverse signal to noise ratio. | |
Returns tensor with the same shape as snr. | |
Subclass may implement. | |
""" | |
raise NotImplementedError | |
# ---------------------------------------------------- | |
def is_continuous(self) -> bool: | |
""" | |
Whether the schedule is continuous. | |
""" | |
return isinstance(self.T, float) | |
def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
""" | |
Diffusion forward function. | |
""" | |
t = expand_dims(t, x_0.ndim) | |
return self.A(t) * x_0 + self.B(t) * x_T | |
def convert_from_pred( | |
self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Convert from prediction. Return predicted x_0 and x_T. | |
""" | |
t = expand_dims(t, x_t.ndim) | |
A_t = self.A(t) | |
B_t = self.B(t) | |
if pred_type == PredictionType.x_T: | |
pred_x_T = pred | |
pred_x_0 = (x_t - B_t * pred_x_T) / A_t | |
elif pred_type == PredictionType.x_0: | |
pred_x_0 = pred | |
pred_x_T = (x_t - A_t * pred_x_0) / B_t | |
elif pred_type == PredictionType.v_cos: | |
pred_x_0 = A_t * x_t - B_t * pred | |
pred_x_T = A_t * pred + B_t * x_t | |
elif pred_type == PredictionType.v_lerp: | |
pred_x_0 = (x_t - B_t * pred) / (A_t + B_t) | |
pred_x_T = (x_t + A_t * pred) / (A_t + B_t) | |
else: | |
raise NotImplementedError | |
return pred_x_0, pred_x_T | |
def convert_to_pred( | |
self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType | |
) -> torch.FloatTensor: | |
""" | |
Convert to prediction target given x_0 and x_T. | |
""" | |
if pred_type == PredictionType.x_T: | |
return x_T | |
if pred_type == PredictionType.x_0: | |
return x_0 | |
if pred_type == PredictionType.v_cos: | |
t = expand_dims(t, x_0.ndim) | |
return self.A(t) * x_T - self.B(t) * x_0 | |
if pred_type == PredictionType.v_lerp: | |
return x_T - x_0 | |
raise NotImplementedError | |