import torch from abc import abstractmethod, ABC try: from torchdyn.core import NeuralODE NEURALODE_INSTALLED = True except ImportError: NEURALODE_INSTALLED = False class SchedulerBase(ABC): def __init__(self) -> None: pass @abstractmethod def set_timesteps(self): pass @abstractmethod def step(self): pass @abstractmethod def add_noise(self): pass class StreamingFlowMatchingScheduler(SchedulerBase): def __init__(self, timesteps=1000, sigma_min=1e-4, ) -> None: super().__init__() self.sigma_min = sigma_min self.timesteps = timesteps self.t_min = 0 self.t_max = 1 - self.sigma_min self.neural_ode = None def set_timesteps(self, timesteps=15): self.timesteps = timesteps def step(self, xt, predicted_v): h = (self.t_max - self.t_min) / self.timesteps h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) xt = xt + h * predicted_v return xt def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): h = (self.t_max - self.t_min) / self.timesteps h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) if verbose: gt_v = x0 - xt for t in time_steps: predicted_v = ode_wrapper(t, xt) if verbose: dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v)) print("Time: {}, Distance: {}".format(t, dist)) xt = xt + h * predicted_v return xt def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): if not NEURALODE_INSTALLED: raise ImportError("NeuralODE is not installed, please install it first.") if self.neural_ode is None: self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min) eval_points, traj = self.neural_ode(xt, time_steps) return traj[-1] def add_noise(self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor,): ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系 t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise return x_noisy, ut