kernel
drbh commited on
Commit
ef966b6
·
1 Parent(s): 0586ba6

fix: bump build

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +1 -1
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +1 -1
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  10. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  11. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  12. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  13. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  14. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py +1 -1
  15. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  16. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  17. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py +1 -1
  18. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  19. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  20. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  21. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  22. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  23. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +1 -1
  24. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  25. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  26. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +1 -1
  27. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  28. build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  29. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  30. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  31. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  32. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py +1 -1
  33. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  34. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  35. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py +1 -1
  36. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  37. build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  38. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  39. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  40. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  41. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py +1 -1
  42. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  43. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  44. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py +1 -1
  45. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  46. build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} +1 -1
  47. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  48. build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  49. build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  50. build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py +1 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a20cd4dc15095b8504db981c651e516e8a7d8394b99d973d632558637c8dba9
3
  size 10517576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fbec6fa49d1b926d45b39b7e8393e06ee9622d0012501adaec213cb5802c86d
3
  size 10517576
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3f69e5978b727f08b43112c2321a222719aa824612d452029225a48976dfbb6
3
  size 11869392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16141033c118b488348a29f3436f778764f8f4275fe510dc36badb7c152e0d42
3
  size 11869392
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:57540f7b6eae09c2c62826d13dfa2be53eaa37c86206df5914611a3fad9878ba
3
  size 11931048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ea768d3d4780563159dd50075ed14d51166e5c3de9f5bd132047cfa6a23ef48
3
  size 11931048
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:defeb9a48abe98940478c79c5eac52f9fc7c22088abf9d119191559787bb95a9
3
  size 10510040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baacdb2bd8bcd004a86f63b0dc2754bac21214c9432bf6c00c464ccc26c25a83
3
  size 10510040
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6cc8f982e35bfa07a9121807bacd3c3572d6ecb1495bcb2b6286b967fb20d58
3
  size 11857920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20a8e0a793ac29bc168d10e1c9e465082c2adb1582ff79d1a083798f9a955a5f
3
  size 11857920
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_dabb815.abi3.so → _megablocks_0586ba6.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:be9d5e5df42fd6d0db62397eae9c462b9775e952ce7f71fb687c3ea75dfe6a74
3
  size 11923672
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2f5209e69d36d632939923c20ab90c074fe0100d8a4efbabe5cdcd32ccbcfd2
3
  size 11923672
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_dabb815
3
- ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_dabb815::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_0586ba6
3
+ ops = torch.ops._megablocks_0586ba6
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_0586ba6::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9