cpu-casuallm / mamba /mamba_ssm /ops /triton /ssd_state_passing.py
Somunia's picture
Upload 116 files
306b4ac verified
# Copyright (c) 2024, Tri Dao, Albert Gu.
"""We want triton==2.1.0 or 2.2.0 for this
"""
import math
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['dim'],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
# Matrix dimensions
dim, nchunks, seqlen, chunk_size,
# Strides
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
stride_seq_idx_batch, stride_seq_idx_seqlen,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
if not HAS_INITSTATES:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
else:
initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
if HAS_SEQ_IDX:
seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['dim'],
)
@triton.jit
def _state_passing_bwd_kernel(
# Pointers to matrices
dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
# Matrix dimensions
dim, nchunks, seqlen, chunk_size,
# Strides
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
# Meta-parameters
CONVERT_STATES: tl.constexpr,
HAS_DFINAL_STATES: tl.constexpr,
HAS_DINITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
if CONVERT_STATES:
states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
if HAS_DFINAL_STATES:
dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
if HAS_DINITSTATES:
dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
dout_ptrs = dout_ptr + offs_m * stride_dout_dim
if CONVERT_STATES:
states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
if HAS_DFINAL_STATES:
dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
else:
dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
if HAS_SEQ_IDX:
seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
dstates_ptrs -= stride_dstates_chunk
for c in range(nchunks - 1):
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
if HAS_SEQ_IDX:
seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if CONVERT_STATES:
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
ddA = tl.sum(out * dstates) * scale
tl.store(ddA_cs_ptr, ddA)
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dstates = scale * dstates + dout
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
dout_ptrs -= stride_dout_chunk
dstates_ptrs -= stride_dstates_chunk
dA_cs_ptr -= stride_dA_cs_chunk
ddA_cs_ptr -= stride_ddA_cs_chunk
out_ptrs -= stride_out_chunk
if CONVERT_STATES:
states_converted_ptrs -= stride_out_chunk
if CONVERT_STATES:
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
if not HAS_DINITSTATES:
tl.store(ddA_cs_ptr, 0.0)
else:
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
if HAS_SEQ_IDX:
scale = tl.where(seq_idx == 0, scale, 0.0)
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
ddA = tl.sum(out * dstates) * scale
tl.store(ddA_cs_ptr, ddA)
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dstates = scale * dstates + dout
tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
out_dtype=None):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if initial_states is not None:
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
final_states.stride(0), final_states.stride(1), final_states.stride(2),
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
*((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
if initial_states is not None else (0, 0, 0)),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
)
return out, final_states
def _state_passing_bwd(
states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
dstates_dtype=None, states_dtype=None, chunk_size=None
):
"""
states contains the initial_states at index 0. The final states are not included in states.
"""
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
assert dout.shape == (batch, nchunks, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
if states_dtype is not None and states_dtype != states.dtype:
states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
assert states_converted.stride() == states.stride()
else:
states_converted = None
if has_initial_states:
dinitstates = torch.empty_like(dstates[:, 0])
else:
dinitstates = None
if dfinal_states is not None:
assert dfinal_states.shape == (batch, nheads, dim)
BLOCK_SIZE_min = 64
n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
dtype=torch.float32, device=dA_chunk_cumsum.device)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
with torch.cuda.device(dout.device.index):
_state_passing_bwd_kernel[grid](
dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
dstates, ddA_chunk_cumsum, dinitstates, states_converted,
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
*((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
if dfinal_states is not None else (0, 0, 0)),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
*((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
if dinitstates is not None else (0, 0, 0)),
CONVERT_STATES=states_converted is not None,
HAS_DFINAL_STATES=dfinal_states is not None,
HAS_DINITSTATES=dinitstates is not None,
HAS_SEQ_IDX=seq_idx is not None,
)
BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
if states_dtype is not None and states_dtype == states.dtype:
states_converted = states
return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
class StatePassingFn(torch.autograd.Function):
@staticmethod
def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if states.stride(-1) != 1:
states = states.contiguous()
out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
ctx.save_for_backward(out, dA_chunk_cumsum)
ctx.has_initial_states = initial_states is not None
return out, final_states
@staticmethod
def backward(ctx, dout, dfinal_states):
out, dA_chunk_cumsum = ctx.saved_tensors
batch, nchunks, nheads, dim = out.shape
assert dout.shape == (batch, nchunks, nheads, dim)
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
assert dfinal_states.shape == (batch, nheads, dim)
if dout.stride(-1) != 1:
dout = dout.contiguous()
dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
)
return dstates, ddA_chunk_cumsum, dinitstates
def state_passing(states, dA_chunk_cumsum, initial_states=None):
"""
Argument:
states: (batch, nchunks, nheads, dim)
dA_chunk_cumsum: (batch, nheads, nchunks)
initial_states: (batch, nheads, dim)
Return:
out: (batch, nchunks, nheads, dim)
final_states: (batch, nheads, dim)
"""
return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
"""
Argument:
states: (batch, nchunks, nheads, dim)
dA_chunk_cumsum: (batch, nheads, nchunks)
initial_states: (batch, nheads, dim)
Return:
out: (batch, nchunks, nheads, dim)
final_states: (batch, nheads, dim)
"""
if initial_states is None:
initial_states = torch.zeros_like(states[:, 0])
states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
nchunks = dA_chunk_cumsum.shape[-1]
# (batch, nheads, nchunks, nchunks)
dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
# (batch, nheads, nchunks, nchunks)
decay_chunk = torch.exp(dt_chunk_segment_sum)
causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
return out[:, :-1], out[:, -1]