diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index cce51bdbf0ec2d65660f1b3cfb1580b616156d14..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7fbec6fa49d1b926d45b39b7e8393e06ee9622d0012501adaec213cb5802c86d -size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d535659f08d5f3e23652bca3c418b66a447a60fa --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b35f3f60e0cbf0ce9e84e1224754d353f9de646cf30df5828168222889d312f +size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index 240f5140b5ff0f7c813fcb49a45c561b4980891e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:16141033c118b488348a29f3436f778764f8f4275fe510dc36badb7c152e0d42 -size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ba97f73d3ce1bc8069367a1459239e814ba244af --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05d38f81524501b75940bfad8686f4f502b5c6af1de85fb1fe5b20da765d4c3c +size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index 9c0b2b801efed567532179b12a8f3d04450e858e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3ea768d3d4780563159dd50075ed14d51166e5c3de9f5bd132047cfa6a23ef48 -size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d7641c05ee43548e4c3ddd503eab29f2858276f7 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e9e392427d3157216b82014570075137082c5ec5c5bd6b63c1458d509ed4ff3 +size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index 5380a6cbd20a15cdbae5b3d385ace2dc2e9da000..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:baacdb2bd8bcd004a86f63b0dc2754bac21214c9432bf6c00c464ccc26c25a83 -size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0ba04bb16f1d0f9bf0c242d2d3e0cd926b15c1da --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2451173cb1d000c6d270b59b2aaab1aa0e54025422ba81b1ee990621c90a823 +size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index b7fe786b99bfb5b162e6633b91ccdd41c041c897..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:20a8e0a793ac29bc168d10e1c9e465082c2adb1582ff79d1a083798f9a955a5f -size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fb0e92fb6d95d232987dbe717d79e998334259c5 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8bfaaeb2a5e226a80403463d15f2c762ac8cb70ca7a44d2156aadfac63ab0d1 +size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index 8b4fdfa424e62d65fd64e4a272b1f4e33b9f6d70..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b2f5209e69d36d632939923c20ab90c074fe0100d8a4efbabe5cdcd32ccbcfd2 -size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..dd8f3345a858f0afb245680a4b6a2200cfc4c9a9 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:637a8c7ef51b1d35911546ef7456854f1ee7cc3278565d2e144e16f733487148 +size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index 4668fc389f1ca9668919303afe1890f8e92abff7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a1326299d5c4185310ff6f43fc0b3b71d48b1bd31001c954b3388d3cb5e08fbc -size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e41c843ab4f5c7f9fd4300ad37fee3c5466c7c8c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:002a58b415ed9e0f6418b103368c4f57f17fa86a851a02f594a33b097b33da09 +size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index e678ded725389f50ee596f97c1b3c7ddd4929902..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ed275d43953fbc984a20503c0d55f56b337576bd43e2c94682b4de91a8df6c8d -size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7aa46281700e821f9ad125f1f72109d8e6aee46d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a243e51490184fb48e02dbc1115545ea69313a3d63058f8423c0c493e90bc5a +size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py index a31770ba179fed06abf2da10102ccaeed1d3ee4e..0e1d956704840aa4daf7d1d71d24e051567feab9 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Union import torch -from stk import Matrix +from ..stk import Matrix def act_fn( diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py index 561fd7982254513e9a9b4f0514edd326b0561256..6d0375a4df2f27134c4127e60be04f3b45693050 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py @@ -2,15 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk.ops import torch -from stk import Matrix + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) # import megablocks.ops as ops # # from megablocks.ops import ops # from megablocks.layers import common, dmlp_registry, moe, mpu # from megablocks.layers.arguments import Arguments +from .. import stk from .. import ops from . import common, dmlp_registry, moe, mpu from .arguments import Arguments diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py index 40b601d4a8eb59e80e1090f31f4172b8f7fb7549..c4c9e6532798615b5c12c96694241a4c18ee8f7b 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py @@ -1,7 +1,16 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + import torch import torch.nn.functional as F diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py index 69ef9ead5461a5b91ace107652fd8d1c3d300b1e..5f297a41ff6a1a2a285f5b461951672364b898da 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py @@ -1,7 +1,17 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk.ops +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + import torch # from megablocks import grouped_gemm_util as gg diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py index 0b6747b1eec2047eb9480592ca550b6c9e868f5e..c99afb9904c24a8b6a83e79059cd1251dbbfd99e 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py @@ -3,9 +3,18 @@ from typing import Any -import stk -import stk.backend.triton_kernels -import stk.ops +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + import torch from packaging import version diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so deleted file mode 100755 index e707d1a69cfad2211ae17f034737b152b601a899..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cffb3e3e44310bba45bcf82b07fc3a2b188bbeb8f50f04784e440ba3bdf5fc0f -size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_63599de.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_63599de.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..28ab6bbd5810fd0aaab0345efad8da900bb1d82c --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_63599de.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dadccc59929c2fdbdf3b153f564d223013924c7b617d1eb2b3ecdc04470a4a60 +size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py index f7aa5a700bc5673c743a3b0fb74107fab793fbad..03e1e323ce98ecea8dda071c089f0d7ba75551b6 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_0586ba6 -ops = torch.ops._megablocks_0586ba6 +from . import _megablocks_63599de +ops = torch.ops._megablocks_63599de def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file + return f"_megablocks_63599de::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py index 9b2c7f062071758def741bf3af2392d3a4855f89..7ccc5dcec5e9a663794fad944c45285869c4d1c1 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -3,7 +3,19 @@ import unittest -import stk + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + import torch from absl.testing import parameterized diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100644 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main()