|
import unittest |
|
|
|
import functools as ft |
|
import itertools as it |
|
|
|
from apex import amp |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from utils import common_init, HALF, FLOAT,\ |
|
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT |
|
|
|
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): |
|
for fn, typ in it.product(fns, expected.keys()): |
|
x = torch.randn(input_shape, dtype=typ).requires_grad_() |
|
y = fn(x) |
|
test_case.assertEqual(y.type(), expected[typ]) |
|
if test_backward: |
|
y.float().sum().backward() |
|
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ]) |
|
|
|
class TestBasicCasts(unittest.TestCase): |
|
def setUp(self): |
|
self.handle = amp.init(enabled=True) |
|
common_init(self) |
|
|
|
def tearDown(self): |
|
self.handle._deactivate() |
|
|
|
def test_linear_is_half(self): |
|
m = nn.Linear(self.h, self.h) |
|
f = ft.partial(F.linear, weight=m.weight, bias=m.bias) |
|
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h)) |
|
|
|
def test_conv2d_is_half(self): |
|
m = nn.Conv2d(self.c, self.c, self.k) |
|
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) |
|
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h)) |
|
|
|
def test_softmax_is_float(self): |
|
m = nn.Softmax(dim=1) |
|
f = ft.partial(F.softmax, dim=1) |
|
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h)) |
|
|
|
def test_group_norm_is_float(self): |
|
m = nn.GroupNorm(num_groups=4, num_channels=self.c) |
|
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h)) |
|
|
|
def test_mse_loss_is_float(self): |
|
shape = (self.b, self.h) |
|
target = torch.randn(shape) |
|
mod = nn.MSELoss() |
|
m = lambda x: mod(x, target) |
|
f = ft.partial(F.mse_loss, target=target) |
|
run_layer_test(self, [m], ALWAYS_FLOAT, shape) |
|
|
|
def test_relu_is_match(self): |
|
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h)) |
|
|
|
def test_batch_norm_is_match(self): |
|
m = nn.BatchNorm2d(num_features=self.c) |
|
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, |
|
weight=m.weight, bias=m.bias, training=True) |
|
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h)) |
|
|
|
|
|
m.eval() |
|
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, |
|
weight=m.weight, bias=m.bias, training=False) |
|
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), |
|
test_backward=False) |
|
|
|
class TestBannedMethods(unittest.TestCase): |
|
def setUp(self): |
|
self.handle = amp.init(enabled=True) |
|
common_init(self) |
|
|
|
def tearDown(self): |
|
self.handle._deactivate() |
|
|
|
def bce_common(self, assertion): |
|
shape = (self.b, self.h) |
|
target = torch.rand(shape) |
|
mod = nn.BCELoss() |
|
m = lambda x: mod(x, target) |
|
f = ft.partial(F.binary_cross_entropy, target=target) |
|
for fn in [m, f]: |
|
x = torch.rand(shape, dtype=torch.half) |
|
assertion(fn, x) |
|
|
|
def test_bce_raises_by_default(self): |
|
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x) |
|
self.bce_common(assertion) |
|
|
|
def test_bce_is_float_with_allow_banned(self): |
|
self.handle._deactivate() |
|
self.handle = amp.init(enabled=True, allow_banned=True) |
|
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) |
|
self.bce_common(assertion) |
|
|
|
class TestTensorCasts(unittest.TestCase): |
|
def setUp(self): |
|
self.handle = amp.init(enabled=True) |
|
common_init(self) |
|
|
|
def tearDown(self): |
|
self.handle._deactivate() |
|
|
|
def test_matmul_method_is_half(self): |
|
other = torch.randn(self.h, self.h) |
|
lhs = lambda x: x.matmul(other) |
|
rhs = lambda x: other.matmul(x) |
|
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) |
|
|
|
def test_matmul_op_is_half(self): |
|
other = torch.randn(self.h, self.h) |
|
lhs = lambda x: x @ other |
|
rhs = lambda x: other @ x |
|
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) |
|
|
|
def test_pow_method_is_float(self): |
|
fn = lambda x: x.pow(2.) |
|
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) |
|
|
|
def test_pow_op_is_float(self): |
|
fn = lambda x: x ** 2. |
|
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) |
|
|
|
def test_cpu_is_float(self): |
|
fn = lambda x: x.cpu() |
|
always_cpu_float = {torch.float: 'torch.FloatTensor', |
|
torch.half: 'torch.FloatTensor'} |
|
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h)) |
|
|
|
def test_sum_is_float(self): |
|
fn = lambda x: x.sum() |
|
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|