|
import random |
|
from dataclasses import dataclass |
|
|
|
import pytest |
|
import torch |
|
|
|
import activation |
|
|
|
from .test_poly_norm import poly_norm |
|
from .utils import assert_close |
|
|
|
CASES = [ |
|
((1, 2048, 8192), torch.bfloat16), |
|
((1, 2048, 16384), torch.bfloat16), |
|
((1, 16384, 8192), torch.bfloat16), |
|
((1, 16384, 16384), torch.bfloat16), |
|
] |
|
NUM_REP = 100 |
|
|
|
|
|
@dataclass |
|
class PerfResult: |
|
type: str |
|
shape: tuple |
|
dtype: torch.dtype |
|
kernel_time_ms: float |
|
torch_time_ms: float |
|
|
|
@property |
|
def speedup(self) -> float: |
|
return self.torch_time_ms / self.kernel_time_ms |
|
|
|
|
|
PERF_RESULTS: list[PerfResult] = [] |
|
|
|
|
|
@pytest.mark.parametrize("cases", CASES) |
|
@pytest.mark.perf |
|
def test_poly_norm( |
|
cases: tuple, |
|
do_plot: bool, |
|
) -> None: |
|
random.seed(12345) |
|
torch.manual_seed(12345) |
|
|
|
torch.set_default_device("cuda") |
|
|
|
shape, dtype = cases |
|
x = torch.randn(shape, dtype=dtype, requires_grad=True) |
|
weight = torch.randn(3, dtype=dtype, requires_grad=True) |
|
bias = torch.randn(1, dtype=dtype, requires_grad=True) |
|
eps = 1e-05 |
|
|
|
x.retain_grad() |
|
weight.retain_grad() |
|
bias.retain_grad() |
|
|
|
|
|
x_ref = x.detach().clone().requires_grad_(True) |
|
weight_ref = weight.detach().clone().requires_grad_(True) |
|
bias_ref = bias.detach().clone().requires_grad_(True) |
|
|
|
torch_fn = poly_norm |
|
layer = activation.layers.PolyNorm(eps) |
|
layer.weight = torch.nn.Parameter(weight) |
|
layer.bias = torch.nn.Parameter(bias) |
|
|
|
|
|
mod_out = layer(x) |
|
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps) |
|
assert_close(mod_out, ref_out) |
|
|
|
out_grad = torch.rand_like(ref_out) |
|
out_grad = out_grad / out_grad.norm() |
|
|
|
ref_out.backward(out_grad, retain_graph=True) |
|
mod_out.backward(out_grad, retain_graph=True) |
|
|
|
assert_close(x.grad, x_ref.grad) |
|
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) |
|
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) |
|
|
|
def time_cuda(fn): |
|
start = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
|
|
for _ in range(5): |
|
fn() |
|
start.record() |
|
for _ in range(NUM_REP): |
|
fn() |
|
end.record() |
|
torch.cuda.synchronize() |
|
return start.elapsed_time(end) / NUM_REP |
|
|
|
kernel_time_ms = time_cuda(lambda: layer(x)) |
|
torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) |
|
|
|
PERF_RESULTS.append( |
|
PerfResult( |
|
type="forward", |
|
shape=shape, |
|
dtype=dtype, |
|
kernel_time_ms=kernel_time_ms, |
|
torch_time_ms=torch_fn_time, |
|
) |
|
) |
|
|
|
kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True)) |
|
torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True)) |
|
|
|
PERF_RESULTS.append( |
|
PerfResult( |
|
type="backward", |
|
shape=shape, |
|
dtype=dtype, |
|
kernel_time_ms=kernel_time_ms, |
|
torch_time_ms=torch_fn_time, |
|
) |
|
) |
|
|