File size: 4,621 Bytes
c165cd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
import torch
@torch.jit.script
def erf(x):
return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2))
def matmul(a, b):
return (a[..., None] * b[..., None, :, :]).sum(dim=-2)
# B,3,4,1 B,1,4,3
# cause nan when fp16
# return torch.matmul(a, b)
def safe_trig_helper(x, fn, t=100 * torch.pi):
"""Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
return fn(torch.where(torch.abs(x) < t, x, x % t))
def safe_cos(x):
return safe_trig_helper(x, torch.cos)
def safe_sin(x):
return safe_trig_helper(x, torch.sin)
def safe_exp(x):
return torch.exp(x.clamp_max(88.))
def safe_exp_jvp(primals, tangents):
"""Override safe_exp()'s gradient so that it's large when inputs are large."""
x, = primals
x_dot, = tangents
exp_x = safe_exp(x)
exp_x_dot = exp_x * x_dot
return exp_x, exp_x_dot
def log_lerp(t, v0, v1):
"""Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
if v0 <= 0 or v1 <= 0:
raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
lv0 = np.log(v0)
lv1 = np.log(v1)
return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)
def learning_rate_decay(step,
lr_init,
lr_final,
max_steps,
lr_delay_steps=0,
lr_delay_mult=1):
"""Continuous learning rate decay function.
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
is log-linearly interpolated elsewhere (equivalent to exponential decay).
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
function of lr_delay_mult, such that the initial learning rate is
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
to the normal learning rate when steps>lr_delay_steps.
Args:
step: int, the current optimization step.
lr_init: float, the initial learning rate.
lr_final: float, the final learning rate.
max_steps: int, the number of steps during optimization.
lr_delay_steps: int, the number of steps to delay the full learning rate.
lr_delay_mult: float, the multiplier on the rate when delaying it.
Returns:
lr: the learning for current step 'step'.
"""
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
else:
delay_rate = 1.
return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)
def sorted_interp(x, xp, fp):
"""A TPU-friendly version of interp(), where xp and fp must be sorted."""
# Identify the location in `xp` that corresponds to each `x`.
# The final `True` index in `mask` is the start of the matching interval.
mask = x[..., None, :] >= xp[..., :, None]
def find_interval(x):
# Grab the value where `mask` switches from True to False, and vice versa.
# This approach takes advantage of the fact that `x` is sorted.
x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values
x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values
return x0, x1
fp0, fp1 = find_interval(fp)
xp0, xp1 = find_interval(xp)
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
ret = fp0 + offset * (fp1 - fp0)
return ret
def sorted_interp_quad(x, xp, fpdf, fcdf):
"""interp in quadratic"""
# Identify the location in `xp` that corresponds to each `x`.
# The final `True` index in `mask` is the start of the matching interval.
mask = x[..., None, :] >= xp[..., :, None]
def find_interval(x, return_idx=False):
# Grab the value where `mask` switches from True to False, and vice versa.
# This approach takes advantage of the fact that `x` is sorted.
x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2)
x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2)
if return_idx:
return x0, x1, x0_idx, x1_idx
return x0, x1
fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True)
fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1)
fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1)
xp0, xp1 = find_interval(xp)
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2
return ret
|