diff --git a/build/torch-universal/triton_kernels/__init__.py b/build/torch-universal/triton_kernels/__init__.py old mode 100644 new mode 100755 index 49694721c244cf71af767e95e7dac841d76dc74b..54a0de0988e7b39c0aeb387cbe7b0e7fdd899fe7 --- a/build/torch-universal/triton_kernels/__init__.py +++ b/build/torch-universal/triton_kernels/__init__.py @@ -1,3 +1,9 @@ +# Make sure to add this in the build folder as this won't build if we put that here +# docker run --rm \ +# -v $(pwd):/app \ +# -w /app \ +# ghcr.io/huggingface/kernel-builder:main + from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing __all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"] \ No newline at end of file diff --git a/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc b/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc old mode 100644 new mode 100755 index e54deb49b23d4ec893717db535096144a013e0cd..4f29e23e3ea99a89a2e34c5692caa31be96832d5 Binary files a/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc and b/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc differ diff --git a/build/torch-universal/triton_kernels/_ops.py b/build/torch-universal/triton_kernels/_ops.py old mode 100644 new mode 100755 index 76aa747cdbf5dda6325c8108bc189e2547607463..9f81446b1a3c1897e48bde2d8311def1c6ffd93a --- a/build/torch-universal/triton_kernels/_ops.py +++ b/build/torch-universal/triton_kernels/_ops.py @@ -1,8 +1,8 @@ import torch -ops = torch.ops._triton_kernels_10e8091_dirty +ops = torch.ops._triton_kernels_a32f88a_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_triton_kernels_10e8091_dirty::{op_name}" \ No newline at end of file + return f"_triton_kernels_a32f88a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch-universal/triton_kernels/compaction.py b/build/torch-universal/triton_kernels/compaction.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py b/build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/matmul_ogs.py b/build/torch-universal/triton_kernels/matmul_ogs.py old mode 100644 new mode 100755 index bde692a04d47d3e0c851d837c99c0473ca65e74b..e2dc00246646a9914bcb63a754788c99218540c5 --- a/build/torch-universal/triton_kernels/matmul_ogs.py +++ b/build/torch-universal/triton_kernels/matmul_ogs.py @@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias, betas = None, gammas = None, round_x = None, round_y = None, + device: str = "cuda", ): is_input_batched = x.ndim == 3 assert x.dtype.itemsize > 1 @@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias, else: idx = gather_indx.src_indx[lo:hi] // n_expts_act batch = i if is_input_batched else 0 - out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(), w[i].float()) if bias is not None: out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/_common.py b/build/torch-universal/triton_kernels/matmul_ogs_details/_common.py old mode 100644 new mode 100755 index 25755d3105ffe2b119f9fc65c96be858a987fd3c..fde71f8a936ae9c6cefc44bd83aae85f900c03a5 --- a/build/torch-universal/triton_kernels/matmul_ogs_details/_common.py +++ b/build/torch-universal/triton_kernels/matmul_ogs_details/_common.py @@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor # ----------------------------------------------------------------------------- # Utilities # ----------------------------------------------------------------------------- +try: + _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0] + _parts = _ver_str.split(".") + _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3]) +except Exception: + _ver_tuple = (0, 0, 0) +if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"): + _constexpr_function = triton.constexpr_function +else: + _constexpr_function = tl.constexpr_function -@tl.constexpr_function + + +@_constexpr_function def get_scaled_dot_format_string(dtype: tl.dtype): mapping = { tl.float16: "fp16", diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py b/build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py old mode 100644 new mode 100755 index c0369f71826cc2a41c4ea825370c54a94a034bf6..d1d2f3cd6add8427a8b9b23b2435264e5c45a8a4 --- a/build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py +++ b/build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py @@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ..target_info import cuda_capability_geq as _cuda_capability_geq from ..target_info import is_hip as _is_hip +from ._common import _constexpr_function # fmt: off -@tl.constexpr_function +@_constexpr_function def is_hip(): return _is_hip() -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(x, y): return _cuda_capability_geq(x, y) -@tl.constexpr_function +@_constexpr_function def log2(n): return len(bin(n)) - 3 -@tl.constexpr_function +@_constexpr_function def _permute_to_end_order(n: int, axis: int): """ Returns the order of the axes of a tensor to permute `axis` to the end. @@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args): return ret -@tl.constexpr_function +@_constexpr_function def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None): """ Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py old mode 100644 new mode 100755 index 952d0c01286de0dff7de87172ba5b87fe9827721..e393d7467bbef15af7428b5c2f42726c89e0e17c --- a/build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import ( compute_scale, ) from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE -from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(major, minor): return target_info.cuda_capability_geq(major, minor) -@tl.constexpr_function +@_constexpr_function def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype: if isinstance(tensor_or_desc, tl.tensor): return tensor_or_desc.dtype.element_ty diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py old mode 100644 new mode 100755 index 32caabc5d4db505159381d64801550b1abd4921c..cbe846db00430cc80325559898a022c8b2b2a6ae --- a/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import triton from ..target_info import get_cdna_version import torch -from .opt_flags_details import opt_flags_amd, opt_flags_nvidia +from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel @dataclass @@ -30,6 +30,83 @@ class OptFlags: raise ValueError("Not supported") +def make_default_opt_flags_intel( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + epilogue_effective_itemsize, + constraints, +): + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"] + assert not any([c not in constraints_supported for c in constraints]), constraints.keys() + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + # pid swizzling + group_m = 8 + xcd_swizzle = 1 + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 128 + else: + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + # block n + block_n = opt_flags_intel.compute_block_n(n) + # is_persistent + is_persistent = constraints.get("is_persistent", False) + # block k + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] + else: + block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config) + # split_k + if constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] + elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: + split_k = 1 + else: + estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n) + split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size) + + epilogue_subtile = constraints.get('epilogue_subtile', None) + if epilogue_subtile is None: + epilogue_subtile = 1 + + ret = OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=opt_flags_intel.compute_num_warps(block_m, block_n), + num_stages=constraints.get("num_stages", 2), + fused_scatter=constraints.get('fused_scatter', False), + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=None, + split_k=split_k, + is_persistent=is_persistent, + epilogue_subtile=epilogue_subtile, + arch=None, + target_kernel_kwargs=dict(), + idle_sms=0, + ) + # check constraints + assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + return ret + def make_default_opt_flags_amd( out_dtype, @@ -292,6 +369,8 @@ def make_opt_flags( enforce_bitwise_invariance, epilogue_effective_itemsize, _opt_flags_constraints] backend = triton.runtime.driver.active.get_current_target().backend + if backend == "xpu": + return make_default_opt_flags_intel(*args) if backend == "hip": return make_default_opt_flags_amd(*args) if backend == "cuda": diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py new file mode 100755 index 0000000000000000000000000000000000000000..0013c3ca666f4e7e4835562f620b1c1dd6ae9607 --- /dev/null +++ b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py @@ -0,0 +1,41 @@ +import torch +import triton + + +def compute_grid_size(routing_data, m, n, block_m, block_n): + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + grid_n = (n + block_n - 1) // block_n + return grid_m * grid_n + + +def compute_block_n(n: int): + # block_n: + return max(16, min(128, triton.next_power_of_2(n))) + + +def compute_block_k(k: int | None, is_persistent: bool, precision_config): + if k is not None: + block_k = max(32, min(128, triton.next_power_of_2(k))) + has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + if is_persistent and has_mx_weight_scale: + block_k = min(block_k, 128) + return block_k + + +def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: + device_props = torch.xpu.get_device_properties(0) + n_sms = device_props.gpu_subslice_count + split_k = n_sms // grid_size + if k is not None: + # avoid split_k for small k + num_block_k = triton.cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + split_k = max(split_k, 1) + return split_k + + +def compute_num_warps(block_m, block_n): + return max(block_m * block_n // 4096, 4) \ No newline at end of file diff --git a/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/numerics.py b/build/torch-universal/triton_kernels/numerics.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/numerics_details/__init__.py b/build/torch-universal/triton_kernels/numerics_details/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/numerics_details/flexpoint.py b/build/torch-universal/triton_kernels/numerics_details/flexpoint.py old mode 100644 new mode 100755 index daa78496d95ec9206cdf183a35dea57122dfbb79..cebe7f721f6945171602e771c951ceee55bcfd27 --- a/build/torch-universal/triton_kernels/numerics_details/flexpoint.py +++ b/build/torch-universal/triton_kernels/numerics_details/flexpoint.py @@ -1,5 +1,6 @@ from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 from .. import target_info +from ..matmul_ogs_details._common import _constexpr_function import triton import triton.language as tl @@ -52,7 +53,7 @@ def rcp_max_finite(dtype): tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(major, minor): return target_info.cuda_capability_geq(major, minor) diff --git a/build/torch-universal/triton_kernels/numerics_details/mxfp.py b/build/torch-universal/triton_kernels/numerics_details/mxfp.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/proton_opts.py b/build/torch-universal/triton_kernels/proton_opts.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py b/build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/routing.py b/build/torch-universal/triton_kernels/routing.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/routing_details/_expt_data.py b/build/torch-universal/triton_kernels/routing_details/_expt_data.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/routing_details/_routing_compute.py b/build/torch-universal/triton_kernels/routing_details/_routing_compute.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/specialize.py b/build/torch-universal/triton_kernels/specialize.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/swiglu.py b/build/torch-universal/triton_kernels/swiglu.py old mode 100644 new mode 100755 index 773070db92e9d03114c0e2b341c1995b969eba85..c0257260e54e0020c296b66b9941dda384848431 --- a/build/torch-universal/triton_kernels/swiglu.py +++ b/build/torch-universal/triton_kernels/swiglu.py @@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function): # optimization hyperparameters BLOCK_M, BLOCK_N = 32 // a.itemsize, 128 num_warps = 4 - kwargs = {'maxnreg': 64} if not target_info.is_hip() else {} + kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {} # launch semi-persistent kernel N_BLOCKS = triton.cdiv(N // 2, BLOCK_N) num_sms = target_info.num_sms() diff --git a/build/torch-universal/triton_kernels/swiglu_details/_swiglu.py b/build/torch-universal/triton_kernels/swiglu_details/_swiglu.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/target_info.py b/build/torch-universal/triton_kernels/target_info.py old mode 100644 new mode 100755 index 9beae7108e8d94e1bf04747e36c7d01a5650e5ca..c1912dc26f0b77f6172f0f6ec36597539b6cd224 --- a/build/torch-universal/triton_kernels/target_info.py +++ b/build/torch-universal/triton_kernels/target_info.py @@ -1,54 +1,70 @@ import torch import triton -cached_capabilities = {} +from .matmul_ogs_details._common import _constexpr_function +from triton.runtime import driver +def current_target(): + try: + active_driver = driver.active + except RuntimeError: + # If there is no active driver, return None + return None + return active_driver.get_current_target() +current_target.__triton_builtin__ = True + + +@_constexpr_function def is_cuda(): - if "is_cuda" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda" - return cached_capabilities["is_cuda"] + target = current_target() + return target is not None and target.backend == "cuda" +@_constexpr_function def is_hip(): - if "is_hip" not in cached_capabilities: - cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip) - return cached_capabilities["is_hip"] + target = current_target() + return target is not None and target.backend == "hip" + +@_constexpr_function +def is_xpu(): + target = current_target() + return target is not None and target.backend == "xpu" + +@_constexpr_function def is_hip_cdna3(): - if "is_hip_cdna3" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip' - and target.arch == 'gfx942') - return cached_capabilities["is_hip_cdna3"] + target = current_target() + return target is not None and target.arch == "gfx942" +@_constexpr_function def is_hip_cdna4(): - if "is_hip_cdna4" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip' - and target.arch == 'gfx950') - return cached_capabilities["is_hip_cdna4"] + target = current_target() + return target is not None and target.arch == "gfx950" +@_constexpr_function def cuda_capability_geq(major, minor=0): """ Determines whether we have compute capability >= (major, minor) and returns this as a constexpr boolean. This can be used for guarding inline asm implementations that require a certain compute capability. """ - if is_hip(): + """ + Determines whether we have compute capability >= (major, minor) and + returns this as a constexpr boolean. This can be used for guarding + inline asm implementations that require a certain compute capability. + """ + target = current_target() + if target is None or target.backend != "cuda": return False - if "cuda" not in cached_capabilities: - if torch.cuda.is_available(): - cached_capabilities["cuda"] = torch.cuda.get_device_capability() - else: - cached_capabilities["cuda"] = (0, 0) - return cached_capabilities["cuda"] >= (major, minor) + assert isinstance(target.arch, int) + return target.arch >= major * 10 + minor +@_constexpr_function def get_cdna_version(): """ Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently @@ -65,13 +81,18 @@ def get_cdna_version(): return -1 +@_constexpr_function def has_tma_gather(): return cuda_capability_geq(10, 0) +@_constexpr_function def has_native_mxfp(): return cuda_capability_geq(10, 0) def num_sms(): - return torch.cuda.get_device_properties(0).multi_processor_count + if is_cuda(): + return torch.cuda.get_device_properties(0).multi_processor_count + if is_xpu(): + return torch.xpu.get_device_properties(0).max_compute_units diff --git a/build/torch-universal/triton_kernels/tensor.py b/build/torch-universal/triton_kernels/tensor.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout.py b/build/torch-universal/triton_kernels/tensor_details/layout.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout_details/base.py b/build/torch-universal/triton_kernels/tensor_details/layout_details/base.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py b/build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py b/build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py b/build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/testing.py b/build/torch-universal/triton_kernels/testing.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/topk.py b/build/torch-universal/triton_kernels/topk.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/topk_details/__init__.py b/build/torch-universal/triton_kernels/topk_details/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/topk_details/_topk_backward.py b/build/torch-universal/triton_kernels/topk_details/_topk_backward.py old mode 100644 new mode 100755 diff --git a/build/torch-universal/triton_kernels/topk_details/_topk_forward.py b/build/torch-universal/triton_kernels/topk_details/_topk_forward.py old mode 100644 new mode 100755 diff --git a/result b/result new file mode 120000 index 0000000000000000000000000000000000000000..ea0ae38812b2d257c05983935c051c55bf57aac3 --- /dev/null +++ b/result @@ -0,0 +1 @@ +/nix/store/jkq2iihqbwik7pdn215w2ysgzhsgj3sc-torch-ext-bundle \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 6a8982884245200f2d7d5aa837b0846518bd53b3..c8458b95ddb44ebce272d80ffd0ba2f4274a83ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest - +import triton def pytest_addoption(parser): parser.addoption("--device", action="store", default="cuda") @@ -12,8 +12,19 @@ def device(request): @pytest.fixture def fresh_knobs(monkeypatch): + try: + _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0] + _parts = _ver_str.split(".") + _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3]) + except Exception: + _ver_tuple = (0, 0, 0) + from triton._internal_testing import _fresh_knobs_impl - fresh_function, reset_function = _fresh_knobs_impl(monkeypatch) + if _ver_tuple > (3, 4, 0): + fresh_function, reset_function = _fresh_knobs_impl() + else: + fresh_function, reset_function = _fresh_knobs_impl(monkeypatch) + try: yield fresh_function() finally: diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 7b10f317b363d844adf004bf49925f5b935785aa..e9f8473103a85a749b532fc0a9b70720d4043de7 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m # testing utilities from triton_kernels.testing import assert_close, compute_actual_scale # target-specific utilities -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 +from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4 # --------------- # initialize data @@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): gs0 = None gs1 = None - if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: + if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: w = w.transpose(-1, -2).contiguous().transpose(-1, -2) return x, w, bias, gs0, gs1 @@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas if hbm_swizzling: if is_hip(): pytest.skip("NYI. HBM swizzling just implemented for CUDA.") - if torch.cuda.get_device_capability()[0] < 9: - pytest.skip("NYI. Ampere swizzling.") - if torch.cuda.get_device_capability()[0] < 10: - if "mxfloat4" not in weight_dtype_str: - pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.") - if k % 64 != 0 or n % 64 != 0: - # Automatic padding not implemented for Hopper swizzle - pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") + if is_cuda(): + if torch.cuda.get_device_capability()[0] < 9: + pytest.skip("NYI. Ampere swizzling.") + if torch.cuda.get_device_capability()[0] < 10: + if "mxfloat4" not in weight_dtype_str: + pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.") + if k % 64 != 0 or n % 64 != 0: + # Automatic padding not implemented for Hopper swizzle + pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") # launch metadata for batched / mx types may not work yet. test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) @@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas torch.manual_seed(0) block_k = None - if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: + if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: # Override block_k for testing correctness. The default is temporarily 128 for # performance reasons which doesn't work with persistent matmul. # TODO: revisit when Triton is better for H100 + MXFP4 @@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, # - rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref) + rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device) scale = lambda val, scal: val if scal is None else val / scal if n_expt_shards > 1: if do_scatter: diff --git a/torch-ext/triton_kernels/__init__.py b/torch-ext/triton_kernels/__init__.py index 4f596996a9d676f61636dcb5d84f7a52d02c684e..1185492a0f0e67ce68f54754d409b811a9d43687 100644 --- a/torch-ext/triton_kernels/__init__.py +++ b/torch-ext/triton_kernels/__init__.py @@ -1,4 +1,11 @@ # Make sure to add this in the build folder as this won't build if we put that here + # from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing # __all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"] + +# Then, run the following code to build the kernels: +# docker run --rm \ +# -v $(pwd):/app \ +# -w /app \ +# ghcr.io/huggingface/kernel-builder:main diff --git a/torch-ext/triton_kernels/matmul_ogs.py b/torch-ext/triton_kernels/matmul_ogs.py index bde692a04d47d3e0c851d837c99c0473ca65e74b..e2dc00246646a9914bcb63a754788c99218540c5 100644 --- a/torch-ext/triton_kernels/matmul_ogs.py +++ b/torch-ext/triton_kernels/matmul_ogs.py @@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias, betas = None, gammas = None, round_x = None, round_y = None, + device: str = "cuda", ): is_input_batched = x.ndim == 3 assert x.dtype.itemsize > 1 @@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias, else: idx = gather_indx.src_indx[lo:hi] // n_expts_act batch = i if is_input_batched else 0 - out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(), w[i].float()) if bias is not None: out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_common.py b/torch-ext/triton_kernels/matmul_ogs_details/_common.py index 25755d3105ffe2b119f9fc65c96be858a987fd3c..fde71f8a936ae9c6cefc44bd83aae85f900c03a5 100644 --- a/torch-ext/triton_kernels/matmul_ogs_details/_common.py +++ b/torch-ext/triton_kernels/matmul_ogs_details/_common.py @@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor # ----------------------------------------------------------------------------- # Utilities # ----------------------------------------------------------------------------- +try: + _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0] + _parts = _ver_str.split(".") + _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3]) +except Exception: + _ver_tuple = (0, 0, 0) +if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"): + _constexpr_function = triton.constexpr_function +else: + _constexpr_function = tl.constexpr_function -@tl.constexpr_function + + +@_constexpr_function def get_scaled_dot_format_string(dtype: tl.dtype): mapping = { tl.float16: "fp16", diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py b/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py index c0369f71826cc2a41c4ea825370c54a94a034bf6..d1d2f3cd6add8427a8b9b23b2435264e5c45a8a4 100644 --- a/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +++ b/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py @@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ..target_info import cuda_capability_geq as _cuda_capability_geq from ..target_info import is_hip as _is_hip +from ._common import _constexpr_function # fmt: off -@tl.constexpr_function +@_constexpr_function def is_hip(): return _is_hip() -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(x, y): return _cuda_capability_geq(x, y) -@tl.constexpr_function +@_constexpr_function def log2(n): return len(bin(n)) - 3 -@tl.constexpr_function +@_constexpr_function def _permute_to_end_order(n: int, axis: int): """ Returns the order of the axes of a tensor to permute `axis` to the end. @@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args): return ret -@tl.constexpr_function +@_constexpr_function def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None): """ Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 952d0c01286de0dff7de87172ba5b87fe9827721..e393d7467bbef15af7428b5c2f42726c89e0e17c 100644 --- a/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import ( compute_scale, ) from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE -from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(major, minor): return target_info.cuda_capability_geq(major, minor) -@tl.constexpr_function +@_constexpr_function def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype: if isinstance(tensor_or_desc, tl.tensor): return tensor_or_desc.dtype.element_ty diff --git a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py index 32caabc5d4db505159381d64801550b1abd4921c..cbe846db00430cc80325559898a022c8b2b2a6ae 100644 --- a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import triton from ..target_info import get_cdna_version import torch -from .opt_flags_details import opt_flags_amd, opt_flags_nvidia +from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel @dataclass @@ -30,6 +30,83 @@ class OptFlags: raise ValueError("Not supported") +def make_default_opt_flags_intel( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + epilogue_effective_itemsize, + constraints, +): + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"] + assert not any([c not in constraints_supported for c in constraints]), constraints.keys() + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + # pid swizzling + group_m = 8 + xcd_swizzle = 1 + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 128 + else: + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + # block n + block_n = opt_flags_intel.compute_block_n(n) + # is_persistent + is_persistent = constraints.get("is_persistent", False) + # block k + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] + else: + block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config) + # split_k + if constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] + elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: + split_k = 1 + else: + estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n) + split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size) + + epilogue_subtile = constraints.get('epilogue_subtile', None) + if epilogue_subtile is None: + epilogue_subtile = 1 + + ret = OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=opt_flags_intel.compute_num_warps(block_m, block_n), + num_stages=constraints.get("num_stages", 2), + fused_scatter=constraints.get('fused_scatter', False), + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=None, + split_k=split_k, + is_persistent=is_persistent, + epilogue_subtile=epilogue_subtile, + arch=None, + target_kernel_kwargs=dict(), + idle_sms=0, + ) + # check constraints + assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + return ret + def make_default_opt_flags_amd( out_dtype, @@ -292,6 +369,8 @@ def make_opt_flags( enforce_bitwise_invariance, epilogue_effective_itemsize, _opt_flags_constraints] backend = triton.runtime.driver.active.get_current_target().backend + if backend == "xpu": + return make_default_opt_flags_intel(*args) if backend == "hip": return make_default_opt_flags_amd(*args) if backend == "cuda": diff --git a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py new file mode 100644 index 0000000000000000000000000000000000000000..0013c3ca666f4e7e4835562f620b1c1dd6ae9607 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py @@ -0,0 +1,41 @@ +import torch +import triton + + +def compute_grid_size(routing_data, m, n, block_m, block_n): + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + grid_n = (n + block_n - 1) // block_n + return grid_m * grid_n + + +def compute_block_n(n: int): + # block_n: + return max(16, min(128, triton.next_power_of_2(n))) + + +def compute_block_k(k: int | None, is_persistent: bool, precision_config): + if k is not None: + block_k = max(32, min(128, triton.next_power_of_2(k))) + has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + if is_persistent and has_mx_weight_scale: + block_k = min(block_k, 128) + return block_k + + +def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: + device_props = torch.xpu.get_device_properties(0) + n_sms = device_props.gpu_subslice_count + split_k = n_sms // grid_size + if k is not None: + # avoid split_k for small k + num_block_k = triton.cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + split_k = max(split_k, 1) + return split_k + + +def compute_num_warps(block_m, block_n): + return max(block_m * block_n // 4096, 4) \ No newline at end of file diff --git a/torch-ext/triton_kernels/numerics_details/flexpoint.py b/torch-ext/triton_kernels/numerics_details/flexpoint.py index daa78496d95ec9206cdf183a35dea57122dfbb79..cebe7f721f6945171602e771c951ceee55bcfd27 100644 --- a/torch-ext/triton_kernels/numerics_details/flexpoint.py +++ b/torch-ext/triton_kernels/numerics_details/flexpoint.py @@ -1,5 +1,6 @@ from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 from .. import target_info +from ..matmul_ogs_details._common import _constexpr_function import triton import triton.language as tl @@ -52,7 +53,7 @@ def rcp_max_finite(dtype): tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") -@tl.constexpr_function +@_constexpr_function def cuda_capability_geq(major, minor): return target_info.cuda_capability_geq(major, minor) diff --git a/torch-ext/triton_kernels/swiglu.py b/torch-ext/triton_kernels/swiglu.py index 773070db92e9d03114c0e2b341c1995b969eba85..c0257260e54e0020c296b66b9941dda384848431 100644 --- a/torch-ext/triton_kernels/swiglu.py +++ b/torch-ext/triton_kernels/swiglu.py @@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function): # optimization hyperparameters BLOCK_M, BLOCK_N = 32 // a.itemsize, 128 num_warps = 4 - kwargs = {'maxnreg': 64} if not target_info.is_hip() else {} + kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {} # launch semi-persistent kernel N_BLOCKS = triton.cdiv(N // 2, BLOCK_N) num_sms = target_info.num_sms() diff --git a/torch-ext/triton_kernels/target_info.py b/torch-ext/triton_kernels/target_info.py index 9beae7108e8d94e1bf04747e36c7d01a5650e5ca..c1912dc26f0b77f6172f0f6ec36597539b6cd224 100644 --- a/torch-ext/triton_kernels/target_info.py +++ b/torch-ext/triton_kernels/target_info.py @@ -1,54 +1,70 @@ import torch import triton -cached_capabilities = {} +from .matmul_ogs_details._common import _constexpr_function +from triton.runtime import driver +def current_target(): + try: + active_driver = driver.active + except RuntimeError: + # If there is no active driver, return None + return None + return active_driver.get_current_target() +current_target.__triton_builtin__ = True + + +@_constexpr_function def is_cuda(): - if "is_cuda" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda" - return cached_capabilities["is_cuda"] + target = current_target() + return target is not None and target.backend == "cuda" +@_constexpr_function def is_hip(): - if "is_hip" not in cached_capabilities: - cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip) - return cached_capabilities["is_hip"] + target = current_target() + return target is not None and target.backend == "hip" + +@_constexpr_function +def is_xpu(): + target = current_target() + return target is not None and target.backend == "xpu" + +@_constexpr_function def is_hip_cdna3(): - if "is_hip_cdna3" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip' - and target.arch == 'gfx942') - return cached_capabilities["is_hip_cdna3"] + target = current_target() + return target is not None and target.arch == "gfx942" +@_constexpr_function def is_hip_cdna4(): - if "is_hip_cdna4" not in cached_capabilities: - target = triton.runtime.driver.active.get_current_target() - cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip' - and target.arch == 'gfx950') - return cached_capabilities["is_hip_cdna4"] + target = current_target() + return target is not None and target.arch == "gfx950" +@_constexpr_function def cuda_capability_geq(major, minor=0): """ Determines whether we have compute capability >= (major, minor) and returns this as a constexpr boolean. This can be used for guarding inline asm implementations that require a certain compute capability. """ - if is_hip(): + """ + Determines whether we have compute capability >= (major, minor) and + returns this as a constexpr boolean. This can be used for guarding + inline asm implementations that require a certain compute capability. + """ + target = current_target() + if target is None or target.backend != "cuda": return False - if "cuda" not in cached_capabilities: - if torch.cuda.is_available(): - cached_capabilities["cuda"] = torch.cuda.get_device_capability() - else: - cached_capabilities["cuda"] = (0, 0) - return cached_capabilities["cuda"] >= (major, minor) + assert isinstance(target.arch, int) + return target.arch >= major * 10 + minor +@_constexpr_function def get_cdna_version(): """ Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently @@ -65,13 +81,18 @@ def get_cdna_version(): return -1 +@_constexpr_function def has_tma_gather(): return cuda_capability_geq(10, 0) +@_constexpr_function def has_native_mxfp(): return cuda_capability_geq(10, 0) def num_sms(): - return torch.cuda.get_device_properties(0).multi_processor_count + if is_cuda(): + return torch.cuda.get_device_properties(0).multi_processor_count + if is_xpu(): + return torch.xpu.get_device_properties(0).max_compute_units