Various small fixes
Browse files
flake.lock
CHANGED
|
@@ -98,11 +98,11 @@
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
-
"lastModified":
|
| 102 |
-
"narHash": "sha256-
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
-
"rev": "
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
|
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
+
"lastModified": 1751014803,
|
| 102 |
+
"narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
+
"rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
flake.nix
CHANGED
|
@@ -13,5 +13,37 @@
|
|
| 13 |
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
path = ./.;
|
| 15 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
};
|
| 17 |
}
|
|
|
|
| 13 |
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
path = ./.;
|
| 15 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
# Building with CDUA later than 12.4 fails with:
|
| 17 |
+
#
|
| 18 |
+
# error: 'ptxas' died due to signal 11 (Invalid memory reference)
|
| 19 |
+
#
|
| 20 |
+
# So, build for 12.4 only and copy to all the other build variants
|
| 21 |
+
# by hand (which works fine thanks to backward compat).
|
| 22 |
+
torchVersions = [
|
| 23 |
+
{
|
| 24 |
+
torchVersion = "2.6";
|
| 25 |
+
cudaVersion = "12.4";
|
| 26 |
+
cxx11Abi = false;
|
| 27 |
+
systems = [ "x86_64-linux" ];
|
| 28 |
+
upstreamVariant = true;
|
| 29 |
+
}
|
| 30 |
+
{
|
| 31 |
+
torchVersion = "2.6";
|
| 32 |
+
cudaVersion = "12.4";
|
| 33 |
+
cxx11Abi = true;
|
| 34 |
+
systems = [ "x86_64-linux" ];
|
| 35 |
+
upstreamVariant = true;
|
| 36 |
+
}
|
| 37 |
+
{
|
| 38 |
+
torchVersion = "2.7";
|
| 39 |
+
cudaVersion = "12.4";
|
| 40 |
+
cxx11Abi = true;
|
| 41 |
+
systems = [
|
| 42 |
+
"x86_64-linux"
|
| 43 |
+
"aarch64-linux"
|
| 44 |
+
];
|
| 45 |
+
upstreamVariant = true;
|
| 46 |
+
}
|
| 47 |
+
];
|
| 48 |
};
|
| 49 |
}
|
torch-ext/{flash_attn → flash_attn3}/__init__.py
RENAMED
|
File without changes
|
torch-ext/{flash_attn → flash_attn3}/flash_attn_interface.py
RENAMED
|
File without changes
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
#include "torch_binding.h"
|
| 6 |
|
| 7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 8 |
-
|
| 9 |
"Tensor q,"
|
| 10 |
"Tensor k,"
|
| 11 |
"Tensor v,"
|
|
@@ -40,7 +40,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 40 |
"int num_splits = 0,"
|
| 41 |
"bool? pack_gqa = None,"
|
| 42 |
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
| 43 |
-
|
| 44 |
"Tensor dout,"
|
| 45 |
"Tensor q,"
|
| 46 |
"Tensor k,"
|
|
@@ -63,12 +63,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 63 |
"float softcap = 0.0,"
|
| 64 |
"bool deterministic = False,"
|
| 65 |
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
| 66 |
-
|
| 67 |
"Tensor out_partial,"
|
| 68 |
"Tensor lse_partial,"
|
| 69 |
"Tensor(out!)? out = None,"
|
| 70 |
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
| 71 |
-
|
| 72 |
"int batch_size,"
|
| 73 |
"int max_seqlen_q,"
|
| 74 |
"int max_seqlen_k,"
|
|
@@ -94,10 +94,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 94 |
"bool? pack_gqa = None,"
|
| 95 |
"int sm_margin = 0) -> Tensor");
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
}
|
| 102 |
|
| 103 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 5 |
#include "torch_binding.h"
|
| 6 |
|
| 7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 8 |
+
ops.def("fwd("
|
| 9 |
"Tensor q,"
|
| 10 |
"Tensor k,"
|
| 11 |
"Tensor v,"
|
|
|
|
| 40 |
"int num_splits = 0,"
|
| 41 |
"bool? pack_gqa = None,"
|
| 42 |
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
| 43 |
+
ops.def("bwd("
|
| 44 |
"Tensor dout,"
|
| 45 |
"Tensor q,"
|
| 46 |
"Tensor k,"
|
|
|
|
| 63 |
"float softcap = 0.0,"
|
| 64 |
"bool deterministic = False,"
|
| 65 |
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
| 66 |
+
ops.def("fwd_combine("
|
| 67 |
"Tensor out_partial,"
|
| 68 |
"Tensor lse_partial,"
|
| 69 |
"Tensor(out!)? out = None,"
|
| 70 |
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
| 71 |
+
ops.def("get_scheduler_metadata("
|
| 72 |
"int batch_size,"
|
| 73 |
"int max_seqlen_q,"
|
| 74 |
"int max_seqlen_k,"
|
|
|
|
| 94 |
"bool? pack_gqa = None,"
|
| 95 |
"int sm_margin = 0) -> Tensor");
|
| 96 |
|
| 97 |
+
ops.impl("fwd", &mha_fwd);
|
| 98 |
+
ops.impl("bwd", &mha_bwd);
|
| 99 |
+
ops.impl("fwd_combine", &mha_combine);
|
| 100 |
+
ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
|
| 101 |
}
|
| 102 |
|
| 103 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|