# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. """ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. See the original Unsloth repository at https://github.com/unslothai/unsloth. The following line https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 is based on code from Unsloth, located at: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 Modifications made by Yanning Chen, 2024. """ import functools import importlib import operator from typing import Callable import torch import triton import triton.language as tl from packaging.version import Version def is_hip() -> bool: return torch.version.hip is not None def ensure_contiguous(fn): @functools.wraps(fn) def wrapper(ctx, *args, **kwargs): def maybe_to_contiguous(x): return x.contiguous() if isinstance(x, torch.Tensor) else x args = [maybe_to_contiguous(arg) for arg in args] kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} return fn(ctx, *args, **kwargs) return wrapper def calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 MAX_FUSED_SIZE = 65536 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds " f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 if not is_hip() else 16 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 return BLOCK_SIZE, num_warps def compare_version(package: str, operator: Callable, target: str): try: pkg = importlib.import_module(package) except ImportError: return False pkg_version = Version(pkg.__version__) return operator(pkg_version, Version(target)) def get_amp_custom_fwd_bwd() -> Callable: if compare_version("torch", operator.ge, "2.4.0"): return ( functools.partial(torch.amp.custom_fwd, device_type="cuda"), functools.partial(torch.amp.custom_bwd, device_type="cuda"), ) return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() torch_to_triton_dtype = { torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } @triton.jit def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, n_cols, BLOCK_SIZE: tl.constexpr, ): """ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. The multiplication is performed in-place on the tensor pointed by X_ptr. Parameters: X_ptr: Pointer to the input tensor. X_stride (int): The stride of the input tensor. grad_output_ptr: Pointer to the gradient output value. n_cols (int): The number of columns in the input tensor. BLOCK_SIZE (int): The block size for Triton operations. """ # Get the program ID and convert it to int64 to avoid overflow program_id = tl.program_id(0).to(tl.int64) # Locate the start index X_ptr += program_id * X_stride # Load the gradient output value grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)