Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
""" | |
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. | |
See the original Unsloth repository at https://github.com/unslothai/unsloth. | |
The following line | |
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 | |
is based on code from Unsloth, located at: | |
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 | |
Modifications made by Yanning Chen, 2024. | |
""" | |
import functools | |
import importlib | |
import operator | |
from typing import Callable | |
import torch | |
import triton | |
import triton.language as tl | |
from packaging.version import Version | |
def is_hip() -> bool: | |
return torch.version.hip is not None | |
def ensure_contiguous(fn): | |
def wrapper(ctx, *args, **kwargs): | |
def maybe_to_contiguous(x): | |
return x.contiguous() if isinstance(x, torch.Tensor) else x | |
args = [maybe_to_contiguous(arg) for arg in args] | |
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} | |
return fn(ctx, *args, **kwargs) | |
return wrapper | |
def calculate_settings(n): | |
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 | |
MAX_FUSED_SIZE = 65536 | |
BLOCK_SIZE = triton.next_power_of_2(n) | |
if BLOCK_SIZE > MAX_FUSED_SIZE: | |
raise RuntimeError( | |
f"Cannot launch Triton kernel since n = {n} exceeds " | |
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." | |
) | |
num_warps = 4 | |
if BLOCK_SIZE >= 32768: | |
num_warps = 32 if not is_hip() else 16 | |
elif BLOCK_SIZE >= 8192: | |
num_warps = 16 | |
elif BLOCK_SIZE >= 2048: | |
num_warps = 8 | |
return BLOCK_SIZE, num_warps | |
def compare_version(package: str, operator: Callable, target: str): | |
try: | |
pkg = importlib.import_module(package) | |
except ImportError: | |
return False | |
pkg_version = Version(pkg.__version__) | |
return operator(pkg_version, Version(target)) | |
def get_amp_custom_fwd_bwd() -> Callable: | |
if compare_version("torch", operator.ge, "2.4.0"): | |
return ( | |
functools.partial(torch.amp.custom_fwd, device_type="cuda"), | |
functools.partial(torch.amp.custom_bwd, device_type="cuda"), | |
) | |
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd | |
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() | |
torch_to_triton_dtype = { | |
torch.float32: tl.float32, | |
torch.float16: tl.float16, | |
torch.bfloat16: tl.bfloat16, | |
} | |
def element_mul_kernel( | |
X_ptr, | |
X_stride, | |
grad_output_ptr, | |
n_cols, | |
BLOCK_SIZE: tl.constexpr, | |
): | |
""" | |
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. | |
The multiplication is performed in-place on the tensor pointed by X_ptr. | |
Parameters: | |
X_ptr: Pointer to the input tensor. | |
X_stride (int): The stride of the input tensor. | |
grad_output_ptr: Pointer to the gradient output value. | |
n_cols (int): The number of columns in the input tensor. | |
BLOCK_SIZE (int): The block size for Triton operations. | |
""" | |
# Get the program ID and convert it to int64 to avoid overflow | |
program_id = tl.program_id(0).to(tl.int64) | |
# Locate the start index | |
X_ptr += program_id * X_stride | |
# Load the gradient output value | |
grad_output = tl.load(grad_output_ptr) | |
# Perform the element-wise multiplication | |
for i in range(0, n_cols, BLOCK_SIZE): | |
X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) | |
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) | |