IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
"""
Schedule base class.
"""
from abc import ABC, abstractmethod, abstractproperty
from typing import Tuple, Union
import torch
from ..types import PredictionType
from ..utils import expand_dims
class Schedule(ABC):
"""
Diffusion schedules are uniquely defined by T, A, B:
x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T]
Schedules can be continuous or discrete.
"""
@abstractproperty
def T(self) -> Union[int, float]:
"""
Maximum timestep inclusive.
Schedule is continuous if float, discrete if int.
"""
@abstractmethod
def A(self, t: torch.Tensor) -> torch.Tensor:
"""
Interpolation coefficient A.
Returns tensor with the same shape as t.
"""
@abstractmethod
def B(self, t: torch.Tensor) -> torch.Tensor:
"""
Interpolation coefficient B.
Returns tensor with the same shape as t.
"""
# ----------------------------------------------------
def snr(self, t: torch.Tensor) -> torch.Tensor:
"""
Signal to noise ratio.
Returns tensor with the same shape as t.
"""
return (self.A(t) ** 2) / (self.B(t) ** 2)
def isnr(self, snr: torch.Tensor) -> torch.Tensor:
"""
Inverse signal to noise ratio.
Returns tensor with the same shape as snr.
Subclass may implement.
"""
raise NotImplementedError
# ----------------------------------------------------
def is_continuous(self) -> bool:
"""
Whether the schedule is continuous.
"""
return isinstance(self.T, float)
def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Diffusion forward function.
"""
t = expand_dims(t, x_0.ndim)
return self.A(t) * x_0 + self.B(t) * x_T
def convert_from_pred(
self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert from prediction. Return predicted x_0 and x_T.
"""
t = expand_dims(t, x_t.ndim)
A_t = self.A(t)
B_t = self.B(t)
if pred_type == PredictionType.x_T:
pred_x_T = pred
pred_x_0 = (x_t - B_t * pred_x_T) / A_t
elif pred_type == PredictionType.x_0:
pred_x_0 = pred
pred_x_T = (x_t - A_t * pred_x_0) / B_t
elif pred_type == PredictionType.v_cos:
pred_x_0 = A_t * x_t - B_t * pred
pred_x_T = A_t * pred + B_t * x_t
elif pred_type == PredictionType.v_lerp:
pred_x_0 = (x_t - B_t * pred) / (A_t + B_t)
pred_x_T = (x_t + A_t * pred) / (A_t + B_t)
else:
raise NotImplementedError
return pred_x_0, pred_x_T
def convert_to_pred(
self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType
) -> torch.FloatTensor:
"""
Convert to prediction target given x_0 and x_T.
"""
if pred_type == PredictionType.x_T:
return x_T
if pred_type == PredictionType.x_0:
return x_0
if pred_type == PredictionType.v_cos:
t = expand_dims(t, x_0.ndim)
return self.A(t) * x_T - self.B(t) * x_0
if pred_type == PredictionType.v_lerp:
return x_T - x_0
raise NotImplementedError