import torch # pytorch cross scan ============= class CrossScan(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) xs = x.new_empty((B, 4, C, H * W)) xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) return y.view(B, -1, H, W) class CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape xs = x.new_empty((B, 4, C, L)) xs[:, 0] = x xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) xs = xs.view(B, 4, C, H, W) return xs # these are for ablations ============= class CrossScan_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) x = torch.cat([x, x.flip(dims=[-1])], dim=1) return x @staticmethod def backward(ctx, ys: torch.Tensor): B, C, H, W = ctx.shape L = H * W ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) return ys.sum(1).view(B, -1, H, W) class CrossMerge_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) return ys.contiguous().sum(1) @staticmethod def backward(ctx, x: torch.Tensor): H, W = ctx.shape B, C, L = x.shape x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) x = torch.cat([x, x.flip(dims=[-1])], dim=1) return x.view(B, 4, C, H, W) class CrossScan_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) return x @staticmethod def backward(ctx, ys: torch.Tensor): B, C, H, W = ctx.shape return ys.view(B, 4, -1, H, W).sum(1) class CrossMerge_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, C, H, W = ys.shape ctx.shape = (B, C, H, W) return ys.view(B, 4, -1, H * W).sum(1) @staticmethod def backward(ctx, x: torch.Tensor): B, C, H, W = ctx.shape return x.view(B, 1, C, H, W).repeat(1, 4, 1, 1, 1) # import selective scan ============================== try: import selective_scan_cuda_oflex except Exception as e: ... print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) print(e, flush=True) try: import selective_scan_cuda_core except Exception as e: ... print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) print(e, flush=True) try: import selective_scan_cuda except Exception as e: ... print(f"WARNING: can not import selective_scan_cuda.", flush=True) print(e, flush=True) def check_nan_inf(tag: str, x: torch.Tensor, enable=True): if enable: if torch.isinf(x).any() or torch.isnan(x).any(): print(tag, torch.isinf(x).any(), torch.isnan(x).any(), flush=True) import pdb; pdb.set_trace() # fvcore flops ======================================= def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ assert not with_complex # https://github.com/state-spaces/mamba/issues/110 flops = 9 * B * L * D * N if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops # this is only for selective_scan_ref... def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops def print_jit_input_names(inputs): print("input params: ", end=" ", flush=True) try: for i in range(10): print(inputs[i].debugName(), end=" ", flush=True) except Exception as e: pass print("", flush=True) # cross selective scan =============================== # comment all checks if inside cross_selective_scan class SelectiveScanMamba(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanCore(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False): ctx.delta_softplus = delta_softplus # out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) # print(u.type, delta.type, A.type, B.type, C.type, D.type) # print(delta_bias, delta_softplus) out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanOflex(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn): print_jit_input_names(inputs) B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) return flops