kernel
drbh commited on
Commit
9354548
·
1 Parent(s): 63599de

feat: bump build for fully vendored version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +1 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +9 -2
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +10 -1
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +11 -1
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +12 -3
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_0586ba6.abi3.so → _megablocks_63599de.abi3.so} +1 -1
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +13 -1
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py +7 -0
  10. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py +0 -0
  11. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py +37 -0
  12. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py +316 -0
  13. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py +393 -0
  14. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py +329 -0
  15. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py +3 -0
  16. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +28 -0
  17. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +86 -0
  18. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py +59 -0
  19. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +216 -0
  20. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py +98 -0
  21. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +62 -0
  22. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py +2 -0
  23. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py +36 -0
  24. build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py +73 -0
  25. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py +1 -1
  26. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py +9 -2
  27. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py +10 -1
  28. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py +11 -1
  29. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py +12 -3
  30. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_0586ba6.abi3.so → _megablocks_63599de.abi3.so} +1 -1
  31. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  32. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +13 -1
  33. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py +7 -0
  34. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py +0 -0
  35. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py +37 -0
  36. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py +316 -0
  37. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py +393 -0
  38. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py +329 -0
  39. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py +3 -0
  40. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +28 -0
  41. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +86 -0
  42. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py +59 -0
  43. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +216 -0
  44. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py +98 -0
  45. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +62 -0
  46. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py +2 -0
  47. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py +36 -0
  48. build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py +73 -0
  49. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +1 -1
  50. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py +9 -2
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
- from stk import Matrix
8
 
9
 
