"""Kernel test utils""" import unittest from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from torch._prims_common import TensorLikeType from .allclose_default import get_default_atol, get_default_rtol # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( "test_schema", "test_autograd_registration", "test_faketensor", ) ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( "test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic", ) def assert_close( a: TensorLikeType, b: TensorLikeType, atol: float | None = None, rtol: float | None = None, ) -> None: atol = atol if atol is not None else get_default_atol(a) rtol = rtol if rtol is not None else get_default_rtol(a) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) # Copied/modified from torch._refs.__init__.py def fp8_allclose( a: TensorLikeType, b: TensorLikeType, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> bool: """ Reference implementation of torch.allclose """ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) return bool( torch.all( torch.isclose( a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan ) ).item() ) # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. def opcheck( op: Union[ torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef, ], args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True, ) -> Dict[str, str]: with unittest.mock.patch("torch.allclose", new=fp8_allclose): return ( torch.library.opcheck( op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception ) if cond else {} )