activation / tests /kernels /test_poly_norm_perf.py
iamwyldecat's picture
feat(rms-norm): Impl fused RMSNorm
f3b99fb
import random
from dataclasses import dataclass
import pytest
import torch
import activation
from .test_poly_norm import poly_norm
from .utils import assert_close
CASES = [
((1, 2048, 8192), torch.bfloat16),
((1, 2048, 16384), torch.bfloat16),
((1, 16384, 8192), torch.bfloat16),
((1, 16384, 16384), torch.bfloat16),
]
NUM_REP = 100
@dataclass
class PerfResult:
type: str # forward or backward
shape: tuple
dtype: torch.dtype
kernel_time_ms: float
torch_time_ms: float
@property
def speedup(self) -> float:
return self.torch_time_ms / self.kernel_time_ms
PERF_RESULTS: list[PerfResult] = []
@pytest.mark.parametrize("cases", CASES)
@pytest.mark.perf
def test_poly_norm(
cases: tuple,
do_plot: bool,
) -> None:
random.seed(12345)
torch.manual_seed(12345)
torch.set_default_device("cuda")
shape, dtype = cases
x = torch.randn(shape, dtype=dtype, requires_grad=True)
weight = torch.randn(3, dtype=dtype, requires_grad=True)
bias = torch.randn(1, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
weight.retain_grad()
bias.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
bias_ref = bias.detach().clone().requires_grad_(True)
torch_fn = poly_norm
layer = activation.layers.PolyNorm(eps)
layer.weight = torch.nn.Parameter(weight)
layer.bias = torch.nn.Parameter(bias)
# Check correctness
mod_out = layer(x)
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
assert_close(mod_out, ref_out)
out_grad = torch.rand_like(ref_out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad, retain_graph=True)
mod_out.backward(out_grad, retain_graph=True)
assert_close(x.grad, x_ref.grad)
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
def time_cuda(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(5):
fn()
start.record()
for _ in range(NUM_REP):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / NUM_REP
kernel_time_ms = time_cuda(lambda: layer(x))
torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
PERF_RESULTS.append(
PerfResult(
type="forward",
shape=shape,
dtype=dtype,
kernel_time_ms=kernel_time_ms,
torch_time_ms=torch_fn_time,
)
)
kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True))
torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True))
PERF_RESULTS.append(
PerfResult(
type="backward",
shape=shape,
dtype=dtype,
kernel_time_ms=kernel_time_ms,
torch_time_ms=torch_fn_time,
)
)