activation / tests /kernels /test_rms_norm.py
iamwyldecat's picture
feat(rms-norm): Impl fused RMSNorm
f3b99fb
import random
import pytest
import torch
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
# D = [512, 13824] # Arbitrary values for testing
NUM_TOKENS = [7, 13] # Arbitrary values for testing
D = [513] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rms_norm(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
weight = torch.randn(d, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
weight.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
torch_layer = torch.nn.RMSNorm(d, eps=eps, dtype=dtype)
torch_layer.weight = torch.nn.Parameter(weight_ref)
op = activation.ops.rms_norm
fn = activation.rms_norm
layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
layer.weight = torch.nn.Parameter(weight)
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
opcheck(op, (out, x, weight, eps))
out = fn(x, weight, eps)
mod_out = layer(x)
ref_out = torch_layer(x_ref)
assert_close(out, ref_out)
assert_close(mod_out, out, atol=0.0, rtol=0.0)
# test backward pass
out_grad = torch.randn_like(out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad)
mod_out.backward(out_grad)
assert_close(x.grad, x_ref.grad)
assert_close(layer.weight.grad, torch_layer.weight.grad, rtol=0.05)