10
  def act_fn(
 
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
+ from ..stk import Matrix
8
 
9
 
10
  def act_fn(
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py CHANGED
@@ -2,15 +2,22 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
5
- import stk.ops
6
  import torch
7
- from stk import Matrix
 
 
 
 
 
 
 
8
 
9
  # import megablocks.ops as ops
10
  # # from megablocks.ops import ops
11
  # from megablocks.layers import common, dmlp_registry, moe, mpu
12
  # from megablocks.layers.arguments import Arguments
13
 
 
14
  from .. import ops
15
  from . import common, dmlp_registry, moe, mpu
16
  from .arguments import Arguments
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
 
5
  import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
 
15
  # import megablocks.ops as ops
16
  # # from megablocks.ops import ops
17
  # from megablocks.layers import common, dmlp_registry, moe, mpu
18
  # from megablocks.layers.arguments import Arguments
19
 
20
+ from .. import stk
21
  from .. import ops
22
  from . import common, dmlp_registry, moe, mpu
23
  from .arguments import Arguments
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py CHANGED
@@ -1,7 +1,16 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import torch.nn.functional as F
7
 
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
  import torch
15
  import torch.nn.functional as F
16
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py CHANGED
@@ -1,7 +1,17 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk.ops
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
 
7
  # from megablocks import grouped_gemm_util as gg
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
  import torch
16
 
17
  # from megablocks import grouped_gemm_util as gg
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py CHANGED
@@ -3,9 +3,18 @@
3
 
4
  from typing import Any
5
 
6
- import stk
7
- import stk.backend.triton_kernels
8
- import stk.ops
 
 
 
 
 
 
 
 
 
9
  import torch
10
  from packaging import version
11
 
 
3
 
4
  from typing import Any
5
 
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
  import torch
19
  from packaging import version
20
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_0586ba6.abi3.so → _megablocks_63599de.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7fbec6fa49d1b926d45b39b7e8393e06ee9622d0012501adaec213cb5802c86d
3
  size 10517576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b35f3f60e0cbf0ce9e84e1224754d353f9de646cf30df5828168222889d312f
3
  size 10517576
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_0586ba6
3
- ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_0586ba6::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_63599de
3
+ ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_63599de::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py CHANGED
@@ -3,7 +3,19 @@
3
 
4
  import unittest
5
 
6
- import stk
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from absl.testing import parameterized
9
 
 
3
 
4
  import unittest
5
 
6
+
7
+ # import stk
8
+
9
+ # try:
10
+ # import stk
11
+ # except ImportError:
12
+ # import warnings
13
+ # warnings.warn(
14
+ # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
+ # )
16
+
17
+ from .. import stk
18
+
19
  import torch
20
  from absl.testing import parameterized
21
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # import stk.random
2
+ # import stk.ops
3
+ # from stk.matrix import Matrix
4
+
5
+ from . import random
6
+ from . import ops
7
+ from .matrix import Matrix
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py ADDED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+
4
+
5
+ def _is_eligible(x):
6
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
7
+
8
+
9
+ def _cast(x, dtype):
10
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
11
+ return x.to(dtype)
12
+ elif isinstance(x, map):
13
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
14
+ elif isinstance(x, list) or isinstance(x, tuple):
15
+ return type(x)(map(lambda y: _cast(y, dtype), x))
16
+ return x
17
+
18
+
19
+ def custom_fwd(fwd):
20
+ """Wrap a custom autograd function that always uses autocast dtype."""
21
+
22
+ @functools.wraps(fwd)
23
+ def decorate_fwd(*args, **kwargs):
24
+ if torch.is_autocast_enabled():
25
+ with torch.autocast(device_type="cuda", enabled=False):
26
+ dtype = torch.get_autocast_gpu_dtype()
27
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
28
+ return fwd(*args, **kwargs)
29
+ return decorate_fwd
30
+
31
+
32
+ def custom_bwd(bwd):
33
+ @functools.wraps(bwd)
34
+ def decorate_bwd(*args, **kwargs):
35
+ with torch.autocast(device_type="cuda", enabled=False):
36
+ return bwd(*args, **kwargs)
37
+ return decorate_bwd
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import triton_kernels as backend
4
+ from ..backend.autocast import custom_bwd, custom_fwd
5
+
6
+
7
+ def _standardize_shape(x, transpose):
8
+ if transpose:
9
+ return torch.Size((x[1], x[0]))
10
+ return x
11
+
12
+
13
+ def _sparse_transpose(x):
14
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
15
+
16
+
17
+ def _transpose_helper(x, transpose):
18
+ if isinstance(x, torch.Tensor):
19
+ return x.t() if transpose else x
20
+ if transpose:
21
+ x = _sparse_transpose(x)
22
+ return x + (transpose,)
23
+
24
+
25
+ def _wrap(x):
26
+ if isinstance(x, torch.Tensor):
27
+ return (x,)
28
+ return x
29
+
30
+
31
+ def _is_transposed(x):
32
+ return (not x.is_contiguous() and
33
+ x.stride()[0] == 1 and
34
+ x.stride()[1] == x.size()[0])
35
+
36
+
37
+ def _call_helper(op, out, a, b, trans_a, trans_b):
38
+ args = (_wrap(_transpose_helper(a, trans_a)) +
39
+ _wrap(_transpose_helper(b, trans_b)))
40
+ if isinstance(out, tuple):
41
+ args = args + out
42
+ return op(*args)
43
+
44
+
45
+ def _preprocess_inputs(lhs, rhs, dy):
46
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
47
+ lhs = lhs.t()
48
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
49
+ rhs = rhs.t()
50
+ if (isinstance(dy, torch.Tensor) and
51
+ not dy.is_contiguous() and
52
+ not _is_transposed(dy)):
53
+ dy = dy.contiguous()
54
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
55
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
56
+ return lhs, rhs, dy
57
+
58
+
59
+ def _postprocess_outputs(x, transpose, grad):
60
+ if isinstance(x, torch.Tensor) and transpose:
61
+ return grad.t()
62
+ return grad
63
+
64
+
65
+ def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
66
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
67
+
68
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
69
+ trans_a = trans_lhs and trans_rhs
70
+ trans_b = trans_lhs or not trans_rhs
71
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
72
+ return _postprocess_outputs(lhs, trans_lhs, out)
73
+
74
+
75
+ def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
76
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
77
+
78
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
79
+ trans_a = not trans_lhs or trans_rhs
80
+ trans_b = trans_lhs and trans_rhs
81
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
82
+ return _postprocess_outputs(rhs, trans_rhs, out)
83
+
84
+
85
+ class DSD(torch.autograd.Function):
86
+
87
+ @staticmethod
88
+ @custom_fwd
89
+ def forward(ctx,
90
+ shape,
91
+ data,
92
+ offsets,
93
+ row_indices,
94
+ column_indices,
95
+ offsets_t,
96
+ column_indices_t,
97
+ block_offsets_t,
98
+ transpose_a,
99
+ rhs):
100
+ ctx.save_for_backward(data,
101
+ offsets,
102
+ row_indices,
103
+ column_indices,
104
+ offsets_t,
105
+ column_indices_t,
106
+ block_offsets_t,
107
+ rhs)
108
+ ctx.shape = _standardize_shape(shape, transpose_a)
109
+ ctx.transpose_a = transpose_a
110
+
111
+ out = torch.empty(
112
+ (shape[0], rhs.size()[1]),
113
+ dtype=rhs.dtype,
114
+ device=rhs.device)
115
+
116
+ backend.dsd(shape,
117
+ data,
118
+ offsets,
119
+ row_indices,
120
+ column_indices,
121
+ offsets_t,
122
+ column_indices_t,
123
+ block_offsets_t,
124
+ transpose_a,
125
+ rhs,
126
+ out)
127
+ return out
128
+
129
+ @staticmethod
130
+ @custom_bwd
131
+ def backward(ctx, dy):
132
+ saved_tensors = ctx.saved_tensors
133
+ lhs = (ctx.shape,) + saved_tensors[:-1]
134
+ rhs = saved_tensors[-1]
135
+ trans_a = ctx.transpose_a
136
+ trans_b = _is_transposed(rhs)
137
+
138
+ ddata = None
139
+ if ctx.needs_input_grad[1]:
140
+ ddata = _lhs_gradient(sdd,
141
+ lhs,
142
+ rhs,
143
+ dy,
144
+ trans_a,
145
+ trans_b)
146
+ drhs = None
147
+ if ctx.needs_input_grad[-1]:
148
+ op = dds if trans_b else dsd
149
+ drhs = _rhs_gradient(op,
150
+ lhs,
151
+ rhs,
152
+ dy,
153
+ trans_a,
154
+ trans_b)
155
+ return None, ddata, None, None, None, None, None, None, None, drhs
156
+
157
+
158
+ dsd = DSD.apply
159
+
160
+
161
+ class DDS(torch.autograd.Function):
162
+
163
+ @staticmethod
164
+ @custom_fwd
165
+ def forward(ctx,
166
+ lhs,
167
+ shape,
168
+ data,
169
+ offsets,
170
+ row_indices,
171
+ column_indices,
172
+ offsets_t,
173
+ column_indices_t,
174
+ block_offsets_t,
175
+ transpose_b):
176
+ ctx.save_for_backward(lhs,
177
+ data,
178
+ offsets,
179
+ row_indices,
180
+ column_indices,
181
+ offsets_t,
182
+ column_indices_t,
183
+ block_offsets_t)
184
+ ctx.shape = _standardize_shape(shape, transpose_b)
185
+ ctx.transpose_b = transpose_b
186
+ out = torch.empty((lhs.size()[0], shape[1]),
187
+ dtype=lhs.dtype,
188
+ device=lhs.device)
189
+ backend.dds(lhs,
190
+ shape,
191
+ data,
192
+ offsets,
193
+ row_indices,
194
+ column_indices,
195
+ offsets_t,
196
+ column_indices_t,
197
+ block_offsets_t,
198
+ transpose_b,
199
+ out)
200
+ return out
201
+
202
+ @staticmethod
203
+ @custom_bwd
204
+ def backward(ctx, dy):
205
+ saved_tensors = ctx.saved_tensors
206
+ lhs = saved_tensors[0]
207
+ rhs = (ctx.shape,) + saved_tensors[1:]
208
+ trans_a = _is_transposed(lhs)
209
+ trans_b = ctx.transpose_b
210
+
211
+ dlhs = None
212
+ if ctx.needs_input_grad[0]:
213
+ op = dsd if trans_a else dds
214
+ dlhs = _lhs_gradient(op,
215
+ lhs,
216
+ rhs,
217
+ dy,
218
+ trans_a,
219
+ trans_b)
220
+ ddata = None
221
+ if ctx.needs_input_grad[2]:
222
+ ddata = _rhs_gradient(sdd,
223
+ lhs,
224
+ rhs,
225
+ dy,
226
+ trans_a,
227
+ trans_b)
228
+ return dlhs, None, ddata, None, None, None, None, None, None, None
229
+
230
+
231
+ dds = DDS.apply
232
+
233
+
234
+ class SDD(torch.autograd.Function):
235
+
236
+ @staticmethod
237
+ @custom_fwd
238
+ def forward(ctx,
239
+ lhs,
240
+ rhs,
241
+ shape,
242
+ data,
243
+ offsets,
244
+ row_indices,
245
+ column_indices,
246
+ offsets_t,
247
+ column_indices_t,
248
+ block_offsets_t):
249
+ ctx.save_for_backward(
250
+ lhs,
251
+ rhs,
252
+ offsets,
253
+ row_indices,
254
+ column_indices,
255
+ offsets_t,
256
+ column_indices_t,
257
+ block_offsets_t)
258
+ ctx.shape = shape
259
+ out = torch.empty(
260
+ data.shape,
261
+ dtype=lhs.dtype,
262
+ device=lhs.device)
263
+ backend.sdd(lhs,
264
+ rhs,
265
+ shape,
266
+ out,
267
+ offsets,
268
+ row_indices,
269
+ column_indices)
270
+ return out
271
+
272
+ @staticmethod
273
+ @custom_bwd
274
+ def backward(ctx, dy):
275
+ saved_tensors = ctx.saved_tensors
276
+ lhs, rhs = saved_tensors[:2]
277
+ dy = (ctx.shape, dy) + saved_tensors[2:]
278
+ trans_a = _is_transposed(lhs)
279
+ trans_b = _is_transposed(rhs)
280
+
281
+ dlhs = None
282
+ if ctx.needs_input_grad[0]:
283
+ op = dds if trans_a else dsd
284
+ dlhs = _lhs_gradient(op,
285
+ lhs,
286
+ rhs,
287
+ dy,
288
+ trans_a,
289
+ trans_b)
290
+ drhs = None
291
+ if ctx.needs_input_grad[1]:
292
+ op = dsd if trans_b else dds
293
+ drhs = _rhs_gradient(op,
294
+ lhs,
295
+ rhs,
296
+ dy,
297
+ trans_a,
298
+ trans_b)
299
+ return dlhs, drhs, None, None, None, None, None, None, None, None
300
+
301
+
302
+ sdd = SDD.apply
303
+
304
+ class RowIndices(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ def forward(ctx, shape, data, offsets, column_indices):
308
+ out = torch.empty(
309
+ column_indices.shape,
310
+ dtype=column_indices.dtype,
311
+ device=column_indices.device)
312
+ backend.row_indices(shape, data, offsets, column_indices, out)
313
+ return out
314
+
315
+
316
+ row_indices = RowIndices.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class TritonConfig:
8
+ BLOCK_M: int = 128
9
+ BLOCK_N: int = 128
10
+ BLOCK_K: int = 32
11
+ BLOCK_SIZE: int = 128
12
+ NUM_STAGES: int = 4
13
+ NUM_WARPS: int = 4
14
+
15
+ def _validate_matmul_dims(M: int, K: int, N: int):
16
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
17
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
18
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
19
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
20
+
21
+ @triton.autotune(
22
+ configs=[
23
+ # basic configs for compute-bound matmuls
24
+ triton.Config({
25
+ 'BLOCK_M': TritonConfig.BLOCK_M,
26
+ 'BLOCK_N': TritonConfig.BLOCK_N,
27
+ 'BLOCK_K': TritonConfig.BLOCK_K,
28
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
29
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
30
+ ],
31
+ key=['M', 'N', 'K'],
32
+ )
33
+ @triton.jit
34
+ def _sdd_kernel(A, B, C, M, N, K,
35
+ stride_am, stride_ak,
36
+ stride_bk, stride_bn,
37
+ stride_cm, stride_cn,
38
+ row_indices, column_indices,
39
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
40
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
41
+ ):
42
+ # matrix multiplication
43
+ pid = tl.program_id(0)
44
+ pid_m = tl.load(row_indices + pid)
45
+ pid_n = tl.load(column_indices + pid)
46
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
48
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
49
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
50
+ rk = tl.arange(0, BLOCK_K)
51
+ # pointers
52
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
53
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
54
+ # do matrix multiplication
55
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
56
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
57
+ a = tl.load(A)
58
+ b = tl.load(B)
59
+ acc += tl.dot(a, b)
60
+ A += BLOCK_K * stride_ak
61
+ B += BLOCK_K * stride_bk
62
+ #Store to sparse matrix
63
+ acc = acc.to(C.dtype.element_ty)
64
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
65
+ cm = tl.arange(0, BLOCK_M)
66
+ cn = tl.arange(0, BLOCK_N)
67
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
68
+ tl.store(C, acc, mask=True)
69
+
70
+ @triton.autotune(
71
+ configs=[
72
+ # basic configs for compute-bound matmuls
73
+ triton.Config({
74
+ 'BLOCK_M': TritonConfig.BLOCK_M,
75
+ 'BLOCK_N': TritonConfig.BLOCK_N,
76
+ 'BLOCK_K': TritonConfig.BLOCK_K,
77
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
78
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
79
+ ],
80
+ key=['M', 'N', 'K'],
81
+ )
82
+ @triton.jit
83
+ def _dsd_kernel(A, B, C, M, N, K,
84
+ stride_am, stride_ak,
85
+ stride_bk, stride_bn,
86
+ stride_cm, stride_cn,
87
+ row_indices, column_indices, offsets,
88
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
89
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
90
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
91
+ ):
92
+
93
+ # matrix multiplication
94
+ pid_m = tl.program_id(0)
95
+ pid_n = tl.program_id(1)
96
+
97
+ num_pid_m = tl.num_programs(0)
98
+ num_pid_n = tl.num_programs(1)
99
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
100
+
101
+ start_inx = tl.load(offsets + pid_m)
102
+ end_inx = tl.load(offsets + pid_m + 1)
103
+
104
+ # pointers to sparse matrix
105
+ rm = tl.arange(0, BLOCK_M)
106
+ rak = tl.arange(0, BLOCK_K)
107
+
108
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
109
+
110
+ # pointers to dense matrix
111
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
112
+ rbk = tl.arange(0, BLOCK_K)
113
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
114
+
115
+ # do matrix multiplication
116
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
117
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
118
+
119
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
120
+ ak_sub_incr = BLOCK_K * stride_ak
121
+ bk_sub_incr = BLOCK_K * stride_bk
122
+ bk_block_incr = BLOCK_SIZE * stride_bk
123
+
124
+ for k in range(nsub_blocks * (end_inx - start_inx)):
125
+ sub_block_inx = k % nsub_blocks
126
+ block_inx = k // nsub_blocks
127
+
128
+ if trans_A:
129
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
130
+ else:
131
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
132
+
133
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
134
+
135
+ a = tl.load(ptr_A)
136
+ b = tl.load(ptr_B)
137
+ acc += tl.dot(a, b)
138
+
139
+ acc = acc.to(C.dtype.element_ty)
140
+
141
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
142
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
143
+
144
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
145
+ tl.store(C, acc, mask=True)
146
+
147
+ @triton.autotune(
148
+ configs=[
149
+ # basic configs for compute-bound matmuls
150
+ triton.Config({
151
+ 'BLOCK_M': TritonConfig.BLOCK_M,
152
+ 'BLOCK_N': TritonConfig.BLOCK_N,
153
+ 'BLOCK_K': TritonConfig.BLOCK_K,
154
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
155
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
156
+ ],
157
+ key=['M', 'N', 'K'],
158
+ )
159
+ @triton.jit
160
+ def _dds_kernel(A, B, C, M, N, K,
161
+ stride_am, stride_ak,
162
+ stride_bk, stride_bn,
163
+ stride_cm, stride_cn,
164
+ row_indices, column_indices, offsets,
165
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
166
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
167
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
168
+ ):
169
+
170
+ # matrix multiplication
171
+ pid_m = tl.program_id(0)
172
+ pid_n = tl.program_id(1)
173
+
174
+ num_pid_m = tl.num_programs(0)
175
+ num_pid_n = tl.num_programs(1)
176
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
177
+
178
+ start_inx = tl.load(offsets + pid_n)
179
+ end_inx = tl.load(offsets + pid_n + 1)
180
+
181
+ # pointers to dense matrix
182
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
183
+ rak = tl.arange(0, BLOCK_K)
184
+
185
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
186
+
187
+ # pointers to sparse matrix
188
+ rn = tl.arange(0, BLOCK_N)
189
+ rbk = tl.arange(0, BLOCK_K)
190
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
191
+
192
+ # do matrix multiplication
193
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
194
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
195
+
196
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
197
+
198
+ ak_sub_incr = BLOCK_K * stride_ak
199
+ ak_block_incr = BLOCK_SIZE * stride_ak
200
+ bk_sub_incr = BLOCK_K * stride_bk
201
+
202
+ for k in range(nsub_blocks * (end_inx - start_inx)):
203
+ sub_block_inx = k % nsub_blocks
204
+ block_inx = k // nsub_blocks
205
+
206
+ if trans_B:
207
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
208
+ else:
209
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
210
+
211
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
212
+ a = tl.load(ptr_A)
213
+ b = tl.load(ptr_B)
214
+ acc += tl.dot(a, b)
215
+
216
+ acc = acc.to(C.dtype.element_ty)
217
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
218
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
219
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
220
+ tl.store(C, acc, mask=True)
221
+
222
+ def dsd(shape,
223
+ data,
224
+ offsets,
225
+ row_indices,
226
+ column_indices,
227
+ offsets_t,
228
+ column_indices_t,
229
+ block_offsets_t,
230
+ transpose_a,
231
+ rhs,
232
+ out
233
+ ):
234
+
235
+ device = rhs.device
236
+ trans_A = transpose_a
237
+ trans_B = False
238
+
239
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
240
+ trans_B = True
241
+
242
+ # checks constraints
243
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
244
+ M, K = shape
245
+ _, N = rhs.shape
246
+
247
+ _validate_matmul_dims(M, K, N)
248
+
249
+ # accumulator types
250
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
251
+
252
+ stride_am, stride_ak = data.stride(1), data.stride(2)
253
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
254
+ a_column_indices = column_indices
255
+ a_offsets = offsets
256
+
257
+ # launch kernel
258
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
259
+
260
+ if trans_A:
261
+ stride_am, stride_ak = data.stride(2), data.stride(1)
262
+ a_column_indices, a_offsets = column_indices_t, offsets_t
263
+
264
+ if trans_B:
265
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
266
+
267
+ _dsd_kernel[grid](
268
+ data.data, rhs, out, M, N, K,
269
+ stride_am, stride_ak,
270
+ stride_bk, stride_bn,
271
+ out.stride(0), out.stride(1),
272
+ row_indices, a_column_indices, a_offsets,
273
+ block_offsets_t, trans_A, trans_B,
274
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
275
+ )
276
+ # return out
277
+
278
+ def dds(lhs,
279
+ shape,
280
+ data,
281
+ offsets,
282
+ row_indices,
283
+ column_indices,
284
+ offsets_t,
285
+ column_indices_t,
286
+ block_offsets_t,
287
+ transpose_b,
288
+ out
289
+ ):
290
+
291
+ device = lhs.device
292
+ trans_B = transpose_b
293
+ trans_A = False
294
+
295
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
296
+ trans_A = True
297
+
298
+ # checks constraints
299
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
300
+ M, K = lhs.shape
301
+ _, N = shape
302
+
303
+ _validate_matmul_dims(M, K, N)
304
+
305
+ # accumulator types
306
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
307
+
308
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
309
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
310
+ b_column_indices = column_indices_t
311
+ b_offsets = offsets_t
312
+
313
+ # launch kernel
314
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
315
+
316
+ if trans_A:
317
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
318
+ if trans_B:
319
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
320
+ b_column_indices, b_offsets = column_indices, offsets
321
+
322
+ _dds_kernel[grid](
323
+ lhs, data, out, M, N, K,
324
+ stride_am, stride_ak,
325
+ stride_bk, stride_bn,
326
+ out.stride(0), out.stride(1),
327
+ row_indices, b_column_indices, b_offsets,
328
+ block_offsets_t, trans_A, trans_B,
329
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
330
+ )
331
+
332
+ def sdd(lhs,
333
+ rhs,
334
+ shape,
335
+ out,
336
+ offsets,
337
+ row_indices,
338
+ column_indices
339
+ ):
340
+
341
+ device = out.device
342
+ trans_A = False
343
+ trans_B = False
344
+
345
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
346
+ trans_A = True
347
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
348
+ trans_B = True
349
+
350
+ # checks constraints
351
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
352
+ M, K = lhs.shape
353
+ _, N = rhs.shape
354
+
355
+ _validate_matmul_dims(M, K, N)
356
+
357
+ # accumulator types
358
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
359
+
360
+ # launch kernel
361
+ nnz_blocks = len(row_indices)
362
+ grid = lambda META: (nnz_blocks,)
363
+
364
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
365
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
366
+
367
+ if trans_A:
368
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
369
+ if trans_B:
370
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
371
+
372
+ _sdd_kernel[grid](
373
+ lhs, rhs, out, M, N, K,
374
+ stride_am, stride_ak,
375
+ stride_bk, stride_bn,
376
+ out.stride(1), out.stride(2),
377
+ row_indices, column_indices,
378
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
379
+ )
380
+
381
+ @triton.jit
382
+ def _row_indices_kernel(offsets, out):
383
+ pid = tl.program_id(0)
384
+ row_offset = tl.load(offsets + pid)
385
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
386
+ for nnz_block in range(nnz_blocks):
387
+ tl.store(out + row_offset + nnz_block, pid)
388
+
389
+ def row_indices(
390
+ shape, data, offsets, column_indices, out
391
+ ):
392
+ block_rows = len(offsets) - 1
393
+ _row_indices_kernel[(block_rows, )](offsets, out)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ # 1. Add heavyweight (data) validation helper.
5
+ # 2. Add construction helpers
6
+ # 3. Make indentation consistent
7
+ # 4. Replace asserts with descriptive errors.
8
+
9
+ ##
10
+ ### Validation helpers.
11
+ ##
12
+
13
+
14
+ def _validate_matrix(shape, data, row_indices, column_indices, offsets):
15
+ # Data should be [nnz, block_size, block_size]
16
+ if data.dim() == 1:
17
+ data = torch.reshape(data, [data.numel(), 1, 1])
18
+
19
+ # Blocks should be square.
20
+ if data.shape[-2] != data.shape[-1]:
21
+ raise ValueError(
22
+ "Expected square blocking in data. "
23
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
24
+
25
+ # Flatten batch dimensions on data - original shape preserved
26
+ # in shape argument.
27
+ block_size = data.shape[-1]
28
+ data = data.view([-1, block_size, block_size])
29
+
30
+ if data.dim() != 3:
31
+ raise ValueError(
32
+ "Expected 3D shape for data (nnz, block, block). "
33
+ f"Got shape {data.dim()}D shape.")
34
+
35
+ block_size = data.shape[1]
36
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
37
+ raise ValueError(
38
+ "Matrix shape must be dividible by blocking. "
39
+ f"Got shape {shape} with "
40
+ f"{[block_size, block_size]} blocking.")
41
+
42
+ if np.prod(shape) < data.numel():
43
+ raise ValueError(
44
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
45
+ f"({data.numel()} v. {np.prod(shape)})")
46
+
47
+ if row_indices.dim() != 1:
48
+ raise ValueError(
49
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
50
+
51
+ if column_indices.dim() != 1:
52
+ raise ValueError(
53
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
54
+
55
+ if offsets.dim() != 1:
56
+ raise ValueError(
57
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
58
+
59
+ if row_indices.numel() != data.shape[0]:
60
+ raise ValueError(
61
+ "Expected 1 index per nonzero block. "
62
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
63
+
64
+ if column_indices.numel() != data.shape[0]:
65
+ raise ValueError(
66
+ "Expected 1 index per nonzero block. "
67
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
68
+
69
+ block_rows = np.prod(shape[:-1]) / block_size
70
+ if offsets.numel() != block_rows + 1:
71
+ raise ValueError(
72
+ "Expected one offset per block row plus one. "
73
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
74
+
75
+ is_cuda = (data.is_cuda and
76
+ row_indices.is_cuda and
77
+ column_indices.is_cuda and
78
+ offsets.is_cuda)
79
+ is_cpu = (not data.is_cuda and
80
+ not row_indices.is_cuda and
81
+ not column_indices.is_cuda and
82
+ not offsets.is_cuda)
83
+ if not (is_cuda or is_cpu):
84
+ raise ValueError(
85
+ "Expected data & meta-data on common device. "
86
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
87
+ f"column_indices on {column_indices.device} and "
88
+ f"offsets on {offsets.device}.")
89
+
90
+ if data.dtype != torch.float16:
91
+ raise ValueError(
92
+ f"Expected float16 data. Got {data.dtype} data.")
93
+ if row_indices.dtype != torch.int16:
94
+ raise ValueError(
95
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
96
+ if column_indices.dtype != torch.int16:
97
+ raise ValueError(
98
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
99
+ if offsets.dtype != torch.int32:
100
+ raise ValueError(
101
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
102
+ return data
103
+
104
+
105
+ def _transpose(size, data, row_indices, column_indices, offsets):
106
+ block_columns = size[1] // data.shape[1]
107
+
108
+ # Sort row indices by column indices to get the transposed matrix's
109
+ # column indices.
110
+ gather_indices = column_indices.argsort()
111
+ column_indices_t = row_indices.gather(0, gather_indices)
112
+ block_offsets_t = gather_indices.int()
113
+
114
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
115
+ # the histogram in 32-bit float, which can exactly represent 16-bit
116
+ # integers.
117
+ column_indices_float = column_indices.float()
118
+
119
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
120
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
121
+ nnz_per_column = nnz_per_column.int()
122
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
123
+ return column_indices_t, offsets_t, block_offsets_t
124
+
125
+
126
+ class Matrix(torch.nn.Module):
127
+ """A matrix stored in sparse format.
128
+
129
+ Underlying format is block compressed sparse row (BCSR).
130
+
131
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
132
+ """
133
+
134
+ def __init__(self,
135
+ size,
136
+ data,
137
+ row_indices,
138
+ column_indices,
139
+ offsets,
140
+ column_indices_t=None,
141
+ offsets_t=None,
142
+ block_offsets_t=None):
143
+ super().__init__()
144
+ self._size = size
145
+ self._data = data
146
+ self._row_indices = row_indices
147
+ self._column_indices = column_indices
148
+ self._offsets = offsets
149
+
150
+ # Produce the transpose meta-data if it is not passed in.
151
+ if ((column_indices_t is None) or (offsets_t is None) or
152
+ (block_offsets_t is None)):
153
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
154
+ size, data, row_indices, column_indices, offsets)
155
+ self._column_indices_t = column_indices_t
156
+ self._offsets_t = offsets_t
157
+ self._block_offsets_t = block_offsets_t
158
+
159
+ self._transposed = False
160
+
161
+ # Validate that our metadata will not overflow.
162
+ max_dim = np.iinfo(np.int16).max * self.blocking
163
+ if column_indices.dtype == torch.int16:
164
+ if size[0] > max_dim or size[1] > max_dim:
165
+ raise ValueError(
166
+ "Sparse matrix with shape {size} exceeds representable "
167
+ "size with 16-bit indices.")
168
+
169
+ def validate(self):
170
+ _validate_matrix(self._size,
171
+ self._data,
172
+ self._row_indices,
173
+ self._column_indices,
174
+ self._offsets)
175
+
176
+ # TODO(tgale): Add heavyweight data validation.
177
+
178
+ def to(self, device):
179
+ # TODO(tgale): Handle type conversions here. We
180
+ # need to set the appropriate meta-data type for
181
+ # the given floating-point type.
182
+ self._data = self._data.to(device)
183
+ self._row_indices = self._row_indices.to(device)
184
+ self._column_indices = self._column_indices.to(device)
185
+ self._offsets = self._offsets.to(device)
186
+ self._column_indices_t = self._column_indices_t.to(device)
187
+ self._offsets_t = self._offsets_t.to(device)
188
+ self._block_offsets_t = self._block_offsets_t.to(device)
189
+ return self
190
+
191
+ def cuda(self):
192
+ return self.to(torch.cuda.current_device())
193
+
194
+ def clone(self):
195
+ return Matrix(
196
+ self.size(),
197
+ self.data.clone(),
198
+ self.row_indices.clone(),
199
+ self.column_indices.clone(),
200
+ self.offsets.clone(),
201
+ self.column_indices_t.clone(),
202
+ self.offsets_t.clone(),
203
+ self.block_offsets_t.clone())
204
+
205
+ def t(self):
206
+ if self.dim() != 2:
207
+ raise ValueError(
208
+ "t() expects a tensor with <= 2 dimensions, "
209
+ f"but self is {self.dim()}D.")
210
+ out = Matrix(self.size(),
211
+ self.data,
212
+ self.row_indices,
213
+ self.column_indices,
214
+ self.offsets,
215
+ self.column_indices_t,
216
+ self.offsets_t,
217
+ self.block_offsets_t)
218
+ out._transposed = not self._transposed
219
+ out._size = torch.Size((self._size[1], self._size[0]))
220
+ return out
221
+
222
+ def contiguous(self):
223
+ raise ValueError("Not yet implemented.")
224
+
225
+ def is_contiguous(self):
226
+ return not self._transposed
227
+
228
+ @property
229
+ def is_cuda(self):
230
+ return self._data.is_cuda
231
+
232
+ @property
233
+ def device(self):
234
+ return self._data.device
235
+
236
+ def size(self):
237
+ return self._size
238
+
239
+ @property
240
+ def shape(self):
241
+ return self.size()
242
+
243
+ def dim(self):
244
+ return len(self._size)
245
+
246
+ @property
247
+ def data(self):
248
+ return self._data
249
+
250
+ @property
251
+ def row_indices(self):
252
+ return self._row_indices
253
+
254
+ @property
255
+ def column_indices(self):
256
+ return self._column_indices
257
+
258
+ @property
259
+ def offsets(self):
260
+ return self._offsets
261
+
262
+ @property
263
+ def offsets_t(self):
264
+ return self._offsets_t
265
+
266
+ @property
267
+ def column_indices_t(self):
268
+ return self._column_indices_t
269
+
270
+ @property
271
+ def block_offsets_t(self):
272
+ return self._block_offsets_t
273
+
274
+ @property
275
+ def dtype(self):
276
+ return self.data.dtype
277
+
278
+ @property
279
+ def nnz(self):
280
+ return self.data.numel()
281
+
282
+ @property
283
+ def blocking(self):
284
+ return self.data.shape[1]
285
+
286
+ @property
287
+ def requires_grad(self):
288
+ return self.data.requires_grad
289
+
290
+ def requires_grad_(self, x):
291
+ self.data.requires_grad_(x)
292
+ return self
293
+
294
+ def view(self, *shape):
295
+ assert self.is_contiguous()
296
+ if shape[-1] != self.size()[-1]:
297
+ raise ValueError(
298
+ "Can't change view on compressed dimension. "
299
+ f"{self.size()[-1]} v. {shape[-1]}.")
300
+ if np.prod(shape) != np.prod(self.size()):
301
+ raise ValueError(
302
+ "Mismatch in numel of Matrix and new shape. "
303
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
304
+ return Matrix(shape,
305
+ self.data,
306
+ self.row_indices,
307
+ self.column_indices,
308
+ self.offsets,
309
+ self.column_indices_t,
310
+ self.offsets_t,
311
+ self.block_offsets_t)
312
+
313
+ @property
314
+ def grad(self):
315
+ # TODO(tgale): Make sure this mirrors torch.Tensor
316
+ # behavior in the case where we ask for the gradient
317
+ # of a non-contiguous tensor.
318
+ size = self.size()
319
+ if not self.is_contiguous():
320
+ size = torch.Size((size[1], size[0]))
321
+ out = Matrix(size,
322
+ self.data.grad,
323
+ self.row_indices,
324
+ self.column_indices,
325
+ self.offsets,
326
+ self.column_indices_t,
327
+ self.offsets_t,
328
+ self.block_offsets_t)
329
+ return out if self.is_contiguous() else out.t()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .linear_ops import dds, dsd, sdd
2
+ from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
3
+ from .eltwise_ops import mul
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..matrix import Matrix
2
+
3
+ def mul(a, b):
4
+ """Performs element-wise multiplication of matrices a and b.
5
+
6
+ It is the user's responsibility to make sure that a and b
7
+ follow the same matrix topology. This function assumes it is safe
8
+ to use the topoplogy of a.
9
+
10
+ Args:
11
+ a: stk.Matrix.
12
+ b: stk.Matrix with a's matrix topology.
13
+
14
+ Returns:
15
+ stk.Matrix where the entries correspond to torch.mul(a, b).
16
+ """
17
+ assert isinstance(a, Matrix)
18
+ assert isinstance(b, Matrix)
19
+ assert a.size() == b.size()
20
+
21
+ return Matrix(a.size(),
22
+ a.data * b.data,
23
+ a.row_indices,
24
+ a.column_indices,
25
+ a.offsets,
26
+ a.column_indices_t,
27
+ a.offsets_t,
28
+ a.block_offsets_t)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import torch
4
+ from absl.testing import parameterized
5
+
6
+ import stk
7
+ from stk.ops.linear_ops_test import allclose, _dense_and_sparse
8
+
9
+ _MATRIX_SIZES = (
10
+ (128, 128, 0.0),
11
+ (256, 256, 0.5),
12
+ (2048, 1024, 0.8),
13
+ (512, 128, 0.0),
14
+ (128, 512, 0.0),
15
+ (1024, 512, 0.0),
16
+ (1024, 512, 0.5),
17
+ (1024, 512, 0.75),
18
+ (512, 1024, 0.0),
19
+ (512, 1024, 0.5),
20
+ (512, 1024, 0.75),
21
+ (1024, 1024, 0.0),
22
+ (1024, 1024, 0.5),
23
+ (1024, 1024, 0.75),
24
+ )
25
+
26
+ _DTYPE = (
27
+ torch.float16, torch.bfloat16
28
+ )
29
+
30
+ def _generate_testcases():
31
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
32
+ testcases = [(*size, 128, dtype) for
33
+ (size, dtype) in testcases]
34
+ return testcases
35
+
36
+ _ELTWISE_OP_TESTS = _generate_testcases()
37
+
38
+ def _dense_and_sparse_like(x, std=0.1):
39
+ dense_data = torch.randn_like(x.data, device=x.device) * std
40
+ sparse = stk.Matrix(x.size(),
41
+ dense_data,
42
+ x.row_indices,
43
+ x.column_indices,
44
+ x.offsets)
45
+ dense = stk.ops.to_dense(sparse)
46
+
47
+ return (dense.requires_grad_(True),
48
+ sparse.requires_grad_(True))
49
+
50
+ @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ class EltwiseOpsTest(parameterized.TestCase):
52
+
53
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+
55
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ b_dense, b = _dense_and_sparse_like(a)
57
+
58
+ out = stk.ops.mul(a, b)
59
+ expected_out = torch.mul(a_dense, b_dense)
60
+
61
+ # Compute the gradients w.r.t. the inputs.
62
+ expected_out.sum().backward()
63
+ stk.ops.sum(out).backward()
64
+
65
+ # Validate the results.
66
+ out = stk.ops.to_dense(out)
67
+ self.assertEqual(out.dim(), 2)
68
+ self.assertEqual(expected_out.size(), out.size())
69
+ self.assertTrue(allclose(out, expected_out))
70
+
71
+ # LHS gradient.
72
+ grad = stk.ops.to_dense(a.grad)
73
+ expected_grad = a_dense.grad
74
+ self.assertEqual(grad.dim(), 2)
75
+ self.assertEqual(expected_grad.size(), grad.size())
76
+ self.assertTrue(allclose(grad, expected_grad))
77
+
78
+ # RHS gradient.
79
+ grad = stk.ops.to_dense(b.grad)
80
+ expected_grad = b_dense.grad
81
+ self.assertEqual(grad.dim(), 2)
82
+ self.assertEqual(expected_grad.size(), grad.size())
83
+ self.assertTrue(allclose(grad, expected_grad))
84
+
85
+ if __name__ == '__main__':
86
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import sputnik
4
+ from ..matrix import Matrix
5
+
6
+
7
+ def dsd(a, b):
8
+ assert isinstance(a, Matrix)
9
+ assert isinstance(b, torch.Tensor)
10
+ return sputnik.dsd(
11
+ a.size(),
12
+ a.data, a.offsets,
13
+ a.row_indices,
14
+ a.column_indices,
15
+ a.offsets_t,
16
+ a.column_indices_t,
17
+ a.block_offsets_t,
18
+ not a.is_contiguous(),
19
+ b)
20
+
21
+
22
+ def dds(a, b):
23
+ assert isinstance(a, torch.Tensor)
24
+ assert isinstance(b, Matrix)
25
+ return sputnik.dds(
26
+ a,
27
+ b.size(),
28
+ b.data, b.offsets,
29
+ b.row_indices,
30
+ b.column_indices,
31
+ b.offsets_t,
32
+ b.column_indices_t,
33
+ b.block_offsets_t,
34
+ not b.is_contiguous())
35
+
36
+
37
+ def sdd(a, b, topo):
38
+ assert isinstance(a, torch.Tensor)
39
+ assert isinstance(b, torch.Tensor)
40
+ assert isinstance(topo, Matrix)
41
+ assert topo.is_contiguous()
42
+ out = sputnik.sdd(
43
+ a, b,
44
+ topo.size(),
45
+ topo.data,
46
+ topo.offsets,
47
+ topo.row_indices,
48
+ topo.column_indices,
49
+ topo.offsets_t,
50
+ topo.column_indices_t,
51
+ topo.block_offsets_t)
52
+ return Matrix(topo.size(),
53
+ out,
54
+ topo.row_indices,
55
+ topo.column_indices,
56
+ topo.offsets,
57
+ topo.column_indices_t,
58
+ topo.offsets_t,
59
+ topo.block_offsets_t)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import numpy as np
4
+ import torch
5
+ from absl.testing import parameterized
6
+
7
+ import stk
8
+
9
+
10
+ def allclose(x, y, pct=0.25):
11
+ mask = torch.isclose(x, y, rtol=5e-2)
12
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
13
+ if pct_diff > pct:
14
+ print("{:.2f}% of values not close.".format(pct_diff))
15
+ return False
16
+ return True
17
+
18
+
19
+ # An assortment of problems designed to make sure
20
+ # the bindings are operating correctly.
21
+ _MATRIX_SIZES = (
22
+ (128, 128, 128, 0.0),
23
+ (256, 256, 256, 0.5),
24
+ (2048, 1024, 512, 0.8),
25
+ (512, 128, 128, 0.0),
26
+ (128, 128, 512, 0.0),
27
+ (1024, 512, 512, 0.0),
28
+ (1024, 512, 512, 0.5),
29
+ (1024, 512, 512, 0.75),
30
+ (512, 512, 1024, 0.0),
31
+ (512, 512, 1024, 0.5),
32
+ (512, 512, 1024, 0.75),
33
+ (1024, 1024, 1024, 0.0),
34
+ (1024, 1024, 1024, 0.5),
35
+ (1024, 1024, 1024, 0.75),
36
+ )
37
+
38
+ _TRANSPOSE = (
39
+ (False, False),
40
+ (False, True),
41
+ (True, False),
42
+ (True, True),
43
+ )
44
+
45
+ _DTYPE = (
46
+ torch.float16, torch.bfloat16
47
+ )
48
+
49
+ def _generate_testcases():
50
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
51
+ testcases = [(*size, *trans, 128, dtype) for
52
+ (size, trans, dtype) in testcases]
53
+ return testcases
54
+
55
+ _LINEAR_OP_TESTS = _generate_testcases()
56
+
57
+ def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
58
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
59
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
60
+ sparse = stk.ops.to_sparse(dense, blocking)
61
+ cuda_device = torch.device("cuda")
62
+ return (dense.to(cuda_device).requires_grad_(True),
63
+ sparse.to(cuda_device).requires_grad_(True))
64
+
65
+
66
+ def _dense(rows, cols, dtype, std=0.1):
67
+ cuda_device = torch.device("cuda")
68
+ out = (torch.randn(rows, cols) * std).type(dtype)
69
+ return out.to(cuda_device).requires_grad_(True)
70
+
71
+
72
+ def _dense_2x(rows, cols, dtype):
73
+ a = _dense(rows, cols, dtype)
74
+ return a, a.detach().requires_grad_(True)
75
+
76
+
77
+ def _with_transpose(op, a, b, trans_a, trans_b):
78
+ a = a.t() if trans_a else a
79
+ b = b.t() if trans_b else b
80
+ return op(a, b)
81
+
82
+
83
+ def _mmm(a, b, topo):
84
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
85
+ return torch.mm(a, b) * mask
86
+
87
+
88
+ def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
89
+ a = a.t() if trans_a else a
90
+ b = b.t() if trans_b else b
91
+ return op(a, b, topo)
92
+
93
+
94
+ def _mask(x, mask):
95
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
96
+ return x * mask
97
+
98
+
99
+ @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ class LinearOpsTest(parameterized.TestCase):
101
+
102
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # Construct the operands.
104
+ a_shape = (k, m) if trans_a else (m, k)
105
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ b_shape = (n, k) if trans_b else (k, n)
107
+ b, bcp = _dense_2x(*b_shape, dtype)
108
+
109
+ # Execute the matmul.
110
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+
113
+ # Compute the gradients w.r.t. the inputs.
114
+ expected_out.sum().backward()
115
+ out.sum().backward()
116
+
117
+ # Validate the results.
118
+ self.assertEqual(out.dim(), 2)
119
+ self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ self.assertTrue(allclose(out, expected_out))
122
+
123
+ # LHS gradient.
124
+ grad = stk.ops.to_dense(a.grad)
125
+ expected_grad = _mask(a_dense.grad, a.grad)
126
+ self.assertEqual(grad.dim(), 2)
127
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ self.assertTrue(allclose(grad, expected_grad))
130
+
131
+ # RHS gradient.
132
+ grad = b.grad
133
+ expected_grad = bcp.grad
134
+ self.assertEqual(grad.dim(), 2)
135
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ self.assertTrue(allclose(grad, expected_grad))
138
+
139
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # Construct the operands.
141
+ a_shape = (k, m) if trans_a else (m, k)
142
+ a, acp = _dense_2x(*a_shape, dtype)
143
+ b_shape = (n, k) if trans_b else (k, n)
144
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+
146
+ # Execute the matmul.
147
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+
150
+ # Compute the gradients w.r.t. the inputs.
151
+ expected_out.sum().backward()
152
+ out.sum().backward()
153
+
154
+ # Validate the results.
155
+ self.assertEqual(out.dim(), 2)
156
+ self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ self.assertTrue(allclose(out, expected_out))
159
+
160
+ # LHS gradient.
161
+ grad = a.grad
162
+ expected_grad = acp.grad
163
+ self.assertEqual(grad.dim(), 2)
164
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ self.assertTrue(allclose(grad, expected_grad))
167
+
168
+ # RHS gradient.
169
+ grad = stk.ops.to_dense(b.grad)
170
+ expected_grad = _mask(b_dense.grad, b.grad)
171
+ self.assertEqual(grad.dim(), 2)
172
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ self.assertTrue(allclose(grad, expected_grad))
175
+
176
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # Construct the operands.
178
+ a_shape = (k, m) if trans_a else (m, k)
179
+ a, acp = _dense_2x(*a_shape, dtype)
180
+ b_shape = (n, k) if trans_b else (k, n)
181
+ b, bcp = _dense_2x(*b_shape, dtype)
182
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+
184
+ # Execute the matmul.
185
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+
188
+ # Compute the gradients w.r.t. the inputs.
189
+ expected_out.sum().backward()
190
+ stk.ops.sum(out).backward()
191
+
192
+ # Validate the results.
193
+ out = stk.ops.to_dense(out)
194
+ self.assertEqual(out.dim(), 2)
195
+ self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ self.assertTrue(allclose(out, expected_out))
198
+
199
+ # LHS gradient.
200
+ grad = a.grad
201
+ expected_grad = acp.grad
202
+ self.assertEqual(grad.dim(), 2)
203
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ self.assertTrue(allclose(grad, expected_grad))
206
+
207
+ # RHS gradient.
208
+ grad = b.grad
209
+ expected_grad = bcp.grad
210
+ self.assertEqual(grad.dim(), 2)
211
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ self.assertTrue(allclose(grad, expected_grad))
214
+
215
+ if __name__ == '__main__':
216
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..backend import sputnik
2
+ from ..matrix import Matrix
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ @torch.no_grad()
8
+ def row_indices(shape, data, offsets, column_indices):
9
+ return sputnik.row_indices(shape, data, offsets, column_indices)
10
+
11
+
12
+ # TODO(tgale): Replace this helper with a custom kernel. This operation
13
+ # is much simpler to do than how it's currently implemented.
14
+ @torch.no_grad()
15
+ def _expand_for_blocking(idxs, blocking):
16
+ # Duplicate for block column dimension.
17
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
18
+
19
+ # Update the column indices.
20
+ idxs[:, :, 1] *= blocking
21
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
22
+
23
+ # Duplicate for block row dimension.
24
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
25
+ idxs = idxs.repeat(1, blocking, 1, 1)
26
+
27
+ # Update the row indices.
28
+ idxs[:, :, :, 0] *= blocking
29
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
30
+ idxs = torch.reshape(idxs, [-1, 2])
31
+ return idxs
32
+
33
+
34
+ # TODO(tgale): Add input type checking.
35
+ @torch.no_grad()
36
+ def to_dense(x):
37
+ assert isinstance(x, Matrix)
38
+
39
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
40
+ row_idxs = x.row_indices.type(torch.int32)
41
+ col_idxs = x.column_indices.type(torch.int32)
42
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
43
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
44
+
45
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
46
+ out.scatter_(0, indices, x.data.flatten())
47
+ return out.reshape(x.size())
48
+
49
+
50
+ @torch.no_grad()
51
+ def _mask(x, blocking=1):
52
+ assert x.dim() == 2
53
+ assert x.size()[0] % blocking == 0
54
+ assert x.size()[1] % blocking == 0
55
+ block_rows = x.size()[0] // blocking
56
+ block_cols = x.size()[1] // blocking
57
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
58
+ x = torch.sum(torch.abs(x), dim=(1, 3))
59
+ return x != 0
60
+
61
+
62
+ # TODO(tgale): Add input type checking.
63
+ @torch.no_grad()
64
+ def to_sparse(x, blocking=1):
65
+ m = _mask(x, blocking)
66
+
67
+ # TODO(tgale): Set to appropriate type for input matrix.
68
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
69
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
70
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
71
+ offsets = offsets.type(torch.int32)
72
+
73
+ indices = torch.nonzero(m).type(torch.int16)
74
+ row_indices = indices[:, 0]
75
+ column_indices = indices[:, 1]
76
+
77
+ # Nonzero indices in the dense matrix.
78
+ nonzero_indices = torch.nonzero(m)
79
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
80
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
81
+
82
+ # Gather the data and construct the sparse matrix.
83
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
84
+ data = torch.reshape(data, [-1, blocking, blocking])
85
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
86
+
87
+
88
+ @torch.no_grad()
89
+ def ones_like(x):
90
+ return Matrix(x.size(),
91
+ torch.ones_like(x.data),
92
+ x.row_indices,
93
+ x.column_indices, x.offsets)
94
+
95
+
96
+ def sum(x):
97
+ assert isinstance(x, Matrix)
98
+ return x.data.sum()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ import stk
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class MatrixOpsTest(parameterized.TestCase):
25
+
26
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+
30
+ # Convert the matrix to sparse format.
31
+ sparse_x = stk.ops.to_sparse(x, blocking)
32
+
33
+ # Validate the matrix.
34
+ sparse_x.validate()
35
+
36
+ # Validate the shape.
37
+ self.assertEqual(sparse_x.dim(), 2)
38
+ self.assertEqual(sparse_x.size()[0], rows)
39
+ self.assertEqual(sparse_x.size()[1], cols)
40
+
41
+ # Validate the sparsity.
42
+ numblocks = rows // blocking * cols // blocking
43
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ self.assertEqual(sparse_x.nnz, nnz)
45
+
46
+ # Convert back to dense format.
47
+ dense_x = stk.ops.to_dense(sparse_x)
48
+
49
+ # Validate the shape.
50
+ self.assertEqual(dense_x.dim(), 2)
51
+ self.assertEqual(dense_x.size()[0], rows)
52
+ self.assertEqual(dense_x.size()[1], cols)
53
+
54
+ # Validate the sparsity
55
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+
57
+ # Validate the output.
58
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
+
60
+
61
+ if __name__ == '__main__':
62
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from stk.random.random_ops import dense_mask, mask, randn
2
+ from .random_ops import dense_mask, mask, randn
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..ops import matrix_ops
4
+
5
+
6
+ @torch.no_grad()
7
+ def dense_mask(rows, cols, sparsity, blocking=1):
8
+ assert sparsity >= 0.0 and sparsity <= 1.0
9
+ assert rows % blocking == 0 and cols % blocking == 0
10
+
11
+ block_rows, block_cols = (rows // blocking, cols // blocking)
12
+ nnz = round(block_rows * block_cols * (1 - sparsity))
13
+
14
+ out = np.ones(block_rows * block_cols)
15
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
16
+ out[mask] = 0.0
17
+
18
+ out = np.tile(
19
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
20
+ (1, blocking, 1, blocking))
21
+ out = np.reshape(out, [rows, cols])
22
+ return torch.from_numpy(out.astype(np.float32))
23
+
24
+
25
+ @torch.no_grad()
26
+ def mask(m, n, sparsity, blocking=1):
27
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
28
+ return matrix_ops.to_sparse(out, blocking=blocking)
29
+
30
+
31
+ @torch.no_grad()
32
+ def randn(shape, sparsity, blocking=1):
33
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
34
+ out = mask(*shape_2d, sparsity, blocking)
35
+ out.data.copy_(torch.randn(*out.data.shape))
36
+ return out.view(*shape)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ from . import random
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class RandomOpsTest(parameterized.TestCase):
25
+
26
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ mask = random.dense_mask(
28
+ rows, cols, sparsity, blocking)
29
+
30
+ # Validate the shape.
31
+ self.assertEqual(mask.dim(), 2)
32
+ self.assertEqual(mask.size()[0], rows)
33
+ self.assertEqual(mask.size()[1], cols)
34
+
35
+ # Validate the sparsity
36
+ numblocks = rows // blocking * cols // blocking
37
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ self.assertEqual(
39
+ torch.count_nonzero(mask).item(),
40
+ nnz)
41
+
42
+ # Check values are zero or one.
43
+ self.assertTrue(
44
+ torch.all(torch.logical_or(
45
+ torch.eq(mask, 0),
46
+ torch.eq(mask, 1))))
47
+
48
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ mask = random.mask(
50
+ rows, cols, sparsity, blocking)
51
+
52
+ # Validate the matrix.
53
+ mask.validate()
54
+
55
+ # Validate the shape.
56
+ self.assertEqual(mask.dim(), 2)
57
+ self.assertEqual(mask.size()[0], rows)
58
+ self.assertEqual(mask.size()[1], cols)
59
+
60
+ # Validate the sparsity.
61
+ numblocks = rows // blocking * cols // blocking
62
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ self.assertEqual(mask.nnz, nnz)
64
+
65
+ # Check values are zero or one.
66
+ self.assertTrue(
67
+ torch.all(torch.logical_or(
68
+ torch.eq(mask.data, 0),
69
+ torch.eq(mask.data, 1))))
70
+
71
+
72
+ if __name__ == '__main__':
73
+ unittest.main()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
- from stk import Matrix
8
 
9
 
10
  def act_fn(
 
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
+ from ..stk import Matrix
8
 
9
 
10
  def act_fn(
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py CHANGED
@@ -2,15 +2,22 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
5
- import stk.ops
6
  import torch
7
- from stk import Matrix
 
 
 
 
 
 
 
8
 
9
  # import megablocks.ops as ops
10
  # # from megablocks.ops import ops
11
  # from megablocks.layers import common, dmlp_registry, moe, mpu
12
  # from megablocks.layers.arguments import Arguments
13
 
 
14
  from .. import ops
15
  from . import common, dmlp_registry, moe, mpu
16
  from .arguments import Arguments
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
 
5
  import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
 
15
  # import megablocks.ops as ops
16
  # # from megablocks.ops import ops
17
  # from megablocks.layers import common, dmlp_registry, moe, mpu
18
  # from megablocks.layers.arguments import Arguments
19
 
20
+ from .. import stk
21
  from .. import ops
22
  from . import common, dmlp_registry, moe, mpu
23
  from .arguments import Arguments
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py CHANGED
@@ -1,7 +1,16 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import torch.nn.functional as F
7
 
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
  import torch
15
  import torch.nn.functional as F
16
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py CHANGED
@@ -1,7 +1,17 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk.ops
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
 
7
  # from megablocks import grouped_gemm_util as gg
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
  import torch
16
 
17
  # from megablocks import grouped_gemm_util as gg
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py CHANGED
@@ -3,9 +3,18 @@
3
 
4
  from typing import Any
5
 
6
- import stk
7
- import stk.backend.triton_kernels
8
- import stk.ops
 
 
 
 
 
 
 
 
 
9
  import torch
10
  from packaging import version
11
 
 
3
 
4
  from typing import Any
5
 
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
  import torch
19
  from packaging import version
20
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_0586ba6.abi3.so → _megablocks_63599de.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:16141033c118b488348a29f3436f778764f8f4275fe510dc36badb7c152e0d42
3
  size 11869392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05d38f81524501b75940bfad8686f4f502b5c6af1de85fb1fe5b20da765d4c3c
3
  size 11869392
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_0586ba6
3
- ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_0586ba6::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_63599de
3
+ ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_63599de::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py CHANGED
@@ -3,7 +3,19 @@
3
 
4
  import unittest
5
 
6
- import stk
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from absl.testing import parameterized
9
 
 
3
 
4
  import unittest
5
 
6
+
7
+ # import stk
8
+
9
+ # try:
10
+ # import stk
11
+ # except ImportError:
12
+ # import warnings
13
+ # warnings.warn(
14
+ # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
+ # )
16
+
17
+ from .. import stk
18
+
19
  import torch
20
  from absl.testing import parameterized
21
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # import stk.random
2
+ # import stk.ops
3
+ # from stk.matrix import Matrix
4
+
5
+ from . import random
6
+ from . import ops
7
+ from .matrix import Matrix
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py ADDED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+
4
+
5
+ def _is_eligible(x):
6
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
7
+
8
+
9
+ def _cast(x, dtype):
10
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
11
+ return x.to(dtype)
12
+ elif isinstance(x, map):
13
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
14
+ elif isinstance(x, list) or isinstance(x, tuple):
15
+ return type(x)(map(lambda y: _cast(y, dtype), x))
16
+ return x
17
+
18
+
19
+ def custom_fwd(fwd):
20
+ """Wrap a custom autograd function that always uses autocast dtype."""
21
+
22
+ @functools.wraps(fwd)
23
+ def decorate_fwd(*args, **kwargs):
24
+ if torch.is_autocast_enabled():
25
+ with torch.autocast(device_type="cuda", enabled=False):
26
+ dtype = torch.get_autocast_gpu_dtype()
27
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
28
+ return fwd(*args, **kwargs)
29
+ return decorate_fwd
30
+
31
+
32
+ def custom_bwd(bwd):
33
+ @functools.wraps(bwd)
34
+ def decorate_bwd(*args, **kwargs):
35
+ with torch.autocast(device_type="cuda", enabled=False):
36
+ return bwd(*args, **kwargs)
37
+ return decorate_bwd
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import triton_kernels as backend
4
+ from ..backend.autocast import custom_bwd, custom_fwd
5
+
6
+
7
+ def _standardize_shape(x, transpose):
8
+ if transpose:
9
+ return torch.Size((x[1], x[0]))
10
+ return x
11
+
12
+
13
+ def _sparse_transpose(x):
14
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
15
+
16
+
17
+ def _transpose_helper(x, transpose):
18
+ if isinstance(x, torch.Tensor):
19
+ return x.t() if transpose else x
20
+ if transpose:
21
+ x = _sparse_transpose(x)
22
+ return x + (transpose,)
23
+
24
+
25
+ def _wrap(x):
26
+ if isinstance(x, torch.Tensor):
27
+ return (x,)
28
+ return x
29
+
30
+
31
+ def _is_transposed(x):
32
+ return (not x.is_contiguous() and
33
+ x.stride()[0] == 1 and
34
+ x.stride()[1] == x.size()[0])
35
+
36
+
37
+ def _call_helper(op, out, a, b, trans_a, trans_b):
38
+ args = (_wrap(_transpose_helper(a, trans_a)) +
39
+ _wrap(_transpose_helper(b, trans_b)))
40
+ if isinstance(out, tuple):
41
+ args = args + out
42
+ return op(*args)
43
+
44
+
45
+ def _preprocess_inputs(lhs, rhs, dy):
46
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
47
+ lhs = lhs.t()
48
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
49
+ rhs = rhs.t()
50
+ if (isinstance(dy, torch.Tensor) and
51
+ not dy.is_contiguous() and
52
+ not _is_transposed(dy)):
53
+ dy = dy.contiguous()
54
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
55
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
56
+ return lhs, rhs, dy
57
+
58
+
59
+ def _postprocess_outputs(x, transpose, grad):
60
+ if isinstance(x, torch.Tensor) and transpose:
61
+ return grad.t()
62
+ return grad
63
+
64
+
65
+ def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
66
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
67
+
68
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
69
+ trans_a = trans_lhs and trans_rhs
70
+ trans_b = trans_lhs or not trans_rhs
71
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
72
+ return _postprocess_outputs(lhs, trans_lhs, out)
73
+
74
+
75
+ def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
76
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
77
+
78
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
79
+ trans_a = not trans_lhs or trans_rhs
80
+ trans_b = trans_lhs and trans_rhs
81
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
82
+ return _postprocess_outputs(rhs, trans_rhs, out)
83
+
84
+
85
+ class DSD(torch.autograd.Function):
86
+
87
+ @staticmethod
88
+ @custom_fwd
89
+ def forward(ctx,
90
+ shape,
91
+ data,
92
+ offsets,
93
+ row_indices,
94
+ column_indices,
95
+ offsets_t,
96
+ column_indices_t,
97
+ block_offsets_t,
98
+ transpose_a,
99
+ rhs):
100
+ ctx.save_for_backward(data,
101
+ offsets,
102
+ row_indices,
103
+ column_indices,
104
+ offsets_t,
105
+ column_indices_t,
106
+ block_offsets_t,
107
+ rhs)
108
+ ctx.shape = _standardize_shape(shape, transpose_a)
109
+ ctx.transpose_a = transpose_a
110
+
111
+ out = torch.empty(
112
+ (shape[0], rhs.size()[1]),
113
+ dtype=rhs.dtype,
114
+ device=rhs.device)
115
+
116
+ backend.dsd(shape,
117
+ data,
118
+ offsets,
119
+ row_indices,
120
+ column_indices,
121
+ offsets_t,
122
+ column_indices_t,
123
+ block_offsets_t,
124
+ transpose_a,
125
+ rhs,
126
+ out)
127
+ return out
128
+
129
+ @staticmethod
130
+ @custom_bwd
131
+ def backward(ctx, dy):
132
+ saved_tensors = ctx.saved_tensors
133
+ lhs = (ctx.shape,) + saved_tensors[:-1]
134
+ rhs = saved_tensors[-1]
135
+ trans_a = ctx.transpose_a
136
+ trans_b = _is_transposed(rhs)
137
+
138
+ ddata = None
139
+ if ctx.needs_input_grad[1]:
140
+ ddata = _lhs_gradient(sdd,
141
+ lhs,
142
+ rhs,
143
+ dy,
144
+ trans_a,
145
+ trans_b)
146
+ drhs = None
147
+ if ctx.needs_input_grad[-1]:
148
+ op = dds if trans_b else dsd
149
+ drhs = _rhs_gradient(op,
150
+ lhs,
151
+ rhs,
152
+ dy,
153
+ trans_a,
154
+ trans_b)
155
+ return None, ddata, None, None, None, None, None, None, None, drhs
156
+
157
+
158
+ dsd = DSD.apply
159
+
160
+
161
+ class DDS(torch.autograd.Function):
162
+
163
+ @staticmethod
164
+ @custom_fwd
165
+ def forward(ctx,
166
+ lhs,
167
+ shape,
168
+ data,
169
+ offsets,
170
+ row_indices,
171
+ column_indices,
172
+ offsets_t,
173
+ column_indices_t,
174
+ block_offsets_t,
175
+ transpose_b):
176
+ ctx.save_for_backward(lhs,
177
+ data,
178
+ offsets,
179
+ row_indices,
180
+ column_indices,
181
+ offsets_t,
182
+ column_indices_t,
183
+ block_offsets_t)
184
+ ctx.shape = _standardize_shape(shape, transpose_b)
185
+ ctx.transpose_b = transpose_b
186
+ out = torch.empty((lhs.size()[0], shape[1]),
187
+ dtype=lhs.dtype,
188
+ device=lhs.device)
189
+ backend.dds(lhs,
190
+ shape,
191
+ data,
192
+ offsets,
193
+ row_indices,
194
+ column_indices,
195
+ offsets_t,
196
+ column_indices_t,
197
+ block_offsets_t,
198
+ transpose_b,
199
+ out)
200
+ return out
201
+
202
+ @staticmethod
203
+ @custom_bwd
204
+ def backward(ctx, dy):
205
+ saved_tensors = ctx.saved_tensors
206
+ lhs = saved_tensors[0]
207
+ rhs = (ctx.shape,) + saved_tensors[1:]
208
+ trans_a = _is_transposed(lhs)
209
+ trans_b = ctx.transpose_b
210
+
211
+ dlhs = None
212
+ if ctx.needs_input_grad[0]:
213
+ op = dsd if trans_a else dds
214
+ dlhs = _lhs_gradient(op,
215
+ lhs,
216
+ rhs,
217
+ dy,
218
+ trans_a,
219
+ trans_b)
220
+ ddata = None
221
+ if ctx.needs_input_grad[2]:
222
+ ddata = _rhs_gradient(sdd,
223
+ lhs,
224
+ rhs,
225
+ dy,
226
+ trans_a,
227
+ trans_b)
228
+ return dlhs, None, ddata, None, None, None, None, None, None, None
229
+
230
+
231
+ dds = DDS.apply
232
+
233
+
234
+ class SDD(torch.autograd.Function):
235
+
236
+ @staticmethod
237
+ @custom_fwd
238
+ def forward(ctx,
239
+ lhs,
240
+ rhs,
241
+ shape,
242
+ data,
243
+ offsets,
244
+ row_indices,
245
+ column_indices,
246
+ offsets_t,
247
+ column_indices_t,
248
+ block_offsets_t):
249
+ ctx.save_for_backward(
250
+ lhs,
251
+ rhs,
252
+ offsets,
253
+ row_indices,
254
+ column_indices,
255
+ offsets_t,
256
+ column_indices_t,
257
+ block_offsets_t)
258
+ ctx.shape = shape
259
+ out = torch.empty(
260
+ data.shape,
261
+ dtype=lhs.dtype,
262
+ device=lhs.device)
263
+ backend.sdd(lhs,
264
+ rhs,
265
+ shape,
266
+ out,
267
+ offsets,
268
+ row_indices,
269
+ column_indices)
270
+ return out
271
+
272
+ @staticmethod
273
+ @custom_bwd
274
+ def backward(ctx, dy):
275
+ saved_tensors = ctx.saved_tensors
276
+ lhs, rhs = saved_tensors[:2]
277
+ dy = (ctx.shape, dy) + saved_tensors[2:]
278
+ trans_a = _is_transposed(lhs)
279
+ trans_b = _is_transposed(rhs)
280
+
281
+ dlhs = None
282
+ if ctx.needs_input_grad[0]:
283
+ op = dds if trans_a else dsd
284
+ dlhs = _lhs_gradient(op,
285
+ lhs,
286
+ rhs,
287
+ dy,
288
+ trans_a,
289
+ trans_b)
290
+ drhs = None
291
+ if ctx.needs_input_grad[1]:
292
+ op = dsd if trans_b else dds
293
+ drhs = _rhs_gradient(op,
294
+ lhs,
295
+ rhs,
296
+ dy,
297
+ trans_a,
298
+ trans_b)
299
+ return dlhs, drhs, None, None, None, None, None, None, None, None
300
+
301
+
302
+ sdd = SDD.apply
303
+
304
+ class RowIndices(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ def forward(ctx, shape, data, offsets, column_indices):
308
+ out = torch.empty(
309
+ column_indices.shape,
310
+ dtype=column_indices.dtype,
311
+ device=column_indices.device)
312
+ backend.row_indices(shape, data, offsets, column_indices, out)
313
+ return out
314
+
315
+
316
+ row_indices = RowIndices.apply
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class TritonConfig:
8
+ BLOCK_M: int = 128
9
+ BLOCK_N: int = 128
10
+ BLOCK_K: int = 32
11
+ BLOCK_SIZE: int = 128
12
+ NUM_STAGES: int = 4
13
+ NUM_WARPS: int = 4
14
+
15
+ def _validate_matmul_dims(M: int, K: int, N: int):
16
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
17
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
18
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
19
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
20
+
21
+ @triton.autotune(
22
+ configs=[
23
+ # basic configs for compute-bound matmuls
24
+ triton.Config({
25
+ 'BLOCK_M': TritonConfig.BLOCK_M,
26
+ 'BLOCK_N': TritonConfig.BLOCK_N,
27
+ 'BLOCK_K': TritonConfig.BLOCK_K,
28
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
29
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
30
+ ],
31
+ key=['M', 'N', 'K'],
32
+ )
33
+ @triton.jit
34
+ def _sdd_kernel(A, B, C, M, N, K,
35
+ stride_am, stride_ak,
36
+ stride_bk, stride_bn,
37
+ stride_cm, stride_cn,
38
+ row_indices, column_indices,
39
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
40
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
41
+ ):
42
+ # matrix multiplication
43
+ pid = tl.program_id(0)
44
+ pid_m = tl.load(row_indices + pid)
45
+ pid_n = tl.load(column_indices + pid)
46
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
48
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
49
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
50
+ rk = tl.arange(0, BLOCK_K)
51
+ # pointers
52
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
53
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
54
+ # do matrix multiplication
55
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
56
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
57
+ a = tl.load(A)
58
+ b = tl.load(B)
59
+ acc += tl.dot(a, b)
60
+ A += BLOCK_K * stride_ak
61
+ B += BLOCK_K * stride_bk
62
+ #Store to sparse matrix
63
+ acc = acc.to(C.dtype.element_ty)
64
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
65
+ cm = tl.arange(0, BLOCK_M)
66
+ cn = tl.arange(0, BLOCK_N)
67
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
68
+ tl.store(C, acc, mask=True)
69
+
70
+ @triton.autotune(
71
+ configs=[
72
+ # basic configs for compute-bound matmuls
73
+ triton.Config({
74
+ 'BLOCK_M': TritonConfig.BLOCK_M,
75
+ 'BLOCK_N': TritonConfig.BLOCK_N,
76
+ 'BLOCK_K': TritonConfig.BLOCK_K,
77
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
78
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
79
+ ],
80
+ key=['M', 'N', 'K'],
81
+ )
82
+ @triton.jit
83
+ def _dsd_kernel(A, B, C, M, N, K,
84
+ stride_am, stride_ak,
85
+ stride_bk, stride_bn,
86
+ stride_cm, stride_cn,
87
+ row_indices, column_indices, offsets,
88
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
89
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
90
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
91
+ ):
92
+
93
+ # matrix multiplication
94
+ pid_m = tl.program_id(0)
95
+ pid_n = tl.program_id(1)
96
+
97
+ num_pid_m = tl.num_programs(0)
98
+ num_pid_n = tl.num_programs(1)
99
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
100
+
101
+ start_inx = tl.load(offsets + pid_m)
102
+ end_inx = tl.load(offsets + pid_m + 1)
103
+
104
+ # pointers to sparse matrix
105
+ rm = tl.arange(0, BLOCK_M)
106
+ rak = tl.arange(0, BLOCK_K)
107
+
108
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
109
+
110
+ # pointers to dense matrix
111
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
112
+ rbk = tl.arange(0, BLOCK_K)
113
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
114
+
115
+ # do matrix multiplication
116
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
117
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
118
+
119
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
120
+ ak_sub_incr = BLOCK_K * stride_ak
121
+ bk_sub_incr = BLOCK_K * stride_bk
122
+ bk_block_incr = BLOCK_SIZE * stride_bk
123
+
124
+ for k in range(nsub_blocks * (end_inx - start_inx)):
125
+ sub_block_inx = k % nsub_blocks
126
+ block_inx = k // nsub_blocks
127
+
128
+ if trans_A:
129
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
130
+ else:
131
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
132
+
133
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
134
+
135
+ a = tl.load(ptr_A)
136
+ b = tl.load(ptr_B)
137
+ acc += tl.dot(a, b)
138
+
139
+ acc = acc.to(C.dtype.element_ty)
140
+
141
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
142
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
143
+
144
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
145
+ tl.store(C, acc, mask=True)
146
+
147
+ @triton.autotune(
148
+ configs=[
149
+ # basic configs for compute-bound matmuls
150
+ triton.Config({
151
+ 'BLOCK_M': TritonConfig.BLOCK_M,
152
+ 'BLOCK_N': TritonConfig.BLOCK_N,
153
+ 'BLOCK_K': TritonConfig.BLOCK_K,
154
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
155
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
156
+ ],
157
+ key=['M', 'N', 'K'],
158
+ )
159
+ @triton.jit
160
+ def _dds_kernel(A, B, C, M, N, K,
161
+ stride_am, stride_ak,
162
+ stride_bk, stride_bn,
163
+ stride_cm, stride_cn,
164
+ row_indices, column_indices, offsets,
165
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
166
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
167
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
168
+ ):
169
+
170
+ # matrix multiplication
171
+ pid_m = tl.program_id(0)
172
+ pid_n = tl.program_id(1)
173
+
174
+ num_pid_m = tl.num_programs(0)
175
+ num_pid_n = tl.num_programs(1)
176
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
177
+
178
+ start_inx = tl.load(offsets + pid_n)
179
+ end_inx = tl.load(offsets + pid_n + 1)
180
+
181
+ # pointers to dense matrix
182
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
183
+ rak = tl.arange(0, BLOCK_K)
184
+
185
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
186
+
187
+ # pointers to sparse matrix
188
+ rn = tl.arange(0, BLOCK_N)
189
+ rbk = tl.arange(0, BLOCK_K)
190
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
191
+
192
+ # do matrix multiplication
193
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
194
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
195
+
196
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
197
+
198
+ ak_sub_incr = BLOCK_K * stride_ak
199
+ ak_block_incr = BLOCK_SIZE * stride_ak
200
+ bk_sub_incr = BLOCK_K * stride_bk
201
+
202
+ for k in range(nsub_blocks * (end_inx - start_inx)):
203
+ sub_block_inx = k % nsub_blocks
204
+ block_inx = k // nsub_blocks
205
+
206
+ if trans_B:
207
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
208
+ else:
209
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
210
+
211
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
212
+ a = tl.load(ptr_A)
213
+ b = tl.load(ptr_B)
214
+ acc += tl.dot(a, b)
215
+
216
+ acc = acc.to(C.dtype.element_ty)
217
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
218
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
219
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
220
+ tl.store(C, acc, mask=True)
221
+
222
+ def dsd(shape,
223
+ data,
224
+ offsets,
225
+ row_indices,
226
+ column_indices,
227
+ offsets_t,
228
+ column_indices_t,
229
+ block_offsets_t,
230
+ transpose_a,
231
+ rhs,
232
+ out
233
+ ):
234
+
235
+ device = rhs.device
236
+ trans_A = transpose_a
237
+ trans_B = False
238
+
239
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
240
+ trans_B = True
241
+
242
+ # checks constraints
243
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
244
+ M, K = shape
245
+ _, N = rhs.shape
246
+
247
+ _validate_matmul_dims(M, K, N)
248
+
249
+ # accumulator types
250
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
251
+
252
+ stride_am, stride_ak = data.stride(1), data.stride(2)
253
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
254
+ a_column_indices = column_indices
255
+ a_offsets = offsets
256
+
257
+ # launch kernel
258
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
259
+
260
+ if trans_A:
261
+ stride_am, stride_ak = data.stride(2), data.stride(1)
262
+ a_column_indices, a_offsets = column_indices_t, offsets_t
263
+
264
+ if trans_B:
265
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
266
+
267
+ _dsd_kernel[grid](
268
+ data.data, rhs, out, M, N, K,
269
+ stride_am, stride_ak,
270
+ stride_bk, stride_bn,
271
+ out.stride(0), out.stride(1),
272
+ row_indices, a_column_indices, a_offsets,
273
+ block_offsets_t, trans_A, trans_B,
274
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
275
+ )
276
+ # return out
277
+
278
+ def dds(lhs,
279
+ shape,
280
+ data,
281
+ offsets,
282
+ row_indices,
283
+ column_indices,
284
+ offsets_t,
285
+ column_indices_t,
286
+ block_offsets_t,
287
+ transpose_b,
288
+ out
289
+ ):
290
+
291
+ device = lhs.device
292
+ trans_B = transpose_b
293
+ trans_A = False
294
+
295
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
296
+ trans_A = True
297
+
298
+ # checks constraints
299
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
300
+ M, K = lhs.shape
301
+ _, N = shape
302
+
303
+ _validate_matmul_dims(M, K, N)
304
+
305
+ # accumulator types
306
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
307
+
308
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
309
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
310
+ b_column_indices = column_indices_t
311
+ b_offsets = offsets_t
312
+
313
+ # launch kernel
314
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
315
+
316
+ if trans_A:
317
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
318
+ if trans_B:
319
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
320
+ b_column_indices, b_offsets = column_indices, offsets
321
+
322
+ _dds_kernel[grid](
323
+ lhs, data, out, M, N, K,
324
+ stride_am, stride_ak,
325
+ stride_bk, stride_bn,
326
+ out.stride(0), out.stride(1),
327
+ row_indices, b_column_indices, b_offsets,
328
+ block_offsets_t, trans_A, trans_B,
329
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
330
+ )
331
+
332
+ def sdd(lhs,
333
+ rhs,
334
+ shape,
335
+ out,
336
+ offsets,
337
+ row_indices,
338
+ column_indices
339
+ ):
340
+
341
+ device = out.device
342
+ trans_A = False
343
+ trans_B = False
344
+
345
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
346
+ trans_A = True
347
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
348
+ trans_B = True
349
+
350
+ # checks constraints
351
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
352
+ M, K = lhs.shape
353
+ _, N = rhs.shape
354
+
355
+ _validate_matmul_dims(M, K, N)
356
+
357
+ # accumulator types
358
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
359
+
360
+ # launch kernel
361
+ nnz_blocks = len(row_indices)
362
+ grid = lambda META: (nnz_blocks,)
363
+
364
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
365
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
366
+
367
+ if trans_A:
368
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
369
+ if trans_B:
370
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
371
+
372
+ _sdd_kernel[grid](
373
+ lhs, rhs, out, M, N, K,
374
+ stride_am, stride_ak,
375
+ stride_bk, stride_bn,
376
+ out.stride(1), out.stride(2),
377
+ row_indices, column_indices,
378
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
379
+ )
380
+
381
+ @triton.jit
382
+ def _row_indices_kernel(offsets, out):
383
+ pid = tl.program_id(0)
384
+ row_offset = tl.load(offsets + pid)
385
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
386
+ for nnz_block in range(nnz_blocks):
387
+ tl.store(out + row_offset + nnz_block, pid)
388
+
389
+ def row_indices(
390
+ shape, data, offsets, column_indices, out
391
+ ):
392
+ block_rows = len(offsets) - 1
393
+ _row_indices_kernel[(block_rows, )](offsets, out)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ # 1. Add heavyweight (data) validation helper.
5
+ # 2. Add construction helpers
6
+ # 3. Make indentation consistent
7
+ # 4. Replace asserts with descriptive errors.
8
+
9
+ ##
10
+ ### Validation helpers.
11
+ ##
12
+
13
+
14
+ def _validate_matrix(shape, data, row_indices, column_indices, offsets):
15
+ # Data should be [nnz, block_size, block_size]
16
+ if data.dim() == 1:
17
+ data = torch.reshape(data, [data.numel(), 1, 1])
18
+
19
+ # Blocks should be square.
20
+ if data.shape[-2] != data.shape[-1]:
21
+ raise ValueError(
22
+ "Expected square blocking in data. "
23
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
24
+
25
+ # Flatten batch dimensions on data - original shape preserved
26
+ # in shape argument.
27
+ block_size = data.shape[-1]
28
+ data = data.view([-1, block_size, block_size])
29
+
30
+ if data.dim() != 3:
31
+ raise ValueError(
32
+ "Expected 3D shape for data (nnz, block, block). "
33
+ f"Got shape {data.dim()}D shape.")
34
+
35
+ block_size = data.shape[1]
36
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
37
+ raise ValueError(
38
+ "Matrix shape must be dividible by blocking. "
39
+ f"Got shape {shape} with "
40
+ f"{[block_size, block_size]} blocking.")
41
+
42
+ if np.prod(shape) < data.numel():
43
+ raise ValueError(
44
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
45
+ f"({data.numel()} v. {np.prod(shape)})")
46
+
47
+ if row_indices.dim() != 1:
48
+ raise ValueError(
49
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
50
+
51
+ if column_indices.dim() != 1:
52
+ raise ValueError(
53
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
54
+
55
+ if offsets.dim() != 1:
56
+ raise ValueError(
57
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
58
+
59
+ if row_indices.numel() != data.shape[0]:
60
+ raise ValueError(
61
+ "Expected 1 index per nonzero block. "
62
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
63
+
64
+ if column_indices.numel() != data.shape[0]:
65
+ raise ValueError(
66
+ "Expected 1 index per nonzero block. "
67
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
68
+
69
+ block_rows = np.prod(shape[:-1]) / block_size
70
+ if offsets.numel() != block_rows + 1:
71
+ raise ValueError(
72
+ "Expected one offset per block row plus one. "
73
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
74
+
75
+ is_cuda = (data.is_cuda and
76
+ row_indices.is_cuda and
77
+ column_indices.is_cuda and
78
+ offsets.is_cuda)
79
+ is_cpu = (not data.is_cuda and
80
+ not row_indices.is_cuda and
81
+ not column_indices.is_cuda and
82
+ not offsets.is_cuda)
83
+ if not (is_cuda or is_cpu):
84
+ raise ValueError(
85
+ "Expected data & meta-data on common device. "
86
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
87
+ f"column_indices on {column_indices.device} and "
88
+ f"offsets on {offsets.device}.")
89
+
90
+ if data.dtype != torch.float16:
91
+ raise ValueError(
92
+ f"Expected float16 data. Got {data.dtype} data.")
93
+ if row_indices.dtype != torch.int16:
94
+ raise ValueError(
95
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
96
+ if column_indices.dtype != torch.int16:
97
+ raise ValueError(
98
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
99
+ if offsets.dtype != torch.int32:
100
+ raise ValueError(
101
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
102
+ return data
103
+
104
+
105
+ def _transpose(size, data, row_indices, column_indices, offsets):
106
+ block_columns = size[1] // data.shape[1]
107
+
108
+ # Sort row indices by column indices to get the transposed matrix's
109
+ # column indices.
110
+ gather_indices = column_indices.argsort()
111
+ column_indices_t = row_indices.gather(0, gather_indices)
112
+ block_offsets_t = gather_indices.int()
113
+
114
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
115
+ # the histogram in 32-bit float, which can exactly represent 16-bit
116
+ # integers.
117
+ column_indices_float = column_indices.float()
118
+
119
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
120
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
121
+ nnz_per_column = nnz_per_column.int()
122
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
123
+ return column_indices_t, offsets_t, block_offsets_t
124
+
125
+
126
+ class Matrix(torch.nn.Module):
127
+ """A matrix stored in sparse format.
128
+
129
+ Underlying format is block compressed sparse row (BCSR).
130
+
131
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
132
+ """
133
+
134
+ def __init__(self,
135
+ size,
136
+ data,
137
+ row_indices,
138
+ column_indices,
139
+ offsets,
140
+ column_indices_t=None,
141
+ offsets_t=None,
142
+ block_offsets_t=None):
143
+ super().__init__()
144
+ self._size = size
145
+ self._data = data
146
+ self._row_indices = row_indices
147
+ self._column_indices = column_indices
148
+ self._offsets = offsets
149
+
150
+ # Produce the transpose meta-data if it is not passed in.
151
+ if ((column_indices_t is None) or (offsets_t is None) or
152
+ (block_offsets_t is None)):
153
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
154
+ size, data, row_indices, column_indices, offsets)
155
+ self._column_indices_t = column_indices_t
156
+ self._offsets_t = offsets_t
157
+ self._block_offsets_t = block_offsets_t
158
+
159
+ self._transposed = False
160
+
161
+ # Validate that our metadata will not overflow.
162
+ max_dim = np.iinfo(np.int16).max * self.blocking
163
+ if column_indices.dtype == torch.int16:
164
+ if size[0] > max_dim or size[1] > max_dim:
165
+ raise ValueError(
166
+ "Sparse matrix with shape {size} exceeds representable "
167
+ "size with 16-bit indices.")
168
+
169
+ def validate(self):
170
+ _validate_matrix(self._size,
171
+ self._data,
172
+ self._row_indices,
173
+ self._column_indices,
174
+ self._offsets)
175
+
176
+ # TODO(tgale): Add heavyweight data validation.
177
+
178
+ def to(self, device):
179
+ # TODO(tgale): Handle type conversions here. We
180
+ # need to set the appropriate meta-data type for
181
+ # the given floating-point type.
182
+ self._data = self._data.to(device)
183
+ self._row_indices = self._row_indices.to(device)
184
+ self._column_indices = self._column_indices.to(device)
185
+ self._offsets = self._offsets.to(device)
186
+ self._column_indices_t = self._column_indices_t.to(device)
187
+ self._offsets_t = self._offsets_t.to(device)
188
+ self._block_offsets_t = self._block_offsets_t.to(device)
189
+ return self
190
+
191
+ def cuda(self):
192
+ return self.to(torch.cuda.current_device())
193
+
194
+ def clone(self):
195
+ return Matrix(
196
+ self.size(),
197
+ self.data.clone(),
198
+ self.row_indices.clone(),
199
+ self.column_indices.clone(),
200
+ self.offsets.clone(),
201
+ self.column_indices_t.clone(),
202
+ self.offsets_t.clone(),
203
+ self.block_offsets_t.clone())
204
+
205
+ def t(self):
206
+ if self.dim() != 2:
207
+ raise ValueError(
208
+ "t() expects a tensor with <= 2 dimensions, "
209
+ f"but self is {self.dim()}D.")
210
+ out = Matrix(self.size(),
211
+ self.data,
212
+ self.row_indices,
213
+ self.column_indices,
214
+ self.offsets,
215
+ self.column_indices_t,
216
+ self.offsets_t,
217
+ self.block_offsets_t)
218
+ out._transposed = not self._transposed
219
+ out._size = torch.Size((self._size[1], self._size[0]))
220
+ return out
221
+
222
+ def contiguous(self):
223
+ raise ValueError("Not yet implemented.")
224
+
225
+ def is_contiguous(self):
226
+ return not self._transposed
227
+
228
+ @property
229
+ def is_cuda(self):
230
+ return self._data.is_cuda
231
+
232
+ @property
233
+ def device(self):
234
+ return self._data.device
235
+
236
+ def size(self):
237
+ return self._size
238
+
239
+ @property
240
+ def shape(self):
241
+ return self.size()
242
+
243
+ def dim(self):
244
+ return len(self._size)
245
+
246
+ @property
247
+ def data(self):
248
+ return self._data
249
+
250
+ @property
251
+ def row_indices(self):
252
+ return self._row_indices
253
+
254
+ @property
255
+ def column_indices(self):
256
+ return self._column_indices
257
+
258
+ @property
259
+ def offsets(self):
260
+ return self._offsets
261
+
262
+ @property
263
+ def offsets_t(self):
264
+ return self._offsets_t
265
+
266
+ @property
267
+ def column_indices_t(self):
268
+ return self._column_indices_t
269
+
270
+ @property
271
+ def block_offsets_t(self):
272
+ return self._block_offsets_t
273
+
274
+ @property
275
+ def dtype(self):
276
+ return self.data.dtype
277
+
278
+ @property
279
+ def nnz(self):
280
+ return self.data.numel()
281
+
282
+ @property
283
+ def blocking(self):
284
+ return self.data.shape[1]
285
+
286
+ @property
287
+ def requires_grad(self):
288
+ return self.data.requires_grad
289
+
290
+ def requires_grad_(self, x):
291
+ self.data.requires_grad_(x)
292
+ return self
293
+
294
+ def view(self, *shape):
295
+ assert self.is_contiguous()
296
+ if shape[-1] != self.size()[-1]:
297
+ raise ValueError(
298
+ "Can't change view on compressed dimension. "
299
+ f"{self.size()[-1]} v. {shape[-1]}.")
300
+ if np.prod(shape) != np.prod(self.size()):
301
+ raise ValueError(
302
+ "Mismatch in numel of Matrix and new shape. "
303
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
304
+ return Matrix(shape,
305
+ self.data,
306
+ self.row_indices,
307
+ self.column_indices,
308
+ self.offsets,
309
+ self.column_indices_t,
310
+ self.offsets_t,
311
+ self.block_offsets_t)
312
+
313
+ @property
314
+ def grad(self):
315
+ # TODO(tgale): Make sure this mirrors torch.Tensor
316
+ # behavior in the case where we ask for the gradient
317
+ # of a non-contiguous tensor.
318
+ size = self.size()
319
+ if not self.is_contiguous():
320
+ size = torch.Size((size[1], size[0]))
321
+ out = Matrix(size,
322
+ self.data.grad,
323
+ self.row_indices,
324
+ self.column_indices,
325
+ self.offsets,
326
+ self.column_indices_t,
327
+ self.offsets_t,
328
+ self.block_offsets_t)
329
+ return out if self.is_contiguous() else out.t()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .linear_ops import dds, dsd, sdd
2
+ from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
3
+ from .eltwise_ops import mul
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..matrix import Matrix
2
+
3
+ def mul(a, b):
4
+ """Performs element-wise multiplication of matrices a and b.
5
+
6
+ It is the user's responsibility to make sure that a and b
7
+ follow the same matrix topology. This function assumes it is safe
8
+ to use the topoplogy of a.
9
+
10
+ Args:
11
+ a: stk.Matrix.
12
+ b: stk.Matrix with a's matrix topology.
13
+
14
+ Returns:
15
+ stk.Matrix where the entries correspond to torch.mul(a, b).
16
+ """
17
+ assert isinstance(a, Matrix)
18
+ assert isinstance(b, Matrix)
19
+ assert a.size() == b.size()
20
+
21
+ return Matrix(a.size(),
22
+ a.data * b.data,
23
+ a.row_indices,
24
+ a.column_indices,
25
+ a.offsets,
26
+ a.column_indices_t,
27
+ a.offsets_t,
28
+ a.block_offsets_t)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import torch
4
+ from absl.testing import parameterized
5
+
6
+ import stk
7
+ from stk.ops.linear_ops_test import allclose, _dense_and_sparse
8
+
9
+ _MATRIX_SIZES = (
10
+ (128, 128, 0.0),
11
+ (256, 256, 0.5),
12
+ (2048, 1024, 0.8),
13
+ (512, 128, 0.0),
14
+ (128, 512, 0.0),
15
+ (1024, 512, 0.0),
16
+ (1024, 512, 0.5),
17
+ (1024, 512, 0.75),
18
+ (512, 1024, 0.0),
19
+ (512, 1024, 0.5),
20
+ (512, 1024, 0.75),
21
+ (1024, 1024, 0.0),
22
+ (1024, 1024, 0.5),
23
+ (1024, 1024, 0.75),
24
+ )
25
+
26
+ _DTYPE = (
27
+ torch.float16, torch.bfloat16
28
+ )
29
+
30
+ def _generate_testcases():
31
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
32
+ testcases = [(*size, 128, dtype) for
33
+ (size, dtype) in testcases]
34
+ return testcases
35
+
36
+ _ELTWISE_OP_TESTS = _generate_testcases()
37
+
38
+ def _dense_and_sparse_like(x, std=0.1):
39
+ dense_data = torch.randn_like(x.data, device=x.device) * std
40
+ sparse = stk.Matrix(x.size(),
41
+ dense_data,
42
+ x.row_indices,
43
+ x.column_indices,
44
+ x.offsets)
45
+ dense = stk.ops.to_dense(sparse)
46
+
47
+ return (dense.requires_grad_(True),
48
+ sparse.requires_grad_(True))
49
+
50
+ @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ class EltwiseOpsTest(parameterized.TestCase):
52
+
53
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+
55
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ b_dense, b = _dense_and_sparse_like(a)
57
+
58
+ out = stk.ops.mul(a, b)
59
+ expected_out = torch.mul(a_dense, b_dense)
60
+
61
+ # Compute the gradients w.r.t. the inputs.
62
+ expected_out.sum().backward()
63
+ stk.ops.sum(out).backward()
64
+
65
+ # Validate the results.
66
+ out = stk.ops.to_dense(out)
67
+ self.assertEqual(out.dim(), 2)
68
+ self.assertEqual(expected_out.size(), out.size())
69
+ self.assertTrue(allclose(out, expected_out))
70
+
71
+ # LHS gradient.
72
+ grad = stk.ops.to_dense(a.grad)
73
+ expected_grad = a_dense.grad
74
+ self.assertEqual(grad.dim(), 2)
75
+ self.assertEqual(expected_grad.size(), grad.size())
76
+ self.assertTrue(allclose(grad, expected_grad))
77
+
78
+ # RHS gradient.
79
+ grad = stk.ops.to_dense(b.grad)
80
+ expected_grad = b_dense.grad
81
+ self.assertEqual(grad.dim(), 2)
82
+ self.assertEqual(expected_grad.size(), grad.size())
83
+ self.assertTrue(allclose(grad, expected_grad))
84
+
85
+ if __name__ == '__main__':
86
+ unittest.main()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import sputnik
4
+ from ..matrix import Matrix
5
+
6
+
7
+ def dsd(a, b):
8
+ assert isinstance(a, Matrix)
9
+ assert isinstance(b, torch.Tensor)
10
+ return sputnik.dsd(
11
+ a.size(),
12
+ a.data, a.offsets,
13
+ a.row_indices,
14
+ a.column_indices,
15
+ a.offsets_t,
16
+ a.column_indices_t,
17
+ a.block_offsets_t,
18
+ not a.is_contiguous(),
19
+ b)
20
+
21
+
22
+ def dds(a, b):
23
+ assert isinstance(a, torch.Tensor)
24
+ assert isinstance(b, Matrix)
25
+ return sputnik.dds(
26
+ a,
27
+ b.size(),
28
+ b.data, b.offsets,
29
+ b.row_indices,
30
+ b.column_indices,
31
+ b.offsets_t,
32
+ b.column_indices_t,
33
+ b.block_offsets_t,
34
+ not b.is_contiguous())
35
+
36
+
37
+ def sdd(a, b, topo):
38
+ assert isinstance(a, torch.Tensor)
39
+ assert isinstance(b, torch.Tensor)
40
+ assert isinstance(topo, Matrix)
41
+ assert topo.is_contiguous()
42
+ out = sputnik.sdd(
43
+ a, b,
44
+ topo.size(),
45
+ topo.data,
46
+ topo.offsets,
47
+ topo.row_indices,
48
+ topo.column_indices,
49
+ topo.offsets_t,
50
+ topo.column_indices_t,
51
+ topo.block_offsets_t)
52
+ return Matrix(topo.size(),
53
+ out,
54
+ topo.row_indices,
55
+ topo.column_indices,
56
+ topo.offsets,
57
+ topo.column_indices_t,
58
+ topo.offsets_t,
59
+ topo.block_offsets_t)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import numpy as np
4
+ import torch
5
+ from absl.testing import parameterized
6
+
7
+ import stk
8
+
9
+
10
+ def allclose(x, y, pct=0.25):
11
+ mask = torch.isclose(x, y, rtol=5e-2)
12
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
13
+ if pct_diff > pct:
14
+ print("{:.2f}% of values not close.".format(pct_diff))
15
+ return False
16
+ return True
17
+
18
+
19
+ # An assortment of problems designed to make sure
20
+ # the bindings are operating correctly.
21
+ _MATRIX_SIZES = (
22
+ (128, 128, 128, 0.0),
23
+ (256, 256, 256, 0.5),
24
+ (2048, 1024, 512, 0.8),
25
+ (512, 128, 128, 0.0),
26
+ (128, 128, 512, 0.0),
27
+ (1024, 512, 512, 0.0),
28
+ (1024, 512, 512, 0.5),
29
+ (1024, 512, 512, 0.75),
30
+ (512, 512, 1024, 0.0),
31
+ (512, 512, 1024, 0.5),
32
+ (512, 512, 1024, 0.75),
33
+ (1024, 1024, 1024, 0.0),
34
+ (1024, 1024, 1024, 0.5),
35
+ (1024, 1024, 1024, 0.75),
36
+ )
37
+
38
+ _TRANSPOSE = (
39
+ (False, False),
40
+ (False, True),
41
+ (True, False),
42
+ (True, True),
43
+ )
44
+
45
+ _DTYPE = (
46
+ torch.float16, torch.bfloat16
47
+ )
48
+
49
+ def _generate_testcases():
50
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
51
+ testcases = [(*size, *trans, 128, dtype) for
52
+ (size, trans, dtype) in testcases]
53
+ return testcases
54
+
55
+ _LINEAR_OP_TESTS = _generate_testcases()
56
+
57
+ def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
58
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
59
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
60
+ sparse = stk.ops.to_sparse(dense, blocking)
61
+ cuda_device = torch.device("cuda")
62
+ return (dense.to(cuda_device).requires_grad_(True),
63
+ sparse.to(cuda_device).requires_grad_(True))
64
+
65
+
66
+ def _dense(rows, cols, dtype, std=0.1):
67
+ cuda_device = torch.device("cuda")
68
+ out = (torch.randn(rows, cols) * std).type(dtype)
69
+ return out.to(cuda_device).requires_grad_(True)
70
+
71
+
72
+ def _dense_2x(rows, cols, dtype):
73
+ a = _dense(rows, cols, dtype)
74
+ return a, a.detach().requires_grad_(True)
75
+
76
+
77
+ def _with_transpose(op, a, b, trans_a, trans_b):
78
+ a = a.t() if trans_a else a
79
+ b = b.t() if trans_b else b
80
+ return op(a, b)
81
+
82
+
83
+ def _mmm(a, b, topo):
84
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
85
+ return torch.mm(a, b) * mask
86
+
87
+
88
+ def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
89
+ a = a.t() if trans_a else a
90
+ b = b.t() if trans_b else b
91
+ return op(a, b, topo)
92
+
93
+
94
+ def _mask(x, mask):
95
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
96
+ return x * mask
97
+
98
+
99
+ @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ class LinearOpsTest(parameterized.TestCase):
101
+
102
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # Construct the operands.
104
+ a_shape = (k, m) if trans_a else (m, k)
105
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ b_shape = (n, k) if trans_b else (k, n)
107
+ b, bcp = _dense_2x(*b_shape, dtype)
108
+
109
+ # Execute the matmul.
110
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+
113
+ # Compute the gradients w.r.t. the inputs.
114
+ expected_out.sum().backward()
115
+ out.sum().backward()
116
+
117
+ # Validate the results.
118
+ self.assertEqual(out.dim(), 2)
119
+ self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ self.assertTrue(allclose(out, expected_out))
122
+
123
+ # LHS gradient.
124
+ grad = stk.ops.to_dense(a.grad)
125
+ expected_grad = _mask(a_dense.grad, a.grad)
126
+ self.assertEqual(grad.dim(), 2)
127
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ self.assertTrue(allclose(grad, expected_grad))
130
+
131
+ # RHS gradient.
132
+ grad = b.grad
133
+ expected_grad = bcp.grad
134
+ self.assertEqual(grad.dim(), 2)
135
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ self.assertTrue(allclose(grad, expected_grad))
138
+
139
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # Construct the operands.
141
+ a_shape = (k, m) if trans_a else (m, k)
142
+ a, acp = _dense_2x(*a_shape, dtype)
143
+ b_shape = (n, k) if trans_b else (k, n)
144
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+
146
+ # Execute the matmul.
147
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+
150
+ # Compute the gradients w.r.t. the inputs.
151
+ expected_out.sum().backward()
152
+ out.sum().backward()
153
+
154
+ # Validate the results.
155
+ self.assertEqual(out.dim(), 2)
156
+ self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ self.assertTrue(allclose(out, expected_out))
159
+
160
+ # LHS gradient.
161
+ grad = a.grad
162
+ expected_grad = acp.grad
163
+ self.assertEqual(grad.dim(), 2)
164
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ self.assertTrue(allclose(grad, expected_grad))
167
+
168
+ # RHS gradient.
169
+ grad = stk.ops.to_dense(b.grad)
170
+ expected_grad = _mask(b_dense.grad, b.grad)
171
+ self.assertEqual(grad.dim(), 2)
172
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ self.assertTrue(allclose(grad, expected_grad))
175
+
176
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # Construct the operands.
178
+ a_shape = (k, m) if trans_a else (m, k)
179
+ a, acp = _dense_2x(*a_shape, dtype)
180
+ b_shape = (n, k) if trans_b else (k, n)
181
+ b, bcp = _dense_2x(*b_shape, dtype)
182
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+
184
+ # Execute the matmul.
185
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+
188
+ # Compute the gradients w.r.t. the inputs.
189
+ expected_out.sum().backward()
190
+ stk.ops.sum(out).backward()
191
+
192
+ # Validate the results.
193
+ out = stk.ops.to_dense(out)
194
+ self.assertEqual(out.dim(), 2)
195
+ self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ self.assertTrue(allclose(out, expected_out))
198
+
199
+ # LHS gradient.
200
+ grad = a.grad
201
+ expected_grad = acp.grad
202
+ self.assertEqual(grad.dim(), 2)
203
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ self.assertTrue(allclose(grad, expected_grad))
206
+
207
+ # RHS gradient.
208
+ grad = b.grad
209
+ expected_grad = bcp.grad
210
+ self.assertEqual(grad.dim(), 2)
211
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ self.assertTrue(allclose(grad, expected_grad))
214
+
215
+ if __name__ == '__main__':
216
+ unittest.main()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..backend import sputnik
2
+ from ..matrix import Matrix
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ @torch.no_grad()
8
+ def row_indices(shape, data, offsets, column_indices):
9
+ return sputnik.row_indices(shape, data, offsets, column_indices)
10
+
11
+
12
+ # TODO(tgale): Replace this helper with a custom kernel. This operation
13
+ # is much simpler to do than how it's currently implemented.
14
+ @torch.no_grad()
15
+ def _expand_for_blocking(idxs, blocking):
16
+ # Duplicate for block column dimension.
17
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
18
+
19
+ # Update the column indices.
20
+ idxs[:, :, 1] *= blocking
21
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
22
+
23
+ # Duplicate for block row dimension.
24
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
25
+ idxs = idxs.repeat(1, blocking, 1, 1)
26
+
27
+ # Update the row indices.
28
+ idxs[:, :, :, 0] *= blocking
29
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
30
+ idxs = torch.reshape(idxs, [-1, 2])
31
+ return idxs
32
+
33
+
34
+ # TODO(tgale): Add input type checking.
35
+ @torch.no_grad()
36
+ def to_dense(x):
37
+ assert isinstance(x, Matrix)
38
+
39
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
40
+ row_idxs = x.row_indices.type(torch.int32)
41
+ col_idxs = x.column_indices.type(torch.int32)
42
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
43
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
44
+
45
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
46
+ out.scatter_(0, indices, x.data.flatten())
47
+ return out.reshape(x.size())
48
+
49
+
50
+ @torch.no_grad()
51
+ def _mask(x, blocking=1):
52
+ assert x.dim() == 2
53
+ assert x.size()[0] % blocking == 0
54
+ assert x.size()[1] % blocking == 0
55
+ block_rows = x.size()[0] // blocking
56
+ block_cols = x.size()[1] // blocking
57
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
58
+ x = torch.sum(torch.abs(x), dim=(1, 3))
59
+ return x != 0
60
+
61
+
62
+ # TODO(tgale): Add input type checking.
63
+ @torch.no_grad()
64
+ def to_sparse(x, blocking=1):
65
+ m = _mask(x, blocking)
66
+
67
+ # TODO(tgale): Set to appropriate type for input matrix.
68
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
69
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
70
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
71
+ offsets = offsets.type(torch.int32)
72
+
73
+ indices = torch.nonzero(m).type(torch.int16)
74
+ row_indices = indices[:, 0]
75
+ column_indices = indices[:, 1]
76
+
77
+ # Nonzero indices in the dense matrix.
78
+ nonzero_indices = torch.nonzero(m)
79
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
80
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
81
+
82
+ # Gather the data and construct the sparse matrix.
83
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
84
+ data = torch.reshape(data, [-1, blocking, blocking])
85
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
86
+
87
+
88
+ @torch.no_grad()
89
+ def ones_like(x):
90
+ return Matrix(x.size(),
91
+ torch.ones_like(x.data),
92
+ x.row_indices,
93
+ x.column_indices, x.offsets)
94
+
95
+
96
+ def sum(x):
97
+ assert isinstance(x, Matrix)
98
+ return x.data.sum()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ import stk
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class MatrixOpsTest(parameterized.TestCase):
25
+
26
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+
30
+ # Convert the matrix to sparse format.
31
+ sparse_x = stk.ops.to_sparse(x, blocking)
32
+
33
+ # Validate the matrix.
34
+ sparse_x.validate()
35
+
36
+ # Validate the shape.
37
+ self.assertEqual(sparse_x.dim(), 2)
38
+ self.assertEqual(sparse_x.size()[0], rows)
39
+ self.assertEqual(sparse_x.size()[1], cols)
40
+
41
+ # Validate the sparsity.
42
+ numblocks = rows // blocking * cols // blocking
43
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ self.assertEqual(sparse_x.nnz, nnz)
45
+
46
+ # Convert back to dense format.
47
+ dense_x = stk.ops.to_dense(sparse_x)
48
+
49
+ # Validate the shape.
50
+ self.assertEqual(dense_x.dim(), 2)
51
+ self.assertEqual(dense_x.size()[0], rows)
52
+ self.assertEqual(dense_x.size()[1], cols)
53
+
54
+ # Validate the sparsity
55
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+
57
+ # Validate the output.
58
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
+
60
+
61
+ if __name__ == '__main__':
62
+ unittest.main()
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from stk.random.random_ops import dense_mask, mask, randn
2
+ from .random_ops import dense_mask, mask, randn
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..ops import matrix_ops
4
+
5
+
6
+ @torch.no_grad()
7
+ def dense_mask(rows, cols, sparsity, blocking=1):
8
+ assert sparsity >= 0.0 and sparsity <= 1.0
9
+ assert rows % blocking == 0 and cols % blocking == 0
10
+
11
+ block_rows, block_cols = (rows // blocking, cols // blocking)
12
+ nnz = round(block_rows * block_cols * (1 - sparsity))
13
+
14
+ out = np.ones(block_rows * block_cols)
15
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
16
+ out[mask] = 0.0
17
+
18
+ out = np.tile(
19
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
20
+ (1, blocking, 1, blocking))
21
+ out = np.reshape(out, [rows, cols])
22
+ return torch.from_numpy(out.astype(np.float32))
23
+
24
+
25
+ @torch.no_grad()
26
+ def mask(m, n, sparsity, blocking=1):
27
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
28
+ return matrix_ops.to_sparse(out, blocking=blocking)
29
+
30
+
31
+ @torch.no_grad()
32
+ def randn(shape, sparsity, blocking=1):
33
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
34
+ out = mask(*shape_2d, sparsity, blocking)
35
+ out.data.copy_(torch.randn(*out.data.shape))
36
+ return out.view(*shape)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ from . import random
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class RandomOpsTest(parameterized.TestCase):
25
+
26
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ mask = random.dense_mask(
28
+ rows, cols, sparsity, blocking)
29
+
30
+ # Validate the shape.
31
+ self.assertEqual(mask.dim(), 2)
32
+ self.assertEqual(mask.size()[0], rows)
33
+ self.assertEqual(mask.size()[1], cols)
34
+
35
+ # Validate the sparsity
36
+ numblocks = rows // blocking * cols // blocking
37
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ self.assertEqual(
39
+ torch.count_nonzero(mask).item(),
40
+ nnz)
41
+
42
+ # Check values are zero or one.
43
+ self.assertTrue(
44
+ torch.all(torch.logical_or(
45
+ torch.eq(mask, 0),
46
+ torch.eq(mask, 1))))
47
+
48
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ mask = random.mask(
50
+ rows, cols, sparsity, blocking)
51
+
52
+ # Validate the matrix.
53
+ mask.validate()
54
+
55
+ # Validate the shape.
56
+ self.assertEqual(mask.dim(), 2)
57
+ self.assertEqual(mask.size()[0], rows)
58
+ self.assertEqual(mask.size()[1], cols)
59
+
60
+ # Validate the sparsity.
61
+ numblocks = rows // blocking * cols // blocking
62
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ self.assertEqual(mask.nnz, nnz)
64
+
65
+ # Check values are zero or one.
66
+ self.assertTrue(
67
+ torch.all(torch.logical_or(
68
+ torch.eq(mask.data, 0),
69
+ torch.eq(mask.data, 1))))
70
+
71
+
72
+ if __name__ == '__main__':
73
+ unittest.main()
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
- from stk import Matrix
8
 
9
 
10
  def act_fn(
 
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
+ from ..stk import Matrix
8
 
9
 
10
  def act_fn(
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py CHANGED
@@ -2,15 +2,22 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
5
- import stk.ops
6
  import torch
7
- from stk import Matrix
 
 
 
 
 
 
 
8
 
9
  # import megablocks.ops as ops
10
  # # from megablocks.ops import ops
11
  # from megablocks.layers import common, dmlp_registry, moe, mpu
12
  # from megablocks.layers.arguments import Arguments
13
 
 
14
  from .. import ops
15
  from . import common, dmlp_registry, moe, mpu
16
  from .arguments import Arguments
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
 
5
  import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
 
15
  # import megablocks.ops as ops
16
  # # from megablocks.ops import ops
17
  # from megablocks.layers import common, dmlp_registry, moe, mpu
18
  # from megablocks.layers.arguments import Arguments
19
 
20
+ from .. import stk
21
  from .. import ops
22
  from . import common, dmlp_registry, moe, mpu
23
  from .arguments import Arguments