maisi_ct_generative / scripts /rectified_flow.py
Venn's picture
Upload maisi_ct_generative version 1.0.0
02aa18d verified
from typing import Any
import numpy as np
import torch
from monai.networks.schedulers import Scheduler
from torch.distributions import LogisticNormal
# code modified from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
def timestep_transform(
t, input_img_size, base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
):
t = t / num_train_timesteps
ratio_space = (input_img_size / base_img_size).pow(1.0 / spatial_dim)
ratio = ratio_space * scale
new_t = ratio * t / (1 + (ratio - 1) * t)
new_t = new_t * num_train_timesteps
return new_t
class RFlowScheduler(Scheduler):
def __init__(
self,
num_train_timesteps=1000,
num_inference_steps=10,
use_discrete_timesteps=False,
sample_method="uniform",
loc=0.0,
scale=1.0,
use_timestep_transform=False,
transform_scale=1.0,
steps_offset: int = 0,
):
self.num_train_timesteps = num_train_timesteps
self.num_inference_steps = num_inference_steps
self.use_discrete_timesteps = use_discrete_timesteps
# sample method
assert sample_method in ["uniform", "logit-normal"]
# assert (
# sample_method == "uniform" or not use_discrete_timesteps
# ), "Only uniform sampling is supported for discrete timesteps"
self.sample_method = sample_method
if sample_method == "logit-normal":
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
# timestep transform
self.use_timestep_transform = use_timestep_transform
self.transform_scale = transform_scale
self.steps_offset = steps_offset
def add_noise(
self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
"""
compatible with diffusers add_noise()
"""
timepoints = timesteps.float() / self.num_train_timesteps
timepoints = 1 - timepoints # [1,1/1000]
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
# expand timepoint to noise shape
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
return timepoints * original_samples + (1 - timepoints) * noise
def set_timesteps(
self,
num_inference_steps: int,
device: str | torch.device | None = None,
input_img_size: int | None = None,
base_img_size: int = 32 * 32 * 32,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
device: target device to put the data.
input_img_size: int, H*W*D of the image, used with self.use_timestep_transform is True.
base_img_size: int, reference H*W*D size, used with self.use_timestep_transform is True.
"""
if num_inference_steps > self.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
# prepare timesteps
timesteps = [
(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
]
if self.use_discrete_timesteps:
timesteps = [int(round(t)) for t in timesteps]
if self.use_timestep_transform:
timesteps = [
timestep_transform(
t,
input_img_size=input_img_size,
base_img_size=base_img_size,
num_train_timesteps=self.num_train_timesteps,
)
for t in timesteps
]
timesteps = np.array(timesteps).astype(np.float16)
if self.use_discrete_timesteps:
timesteps = timesteps.astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.steps_offset
print(self.timesteps)
def sample_timesteps(self, x_start):
if self.sample_method == "uniform":
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
elif self.sample_method == "logit-normal":
t = self.sample_t(x_start) * self.num_train_timesteps
if self.use_discrete_timesteps:
t = t.long()
if self.use_timestep_transform:
input_img_size = torch.prod(torch.tensor(x_start.shape[-3:]))
base_img_size = 32 * 32 * 32
t = timestep_transform(
t,
input_img_size=input_img_size,
base_img_size=base_img_size,
num_train_timesteps=self.num_train_timesteps,
)
return t
def step(
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None
) -> tuple[torch.Tensor, Any]:
"""
Predict the sample at the previous timestep. Core function to propagate the diffusion
process from the learned model outputs.
Args:
model_output: direct output from learned diffusion model.
timestep: current discrete timestep in the diffusion chain.
sample: current instance of sample being created by diffusion process.
Returns:
pred_prev_sample: Predicted previous sample
None
"""
v_pred = model_output
if next_timestep is None:
dt = 1.0 / self.num_inference_steps
else:
dt = timestep - next_timestep
dt = dt / self.num_train_timesteps
z = sample + v_pred * dt
return z, None