Spaces:
Running
on
A100
Running
on
A100
File size: 3,981 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
"""
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
See the original Unsloth repository at https://github.com/unslothai/unsloth.
The following line
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
is based on code from Unsloth, located at:
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
Modifications made by Yanning Chen, 2024.
"""
import functools
import importlib
import operator
from typing import Callable
import torch
import triton
import triton.language as tl
from packaging.version import Version
def is_hip() -> bool:
return torch.version.hip is not None
def ensure_contiguous(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
def maybe_to_contiguous(x):
return x.contiguous() if isinstance(x, torch.Tensor) else x
args = [maybe_to_contiguous(arg) for arg in args]
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
return fn(ctx, *args, **kwargs)
return wrapper
def calculate_settings(n):
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
MAX_FUSED_SIZE = 65536
BLOCK_SIZE = triton.next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds "
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32 if not is_hip() else 16
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
def compare_version(package: str, operator: Callable, target: str):
try:
pkg = importlib.import_module(package)
except ImportError:
return False
pkg_version = Version(pkg.__version__)
return operator(pkg_version, Version(target))
def get_amp_custom_fwd_bwd() -> Callable:
if compare_version("torch", operator.ge, "2.4.0"):
return (
functools.partial(torch.amp.custom_fwd, device_type="cuda"),
functools.partial(torch.amp.custom_bwd, device_type="cuda"),
)
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
torch_to_triton_dtype = {
torch.float32: tl.float32,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
}
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|