kernel
drbh commited on
Commit
0586ba6
·
1 Parent(s): 2b84d84

fix: vendor stk decorators

Browse files
torch-ext/megablocks/ops/binned_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
torch-ext/megablocks/ops/binned_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
torch-ext/megablocks/ops/gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
torch-ext/megablocks/ops/padded_gather.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
torch-ext/megablocks/ops/padded_scatter.py CHANGED
@@ -3,7 +3,7 @@
3
  from typing import Any
4
 
5
  import torch
6
- from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
 
3
  from typing import Any
4
 
5
  import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
 
8
  from ..backend import kernels
9
 
torch-ext/megablocks/ops/scatter.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
 
9
  from ..backend import kernels
10
 
torch-ext/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd