|
import random |
|
|
|
import pytest |
|
import torch |
|
|
|
import activation |
|
|
|
from .utils import assert_close, opcheck |
|
|
|
DTYPES = [torch.float, torch.bfloat16, torch.half] |
|
|
|
|
|
NUM_TOKENS = [7, 13] |
|
D = [513] |
|
SEEDS = [0] |
|
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
|
|
|
|
|
def norm(x, eps: float) -> torch.Tensor: |
|
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) |
|
|
|
|
|
def poly_norm( |
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float |
|
) -> torch.Tensor: |
|
x = x.float() |
|
return ( |
|
weight[0] * norm(x**3, eps) |
|
+ weight[1] * norm(x**2, eps) |
|
+ weight[2] * norm(x, eps) |
|
+ bias |
|
).to(weight.dtype) |
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
@pytest.mark.parametrize("d", D) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
def test_poly_norm( |
|
num_tokens: int, |
|
d: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
) -> None: |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.set_default_device(device) |
|
|
|
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) |
|
weight = torch.randn(3, dtype=dtype, requires_grad=True) |
|
bias = torch.randn(1, dtype=dtype, requires_grad=True) |
|
eps = 1e-05 |
|
|
|
x.retain_grad() |
|
weight.retain_grad() |
|
bias.retain_grad() |
|
|
|
|
|
x_ref = x.detach().clone().requires_grad_(True) |
|
weight_ref = weight.detach().clone().requires_grad_(True) |
|
bias_ref = bias.detach().clone().requires_grad_(True) |
|
|
|
torch_fn = poly_norm |
|
op = activation.ops.poly_norm |
|
fn = activation.poly_norm |
|
layer = activation.layers.PolyNorm(eps) |
|
layer.weight = torch.nn.Parameter(weight) |
|
layer.bias = torch.nn.Parameter(bias) |
|
|
|
out = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
|
opcheck(op, (out, x, weight, bias, eps)) |
|
|
|
out = fn(x, weight, bias, eps) |
|
mod_out = layer(x) |
|
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps) |
|
|
|
assert_close(out, ref_out) |
|
assert_close(mod_out, out, atol=0.0, rtol=0.0) |
|
|
|
|
|
out_grad = torch.randn_like(out) |
|
out_grad = out_grad / out_grad.norm() |
|
|
|
ref_out.backward(out_grad) |
|
mod_out.backward(out_grad) |
|
|
|
assert_close(x.grad, x_ref.grad) |
|
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) |
|
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) |
|
|