kernel
File size: 383 Bytes
3224250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#pragma once

// // Set default if not already defined
// #ifndef GROUPED_GEMM_CUTLASS
// #define GROUPED_GEMM_CUTLASS 0
// #endif

// #include <torch/extension.h>
#include <torch/torch.h>

namespace grouped_gemm {

void GroupedGemm(torch::Tensor a,
		 torch::Tensor b,
		 torch::Tensor c,
		 torch::Tensor batch_sizes,
		 bool trans_a, bool trans_b);

}  // namespace grouped_gemm