Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,126 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
"""
Schedule base class.
"""
from abc import ABC, abstractmethod, abstractproperty
from typing import Tuple, Union
import torch
from ..types import PredictionType
from ..utils import expand_dims
class Schedule(ABC):
"""
Diffusion schedules are uniquely defined by T, A, B:
x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T]
Schedules can be continuous or discrete.
"""
@abstractproperty
def T(self) -> Union[int, float]:
"""
Maximum timestep inclusive.
Schedule is continuous if float, discrete if int.
"""
@abstractmethod
def A(self, t: torch.Tensor) -> torch.Tensor:
"""
Interpolation coefficient A.
Returns tensor with the same shape as t.
"""
@abstractmethod
def B(self, t: torch.Tensor) -> torch.Tensor:
"""
Interpolation coefficient B.
Returns tensor with the same shape as t.
"""
# ----------------------------------------------------
def snr(self, t: torch.Tensor) -> torch.Tensor:
"""
Signal to noise ratio.
Returns tensor with the same shape as t.
"""
return (self.A(t) ** 2) / (self.B(t) ** 2)
def isnr(self, snr: torch.Tensor) -> torch.Tensor:
"""
Inverse signal to noise ratio.
Returns tensor with the same shape as snr.
Subclass may implement.
"""
raise NotImplementedError
# ----------------------------------------------------
def is_continuous(self) -> bool:
"""
Whether the schedule is continuous.
"""
return isinstance(self.T, float)
def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Diffusion forward function.
"""
t = expand_dims(t, x_0.ndim)
return self.A(t) * x_0 + self.B(t) * x_T
def convert_from_pred(
self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert from prediction. Return predicted x_0 and x_T.
"""
t = expand_dims(t, x_t.ndim)
A_t = self.A(t)
B_t = self.B(t)
if pred_type == PredictionType.x_T:
pred_x_T = pred
pred_x_0 = (x_t - B_t * pred_x_T) / A_t
elif pred_type == PredictionType.x_0:
pred_x_0 = pred
pred_x_T = (x_t - A_t * pred_x_0) / B_t
elif pred_type == PredictionType.v_cos:
pred_x_0 = A_t * x_t - B_t * pred
pred_x_T = A_t * pred + B_t * x_t
elif pred_type == PredictionType.v_lerp:
pred_x_0 = (x_t - B_t * pred) / (A_t + B_t)
pred_x_T = (x_t + A_t * pred) / (A_t + B_t)
else:
raise NotImplementedError
return pred_x_0, pred_x_T
def convert_to_pred(
self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType
) -> torch.FloatTensor:
"""
Convert to prediction target given x_0 and x_T.
"""
if pred_type == PredictionType.x_T:
return x_T
if pred_type == PredictionType.x_0:
return x_0
if pred_type == PredictionType.v_cos:
t = expand_dims(t, x_0.ndim)
return self.A(t) * x_T - self.B(t) * x_0
if pred_type == PredictionType.v_lerp:
return x_T - x_0
raise NotImplementedError
|