File size: 3,207 Bytes
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import numpy as np


def eps_from_v(z_0, z_t, sigma_t):
    return (z_t - z_0) / sigma_t

def v_to_eps(v, t, x_t):
    """
    function to compute the epsilon parametrization from the velocity field
    with x_t = t * x_0 + (1 - t) * x_1 with x_0 ~ N(0,I)
    """
    eps_t = (1-t)*v + x_t
    return eps_t


def clip_gradients(gradients, clip_value):
    grad_norm = gradients.norm(dim=2)
    mask = grad_norm > clip_value
    mask_exp = mask[:, :, None].expand_as(gradients)
    gradients[mask_exp] = (
        gradients[mask_exp]
        / grad_norm[:, :, None].expand_as(gradients)[mask_exp]
        * clip_value
    )
    return gradients


class Adam:

    def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.t = 0
        self.m = torch.zeros_like(parameters)
        self.v = torch.zeros_like(parameters)

    def step(self, params, grad) -> torch.Tensor:
        self.t += 1

        self.m = self.beta1 * self.m + (1 - self.beta1) * grad
        self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2

        m_hat = self.m / (1 - self.beta1**self.t)
        v_hat = self.v / (1 - self.beta2**self.t)

        # check if self.lr is callable
        if callable(self.lr):
            lr = self.lr(self.t - 1)
        else:
            lr = self.lr
        update = lr * m_hat / (torch.sqrt(v_hat) + self.epsilon)

        return params - update


def make_cosine_decay_schedule(
    init_value: float,
    total_steps: int,
    alpha: float = 0.0,
    exponent: float = 1.0,
    warmup_steps=0,
):
    def schedule(count):
        if count < warmup_steps:
            # linear up
            return (init_value / warmup_steps) * count
        else:
            # half cosine down
            decay_steps = total_steps - warmup_steps
            count = min(count - warmup_steps, decay_steps)
            cosine_decay = 0.5 * (1 + np.cos(np.pi * count / decay_steps))
            decayed = (1 - alpha) * cosine_decay**exponent + alpha
            return init_value * decayed

    return schedule


def make_linear_decay_schedule(
    init_value: float, total_steps: int, final_value: float = 0, warmup_steps=0
):
    def schedule(count):
        if count < warmup_steps:
            # linear up
            return (init_value / warmup_steps) * count
        else:
            # linear down
            decay_steps = total_steps - warmup_steps
            count = min(count - warmup_steps, decay_steps)
            return init_value - (init_value - final_value) * count / decay_steps

    return schedule


def clip_norm_(tensor, max_norm):
    norm = tensor.norm()
    if norm > max_norm:
        tensor.mul_(max_norm / norm)


def lr_warmup(step, warmup_steps):
    return min(1.0, step / max(warmup_steps, 1))


def linear_decay_lambda(step, warmup_steps, decay_steps, total_steps):
    if step < warmup_steps:
        min(1.0, step / max(warmup_steps, 1))
    else:
        # linear down
        # decay_steps = total_steps - warmup_steps
        count = min(step - warmup_steps, decay_steps)
        return 1 - count / decay_steps