File size: 3,139 Bytes
d14fd4d
 
 
 
 
 
 
 
f3b99fb
d14fd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import random
from dataclasses import dataclass

import pytest
import torch

import activation

from .test_poly_norm import poly_norm
from .utils import assert_close

CASES = [
    ((1, 2048, 8192), torch.bfloat16),
    ((1, 2048, 16384), torch.bfloat16),
    ((1, 16384, 8192), torch.bfloat16),
    ((1, 16384, 16384), torch.bfloat16),
]
NUM_REP = 100


@dataclass
class PerfResult:
    type: str  # forward or backward
    shape: tuple
    dtype: torch.dtype
    kernel_time_ms: float
    torch_time_ms: float

    @property
    def speedup(self) -> float:
        return self.torch_time_ms / self.kernel_time_ms


PERF_RESULTS: list[PerfResult] = []


@pytest.mark.parametrize("cases", CASES)
@pytest.mark.perf
def test_poly_norm(
    cases: tuple,
    do_plot: bool,
) -> None:
    random.seed(12345)
    torch.manual_seed(12345)

    torch.set_default_device("cuda")

    shape, dtype = cases
    x = torch.randn(shape, dtype=dtype, requires_grad=True)
    weight = torch.randn(3, dtype=dtype, requires_grad=True)
    bias = torch.randn(1, dtype=dtype, requires_grad=True)
    eps = 1e-05

    x.retain_grad()
    weight.retain_grad()
    bias.retain_grad()
    # To separate gradient computation, clone the inputs

    x_ref = x.detach().clone().requires_grad_(True)
    weight_ref = weight.detach().clone().requires_grad_(True)
    bias_ref = bias.detach().clone().requires_grad_(True)

    torch_fn = poly_norm
    layer = activation.layers.PolyNorm(eps)
    layer.weight = torch.nn.Parameter(weight)
    layer.bias = torch.nn.Parameter(bias)

    # Check correctness
    mod_out = layer(x)
    ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
    assert_close(mod_out, ref_out)

    out_grad = torch.rand_like(ref_out)
    out_grad = out_grad / out_grad.norm()

    ref_out.backward(out_grad, retain_graph=True)
    mod_out.backward(out_grad, retain_graph=True)

    assert_close(x.grad, x_ref.grad)
    assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
    assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)

    def time_cuda(fn):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        for _ in range(5):
            fn()
        start.record()
        for _ in range(NUM_REP):
            fn()
        end.record()
        torch.cuda.synchronize()
        return start.elapsed_time(end) / NUM_REP

    kernel_time_ms = time_cuda(lambda: layer(x))
    torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))

    PERF_RESULTS.append(
        PerfResult(
            type="forward",
            shape=shape,
            dtype=dtype,
            kernel_time_ms=kernel_time_ms,
            torch_time_ms=torch_fn_time,
        )
    )

    kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True))
    torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True))

    PERF_RESULTS.append(
        PerfResult(
            type="backward",
            shape=shape,
            dtype=dtype,
            kernel_time_ms=kernel_time_ms,
            torch_time_ms=torch_fn_time,
        )
    )