import random from dataclasses import dataclass import pytest import torch import activation from .test_poly_norm import poly_norm from .utils import assert_close CASES = [ ((1, 2048, 8192), torch.bfloat16), ((1, 2048, 16384), torch.bfloat16), ((1, 16384, 8192), torch.bfloat16), ((1, 16384, 16384), torch.bfloat16), ] NUM_REP = 100 @dataclass class PerfResult: type: str # forward or backward shape: tuple dtype: torch.dtype kernel_time_ms: float torch_time_ms: float @property def speedup(self) -> float: return self.torch_time_ms / self.kernel_time_ms PERF_RESULTS: list[PerfResult] = [] @pytest.mark.parametrize("cases", CASES) @pytest.mark.perf def test_poly_norm( cases: tuple, do_plot: bool, ) -> None: random.seed(12345) torch.manual_seed(12345) torch.set_default_device("cuda") shape, dtype = cases x = torch.randn(shape, dtype=dtype, requires_grad=True) weight = torch.randn(3, dtype=dtype, requires_grad=True) bias = torch.randn(1, dtype=dtype, requires_grad=True) eps = 1e-05 x.retain_grad() weight.retain_grad() bias.retain_grad() # To separate gradient computation, clone the inputs x_ref = x.detach().clone().requires_grad_(True) weight_ref = weight.detach().clone().requires_grad_(True) bias_ref = bias.detach().clone().requires_grad_(True) torch_fn = poly_norm layer = activation.layers.PolyNorm(eps) layer.weight = torch.nn.Parameter(weight) layer.bias = torch.nn.Parameter(bias) # Check correctness mod_out = layer(x) ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps) assert_close(mod_out, ref_out) out_grad = torch.rand_like(ref_out) out_grad = out_grad / out_grad.norm() ref_out.backward(out_grad, retain_graph=True) mod_out.backward(out_grad, retain_graph=True) assert_close(x.grad, x_ref.grad) assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) def time_cuda(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for _ in range(5): fn() start.record() for _ in range(NUM_REP): fn() end.record() torch.cuda.synchronize() return start.elapsed_time(end) / NUM_REP kernel_time_ms = time_cuda(lambda: layer(x)) torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) PERF_RESULTS.append( PerfResult( type="forward", shape=shape, dtype=dtype, kernel_time_ms=kernel_time_ms, torch_time_ms=torch_fn_time, ) ) kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True)) torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True)) PERF_RESULTS.append( PerfResult( type="backward", shape=shape, dtype=dtype, kernel_time_ms=kernel_time_ms, torch_time_ms=torch_fn_time, ) )