|
"""Tests for c++ MLP""" |
|
from itertools import product |
|
from time import time |
|
|
|
import torch |
|
from torch import nn |
|
from torch.testing._internal import common_utils |
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests |
|
from torch.testing._internal.common_device_type import onlyCUDA |
|
|
|
from apex.mlp import MLP |
|
|
|
|
|
batch_size = 1024 |
|
mlp_sizes = [480, 1024, 1024, 512, 256, 1] |
|
num_iters = 10 |
|
|
|
|
|
|
|
class TestMLP(common_utils.TestCase): |
|
def test_creation(self): |
|
MLP(mlp_sizes) |
|
|
|
def test_numeric(self): |
|
mlp = MLP(mlp_sizes).cuda() |
|
|
|
mlp_layers = [] |
|
for i in range(mlp.num_layers): |
|
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) |
|
with torch.no_grad(): |
|
mlp.weights[i].copy_(linear.weight) |
|
mlp.biases[i].copy_(linear.bias) |
|
mlp_layers.append(linear) |
|
mlp_layers.append(nn.ReLU()) |
|
|
|
ref_mlp = nn.Sequential(*mlp_layers).cuda() |
|
|
|
test_input = ( |
|
torch.empty(batch_size, mlp_sizes[0], device="cuda") |
|
.uniform_(-1.0, 1.0) |
|
.requires_grad_() |
|
) |
|
ref_input = test_input.clone().detach().requires_grad_() |
|
mlp_out = mlp(test_input) |
|
ref_out = ref_mlp(ref_input) |
|
self.assertEqual(mlp_out, ref_out) |
|
|
|
|
|
mlp_out.mean().mul(10.0).backward() |
|
ref_out.mean().mul(10.0).backward() |
|
self.assertEqual(test_input.grad, ref_input.grad) |
|
self.assertEqual(mlp.biases[0].grad, ref_mlp[0].bias.grad) |
|
|
|
def _test_mlp_impl(self, use_activation: str, bias: bool, enable_autocast: bool): |
|
mlp = MLP(mlp_sizes, bias=bias, activation=use_activation).cuda() |
|
|
|
mlp_layers = [] |
|
for i in range(mlp.num_layers): |
|
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=bias) |
|
with torch.no_grad(): |
|
mlp.weights[i].copy_(linear.weight) |
|
if bias: |
|
mlp.biases[i].copy_(linear.bias) |
|
mlp_layers.append(linear) |
|
if use_activation == "relu": |
|
mlp_layers.append(nn.ReLU()) |
|
if use_activation == "sigmoid": |
|
mlp_layers.append(nn.Sigmoid()) |
|
|
|
ref_mlp = nn.Sequential(*mlp_layers).cuda() |
|
|
|
test_input = ( |
|
torch.empty(batch_size, mlp_sizes[0], device="cuda") |
|
.uniform_(-1.0, 1.0) |
|
.requires_grad_() |
|
) |
|
ref_input = test_input.clone().detach().requires_grad_() |
|
|
|
with torch.cuda.amp.autocast_mode.autocast(enabled=enable_autocast): |
|
mlp_out = mlp(test_input) |
|
mlp_loss = mlp_out.mean().mul(10.0) |
|
|
|
ref_out = ref_mlp(ref_input) |
|
ref_loss = ref_out.mean().mul(10.0) |
|
|
|
mlp_loss.backward() |
|
ref_loss.backward() |
|
if enable_autocast: |
|
self.assertEqual(mlp_out.dtype, torch.float16) |
|
self.assertEqual(ref_out.dtype, torch.float16) |
|
else: |
|
self.assertEqual(mlp_out, ref_out) |
|
self.assertEqual(test_input.grad, ref_input.grad) |
|
self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad) |
|
|
|
@common_utils.parametrize( |
|
"use_activation,bias", |
|
list(product(("none", "relu", "sigmoid"), (True, False))), |
|
) |
|
def test_mlp(self, use_activation: str, bias: bool): |
|
self._test_mlp_impl(use_activation, bias, enable_autocast=False) |
|
|
|
@common_utils.parametrize( |
|
"use_activation,bias", |
|
list(product(("none", "relu", "sigmoid"), (True, False))), |
|
) |
|
def test_mlp_autocast_fp16(self, use_activation: str, bias: bool): |
|
self._test_mlp_impl(use_activation, bias, enable_autocast=True) |
|
|
|
def test_no_grad(self): |
|
mlp = MLP(mlp_sizes).cuda() |
|
|
|
mlp_layers = [] |
|
for i in range(mlp.num_layers): |
|
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) |
|
with torch.no_grad(): |
|
mlp.weights[i].copy_(linear.weight) |
|
mlp.biases[i].copy_(linear.bias) |
|
mlp_layers.append(linear) |
|
mlp_layers.append(nn.ReLU(inplace=True)) |
|
|
|
ref_mlp = nn.Sequential(*mlp_layers).cuda() |
|
|
|
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1.0, 1.0) |
|
ref_input = test_input.clone().detach() |
|
mlp_out = mlp(test_input) |
|
ref_out = ref_mlp(ref_input) |
|
self.assertEqual(mlp_out, ref_out) |
|
|
|
|
|
mlp_out.mean().mul(10.0).backward() |
|
ref_out.mean().mul(10.0).backward() |
|
self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad) |
|
|
|
def test_performance_half(self): |
|
mlp = MLP(mlp_sizes).cuda().half() |
|
|
|
mlp_layers = [] |
|
for i in range(mlp.num_layers): |
|
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) |
|
mlp.weights[i].data.copy_(linear.weight) |
|
mlp.biases[i].data.copy_(linear.bias) |
|
mlp_layers.append(linear) |
|
mlp_layers.append(nn.ReLU(inplace=True)) |
|
|
|
ref_mlp = nn.Sequential(*mlp_layers).cuda().half() |
|
|
|
test_input = ( |
|
torch.empty(batch_size, mlp_sizes[0], device="cuda", dtype=torch.half) |
|
.fill_(10.0) |
|
.requires_grad_() |
|
) |
|
ref_input = ( |
|
torch.empty(batch_size, mlp_sizes[0], device="cuda", dtype=torch.half) |
|
.fill_(10.0) |
|
.requires_grad_() |
|
) |
|
|
|
|
|
for _ in range(100): |
|
ref_out = ref_mlp(ref_input) |
|
ref_loss = ref_out.mean() |
|
ref_mlp.zero_grad() |
|
ref_loss.backward() |
|
mlp_out = mlp(test_input) |
|
test_loss = mlp_out.mean() |
|
mlp.zero_grad() |
|
test_loss.backward() |
|
|
|
torch.cuda.profiler.start() |
|
torch.cuda.synchronize() |
|
start_time = time() |
|
for _ in range(num_iters): |
|
ref_out = ref_mlp(ref_input) |
|
ref_loss = ref_out.mean() |
|
ref_mlp.zero_grad() |
|
ref_loss.backward() |
|
torch.cuda.synchronize() |
|
stop_time = time() |
|
ref_time = (stop_time - start_time) * 1000.0 / num_iters |
|
print(f"\nPytorch MLP time {ref_time:.4f} ms") |
|
|
|
torch.cuda.synchronize() |
|
start_time = time() |
|
for _ in range(num_iters): |
|
mlp_out = mlp(test_input) |
|
test_loss = mlp_out.mean() |
|
mlp.zero_grad() |
|
test_loss.backward() |
|
torch.cuda.synchronize() |
|
stop_time = time() |
|
actual_time = (stop_time - start_time) * 1000.0 / num_iters |
|
print(f"C++ MLP time {actual_time:.4f} ms") |
|
torch.cuda.profiler.stop() |
|
self.assertLessEqual( |
|
actual_time, |
|
ref_time, |
|
msg=f"Custom extension took {actual_time:.4f} while PyTorch took {ref_time:.4f}", |
|
) |
|
|
|
|
|
instantiate_device_type_tests(TestMLP, globals(), only_for=("cuda",)) |
|
|
|
|
|
if __name__ == "__main__": |
|
common_utils.run_tests() |
|
|