diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..cce51bdbf0ec2d65660f1b3cfb1580b616156d14 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fbec6fa49d1b926d45b39b7e8393e06ee9622d0012501adaec213cb5802c86d +size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index 779e13d3baac0ee00e89944b954a8bc75fcb432f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a20cd4dc15095b8504db981c651e516e8a7d8394b99d973d632558637c8dba9 -size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..240f5140b5ff0f7c813fcb49a45c561b4980891e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16141033c118b488348a29f3436f778764f8f4275fe510dc36badb7c152e0d42 +size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index e349ce35ea4dd2273fcca6389af58dd2a08bff48..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a3f69e5978b727f08b43112c2321a222719aa824612d452029225a48976dfbb6 -size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9c0b2b801efed567532179b12a8f3d04450e858e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ea768d3d4780563159dd50075ed14d51166e5c3de9f5bd132047cfa6a23ef48 +size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index f6c25f44514edb4a00c0afaed50e5d80b8d07261..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:57540f7b6eae09c2c62826d13dfa2be53eaa37c86206df5914611a3fad9878ba -size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5380a6cbd20a15cdbae5b3d385ace2dc2e9da000 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:baacdb2bd8bcd004a86f63b0dc2754bac21214c9432bf6c00c464ccc26c25a83 +size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index b9f0146d0a5676550971b4bf724c153bfd6f200f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:defeb9a48abe98940478c79c5eac52f9fc7c22088abf9d119191559787bb95a9 -size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b7fe786b99bfb5b162e6633b91ccdd41c041c897 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20a8e0a793ac29bc168d10e1c9e465082c2adb1582ff79d1a083798f9a955a5f +size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index 9f9e0e7706d6ae183a1e78ed2156b4e74e2a3ffd..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b6cc8f982e35bfa07a9121807bacd3c3572d6ecb1495bcb2b6286b967fb20d58 -size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8b4fdfa424e62d65fd64e4a272b1f4e33b9f6d70 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2f5209e69d36d632939923c20ab90c074fe0100d8a4efbabe5cdcd32ccbcfd2 +size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index 60b4dab63ebafbe4c9318da846c2d6814b91376b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:be9d5e5df42fd6d0db62397eae9c462b9775e952ce7f71fb687c3ea75dfe6a74 -size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4668fc389f1ca9668919303afe1890f8e92abff7 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1326299d5c4185310ff6f43fc0b3b71d48b1bd31001c954b3388d3cb5e08fbc +size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index 69f059780cd3ae0839b2d079bc031ccaad4a4da6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9c297e2817e6b2dd9af4f1166e448844f785ef45ed769d0289766bc9169767df -size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e678ded725389f50ee596f97c1b3c7ddd4929902 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed275d43953fbc984a20503c0d55f56b337576bd43e2c94682b4de91a8df6c8d +size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index 96ff9ba92b272f12146d8eea4f96fe28724fc82d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7fbf97bdf349597f84616a22e3bc8c25cba2d77aef15dfa84bde284ffa51fe38 -size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e707d1a69cfad2211ae17f034737b152b601a899 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_0586ba6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cffb3e3e44310bba45bcf82b07fc3a2b188bbeb8f50f04784e440ba3bdf5fc0f +size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so deleted file mode 100755 index a723817e44e7aec8375f871ad1bf14ee404c7d1d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bac66fb2798cffcb10563bb1ad157a5e1d231dcf1fc615825b6c8e6f6b297d20 -size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py index 3cc5ae109015332d831f2357627d00ef67faa25d..f7aa5a700bc5673c743a3b0fb74107fab793fbad 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_dabb815 -ops = torch.ops._megablocks_dabb815 +from . import _megablocks_0586ba6 +ops = torch.ops._megablocks_0586ba6 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_dabb815::{op_name}" \ No newline at end of file + return f"_megablocks_0586ba6::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py index 8ae2ad8388b06db46a13f8fa46083619c44eefe2..189a7fa3518d660f29ea32e7a04827164af98d60 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py index 6d8654bdc718cd893f7899e9cdb6cd20544d189c..cb937c0c106662ce8108c1cb926f8f063b163d3d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py index 4edf4541dac52abab94151efb414acaa7711f8f6..f1f87c1e7bed8d3589dd790805234976e0b05898 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py index 0ffe1369d6adbce5c4a54155c636d1a4c022a41d..c1cf4047c9494394d2a3884ba8830179013db7ff 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py index 6685b0b83e7b1d92e7a772715991ead8ad94b153..61e021b81497e472cda5d72bdac557a0ca92d262 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py @@ -3,7 +3,7 @@ from typing import Any import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py index 3f4bacf3422543e11abb795c279876147f8610a8..f4605d9b46f387761b070352365f223dbfe69d47 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from stk.backend.autocast import custom_bwd, custom_fwd +from .stk_autocast import custom_bwd, custom_fwd from ..backend import kernels diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file