Somunia's picture
Upload 116 files
306b4ac verified
# Copyright (c) 2024, Tri Dao, Albert Gu.
"""We want triton==2.1.0 or 2.2.0 for this
"""
import math
from packaging import version
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,
# Matrix dimensions
chunk_size, hdim, dstate,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
stride_D_head,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Without the if (pid_c > -1), with Triton 2.1.0, I get
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
# With Triton 2.2.0, this works
if IS_TRITON_22 or pid_c > -1:
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
if not HAS_SEQ_IDX:
scale_m = tl.exp(dA_cs_m)
else:
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
acc += x_residual * D
if HAS_Z:
out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)
tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim', 'dstate'],
)
@triton.jit
def _chunk_scan_fwd_kernel_wip(
# Pointers to matrices
cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr,
# Matrix dimensions
chunk_size, hdim, dstate,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate,
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
stride_D_head,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_n = tl.program_id(axis=0)
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head
prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
offs_m = tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
# if pid_c == 0:
# if pid_b == 0:
# if pid_h == 0:
# tl.device_print("", prev_states)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
# dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# scale_m = tl.exp(dA_cs_m)
# C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
# acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32)
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# cb *= dt_m
# mask = offs_m[:, None] >= offs_m[None, :]
# cb = tl.where(mask, cb, 0.0)
# cb = cb.to(x_ptr.dtype.element_ty)
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0)
# acc += tl.dot(cb, x)
# if HAS_D:
# if D_HAS_HDIM:
# D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
# else:
# D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
# acc += x.to(tl.float32) * D
# tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M):
start_m = tl.multiple_of(start_m, BLOCK_SIZE_M)
dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1)
if not HAS_SEQ_IDX:
scale_m = tl.exp(dA_cs_m)
else:
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0)
acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32)
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
# cb *= dt_m
# mask = offs_m[:, None] >= offs_m[None, :]
# cb = tl.where(mask, cb, 0.0)
# cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0)
# acc += tl.dot(cb, x)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
acc += x.to(tl.float32) * D
# if HAS_Z:
# out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
# out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
# tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
# z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
# z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
# z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
# acc *= z * tl.sigmoid(z)
tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim))
# TODO: this is not correct, and quite a bit slower
if start_m + BLOCK_SIZE_M < chunk_size_limit:
# B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32)
B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0)
dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32)
# TODO: seq_idx
scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m
# B *= scale
B = B.to(x_ptr.dtype.element_ty)
tmp = tl.dot(B, x)
prev_states += tmp.to(prev_states.dtype)
C_ptrs += BLOCK_SIZE_M * stride_C_seqlen
B_ptrs += BLOCK_SIZE_M * stride_B_seqlen
cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_M * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_M * stride_dt_csize
out_ptrs += BLOCK_SIZE_M * stride_out_seqlen
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32}),
triton.Config({'BLOCK_SIZE_M': 64}),
triton.Config({'BLOCK_SIZE_M': 128}),
triton.Config({'BLOCK_SIZE_M': 256}),
],
key=["chunk_size", "hdim"],
)
@triton.jit
def _chunk_scan_bwd_dz_kernel(
# Pointers to matrices
dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_D_head,
stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim,
stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim,
stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim,
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_DDACS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head
if RECOMPUTE_OUTPUT:
outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head
if HAS_DDACS:
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
if HAS_D:
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim)
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim)
dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim)
if RECOMPUTE_OUTPUT:
outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim)
if HAS_D:
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
if D_HAS_HDIM:
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
z_sigmoid = tl.sigmoid(z)
if RECOMPUTE_OUTPUT:
outz = out * z * z_sigmoid
tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid))
tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
dout *= z * z_sigmoid
tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
if HAS_D:
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
dD = tl.sum(dout * x, axis=0)
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
dD = tl.sum(dout * x)
tl.store(dD_ptr, dD)
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
out -= x * D
if HAS_DDACS:
ddA_cs = tl.sum(dout * out, axis=1)
tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
key=['hdim', 'dstate', 'chunk_size'],
)
@triton.jit
def _chunk_scan_bwd_dstates_kernel(
# Pointers to matrices
dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr,
# Matrix dimensions
hdim, dstate, chunk_size,
batch, seqlen, nchunks, nheads_ngroups_ratio,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate,
stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen)
c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale_k = tl.exp(dA_cs_k)
else:
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0)
dout = (dout * scale_k).to(dout_ptr.dtype.element_ty)
c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0)
acc += tl.dot(dout, c)
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
c_ptrs += BLOCK_SIZE_K * stride_c_seqlen
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
out = acc.to(dprev_states_ptr.dtype.element_ty)
dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate)
tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate))
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
],
key=['chunk_size', 'dstate', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_dc_kernel(
# Pointers to matrices
dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
dc_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, dstate, hdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_DDA_CS: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_sg = tl.program_id(axis=2)
pid_s = pid_sg // ngroups
pid_g = pid_sg - pid_s * ngroups
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split
prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
if HAS_DDA_CS:
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
if HAS_DDA_CS:
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_DDA_CS:
c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
for h in range(nheads_iter):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
dc = tl.dot(dout, prev_states)
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_m)
else:
scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
dc *= scale[:, None]
if HAS_DDA_CS:
ddA_cs = tl.sum(dc * c, axis=1)
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
acc += dc
dout_ptrs += stride_dout_head
prev_states_ptrs += stride_prev_states_head
dA_cumsum_ptrs += stride_dA_cs_head
if HAS_DDA_CS:
ddA_cumsum_ptrs += stride_ddA_cs_head
# if HAS_SEQ_IDX:
# seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
# acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate)
tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_dx_kernel(
# Pointers to matrices
x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr,
dx_ptr, ddt_ptr, # dD_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_D_head,
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
# stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
# if HAS_D:
# dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Idk why limiting K_MAX gives wrong results, is it a Triton bug?
# K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
K_MAX = chunk_size_limit
for k in range(0, K_MAX, BLOCK_SIZE_K):
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
# This will cause NaN in acc, and hence NaN in dx and ddt.
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
cb = tl.where(mask, cb, 0.0)
cb = cb.to(dout_ptr.dtype.element_ty)
acc += tl.dot(cb, dout)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dx = acc * dt_m[:, None]
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
if HAS_D:
dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
dx += dout_res * D
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
ddt = tl.sum(acc * x, axis=1)
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
# if HAS_D:
# dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim)
# dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
# dD = tl.sum(x * dout, axis=0)
# tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N)
# Disabling HAS_DDA_CS for now since it's much slower
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim'],
)
# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)})
# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32})
@triton.jit
def _chunk_scan_bwd_dcb_kernel(
# Pointers to matrices
x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
dcb_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
# Meta-parameters
HAS_DDA_CS: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_sg = tl.program_id(axis=2)
pid_s = pid_sg // ngroups
pid_g = pid_sg - pid_s * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
if HAS_DDA_CS:
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
if HAS_DDA_CS:
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
return
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_DDA_CS:
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
for h in range(nheads_iter):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
dcb = tl.dot(dout, x)
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
dcb *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
if HAS_DDA_CS:
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
ddA_cs = dcb * cb
mask = offs_m[:, None] >= offs_n[None, :] + 1
ddA_cs = tl.where(mask, ddA_cs, 0.0)
ddA_cs = tl.cumsum(ddA_cs, axis=1)
ddA_cs = tl.where(mask, ddA_cs, 0.0)
ddA_cs = tl.sum(ddA_cs, axis=0)
tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
tl.store(ddA_cumsum_ptr, 0.0)
acc += dcb
dout_ptrs += stride_dout_head
x_ptrs += stride_x_head
dt_ptrs += stride_dt_head
dA_cumsum_ptr += stride_dA_cs_head
if HAS_DDA_CS:
ddA_cumsum_ptr += stride_ddA_cs_head
ddA_cumsum_ptrs += stride_ddA_cs_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
mask = offs_m[:, None] >= offs_n[None, :]
acc = tl.where(mask, acc, 0.0)
dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
# Not numerically stable and should not be used. Leaving here for reference.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32}),
triton.Config({'BLOCK_SIZE_M': 64}),
triton.Config({'BLOCK_SIZE_M': 128}),
triton.Config({'BLOCK_SIZE_M': 256}),
],
key=["chunk_size", "hdim"],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_unstable_kernel(
# Pointers to matrices
dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr,
ddA_cumsum_ptr, dD_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_D_head,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
SUBTRACT_DDTDT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
if HAS_D:
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
if HAS_D:
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
if D_HAS_HDIM:
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if HAS_D:
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
dD = tl.sum(dout * x, axis=0)
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
dD = tl.sum(dout * x)
tl.store(dD_ptr, dD)
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
out -= x * D
ddA_cs = tl.sum(dout * out, axis=1)
if SUBTRACT_DDTDT:
dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
ddA_cs -= dt * ddt
tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_stable_kernel_old(
# Pointers to matrices
x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
ddAcs_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
# Doing a matmul loop with cumsum later on will cause Triton to crash
# Instead we do just one big matmul
# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# for k in range(0, hdim, BLOCK_SIZE_K):
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
# acc += tl.dot(dout, x)
# dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
# x_ptrs += BLOCK_SIZE_K * stride_x_hdim
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
acc = tl.dot(dout, x)
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
acc *= cb
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
acc *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
acc = tl.cumsum(acc, axis=1)
acc = tl.where(mask, acc, 0.0)
ddA_cs = tl.sum(acc, axis=0)
ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
tl.store(ddAcs_ptr, 0.0)
# offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64)
# offs_k = tl.arange(0, BLOCK_SIZE_K)
# dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
# x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
# dt_ptrs = dt_ptr + offs_n * stride_dt_csize
# cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
# chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
# chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
# rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
# dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
# ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
# for n in range(0, chunk_size_limit_n, 64):
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0)
# acc = tl.dot(dout, x)
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32)
# acc *= cb
# dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
# acc *= dt_n
# dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
# mask = offs_m[:, None] >= offs_n[None, :] + 1 + n
# acc = tl.where(mask, acc, 0.0)
# acc = tl.cumsum(acc, axis=1)
# acc = tl.where(mask, acc, 0.0)
# ddA_cs = tl.sum(acc, axis=0)
# tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n)
# # tl.store(ddAcs_ptr, 0.0)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_stable_kernel(
# Pointers to matrices
x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
tl.store(ddA_cumsum_ptr, 0.0)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower
lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M
# lo, hi = 0, chunk_size
for start_n in range(lo, hi, BLOCK_SIZE_N):
start_n = tl.multiple_of(start_n, BLOCK_SIZE_N)
# Doing a matmul loop with cumsum later on will cause Triton to crash
# Instead we do just one big matmul
# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# for k in range(0, hdim, BLOCK_SIZE_K):
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
# acc += tl.dot(dout, x)
# dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
# x_ptrs += BLOCK_SIZE_K * stride_x_hdim
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0)
acc = tl.dot(dout, x)
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
acc *= dt_n
# If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
acc *= cb
dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
rowsum_new = rowsum + tl.sum(acc, axis=1)
acc = rowsum[:, None] + tl.cumsum(acc, axis=1)
rowsum = rowsum_new
acc = tl.where(mask, acc, 0.0)
ddA_cs = tl.sum(acc, axis=0)
tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1)
x_ptrs += BLOCK_SIZE_N * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_N * stride_dt_csize
cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n
ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
# Need to zero out the rest, since we'll be summing the rows together
for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
],
key=['chunk_size', 'dstate', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_prev_kernel(
# Pointers to matrices
dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, dstate, hdim,
batch, seqlen, nchunks, nheads_ngroups_ratio,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
acc = tl.dot(dout, prev_states)
c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
ddA_cs = tl.sum(acc * c, axis=1)
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_m)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
ddA_cs *= scale
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
# Allocates output.
out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
if z is not None:
out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
if z is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel[grid](
cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,
chunk_size, headdim, dstate,
batch, seqlen, nheads // ngroups,
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
z_strides[0], z_strides[1], z_strides[2], z_strides[3],
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
D.stride(0) if D is not None else 0,
True,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_TRITON_22=TRITON_22,
)
return out, out_x
def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert B.shape == C.shape
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
# Allocates output.
out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
if z is not None:
out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
if z is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel_wip[grid](
cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D,
chunk_size, headdim, dstate,
batch, seqlen, nheads // ngroups,
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
z_strides[0], z_strides[1], z_strides[2], z_strides[3],
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
D.stride(0) if D is not None else 0,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
BLOCK_SIZE_M=128,
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
)
return out, out_x
def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False):
batch, seqlen, nheads, headdim = x.shape
assert z.shape == x.shape
assert out.shape == x.shape
assert dout.shape == out.shape
nchunks = math.ceil(seqlen / chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert D.stride(-1) == 1
if has_ddAcs:
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
if D is not None:
BLOCK_SIZE_min = 32
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
else:
dD = None
if dz is not None:
assert dz.shape == z.shape
else:
dz = torch.empty_like(z)
if recompute_output:
outz = torch.empty_like(x)
dout_x = torch.empty_like(dout)
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
if D is not None else (0, 0, 0, 0, 0))
grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dz_kernel[grid_dz](
dout, out, z, x, D, outz if recompute_output else None,
dz, dout_x, dD, ddA_cumsum if has_ddAcs else None,
chunk_size, headdim,
batch, seqlen,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
z.stride(0), z.stride(1), z.stride(2), z.stride(3),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
D.stride(0) if D is not None else 0,
*((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)),
dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),
dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3),
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
*((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
if has_ddAcs else (0, 0, 0, 0)),
D is not None,
D.dim() == 2 if D is not None else True,
has_ddAcs,
BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
RECOMPUTE_OUTPUT=recompute_output,
)
if D is not None:
BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
if D.dim() == 1:
dD = rearrange(dD, "h 1 -> h")
return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD)
return return_vals if not recompute_output else (*return_vals, outz)
def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None):
batch, seqlen, nheads, headdim = dout.shape
_, _, nchunks, chunk_size = dA_cumsum.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
dtype = C.dtype if dtype is None else dtype
dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype)
grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(C.device.index):
_chunk_scan_bwd_dstates_kernel[grid_dstates](
dout, C, dprev_states, dA_cumsum, seq_idx,
headdim, dstate, chunk_size,
batch, seqlen, nchunks, nheads // ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_SEQ_IDX=seq_idx is not None,
)
return dprev_states
def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1):
batch, nchunks, nheads, headdim, dstate = prev_states.shape
_, seqlen, _, _ = dout.shape
_, _, _, chunk_size = dA_cumsum.shape
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == (batch, seqlen, nheads, headdim)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if C is not None:
assert C.shape == (batch, seqlen, ngroups, dstate)
C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3))
ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3))
else:
C_strides = (0, 0, 0, 0)
ddA_cumsum_prev = None
ddA_cumsum_prev_strides = (0, 0, 0, 0)
nheads_ngroups_ratio = nheads // ngroups
sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32)
grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nsplits * ngroups)
with torch.cuda.device(dout.device.index):
_chunk_scan_bwd_dc_kernel[grid_dc](
dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev,
chunk_size, dstate, headdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
*C_strides,
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4),
*ddA_cumsum_prev_strides,
HAS_DDA_CS=ddA_cumsum_prev is not None,
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
dC = dC.sum(2)
return dC if C is None else (dC, ddA_cumsum_prev)
def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert dout.shape == x.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if CB is not None:
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4))
BLOCK_SIZE_M_min = 16
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4))
else:
CB_strides = (0, 0, 0, 0, 0)
ddA_cumsum = None
ddA_cumsum_strides = (0, 0, 0, 0, 0)
nheads_ngroups_ratio = nheads // ngroups
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32)
grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
batch * nchunks, nsplits * ngroups)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dcb_kernel[grid_dcb](
x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
*CB_strides,
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5),
*ddA_cumsum_strides,
HAS_DDA_CS=ddA_cumsum is not None,
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
dcb = dcb.sum(2)
if ddA_cumsum is not None:
BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return dcb if CB is None else (dcb, ddA_cumsum)
def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert dout.shape == x.shape
# if D is not None:
# BLOCK_SIZE_M_min = 32
# dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32)
# else:
# dD = None
dx = torch.empty_like(x)
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dx_kernel[grid_dx](
x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
D.stride(0) if D is not None else 0,
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
# dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0,
D is not None,
D.dim() == 2 if D is not None else True,
)
# if D is not None:
# BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
# n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
# dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
return dx, ddt.to(dtype=dt.dtype)
def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True):
"""Not numerically stable and should not be used. Leaving here for reference.
"""
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert ddt.shape == dt.shape
assert out.shape == x.shape
assert dout.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
ddA_cumsum = torch.empty_like(dt)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
if D is not None: # Triton gives wrong results if we write to the same location
BLOCK_SIZE_min = 32
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
else:
dD = None
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
if D is not None else (0, 0, 0, 0, 0))
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs](
dout, out, dt, ddt, x, D, ddA_cumsum, dD,
chunk_size, headdim,
batch, seqlen,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
D.stride(0) if D is not None else 0,
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
D is not None,
D.dim() == 2 if D is not None else True,
subtract_ddtdt,
BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
)
if D is not None:
BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
if D.dim() == 1:
dD = rearrange(dD, "h 1 -> h")
return ddA_cumsum, dD
def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == x.shape
assert dA_cumsum.shape == dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
BLOCK_SIZE_M_min = 16
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs](
x, dout, dt, dA_cumsum, cb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16),
)
BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return ddA_cumsum
def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == x.shape
assert dA_cumsum.shape == dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
BLOCK_SIZE_M_min = 32
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs](
x, dout, dt, dA_cumsum, cb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return ddA_cumsum
def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None):
batch, nchunks, nheads, headdim, dstate = prev_states.shape
_, seqlen, _, _ = dout.shape
_, _, _, chunk_size = dA_cumsum.shape
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == (batch, seqlen, nheads, headdim)
ngroups = C.shape[2]
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(dout.device.index):
_chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs](
dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev,
chunk_size, dstate, headdim,
batch, seqlen, nchunks, nheads // ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3),
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
return ddA_cumsum_prev
class ChunkScanFn(torch.autograd.Function):
@staticmethod
def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
# Check constraints.
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
CB = _bmm_chunk_fwd(C, B, chunk_size)
out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)
ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)
return out
@staticmethod
def backward(ctx, dout):
if dout.stride(-1) != 1:
dout = dout.contiguous()
out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = B.shape
assert dout.shape == (batch, seqlen, nheads, headdim)
if z is not None:
dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D)
else:
dz = None
dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype)
dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups)
dC = dC.to(C.dtype)
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups)
dCB = dCB.to(CB.dtype)
dB = _bmm_chunk_bwd(C, dCB)
dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC)
dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D)
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
if z is not None:
ddA_cumsum -= ddt * dt
else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz
ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D)
ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype)
return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz
def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
"""
prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.
Argument:
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)
def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
"""
Argument:
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
assert C.shape == B.shape
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
# (batch, nheads, nchunks, chunksize, chunksize)
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
prev_states.to(C.dtype)) * state_decay_out
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
if D.dim() == 1:
D = rearrange(D, "h -> h 1")
out = out + x * D
return out if z is None else out * F.silu(z)