kernel
drbh commited on
Commit
63599de
·
1 Parent(s): ef966b6

fix: fully vendor stk and fix imports

Browse files
flake.lock CHANGED
@@ -1,21 +1,5 @@
1
  {
2
  "nodes": {
3
- "composer": {
4
- "flake": false,
5
- "locked": {
6
- "lastModified": 1749592532,
7
- "narHash": "sha256-VKfSWtf+Z20nP1cHiBNwFzYCuMGL0xelvd6HMyDnIhc=",
8
- "owner": "mosaicml",
9
- "repo": "composer",
10
- "rev": "0eec49da42e7f617329f035853800211f0a54ca3",
11
- "type": "github"
12
- },
13
- "original": {
14
- "owner": "mosaicml",
15
- "repo": "composer",
16
- "type": "github"
17
- }
18
- },
19
  "flake-compat": {
20
  "locked": {
21
  "lastModified": 1747046372,
@@ -89,10 +73,11 @@
89
  "nixpkgs": "nixpkgs"
90
  },
91
  "locked": {
92
- "lastModified": 1748598786,
 
93
  "owner": "huggingface",
94
  "repo": "hf-nix",
95
- "rev": "6ca679441494139fde1f2355691ddb5dc8170269",
96
  "type": "github"
97
  },
98
  "original": {
@@ -113,16 +98,15 @@
113
  ]
114
  },
115
  "locked": {
116
- "lastModified": 1749765513,
117
- "narHash": "sha256-/Tyhxb1v4ks5G7eewZiK9BHSn2606wH/KdCMmdeAw2c=",
118
  "owner": "huggingface",
119
  "repo": "kernel-builder",
120
- "rev": "a9c26e450a81296525128e3860c21ab028aa8d07",
121
  "type": "github"
122
  },
123
  "original": {
124
  "owner": "huggingface",
125
- "ref": "support-custom-python-libraries-in-dev-shell-nixland",
126
  "repo": "kernel-builder",
127
  "type": "github"
128
  }
@@ -145,25 +129,7 @@
145
  },
146
  "root": {
147
  "inputs": {
148
- "composer": "composer",
149
- "kernel-builder": "kernel-builder",
150
- "stk": "stk"
151
- }
152
- },
153
- "stk": {
154
- "flake": false,
155
- "locked": {
156
- "lastModified": 1724272107,
157
- "narHash": "sha256-f6eydO4u6jasvepP25a6jacSvoUNyfKW51FxahMtz1Q=",
158
- "owner": "stanford-futuredata",
159
- "repo": "stk",
160
- "rev": "736313768ef697ce13a0594a41b2512a0fbc9884",
161
- "type": "github"
162
- },
163
- "original": {
164
- "owner": "stanford-futuredata",
165
- "repo": "stk",
166
- "type": "github"
167
  }
168
  },
169
  "systems": {
 
1
  {
2
  "nodes": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  "flake-compat": {
4
  "locked": {
5
  "lastModified": 1747046372,
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1751014803,
102
+ "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
 
110
  "repo": "kernel-builder",
111
  "type": "github"
112
  }
 
129
  },
130
  "root": {
131
  "inputs": {
132
+ "kernel-builder": "kernel-builder"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  }
134
  },
135
  "systems": {
flake.nix CHANGED
@@ -1,52 +1,24 @@
1
  {
2
  description = "Flake for megablocks_moe kernel";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder/support-custom-python-libraries-in-dev-shell-nixland";
6
- # Add libraries as inputs
7
- composer = {
8
- url = "github:mosaicml/composer";
9
- flake = false;
10
- };
11
- stk = {
12
- url = "github:stanford-futuredata/stk";
13
- flake = false;
14
- };
15
 
16
- # TODO: update to build with the correct torch version
17
- # grouped_gemm = {
18
- # url = "github:tgale96/grouped_gemm";
19
- # flake = false;
20
- # };
21
  };
22
-
23
- outputs = {
24
- self,
25
- kernel-builder,
26
- composer,
27
- stk,
28
- # grouped_gemm,
29
- }:
30
  kernel-builder.lib.genFlakeOutputs {
31
  path = ./.;
32
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
33
-
34
- # Map custom packages to their sources
35
- customPythonPackages = {
36
- composer = composer;
37
- stk = stk;
38
- # grouped_gemm = grouped_gemm;
39
- };
40
-
41
- pythonTestDeps = [
42
- "tqdm"
43
- "py-cpuinfo"
44
- "importlib-metadata"
45
- "torchmetrics"
46
- "composer"
47
- "stk"
48
- # "grouped_gemm"
49
- # "yahp" # may be needed for some testing plugin
50
  ];
51
  };
52
- }
 
1
  {
2
  description = "Flake for megablocks_moe kernel";
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
 
 
 
6
  };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
 
 
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+
17
+ pythonCheckInputs = pkgs: with pkgs; [
18
+ tqdm
19
+ py-cpuinfo
20
+ importlib-metadata
21
+ torchmetrics
 
 
 
 
 
 
 
 
 
 
 
22
  ];
23
  };
24
+ }
torch-ext/megablocks/_layers/activation_fn.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
- from stk import Matrix
8
 
9
 
10
  def act_fn(
 
4
  from typing import Any, Callable, Union
5
 
6
  import torch
7
+ from ..stk import Matrix
8
 
9
 
10
  def act_fn(
torch-ext/megablocks/_layers/dmoe.py CHANGED
@@ -2,15 +2,22 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
5
- import stk.ops
6
  import torch
7
- from stk import Matrix
 
 
 
 
 
 
 
8
 
9
  # import megablocks.ops as ops
10
  # # from megablocks.ops import ops
11
  # from megablocks.layers import common, dmlp_registry, moe, mpu
12
  # from megablocks.layers.arguments import Arguments
13
 
 
14
  from .. import ops
15
  from . import common, dmlp_registry, moe, mpu
16
  from .arguments import Arguments
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  import numpy as np
 
5
  import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
 
15
  # import megablocks.ops as ops
16
  # # from megablocks.ops import ops
17
  # from megablocks.layers import common, dmlp_registry, moe, mpu
18
  # from megablocks.layers.arguments import Arguments
19
 
20
+ from .. import stk
21
  from .. import ops
22
  from . import common, dmlp_registry, moe, mpu
23
  from .arguments import Arguments
torch-ext/megablocks/_layers/gelu.py CHANGED
@@ -1,7 +1,16 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import torch.nn.functional as F
7
 
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
  import torch
15
  import torch.nn.functional as F
16
 
torch-ext/megablocks/_layers/glu.py CHANGED
@@ -1,7 +1,17 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- import stk.ops
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
 
7
  # from megablocks import grouped_gemm_util as gg
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
  import torch
16
 
17
  # from megablocks import grouped_gemm_util as gg
torch-ext/megablocks/_layers/mlp.py CHANGED
@@ -3,9 +3,18 @@
3
 
4
  from typing import Any
5
 
6
- import stk
7
- import stk.backend.triton_kernels
8
- import stk.ops
 
 
 
 
 
 
 
 
 
9
  import torch
10
  from packaging import version
11
 
 
3
 
4
  from typing import Any
5
 
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
  import torch
19
  from packaging import version
20
 
torch-ext/megablocks/ops/matmul_benchmark.py CHANGED
@@ -3,7 +3,19 @@
3
 
4
  import unittest
5
 
6
- import stk
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from absl.testing import parameterized
9
 
 
3
 
4
  import unittest
5
 
6
+
7
+ # import stk
8
+
9
+ # try:
10
+ # import stk
11
+ # except ImportError:
12
+ # import warnings
13
+ # warnings.warn(
14
+ # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
+ # )
16
+
17
+ from .. import stk
18
+
19
  import torch
20
  from absl.testing import parameterized
21
 
torch-ext/megablocks/stk/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # import stk.random
2
+ # import stk.ops
3
+ # from stk.matrix import Matrix
4
+
5
+ from . import random
6
+ from . import ops
7
+ from .matrix import Matrix
torch-ext/megablocks/stk/backend/__init__.py ADDED
File without changes
torch-ext/megablocks/stk/backend/autocast.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+
4
+
5
+ def _is_eligible(x):
6
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
7
+
8
+
9
+ def _cast(x, dtype):
10
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
11
+ return x.to(dtype)
12
+ elif isinstance(x, map):
13
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
14
+ elif isinstance(x, list) or isinstance(x, tuple):
15
+ return type(x)(map(lambda y: _cast(y, dtype), x))
16
+ return x
17
+
18
+
19
+ def custom_fwd(fwd):
20
+ """Wrap a custom autograd function that always uses autocast dtype."""
21
+
22
+ @functools.wraps(fwd)
23
+ def decorate_fwd(*args, **kwargs):
24
+ if torch.is_autocast_enabled():
25
+ with torch.autocast(device_type="cuda", enabled=False):
26
+ dtype = torch.get_autocast_gpu_dtype()
27
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
28
+ return fwd(*args, **kwargs)
29
+ return decorate_fwd
30
+
31
+
32
+ def custom_bwd(bwd):
33
+ @functools.wraps(bwd)
34
+ def decorate_bwd(*args, **kwargs):
35
+ with torch.autocast(device_type="cuda", enabled=False):
36
+ return bwd(*args, **kwargs)
37
+ return decorate_bwd
torch-ext/megablocks/stk/backend/sputnik.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import triton_kernels as backend
4
+ from ..backend.autocast import custom_bwd, custom_fwd
5
+
6
+
7
+ def _standardize_shape(x, transpose):
8
+ if transpose:
9
+ return torch.Size((x[1], x[0]))
10
+ return x
11
+
12
+
13
+ def _sparse_transpose(x):
14
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
15
+
16
+
17
+ def _transpose_helper(x, transpose):
18
+ if isinstance(x, torch.Tensor):
19
+ return x.t() if transpose else x
20
+ if transpose:
21
+ x = _sparse_transpose(x)
22
+ return x + (transpose,)
23
+
24
+
25
+ def _wrap(x):
26
+ if isinstance(x, torch.Tensor):
27
+ return (x,)
28
+ return x
29
+
30
+
31
+ def _is_transposed(x):
32
+ return (not x.is_contiguous() and
33
+ x.stride()[0] == 1 and
34
+ x.stride()[1] == x.size()[0])
35
+
36
+
37
+ def _call_helper(op, out, a, b, trans_a, trans_b):
38
+ args = (_wrap(_transpose_helper(a, trans_a)) +
39
+ _wrap(_transpose_helper(b, trans_b)))
40
+ if isinstance(out, tuple):
41
+ args = args + out
42
+ return op(*args)
43
+
44
+
45
+ def _preprocess_inputs(lhs, rhs, dy):
46
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
47
+ lhs = lhs.t()
48
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
49
+ rhs = rhs.t()
50
+ if (isinstance(dy, torch.Tensor) and
51
+ not dy.is_contiguous() and
52
+ not _is_transposed(dy)):
53
+ dy = dy.contiguous()
54
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
55
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
56
+ return lhs, rhs, dy
57
+
58
+
59
+ def _postprocess_outputs(x, transpose, grad):
60
+ if isinstance(x, torch.Tensor) and transpose:
61
+ return grad.t()
62
+ return grad
63
+
64
+
65
+ def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
66
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
67
+
68
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
69
+ trans_a = trans_lhs and trans_rhs
70
+ trans_b = trans_lhs or not trans_rhs
71
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
72
+ return _postprocess_outputs(lhs, trans_lhs, out)
73
+
74
+
75
+ def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
76
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
77
+
78
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
79
+ trans_a = not trans_lhs or trans_rhs
80
+ trans_b = trans_lhs and trans_rhs
81
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
82
+ return _postprocess_outputs(rhs, trans_rhs, out)
83
+
84
+
85
+ class DSD(torch.autograd.Function):
86
+
87
+ @staticmethod
88
+ @custom_fwd
89
+ def forward(ctx,
90
+ shape,
91
+ data,
92
+ offsets,
93
+ row_indices,
94
+ column_indices,
95
+ offsets_t,
96
+ column_indices_t,
97
+ block_offsets_t,
98
+ transpose_a,
99
+ rhs):
100
+ ctx.save_for_backward(data,
101
+ offsets,
102
+ row_indices,
103
+ column_indices,
104
+ offsets_t,
105
+ column_indices_t,
106
+ block_offsets_t,
107
+ rhs)
108
+ ctx.shape = _standardize_shape(shape, transpose_a)
109
+ ctx.transpose_a = transpose_a
110
+
111
+ out = torch.empty(
112
+ (shape[0], rhs.size()[1]),
113
+ dtype=rhs.dtype,
114
+ device=rhs.device)
115
+
116
+ backend.dsd(shape,
117
+ data,
118
+ offsets,
119
+ row_indices,
120
+ column_indices,
121
+ offsets_t,
122
+ column_indices_t,
123
+ block_offsets_t,
124
+ transpose_a,
125
+ rhs,
126
+ out)
127
+ return out
128
+
129
+ @staticmethod
130
+ @custom_bwd
131
+ def backward(ctx, dy):
132
+ saved_tensors = ctx.saved_tensors
133
+ lhs = (ctx.shape,) + saved_tensors[:-1]
134
+ rhs = saved_tensors[-1]
135
+ trans_a = ctx.transpose_a
136
+ trans_b = _is_transposed(rhs)
137
+
138
+ ddata = None
139
+ if ctx.needs_input_grad[1]:
140
+ ddata = _lhs_gradient(sdd,
141
+ lhs,
142
+ rhs,
143
+ dy,
144
+ trans_a,
145
+ trans_b)
146
+ drhs = None
147
+ if ctx.needs_input_grad[-1]:
148
+ op = dds if trans_b else dsd
149
+ drhs = _rhs_gradient(op,
150
+ lhs,
151
+ rhs,
152
+ dy,
153
+ trans_a,
154
+ trans_b)
155
+ return None, ddata, None, None, None, None, None, None, None, drhs
156
+
157
+
158
+ dsd = DSD.apply
159
+
160
+
161
+ class DDS(torch.autograd.Function):
162
+
163
+ @staticmethod
164
+ @custom_fwd
165
+ def forward(ctx,
166
+ lhs,
167
+ shape,
168
+ data,
169
+ offsets,
170
+ row_indices,
171
+ column_indices,
172
+ offsets_t,
173
+ column_indices_t,
174
+ block_offsets_t,
175
+ transpose_b):
176
+ ctx.save_for_backward(lhs,
177
+ data,
178
+ offsets,
179
+ row_indices,
180
+ column_indices,
181
+ offsets_t,
182
+ column_indices_t,
183
+ block_offsets_t)
184
+ ctx.shape = _standardize_shape(shape, transpose_b)
185
+ ctx.transpose_b = transpose_b
186
+ out = torch.empty((lhs.size()[0], shape[1]),
187
+ dtype=lhs.dtype,
188
+ device=lhs.device)
189
+ backend.dds(lhs,
190
+ shape,
191
+ data,
192
+ offsets,
193
+ row_indices,
194
+ column_indices,
195
+ offsets_t,
196
+ column_indices_t,
197
+ block_offsets_t,
198
+ transpose_b,
199
+ out)
200
+ return out
201
+
202
+ @staticmethod
203
+ @custom_bwd
204
+ def backward(ctx, dy):
205
+ saved_tensors = ctx.saved_tensors
206
+ lhs = saved_tensors[0]
207
+ rhs = (ctx.shape,) + saved_tensors[1:]
208
+ trans_a = _is_transposed(lhs)
209
+ trans_b = ctx.transpose_b
210
+
211
+ dlhs = None
212
+ if ctx.needs_input_grad[0]:
213
+ op = dsd if trans_a else dds
214
+ dlhs = _lhs_gradient(op,
215
+ lhs,
216
+ rhs,
217
+ dy,
218
+ trans_a,
219
+ trans_b)
220
+ ddata = None
221
+ if ctx.needs_input_grad[2]:
222
+ ddata = _rhs_gradient(sdd,
223
+ lhs,
224
+ rhs,
225
+ dy,
226
+ trans_a,
227
+ trans_b)
228
+ return dlhs, None, ddata, None, None, None, None, None, None, None
229
+
230
+
231
+ dds = DDS.apply
232
+
233
+
234
+ class SDD(torch.autograd.Function):
235
+
236
+ @staticmethod
237
+ @custom_fwd
238
+ def forward(ctx,
239
+ lhs,
240
+ rhs,
241
+ shape,
242
+ data,
243
+ offsets,
244
+ row_indices,
245
+ column_indices,
246
+ offsets_t,
247
+ column_indices_t,
248
+ block_offsets_t):
249
+ ctx.save_for_backward(
250
+ lhs,
251
+ rhs,
252
+ offsets,
253
+ row_indices,
254
+ column_indices,
255
+ offsets_t,
256
+ column_indices_t,
257
+ block_offsets_t)
258
+ ctx.shape = shape
259
+ out = torch.empty(
260
+ data.shape,
261
+ dtype=lhs.dtype,
262
+ device=lhs.device)
263
+ backend.sdd(lhs,
264
+ rhs,
265
+ shape,
266
+ out,
267
+ offsets,
268
+ row_indices,
269
+ column_indices)
270
+ return out
271
+
272
+ @staticmethod
273
+ @custom_bwd
274
+ def backward(ctx, dy):
275
+ saved_tensors = ctx.saved_tensors
276
+ lhs, rhs = saved_tensors[:2]
277
+ dy = (ctx.shape, dy) + saved_tensors[2:]
278
+ trans_a = _is_transposed(lhs)
279
+ trans_b = _is_transposed(rhs)
280
+
281
+ dlhs = None
282
+ if ctx.needs_input_grad[0]:
283
+ op = dds if trans_a else dsd
284
+ dlhs = _lhs_gradient(op,
285
+ lhs,
286
+ rhs,
287
+ dy,
288
+ trans_a,
289
+ trans_b)
290
+ drhs = None
291
+ if ctx.needs_input_grad[1]:
292
+ op = dsd if trans_b else dds
293
+ drhs = _rhs_gradient(op,
294
+ lhs,
295
+ rhs,
296
+ dy,
297
+ trans_a,
298
+ trans_b)
299
+ return dlhs, drhs, None, None, None, None, None, None, None, None
300
+
301
+
302
+ sdd = SDD.apply
303
+
304
+ class RowIndices(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ def forward(ctx, shape, data, offsets, column_indices):
308
+ out = torch.empty(
309
+ column_indices.shape,
310
+ dtype=column_indices.dtype,
311
+ device=column_indices.device)
312
+ backend.row_indices(shape, data, offsets, column_indices, out)
313
+ return out
314
+
315
+
316
+ row_indices = RowIndices.apply
torch-ext/megablocks/stk/backend/triton_kernels.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class TritonConfig:
8
+ BLOCK_M: int = 128
9
+ BLOCK_N: int = 128
10
+ BLOCK_K: int = 32
11
+ BLOCK_SIZE: int = 128
12
+ NUM_STAGES: int = 4
13
+ NUM_WARPS: int = 4
14
+
15
+ def _validate_matmul_dims(M: int, K: int, N: int):
16
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
17
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
18
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
19
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
20
+
21
+ @triton.autotune(
22
+ configs=[
23
+ # basic configs for compute-bound matmuls
24
+ triton.Config({
25
+ 'BLOCK_M': TritonConfig.BLOCK_M,
26
+ 'BLOCK_N': TritonConfig.BLOCK_N,
27
+ 'BLOCK_K': TritonConfig.BLOCK_K,
28
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
29
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
30
+ ],
31
+ key=['M', 'N', 'K'],
32
+ )
33
+ @triton.jit
34
+ def _sdd_kernel(A, B, C, M, N, K,
35
+ stride_am, stride_ak,
36
+ stride_bk, stride_bn,
37
+ stride_cm, stride_cn,
38
+ row_indices, column_indices,
39
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
40
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
41
+ ):
42
+ # matrix multiplication
43
+ pid = tl.program_id(0)
44
+ pid_m = tl.load(row_indices + pid)
45
+ pid_n = tl.load(column_indices + pid)
46
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
48
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
49
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
50
+ rk = tl.arange(0, BLOCK_K)
51
+ # pointers
52
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
53
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
54
+ # do matrix multiplication
55
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
56
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
57
+ a = tl.load(A)
58
+ b = tl.load(B)
59
+ acc += tl.dot(a, b)
60
+ A += BLOCK_K * stride_ak
61
+ B += BLOCK_K * stride_bk
62
+ #Store to sparse matrix
63
+ acc = acc.to(C.dtype.element_ty)
64
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
65
+ cm = tl.arange(0, BLOCK_M)
66
+ cn = tl.arange(0, BLOCK_N)
67
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
68
+ tl.store(C, acc, mask=True)
69
+
70
+ @triton.autotune(
71
+ configs=[
72
+ # basic configs for compute-bound matmuls
73
+ triton.Config({
74
+ 'BLOCK_M': TritonConfig.BLOCK_M,
75
+ 'BLOCK_N': TritonConfig.BLOCK_N,
76
+ 'BLOCK_K': TritonConfig.BLOCK_K,
77
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
78
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
79
+ ],
80
+ key=['M', 'N', 'K'],
81
+ )
82
+ @triton.jit
83
+ def _dsd_kernel(A, B, C, M, N, K,
84
+ stride_am, stride_ak,
85
+ stride_bk, stride_bn,
86
+ stride_cm, stride_cn,
87
+ row_indices, column_indices, offsets,
88
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
89
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
90
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
91
+ ):
92
+
93
+ # matrix multiplication
94
+ pid_m = tl.program_id(0)
95
+ pid_n = tl.program_id(1)
96
+
97
+ num_pid_m = tl.num_programs(0)
98
+ num_pid_n = tl.num_programs(1)
99
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
100
+
101
+ start_inx = tl.load(offsets + pid_m)
102
+ end_inx = tl.load(offsets + pid_m + 1)
103
+
104
+ # pointers to sparse matrix
105
+ rm = tl.arange(0, BLOCK_M)
106
+ rak = tl.arange(0, BLOCK_K)
107
+
108
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
109
+
110
+ # pointers to dense matrix
111
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
112
+ rbk = tl.arange(0, BLOCK_K)
113
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
114
+
115
+ # do matrix multiplication
116
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
117
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
118
+
119
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
120
+ ak_sub_incr = BLOCK_K * stride_ak
121
+ bk_sub_incr = BLOCK_K * stride_bk
122
+ bk_block_incr = BLOCK_SIZE * stride_bk
123
+
124
+ for k in range(nsub_blocks * (end_inx - start_inx)):
125
+ sub_block_inx = k % nsub_blocks
126
+ block_inx = k // nsub_blocks
127
+
128
+ if trans_A:
129
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
130
+ else:
131
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
132
+
133
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
134
+
135
+ a = tl.load(ptr_A)
136
+ b = tl.load(ptr_B)
137
+ acc += tl.dot(a, b)
138
+
139
+ acc = acc.to(C.dtype.element_ty)
140
+
141
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
142
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
143
+
144
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
145
+ tl.store(C, acc, mask=True)
146
+
147
+ @triton.autotune(
148
+ configs=[
149
+ # basic configs for compute-bound matmuls
150
+ triton.Config({
151
+ 'BLOCK_M': TritonConfig.BLOCK_M,
152
+ 'BLOCK_N': TritonConfig.BLOCK_N,
153
+ 'BLOCK_K': TritonConfig.BLOCK_K,
154
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
155
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
156
+ ],
157
+ key=['M', 'N', 'K'],
158
+ )
159
+ @triton.jit
160
+ def _dds_kernel(A, B, C, M, N, K,
161
+ stride_am, stride_ak,
162
+ stride_bk, stride_bn,
163
+ stride_cm, stride_cn,
164
+ row_indices, column_indices, offsets,
165
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
166
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
167
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
168
+ ):
169
+
170
+ # matrix multiplication
171
+ pid_m = tl.program_id(0)
172
+ pid_n = tl.program_id(1)
173
+
174
+ num_pid_m = tl.num_programs(0)
175
+ num_pid_n = tl.num_programs(1)
176
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
177
+
178
+ start_inx = tl.load(offsets + pid_n)
179
+ end_inx = tl.load(offsets + pid_n + 1)
180
+
181
+ # pointers to dense matrix
182
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
183
+ rak = tl.arange(0, BLOCK_K)
184
+
185
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
186
+
187
+ # pointers to sparse matrix
188
+ rn = tl.arange(0, BLOCK_N)
189
+ rbk = tl.arange(0, BLOCK_K)
190
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
191
+
192
+ # do matrix multiplication
193
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
194
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
195
+
196
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
197
+
198
+ ak_sub_incr = BLOCK_K * stride_ak
199
+ ak_block_incr = BLOCK_SIZE * stride_ak
200
+ bk_sub_incr = BLOCK_K * stride_bk
201
+
202
+ for k in range(nsub_blocks * (end_inx - start_inx)):
203
+ sub_block_inx = k % nsub_blocks
204
+ block_inx = k // nsub_blocks
205
+
206
+ if trans_B:
207
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
208
+ else:
209
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
210
+
211
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
212
+ a = tl.load(ptr_A)
213
+ b = tl.load(ptr_B)
214
+ acc += tl.dot(a, b)
215
+
216
+ acc = acc.to(C.dtype.element_ty)
217
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
218
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
219
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
220
+ tl.store(C, acc, mask=True)
221
+
222
+ def dsd(shape,
223
+ data,
224
+ offsets,
225
+ row_indices,
226
+ column_indices,
227
+ offsets_t,
228
+ column_indices_t,
229
+ block_offsets_t,
230
+ transpose_a,
231
+ rhs,
232
+ out
233
+ ):
234
+
235
+ device = rhs.device
236
+ trans_A = transpose_a
237
+ trans_B = False
238
+
239
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
240
+ trans_B = True
241
+
242
+ # checks constraints
243
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
244
+ M, K = shape
245
+ _, N = rhs.shape
246
+
247
+ _validate_matmul_dims(M, K, N)
248
+
249
+ # accumulator types
250
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
251
+
252
+ stride_am, stride_ak = data.stride(1), data.stride(2)
253
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
254
+ a_column_indices = column_indices
255
+ a_offsets = offsets
256
+
257
+ # launch kernel
258
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
259
+
260
+ if trans_A:
261
+ stride_am, stride_ak = data.stride(2), data.stride(1)
262
+ a_column_indices, a_offsets = column_indices_t, offsets_t
263
+
264
+ if trans_B:
265
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
266
+
267
+ _dsd_kernel[grid](
268
+ data.data, rhs, out, M, N, K,
269
+ stride_am, stride_ak,
270
+ stride_bk, stride_bn,
271
+ out.stride(0), out.stride(1),
272
+ row_indices, a_column_indices, a_offsets,
273
+ block_offsets_t, trans_A, trans_B,
274
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
275
+ )
276
+ # return out
277
+
278
+ def dds(lhs,
279
+ shape,
280
+ data,
281
+ offsets,
282
+ row_indices,
283
+ column_indices,
284
+ offsets_t,
285
+ column_indices_t,
286
+ block_offsets_t,
287
+ transpose_b,
288
+ out
289
+ ):
290
+
291
+ device = lhs.device
292
+ trans_B = transpose_b
293
+ trans_A = False
294
+
295
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
296
+ trans_A = True
297
+
298
+ # checks constraints
299
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
300
+ M, K = lhs.shape
301
+ _, N = shape
302
+
303
+ _validate_matmul_dims(M, K, N)
304
+
305
+ # accumulator types
306
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
307
+
308
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
309
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
310
+ b_column_indices = column_indices_t
311
+ b_offsets = offsets_t
312
+
313
+ # launch kernel
314
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
315
+
316
+ if trans_A:
317
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
318
+ if trans_B:
319
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
320
+ b_column_indices, b_offsets = column_indices, offsets
321
+
322
+ _dds_kernel[grid](
323
+ lhs, data, out, M, N, K,
324
+ stride_am, stride_ak,
325
+ stride_bk, stride_bn,
326
+ out.stride(0), out.stride(1),
327
+ row_indices, b_column_indices, b_offsets,
328
+ block_offsets_t, trans_A, trans_B,
329
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
330
+ )
331
+
332
+ def sdd(lhs,
333
+ rhs,
334
+ shape,
335
+ out,
336
+ offsets,
337
+ row_indices,
338
+ column_indices
339
+ ):
340
+
341
+ device = out.device
342
+ trans_A = False
343
+ trans_B = False
344
+
345
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
346
+ trans_A = True
347
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
348
+ trans_B = True
349
+
350
+ # checks constraints
351
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
352
+ M, K = lhs.shape
353
+ _, N = rhs.shape
354
+
355
+ _validate_matmul_dims(M, K, N)
356
+
357
+ # accumulator types
358
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
359
+
360
+ # launch kernel
361
+ nnz_blocks = len(row_indices)
362
+ grid = lambda META: (nnz_blocks,)
363
+
364
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
365
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
366
+
367
+ if trans_A:
368
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
369
+ if trans_B:
370
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
371
+
372
+ _sdd_kernel[grid](
373
+ lhs, rhs, out, M, N, K,
374
+ stride_am, stride_ak,
375
+ stride_bk, stride_bn,
376
+ out.stride(1), out.stride(2),
377
+ row_indices, column_indices,
378
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
379
+ )
380
+
381
+ @triton.jit
382
+ def _row_indices_kernel(offsets, out):
383
+ pid = tl.program_id(0)
384
+ row_offset = tl.load(offsets + pid)
385
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
386
+ for nnz_block in range(nnz_blocks):
387
+ tl.store(out + row_offset + nnz_block, pid)
388
+
389
+ def row_indices(
390
+ shape, data, offsets, column_indices, out
391
+ ):
392
+ block_rows = len(offsets) - 1
393
+ _row_indices_kernel[(block_rows, )](offsets, out)
torch-ext/megablocks/stk/matrix.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ # 1. Add heavyweight (data) validation helper.
5
+ # 2. Add construction helpers
6
+ # 3. Make indentation consistent
7
+ # 4. Replace asserts with descriptive errors.
8
+
9
+ ##
10
+ ### Validation helpers.
11
+ ##
12
+
13
+
14
+ def _validate_matrix(shape, data, row_indices, column_indices, offsets):
15
+ # Data should be [nnz, block_size, block_size]
16
+ if data.dim() == 1:
17
+ data = torch.reshape(data, [data.numel(), 1, 1])
18
+
19
+ # Blocks should be square.
20
+ if data.shape[-2] != data.shape[-1]:
21
+ raise ValueError(
22
+ "Expected square blocking in data. "
23
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
24
+
25
+ # Flatten batch dimensions on data - original shape preserved
26
+ # in shape argument.
27
+ block_size = data.shape[-1]
28
+ data = data.view([-1, block_size, block_size])
29
+
30
+ if data.dim() != 3:
31
+ raise ValueError(
32
+ "Expected 3D shape for data (nnz, block, block). "
33
+ f"Got shape {data.dim()}D shape.")
34
+
35
+ block_size = data.shape[1]
36
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
37
+ raise ValueError(
38
+ "Matrix shape must be dividible by blocking. "
39
+ f"Got shape {shape} with "
40
+ f"{[block_size, block_size]} blocking.")
41
+
42
+ if np.prod(shape) < data.numel():
43
+ raise ValueError(
44
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
45
+ f"({data.numel()} v. {np.prod(shape)})")
46
+
47
+ if row_indices.dim() != 1:
48
+ raise ValueError(
49
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
50
+
51
+ if column_indices.dim() != 1:
52
+ raise ValueError(
53
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
54
+
55
+ if offsets.dim() != 1:
56
+ raise ValueError(
57
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
58
+
59
+ if row_indices.numel() != data.shape[0]:
60
+ raise ValueError(
61
+ "Expected 1 index per nonzero block. "
62
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
63
+
64
+ if column_indices.numel() != data.shape[0]:
65
+ raise ValueError(
66
+ "Expected 1 index per nonzero block. "
67
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
68
+
69
+ block_rows = np.prod(shape[:-1]) / block_size
70
+ if offsets.numel() != block_rows + 1:
71
+ raise ValueError(
72
+ "Expected one offset per block row plus one. "
73
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
74
+
75
+ is_cuda = (data.is_cuda and
76
+ row_indices.is_cuda and
77
+ column_indices.is_cuda and
78
+ offsets.is_cuda)
79
+ is_cpu = (not data.is_cuda and
80
+ not row_indices.is_cuda and
81
+ not column_indices.is_cuda and
82
+ not offsets.is_cuda)
83
+ if not (is_cuda or is_cpu):
84
+ raise ValueError(
85
+ "Expected data & meta-data on common device. "
86
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
87
+ f"column_indices on {column_indices.device} and "
88
+ f"offsets on {offsets.device}.")
89
+
90
+ if data.dtype != torch.float16:
91
+ raise ValueError(
92
+ f"Expected float16 data. Got {data.dtype} data.")
93
+ if row_indices.dtype != torch.int16:
94
+ raise ValueError(
95
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
96
+ if column_indices.dtype != torch.int16:
97
+ raise ValueError(
98
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
99
+ if offsets.dtype != torch.int32:
100
+ raise ValueError(
101
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
102
+ return data
103
+
104
+
105
+ def _transpose(size, data, row_indices, column_indices, offsets):
106
+ block_columns = size[1] // data.shape[1]
107
+
108
+ # Sort row indices by column indices to get the transposed matrix's
109
+ # column indices.
110
+ gather_indices = column_indices.argsort()
111
+ column_indices_t = row_indices.gather(0, gather_indices)
112
+ block_offsets_t = gather_indices.int()
113
+
114
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
115
+ # the histogram in 32-bit float, which can exactly represent 16-bit
116
+ # integers.
117
+ column_indices_float = column_indices.float()
118
+
119
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
120
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
121
+ nnz_per_column = nnz_per_column.int()
122
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
123
+ return column_indices_t, offsets_t, block_offsets_t
124
+
125
+
126
+ class Matrix(torch.nn.Module):
127
+ """A matrix stored in sparse format.
128
+
129
+ Underlying format is block compressed sparse row (BCSR).
130
+
131
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
132
+ """
133
+
134
+ def __init__(self,
135
+ size,
136
+ data,
137
+ row_indices,
138
+ column_indices,
139
+ offsets,
140
+ column_indices_t=None,
141
+ offsets_t=None,
142
+ block_offsets_t=None):
143
+ super().__init__()
144
+ self._size = size
145
+ self._data = data
146
+ self._row_indices = row_indices
147
+ self._column_indices = column_indices
148
+ self._offsets = offsets
149
+
150
+ # Produce the transpose meta-data if it is not passed in.
151
+ if ((column_indices_t is None) or (offsets_t is None) or
152
+ (block_offsets_t is None)):
153
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
154
+ size, data, row_indices, column_indices, offsets)
155
+ self._column_indices_t = column_indices_t
156
+ self._offsets_t = offsets_t
157
+ self._block_offsets_t = block_offsets_t
158
+
159
+ self._transposed = False
160
+
161
+ # Validate that our metadata will not overflow.
162
+ max_dim = np.iinfo(np.int16).max * self.blocking
163
+ if column_indices.dtype == torch.int16:
164
+ if size[0] > max_dim or size[1] > max_dim:
165
+ raise ValueError(
166
+ "Sparse matrix with shape {size} exceeds representable "
167
+ "size with 16-bit indices.")
168
+
169
+ def validate(self):
170
+ _validate_matrix(self._size,
171
+ self._data,
172
+ self._row_indices,
173
+ self._column_indices,
174
+ self._offsets)
175
+
176
+ # TODO(tgale): Add heavyweight data validation.
177
+
178
+ def to(self, device):
179
+ # TODO(tgale): Handle type conversions here. We
180
+ # need to set the appropriate meta-data type for
181
+ # the given floating-point type.
182
+ self._data = self._data.to(device)
183
+ self._row_indices = self._row_indices.to(device)
184
+ self._column_indices = self._column_indices.to(device)
185
+ self._offsets = self._offsets.to(device)
186
+ self._column_indices_t = self._column_indices_t.to(device)
187
+ self._offsets_t = self._offsets_t.to(device)
188
+ self._block_offsets_t = self._block_offsets_t.to(device)
189
+ return self
190
+
191
+ def cuda(self):
192
+ return self.to(torch.cuda.current_device())
193
+
194
+ def clone(self):
195
+ return Matrix(
196
+ self.size(),
197
+ self.data.clone(),
198
+ self.row_indices.clone(),
199
+ self.column_indices.clone(),
200
+ self.offsets.clone(),
201
+ self.column_indices_t.clone(),
202
+ self.offsets_t.clone(),
203
+ self.block_offsets_t.clone())
204
+
205
+ def t(self):
206
+ if self.dim() != 2:
207
+ raise ValueError(
208
+ "t() expects a tensor with <= 2 dimensions, "
209
+ f"but self is {self.dim()}D.")
210
+ out = Matrix(self.size(),
211
+ self.data,
212
+ self.row_indices,
213
+ self.column_indices,
214
+ self.offsets,
215
+ self.column_indices_t,
216
+ self.offsets_t,
217
+ self.block_offsets_t)
218
+ out._transposed = not self._transposed
219
+ out._size = torch.Size((self._size[1], self._size[0]))
220
+ return out
221
+
222
+ def contiguous(self):
223
+ raise ValueError("Not yet implemented.")
224
+
225
+ def is_contiguous(self):
226
+ return not self._transposed
227
+
228
+ @property
229
+ def is_cuda(self):
230
+ return self._data.is_cuda
231
+
232
+ @property
233
+ def device(self):
234
+ return self._data.device
235
+
236
+ def size(self):
237
+ return self._size
238
+
239
+ @property
240
+ def shape(self):
241
+ return self.size()
242
+
243
+ def dim(self):
244
+ return len(self._size)
245
+
246
+ @property
247
+ def data(self):
248
+ return self._data
249
+
250
+ @property
251
+ def row_indices(self):
252
+ return self._row_indices
253
+
254
+ @property
255
+ def column_indices(self):
256
+ return self._column_indices
257
+
258
+ @property
259
+ def offsets(self):
260
+ return self._offsets
261
+
262
+ @property
263
+ def offsets_t(self):
264
+ return self._offsets_t
265
+
266
+ @property
267
+ def column_indices_t(self):
268
+ return self._column_indices_t
269
+
270
+ @property
271
+ def block_offsets_t(self):
272
+ return self._block_offsets_t
273
+
274
+ @property
275
+ def dtype(self):
276
+ return self.data.dtype
277
+
278
+ @property
279
+ def nnz(self):
280
+ return self.data.numel()
281
+
282
+ @property
283
+ def blocking(self):
284
+ return self.data.shape[1]
285
+
286
+ @property
287
+ def requires_grad(self):
288
+ return self.data.requires_grad
289
+
290
+ def requires_grad_(self, x):
291
+ self.data.requires_grad_(x)
292
+ return self
293
+
294
+ def view(self, *shape):
295
+ assert self.is_contiguous()
296
+ if shape[-1] != self.size()[-1]:
297
+ raise ValueError(
298
+ "Can't change view on compressed dimension. "
299
+ f"{self.size()[-1]} v. {shape[-1]}.")
300
+ if np.prod(shape) != np.prod(self.size()):
301
+ raise ValueError(
302
+ "Mismatch in numel of Matrix and new shape. "
303
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
304
+ return Matrix(shape,
305
+ self.data,
306
+ self.row_indices,
307
+ self.column_indices,
308
+ self.offsets,
309
+ self.column_indices_t,
310
+ self.offsets_t,
311
+ self.block_offsets_t)
312
+
313
+ @property
314
+ def grad(self):
315
+ # TODO(tgale): Make sure this mirrors torch.Tensor
316
+ # behavior in the case where we ask for the gradient
317
+ # of a non-contiguous tensor.
318
+ size = self.size()
319
+ if not self.is_contiguous():
320
+ size = torch.Size((size[1], size[0]))
321
+ out = Matrix(size,
322
+ self.data.grad,
323
+ self.row_indices,
324
+ self.column_indices,
325
+ self.offsets,
326
+ self.column_indices_t,
327
+ self.offsets_t,
328
+ self.block_offsets_t)
329
+ return out if self.is_contiguous() else out.t()
torch-ext/megablocks/stk/ops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .linear_ops import dds, dsd, sdd
2
+ from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
3
+ from .eltwise_ops import mul
torch-ext/megablocks/stk/ops/eltwise_ops.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..matrix import Matrix
2
+
3
+ def mul(a, b):
4
+ """Performs element-wise multiplication of matrices a and b.
5
+
6
+ It is the user's responsibility to make sure that a and b
7
+ follow the same matrix topology. This function assumes it is safe
8
+ to use the topoplogy of a.
9
+
10
+ Args:
11
+ a: stk.Matrix.
12
+ b: stk.Matrix with a's matrix topology.
13
+
14
+ Returns:
15
+ stk.Matrix where the entries correspond to torch.mul(a, b).
16
+ """
17
+ assert isinstance(a, Matrix)
18
+ assert isinstance(b, Matrix)
19
+ assert a.size() == b.size()
20
+
21
+ return Matrix(a.size(),
22
+ a.data * b.data,
23
+ a.row_indices,
24
+ a.column_indices,
25
+ a.offsets,
26
+ a.column_indices_t,
27
+ a.offsets_t,
28
+ a.block_offsets_t)
torch-ext/megablocks/stk/ops/eltwise_ops_test.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import torch
4
+ from absl.testing import parameterized
5
+
6
+ import stk
7
+ from stk.ops.linear_ops_test import allclose, _dense_and_sparse
8
+
9
+ _MATRIX_SIZES = (
10
+ (128, 128, 0.0),
11
+ (256, 256, 0.5),
12
+ (2048, 1024, 0.8),
13
+ (512, 128, 0.0),
14
+ (128, 512, 0.0),
15
+ (1024, 512, 0.0),
16
+ (1024, 512, 0.5),
17
+ (1024, 512, 0.75),
18
+ (512, 1024, 0.0),
19
+ (512, 1024, 0.5),
20
+ (512, 1024, 0.75),
21
+ (1024, 1024, 0.0),
22
+ (1024, 1024, 0.5),
23
+ (1024, 1024, 0.75),
24
+ )
25
+
26
+ _DTYPE = (
27
+ torch.float16, torch.bfloat16
28
+ )
29
+
30
+ def _generate_testcases():
31
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
32
+ testcases = [(*size, 128, dtype) for
33
+ (size, dtype) in testcases]
34
+ return testcases
35
+
36
+ _ELTWISE_OP_TESTS = _generate_testcases()
37
+
38
+ def _dense_and_sparse_like(x, std=0.1):
39
+ dense_data = torch.randn_like(x.data, device=x.device) * std
40
+ sparse = stk.Matrix(x.size(),
41
+ dense_data,
42
+ x.row_indices,
43
+ x.column_indices,
44
+ x.offsets)
45
+ dense = stk.ops.to_dense(sparse)
46
+
47
+ return (dense.requires_grad_(True),
48
+ sparse.requires_grad_(True))
49
+
50
+ @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ class EltwiseOpsTest(parameterized.TestCase):
52
+
53
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+
55
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ b_dense, b = _dense_and_sparse_like(a)
57
+
58
+ out = stk.ops.mul(a, b)
59
+ expected_out = torch.mul(a_dense, b_dense)
60
+
61
+ # Compute the gradients w.r.t. the inputs.
62
+ expected_out.sum().backward()
63
+ stk.ops.sum(out).backward()
64
+
65
+ # Validate the results.
66
+ out = stk.ops.to_dense(out)
67
+ self.assertEqual(out.dim(), 2)
68
+ self.assertEqual(expected_out.size(), out.size())
69
+ self.assertTrue(allclose(out, expected_out))
70
+
71
+ # LHS gradient.
72
+ grad = stk.ops.to_dense(a.grad)
73
+ expected_grad = a_dense.grad
74
+ self.assertEqual(grad.dim(), 2)
75
+ self.assertEqual(expected_grad.size(), grad.size())
76
+ self.assertTrue(allclose(grad, expected_grad))
77
+
78
+ # RHS gradient.
79
+ grad = stk.ops.to_dense(b.grad)
80
+ expected_grad = b_dense.grad
81
+ self.assertEqual(grad.dim(), 2)
82
+ self.assertEqual(expected_grad.size(), grad.size())
83
+ self.assertTrue(allclose(grad, expected_grad))
84
+
85
+ if __name__ == '__main__':
86
+ unittest.main()
torch-ext/megablocks/stk/ops/linear_ops.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..backend import sputnik
4
+ from ..matrix import Matrix
5
+
6
+
7
+ def dsd(a, b):
8
+ assert isinstance(a, Matrix)
9
+ assert isinstance(b, torch.Tensor)
10
+ return sputnik.dsd(
11
+ a.size(),
12
+ a.data, a.offsets,
13
+ a.row_indices,
14
+ a.column_indices,
15
+ a.offsets_t,
16
+ a.column_indices_t,
17
+ a.block_offsets_t,
18
+ not a.is_contiguous(),
19
+ b)
20
+
21
+
22
+ def dds(a, b):
23
+ assert isinstance(a, torch.Tensor)
24
+ assert isinstance(b, Matrix)
25
+ return sputnik.dds(
26
+ a,
27
+ b.size(),
28
+ b.data, b.offsets,
29
+ b.row_indices,
30
+ b.column_indices,
31
+ b.offsets_t,
32
+ b.column_indices_t,
33
+ b.block_offsets_t,
34
+ not b.is_contiguous())
35
+
36
+
37
+ def sdd(a, b, topo):
38
+ assert isinstance(a, torch.Tensor)
39
+ assert isinstance(b, torch.Tensor)
40
+ assert isinstance(topo, Matrix)
41
+ assert topo.is_contiguous()
42
+ out = sputnik.sdd(
43
+ a, b,
44
+ topo.size(),
45
+ topo.data,
46
+ topo.offsets,
47
+ topo.row_indices,
48
+ topo.column_indices,
49
+ topo.offsets_t,
50
+ topo.column_indices_t,
51
+ topo.block_offsets_t)
52
+ return Matrix(topo.size(),
53
+ out,
54
+ topo.row_indices,
55
+ topo.column_indices,
56
+ topo.offsets,
57
+ topo.column_indices_t,
58
+ topo.offsets_t,
59
+ topo.block_offsets_t)
torch-ext/megablocks/stk/ops/linear_ops_test.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+ import numpy as np
4
+ import torch
5
+ from absl.testing import parameterized
6
+
7
+ import stk
8
+
9
+
10
+ def allclose(x, y, pct=0.25):
11
+ mask = torch.isclose(x, y, rtol=5e-2)
12
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
13
+ if pct_diff > pct:
14
+ print("{:.2f}% of values not close.".format(pct_diff))
15
+ return False
16
+ return True
17
+
18
+
19
+ # An assortment of problems designed to make sure
20
+ # the bindings are operating correctly.
21
+ _MATRIX_SIZES = (
22
+ (128, 128, 128, 0.0),
23
+ (256, 256, 256, 0.5),
24
+ (2048, 1024, 512, 0.8),
25
+ (512, 128, 128, 0.0),
26
+ (128, 128, 512, 0.0),
27
+ (1024, 512, 512, 0.0),
28
+ (1024, 512, 512, 0.5),
29
+ (1024, 512, 512, 0.75),
30
+ (512, 512, 1024, 0.0),
31
+ (512, 512, 1024, 0.5),
32
+ (512, 512, 1024, 0.75),
33
+ (1024, 1024, 1024, 0.0),
34
+ (1024, 1024, 1024, 0.5),
35
+ (1024, 1024, 1024, 0.75),
36
+ )
37
+
38
+ _TRANSPOSE = (
39
+ (False, False),
40
+ (False, True),
41
+ (True, False),
42
+ (True, True),
43
+ )
44
+
45
+ _DTYPE = (
46
+ torch.float16, torch.bfloat16
47
+ )
48
+
49
+ def _generate_testcases():
50
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
51
+ testcases = [(*size, *trans, 128, dtype) for
52
+ (size, trans, dtype) in testcases]
53
+ return testcases
54
+
55
+ _LINEAR_OP_TESTS = _generate_testcases()
56
+
57
+ def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
58
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
59
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
60
+ sparse = stk.ops.to_sparse(dense, blocking)
61
+ cuda_device = torch.device("cuda")
62
+ return (dense.to(cuda_device).requires_grad_(True),
63
+ sparse.to(cuda_device).requires_grad_(True))
64
+
65
+
66
+ def _dense(rows, cols, dtype, std=0.1):
67
+ cuda_device = torch.device("cuda")
68
+ out = (torch.randn(rows, cols) * std).type(dtype)
69
+ return out.to(cuda_device).requires_grad_(True)
70
+
71
+
72
+ def _dense_2x(rows, cols, dtype):
73
+ a = _dense(rows, cols, dtype)
74
+ return a, a.detach().requires_grad_(True)
75
+
76
+
77
+ def _with_transpose(op, a, b, trans_a, trans_b):
78
+ a = a.t() if trans_a else a
79
+ b = b.t() if trans_b else b
80
+ return op(a, b)
81
+
82
+
83
+ def _mmm(a, b, topo):
84
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
85
+ return torch.mm(a, b) * mask
86
+
87
+
88
+ def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
89
+ a = a.t() if trans_a else a
90
+ b = b.t() if trans_b else b
91
+ return op(a, b, topo)
92
+
93
+
94
+ def _mask(x, mask):
95
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
96
+ return x * mask
97
+
98
+
99
+ @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ class LinearOpsTest(parameterized.TestCase):
101
+
102
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # Construct the operands.
104
+ a_shape = (k, m) if trans_a else (m, k)
105
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ b_shape = (n, k) if trans_b else (k, n)
107
+ b, bcp = _dense_2x(*b_shape, dtype)
108
+
109
+ # Execute the matmul.
110
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+
113
+ # Compute the gradients w.r.t. the inputs.
114
+ expected_out.sum().backward()
115
+ out.sum().backward()
116
+
117
+ # Validate the results.
118
+ self.assertEqual(out.dim(), 2)
119
+ self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ self.assertTrue(allclose(out, expected_out))
122
+
123
+ # LHS gradient.
124
+ grad = stk.ops.to_dense(a.grad)
125
+ expected_grad = _mask(a_dense.grad, a.grad)
126
+ self.assertEqual(grad.dim(), 2)
127
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ self.assertTrue(allclose(grad, expected_grad))
130
+
131
+ # RHS gradient.
132
+ grad = b.grad
133
+ expected_grad = bcp.grad
134
+ self.assertEqual(grad.dim(), 2)
135
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ self.assertTrue(allclose(grad, expected_grad))
138
+
139
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # Construct the operands.
141
+ a_shape = (k, m) if trans_a else (m, k)
142
+ a, acp = _dense_2x(*a_shape, dtype)
143
+ b_shape = (n, k) if trans_b else (k, n)
144
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+
146
+ # Execute the matmul.
147
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+
150
+ # Compute the gradients w.r.t. the inputs.
151
+ expected_out.sum().backward()
152
+ out.sum().backward()
153
+
154
+ # Validate the results.
155
+ self.assertEqual(out.dim(), 2)
156
+ self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ self.assertTrue(allclose(out, expected_out))
159
+
160
+ # LHS gradient.
161
+ grad = a.grad
162
+ expected_grad = acp.grad
163
+ self.assertEqual(grad.dim(), 2)
164
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ self.assertTrue(allclose(grad, expected_grad))
167
+
168
+ # RHS gradient.
169
+ grad = stk.ops.to_dense(b.grad)
170
+ expected_grad = _mask(b_dense.grad, b.grad)
171
+ self.assertEqual(grad.dim(), 2)
172
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ self.assertTrue(allclose(grad, expected_grad))
175
+
176
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # Construct the operands.
178
+ a_shape = (k, m) if trans_a else (m, k)
179
+ a, acp = _dense_2x(*a_shape, dtype)
180
+ b_shape = (n, k) if trans_b else (k, n)
181
+ b, bcp = _dense_2x(*b_shape, dtype)
182
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+
184
+ # Execute the matmul.
185
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+
188
+ # Compute the gradients w.r.t. the inputs.
189
+ expected_out.sum().backward()
190
+ stk.ops.sum(out).backward()
191
+
192
+ # Validate the results.
193
+ out = stk.ops.to_dense(out)
194
+ self.assertEqual(out.dim(), 2)
195
+ self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ self.assertTrue(allclose(out, expected_out))
198
+
199
+ # LHS gradient.
200
+ grad = a.grad
201
+ expected_grad = acp.grad
202
+ self.assertEqual(grad.dim(), 2)
203
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ self.assertTrue(allclose(grad, expected_grad))
206
+
207
+ # RHS gradient.
208
+ grad = b.grad
209
+ expected_grad = bcp.grad
210
+ self.assertEqual(grad.dim(), 2)
211
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ self.assertTrue(allclose(grad, expected_grad))
214
+
215
+ if __name__ == '__main__':
216
+ unittest.main()
torch-ext/megablocks/stk/ops/matrix_ops.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..backend import sputnik
2
+ from ..matrix import Matrix
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ @torch.no_grad()
8
+ def row_indices(shape, data, offsets, column_indices):
9
+ return sputnik.row_indices(shape, data, offsets, column_indices)
10
+
11
+
12
+ # TODO(tgale): Replace this helper with a custom kernel. This operation
13
+ # is much simpler to do than how it's currently implemented.
14
+ @torch.no_grad()
15
+ def _expand_for_blocking(idxs, blocking):
16
+ # Duplicate for block column dimension.
17
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
18
+
19
+ # Update the column indices.
20
+ idxs[:, :, 1] *= blocking
21
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
22
+
23
+ # Duplicate for block row dimension.
24
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
25
+ idxs = idxs.repeat(1, blocking, 1, 1)
26
+
27
+ # Update the row indices.
28
+ idxs[:, :, :, 0] *= blocking
29
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
30
+ idxs = torch.reshape(idxs, [-1, 2])
31
+ return idxs
32
+
33
+
34
+ # TODO(tgale): Add input type checking.
35
+ @torch.no_grad()
36
+ def to_dense(x):
37
+ assert isinstance(x, Matrix)
38
+
39
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
40
+ row_idxs = x.row_indices.type(torch.int32)
41
+ col_idxs = x.column_indices.type(torch.int32)
42
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
43
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
44
+
45
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
46
+ out.scatter_(0, indices, x.data.flatten())
47
+ return out.reshape(x.size())
48
+
49
+
50
+ @torch.no_grad()
51
+ def _mask(x, blocking=1):
52
+ assert x.dim() == 2
53
+ assert x.size()[0] % blocking == 0
54
+ assert x.size()[1] % blocking == 0
55
+ block_rows = x.size()[0] // blocking
56
+ block_cols = x.size()[1] // blocking
57
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
58
+ x = torch.sum(torch.abs(x), dim=(1, 3))
59
+ return x != 0
60
+
61
+
62
+ # TODO(tgale): Add input type checking.
63
+ @torch.no_grad()
64
+ def to_sparse(x, blocking=1):
65
+ m = _mask(x, blocking)
66
+
67
+ # TODO(tgale): Set to appropriate type for input matrix.
68
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
69
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
70
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
71
+ offsets = offsets.type(torch.int32)
72
+
73
+ indices = torch.nonzero(m).type(torch.int16)
74
+ row_indices = indices[:, 0]
75
+ column_indices = indices[:, 1]
76
+
77
+ # Nonzero indices in the dense matrix.
78
+ nonzero_indices = torch.nonzero(m)
79
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
80
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
81
+
82
+ # Gather the data and construct the sparse matrix.
83
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
84
+ data = torch.reshape(data, [-1, blocking, blocking])
85
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
86
+
87
+
88
+ @torch.no_grad()
89
+ def ones_like(x):
90
+ return Matrix(x.size(),
91
+ torch.ones_like(x.data),
92
+ x.row_indices,
93
+ x.column_indices, x.offsets)
94
+
95
+
96
+ def sum(x):
97
+ assert isinstance(x, Matrix)
98
+ return x.data.sum()
torch-ext/megablocks/stk/ops/matrix_ops_test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ import stk
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class MatrixOpsTest(parameterized.TestCase):
25
+
26
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+
30
+ # Convert the matrix to sparse format.
31
+ sparse_x = stk.ops.to_sparse(x, blocking)
32
+
33
+ # Validate the matrix.
34
+ sparse_x.validate()
35
+
36
+ # Validate the shape.
37
+ self.assertEqual(sparse_x.dim(), 2)
38
+ self.assertEqual(sparse_x.size()[0], rows)
39
+ self.assertEqual(sparse_x.size()[1], cols)
40
+
41
+ # Validate the sparsity.
42
+ numblocks = rows // blocking * cols // blocking
43
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ self.assertEqual(sparse_x.nnz, nnz)
45
+
46
+ # Convert back to dense format.
47
+ dense_x = stk.ops.to_dense(sparse_x)
48
+
49
+ # Validate the shape.
50
+ self.assertEqual(dense_x.dim(), 2)
51
+ self.assertEqual(dense_x.size()[0], rows)
52
+ self.assertEqual(dense_x.size()[1], cols)
53
+
54
+ # Validate the sparsity
55
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+
57
+ # Validate the output.
58
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
+
60
+
61
+ if __name__ == '__main__':
62
+ unittest.main()
torch-ext/megablocks/stk/random/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from stk.random.random_ops import dense_mask, mask, randn
2
+ from .random_ops import dense_mask, mask, randn
torch-ext/megablocks/stk/random/random_ops.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..ops import matrix_ops
4
+
5
+
6
+ @torch.no_grad()
7
+ def dense_mask(rows, cols, sparsity, blocking=1):
8
+ assert sparsity >= 0.0 and sparsity <= 1.0
9
+ assert rows % blocking == 0 and cols % blocking == 0
10
+
11
+ block_rows, block_cols = (rows // blocking, cols // blocking)
12
+ nnz = round(block_rows * block_cols * (1 - sparsity))
13
+
14
+ out = np.ones(block_rows * block_cols)
15
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
16
+ out[mask] = 0.0
17
+
18
+ out = np.tile(
19
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
20
+ (1, blocking, 1, blocking))
21
+ out = np.reshape(out, [rows, cols])
22
+ return torch.from_numpy(out.astype(np.float32))
23
+
24
+
25
+ @torch.no_grad()
26
+ def mask(m, n, sparsity, blocking=1):
27
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
28
+ return matrix_ops.to_sparse(out, blocking=blocking)
29
+
30
+
31
+ @torch.no_grad()
32
+ def randn(shape, sparsity, blocking=1):
33
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
34
+ out = mask(*shape_2d, sparsity, blocking)
35
+ out.data.copy_(torch.randn(*out.data.shape))
36
+ return out.view(*shape)
torch-ext/megablocks/stk/random/random_ops_test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from absl.testing import parameterized
4
+ from . import random
5
+ import torch
6
+
7
+
8
+ @parameterized.parameters(
9
+ (8, 16, 0.0, 1),
10
+ (8, 16, 0.5, 1),
11
+ (8, 16, .95, 1),
12
+ (16, 8, 0.0, 1),
13
+ (16, 8, 0.5, 1),
14
+ (16, 8, .95, 1),
15
+ (8, 16, 0.0, 8),
16
+ (8, 16, 0.5, 8),
17
+ (8, 16, 1.0, 8),
18
+ (16, 8, 0.0, 8),
19
+ (16, 8, 0.5, 8),
20
+ (16, 8, 1.0, 8),
21
+ (128, 256, 0.5, 16),
22
+ (256, 128, 0.75, 32),
23
+ (512, 512, .875, 128))
24
+ class RandomOpsTest(parameterized.TestCase):
25
+
26
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ mask = random.dense_mask(
28
+ rows, cols, sparsity, blocking)
29
+
30
+ # Validate the shape.
31
+ self.assertEqual(mask.dim(), 2)
32
+ self.assertEqual(mask.size()[0], rows)
33
+ self.assertEqual(mask.size()[1], cols)
34
+
35
+ # Validate the sparsity
36
+ numblocks = rows // blocking * cols // blocking
37
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ self.assertEqual(
39
+ torch.count_nonzero(mask).item(),
40
+ nnz)
41
+
42
+ # Check values are zero or one.
43
+ self.assertTrue(
44
+ torch.all(torch.logical_or(
45
+ torch.eq(mask, 0),
46
+ torch.eq(mask, 1))))
47
+
48
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ mask = random.mask(
50
+ rows, cols, sparsity, blocking)
51
+
52
+ # Validate the matrix.
53
+ mask.validate()
54
+
55
+ # Validate the shape.
56
+ self.assertEqual(mask.dim(), 2)
57
+ self.assertEqual(mask.size()[0], rows)
58
+ self.assertEqual(mask.size()[1], cols)
59
+
60
+ # Validate the sparsity.
61
+ numblocks = rows // blocking * cols // blocking
62
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ self.assertEqual(mask.nnz, nnz)
64
+
65
+ # Check values are zero or one.
66
+ self.assertTrue(
67
+ torch.all(torch.logical_or(
68
+ torch.eq(mask.data, 0),
69
+ torch.eq(mask.data, 1))))
70
+
71
+
72
+ if __name__ == '__main__':
73
+ unittest.main()