import torch import numpy as np def eps_from_v(z_0, z_t, sigma_t): return (z_t - z_0) / sigma_t def v_to_eps(v, t, x_t): """ function to compute the epsilon parametrization from the velocity field with x_t = t * x_0 + (1 - t) * x_1 with x_0 ~ N(0,I) """ eps_t = (1-t)*v + x_t return eps_t def clip_gradients(gradients, clip_value): grad_norm = gradients.norm(dim=2) mask = grad_norm > clip_value mask_exp = mask[:, :, None].expand_as(gradients) gradients[mask_exp] = ( gradients[mask_exp] / grad_norm[:, :, None].expand_as(gradients)[mask_exp] * clip_value ) return gradients class Adam: def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.t = 0 self.m = torch.zeros_like(parameters) self.v = torch.zeros_like(parameters) def step(self, params, grad) -> torch.Tensor: self.t += 1 self.m = self.beta1 * self.m + (1 - self.beta1) * grad self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2 m_hat = self.m / (1 - self.beta1**self.t) v_hat = self.v / (1 - self.beta2**self.t) # check if self.lr is callable if callable(self.lr): lr = self.lr(self.t - 1) else: lr = self.lr update = lr * m_hat / (torch.sqrt(v_hat) + self.epsilon) return params - update def make_cosine_decay_schedule( init_value: float, total_steps: int, alpha: float = 0.0, exponent: float = 1.0, warmup_steps=0, ): def schedule(count): if count < warmup_steps: # linear up return (init_value / warmup_steps) * count else: # half cosine down decay_steps = total_steps - warmup_steps count = min(count - warmup_steps, decay_steps) cosine_decay = 0.5 * (1 + np.cos(np.pi * count / decay_steps)) decayed = (1 - alpha) * cosine_decay**exponent + alpha return init_value * decayed return schedule def make_linear_decay_schedule( init_value: float, total_steps: int, final_value: float = 0, warmup_steps=0 ): def schedule(count): if count < warmup_steps: # linear up return (init_value / warmup_steps) * count else: # linear down decay_steps = total_steps - warmup_steps count = min(count - warmup_steps, decay_steps) return init_value - (init_value - final_value) * count / decay_steps return schedule def clip_norm_(tensor, max_norm): norm = tensor.norm() if norm > max_norm: tensor.mul_(max_norm / norm) def lr_warmup(step, warmup_steps): return min(1.0, step / max(warmup_steps, 1)) def linear_decay_lambda(step, warmup_steps, decay_steps, total_steps): if step < warmup_steps: min(1.0, step / max(warmup_steps, 1)) else: # linear down # decay_steps = total_steps - warmup_steps count = min(step - warmup_steps, decay_steps) return 1 - count / decay_steps