|
import torch |
|
|
|
|
|
class CrossScan(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor): |
|
B, C, H, W = x.shape |
|
ctx.shape = (B, C, H, W) |
|
xs = x.new_empty((B, 4, C, H * W)) |
|
xs[:, 0] = x.flatten(2, 3) |
|
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) |
|
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) |
|
return xs |
|
|
|
@staticmethod |
|
def backward(ctx, ys: torch.Tensor): |
|
|
|
B, C, H, W = ctx.shape |
|
L = H * W |
|
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
|
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) |
|
return y.view(B, -1, H, W) |
|
|
|
|
|
class CrossMerge(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, ys: torch.Tensor): |
|
B, K, D, H, W = ys.shape |
|
ctx.shape = (H, W) |
|
ys = ys.view(B, K, D, -1) |
|
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, x: torch.Tensor): |
|
|
|
|
|
H, W = ctx.shape |
|
B, C, L = x.shape |
|
xs = x.new_empty((B, 4, C, L)) |
|
xs[:, 0] = x |
|
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) |
|
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) |
|
xs = xs.view(B, 4, C, H, W) |
|
return xs |
|
|
|
|
|
|
|
class CrossScan_Ab_2direction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor): |
|
B, C, H, W = x.shape |
|
ctx.shape = (B, C, H, W) |
|
x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
|
x = torch.cat([x, x.flip(dims=[-1])], dim=1) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, ys: torch.Tensor): |
|
B, C, H, W = ctx.shape |
|
L = H * W |
|
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
|
return ys.sum(1).view(B, -1, H, W) |
|
|
|
|
|
class CrossMerge_Ab_2direction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, ys: torch.Tensor): |
|
B, K, D, H, W = ys.shape |
|
ctx.shape = (H, W) |
|
ys = ys.view(B, K, D, -1) |
|
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
return ys.contiguous().sum(1) |
|
|
|
@staticmethod |
|
def backward(ctx, x: torch.Tensor): |
|
H, W = ctx.shape |
|
B, C, L = x.shape |
|
x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
|
x = torch.cat([x, x.flip(dims=[-1])], dim=1) |
|
return x.view(B, 4, C, H, W) |
|
|
|
|
|
class CrossScan_Ab_1direction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor): |
|
B, C, H, W = x.shape |
|
ctx.shape = (B, C, H, W) |
|
x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
|
return x |
|
|
|
|
|
@staticmethod |
|
def backward(ctx, ys: torch.Tensor): |
|
B, C, H, W = ctx.shape |
|
return ys.view(B, 4, -1, H, W).sum(1) |
|
|
|
|
|
class CrossMerge_Ab_1direction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, ys: torch.Tensor): |
|
B, K, C, H, W = ys.shape |
|
ctx.shape = (B, C, H, W) |
|
return ys.view(B, 4, -1, H * W).sum(1) |
|
|
|
@staticmethod |
|
def backward(ctx, x: torch.Tensor): |
|
B, C, H, W = ctx.shape |
|
return x.view(B, 1, C, H, W).repeat(1, 4, 1, 1, 1) |
|
|
|
|
|
|
|
try: |
|
import selective_scan_cuda_oflex |
|
except Exception as e: |
|
... |
|
print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) |
|
print(e, flush=True) |
|
|
|
try: |
|
import selective_scan_cuda_core |
|
except Exception as e: |
|
... |
|
print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) |
|
print(e, flush=True) |
|
|
|
try: |
|
import selective_scan_cuda |
|
except Exception as e: |
|
... |
|
print(f"WARNING: can not import selective_scan_cuda.", flush=True) |
|
print(e, flush=True) |
|
|
|
|
|
def check_nan_inf(tag: str, x: torch.Tensor, enable=True): |
|
if enable: |
|
if torch.isinf(x).any() or torch.isnan(x).any(): |
|
print(tag, torch.isinf(x).any(), torch.isnan(x).any(), flush=True) |
|
import pdb; pdb.set_trace() |
|
|
|
|
|
|
|
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): |
|
""" |
|
u: r(B D L) |
|
delta: r(B D L) |
|
A: r(D N) |
|
B: r(B N L) |
|
C: r(B N L) |
|
D: r(D) |
|
z: r(B D L) |
|
delta_bias: r(D), fp32 |
|
|
|
ignores: |
|
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] |
|
""" |
|
assert not with_complex |
|
|
|
flops = 9 * B * L * D * N |
|
if with_D: |
|
flops += B * D * L |
|
if with_Z: |
|
flops += B * D * L |
|
return flops |
|
|
|
|
|
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): |
|
""" |
|
u: r(B D L) |
|
delta: r(B D L) |
|
A: r(D N) |
|
B: r(B N L) |
|
C: r(B N L) |
|
D: r(D) |
|
z: r(B D L) |
|
delta_bias: r(D), fp32 |
|
|
|
ignores: |
|
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] |
|
""" |
|
import numpy as np |
|
|
|
|
|
def get_flops_einsum(input_shapes, equation): |
|
np_arrs = [np.zeros(s) for s in input_shapes] |
|
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] |
|
for line in optim.split("\n"): |
|
if "optimized flop" in line.lower(): |
|
|
|
flop = float(np.floor(float(line.split(":")[-1]) / 2)) |
|
return flop |
|
|
|
|
|
assert not with_complex |
|
|
|
flops = 0 |
|
|
|
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") |
|
if with_Group: |
|
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") |
|
else: |
|
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") |
|
|
|
in_for_flops = B * D * N |
|
if with_Group: |
|
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") |
|
else: |
|
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") |
|
flops += L * in_for_flops |
|
if with_D: |
|
flops += B * D * L |
|
if with_Z: |
|
flops += B * D * L |
|
return flops |
|
|
|
|
|
def print_jit_input_names(inputs): |
|
print("input params: ", end=" ", flush=True) |
|
try: |
|
for i in range(10): |
|
print(inputs[i].debugName(), end=" ", flush=True) |
|
except Exception as e: |
|
pass |
|
print("", flush=True) |
|
|
|
|
|
|
|
class SelectiveScanMamba(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd |
|
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): |
|
ctx.delta_softplus = delta_softplus |
|
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) |
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
|
return out |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(ctx, dout, *args): |
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
|
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( |
|
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, |
|
False |
|
) |
|
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) |
|
|
|
|
|
class SelectiveScanCore(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd |
|
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False): |
|
|
|
ctx.delta_softplus = delta_softplus |
|
|
|
|
|
|
|
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) |
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
|
return out |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(ctx, dout, *args): |
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( |
|
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 |
|
) |
|
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) |
|
|
|
|
|
class SelectiveScanOflex(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd |
|
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): |
|
ctx.delta_softplus = delta_softplus |
|
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) |
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
|
return out |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(ctx, dout, *args): |
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( |
|
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 |
|
) |
|
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) |
|
|
|
|
|
def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn): |
|
print_jit_input_names(inputs) |
|
B, D, L = inputs[0].type().sizes() |
|
N = inputs[2].type().sizes()[1] |
|
flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) |
|
return flops |
|
|
|
|
|
|
|
|
|
|