iamwyldecat's picture
feat(poly-norm): Add PolyNorm
44e9845
"""Kernel test utils"""
import unittest
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from torch._prims_common import TensorLikeType
from .allclose_default import get_default_atol, get_default_rtol
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
)
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
)
def assert_close(
a: TensorLikeType,
b: TensorLikeType,
atol: float | None = None,
rtol: float | None = None,
) -> None:
atol = atol if atol is not None else get_default_atol(a)
rtol = rtol if rtol is not None else get_default_rtol(a)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
return bool(
torch.all(
torch.isclose(
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
)
).item()
)
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(
op: Union[
torch._ops.OpOverload,
torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef,
],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True,
) -> Dict[str, str]:
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
return (
torch.library.opcheck(
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
)
if cond
else {}
)