Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch as th | |
| import numpy as np | |
| from functools import partial | |
| def expand_t_like_x(t, x): | |
| """Function to reshape time t to broadcastable dimension of x | |
| Args: | |
| t: [batch_dim,], time vector | |
| x: [batch_dim,...], data point | |
| """ | |
| dims = [1] * (len(x.size()) - 1) | |
| t = t.view(t.size(0), *dims) | |
| return t | |
| #################### Coupling Plans #################### | |
| class ICPlan: | |
| """Linear Coupling Plan""" | |
| def __init__(self, sigma=0.0): | |
| self.sigma = sigma | |
| def compute_alpha_t(self, t): | |
| """Compute the data coefficient along the path""" | |
| return t, 1 | |
| def compute_sigma_t(self, t): | |
| """Compute the noise coefficient along the path""" | |
| return 1 - t, -1 | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Compute the ratio between d_alpha and alpha""" | |
| return 1 / t | |
| def compute_drift(self, x, t): | |
| """We always output sde according to score parametrization; """ | |
| t = expand_t_like_x(t, x) | |
| alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| drift = alpha_ratio * x | |
| diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t | |
| return -drift, diffusion | |
| def compute_diffusion(self, x, t, form="constant", norm=1.0): | |
| """Compute the diffusion term of the SDE | |
| Args: | |
| x: [batch_dim, ...], data point | |
| t: [batch_dim,], time vector | |
| form: str, form of the diffusion term | |
| norm: float, norm of the diffusion term | |
| """ | |
| t = expand_t_like_x(t, x) | |
| choices = { | |
| "constant": norm, | |
| "SBDM": norm * self.compute_drift(x, t)[1], | |
| "sigma": norm * self.compute_sigma_t(t)[0], | |
| "linear": norm * (1 - t), | |
| "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, | |
| "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, | |
| } | |
| try: | |
| diffusion = choices[form] | |
| except KeyError: | |
| raise NotImplementedError(f"Diffusion form {form} not implemented") | |
| return diffusion | |
| def get_score_from_velocity(self, velocity, x, t): | |
| """Wrapper function: transfrom velocity prediction model to score | |
| Args: | |
| velocity: [batch_dim, ...] shaped tensor; velocity model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| mean = x | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t | |
| score = (reverse_alpha_ratio * velocity - mean) / var | |
| return score | |
| def get_noise_from_velocity(self, velocity, x, t): | |
| """Wrapper function: transfrom velocity prediction model to denoiser | |
| Args: | |
| velocity: [batch_dim, ...] shaped tensor; velocity model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| mean = x | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = reverse_alpha_ratio * d_sigma_t - sigma_t | |
| noise = (reverse_alpha_ratio * velocity - mean) / var | |
| return noise | |
| def get_velocity_from_score(self, score, x, t): | |
| """Wrapper function: transfrom score prediction model to velocity | |
| Args: | |
| score: [batch_dim, ...] shaped tensor; score model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| drift, var = self.compute_drift(x, t) | |
| velocity = var * score - drift | |
| return velocity | |
| def compute_mu_t(self, t, x0, x1): | |
| """Compute the mean of time-dependent density p_t""" | |
| t = expand_t_like_x(t, x1) | |
| alpha_t, _ = self.compute_alpha_t(t) | |
| sigma_t, _ = self.compute_sigma_t(t) | |
| return alpha_t * x1 + sigma_t * x0 | |
| def compute_xt(self, t, x0, x1): | |
| """Sample xt from time-dependent density p_t; rng is required""" | |
| xt = self.compute_mu_t(t, x0, x1) | |
| return xt | |
| def compute_ut(self, t, x0, x1, xt): | |
| """Compute the vector field corresponding to p_t""" | |
| t = expand_t_like_x(t, x1) | |
| _, d_alpha_t = self.compute_alpha_t(t) | |
| _, d_sigma_t = self.compute_sigma_t(t) | |
| return d_alpha_t * x1 + d_sigma_t * x0 | |
| def plan(self, t, x0, x1): | |
| xt = self.compute_xt(t, x0, x1) | |
| ut = self.compute_ut(t, x0, x1, xt) | |
| return t, xt, ut | |
| class VPCPlan(ICPlan): | |
| """class for VP path flow matching""" | |
| def __init__(self, sigma_min=0.1, sigma_max=20.0): | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min | |
| self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min | |
| def compute_alpha_t(self, t): | |
| """Compute coefficient of x1""" | |
| alpha_t = self.log_mean_coeff(t) | |
| alpha_t = th.exp(alpha_t) | |
| d_alpha_t = alpha_t * self.d_log_mean_coeff(t) | |
| return alpha_t, d_alpha_t | |
| def compute_sigma_t(self, t): | |
| """Compute coefficient of x0""" | |
| p_sigma_t = 2 * self.log_mean_coeff(t) | |
| sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) | |
| d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) | |
| return sigma_t, d_sigma_t | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
| return self.d_log_mean_coeff(t) | |
| def compute_drift(self, x, t): | |
| """Compute the drift term of the SDE""" | |
| t = expand_t_like_x(t, x) | |
| beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) | |
| return -0.5 * beta_t * x, beta_t / 2 | |
| class GVPCPlan(ICPlan): | |
| def __init__(self, sigma=0.0): | |
| super().__init__(sigma) | |
| def compute_alpha_t(self, t): | |
| """Compute coefficient of x1""" | |
| alpha_t = th.sin(t * np.pi / 2) | |
| d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) | |
| return alpha_t, d_alpha_t | |
| def compute_sigma_t(self, t): | |
| """Compute coefficient of x0""" | |
| sigma_t = th.cos(t * np.pi / 2) | |
| d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) | |
| return sigma_t, d_sigma_t | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
| return np.pi / (2 * th.tan(t * np.pi / 2)) | 
