Spaces:
Runtime error
Runtime error
int channelnorm_cuda_forward( | |
at::Tensor& input1, | |
at::Tensor& output, | |
int norm_deg) { | |
channelnorm_kernel_forward(input1, output, norm_deg); | |
return 1; | |
} | |
int channelnorm_cuda_backward( | |
at::Tensor& input1, | |
at::Tensor& output, | |
at::Tensor& gradOutput, | |
at::Tensor& gradInput1, | |
int norm_deg) { | |
channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); | |
return 1; | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); | |
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); | |
} | |