#pragma once #include void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps); void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, torch::Tensor &bias_grad, const torch::Tensor &output_grad, const torch::Tensor &input, const torch::Tensor &weight, double eps); torch::Tensor rms_norm(const torch::Tensor &input, const torch::Tensor &weights, double eps); std::tuple rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input, const torch::Tensor &weight, double eps); void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &mul, const torch::Tensor &weights, const torch::Tensor &bias, double eps); void fused_mul_poly_norm_backward( torch::Tensor &input_grad, torch::Tensor &mul_grad, torch::Tensor &weight_grad, torch::Tensor &bias_grad, const torch::Tensor &output_grad, const torch::Tensor &input, const torch::Tensor &mul, const torch::Tensor &weight, const torch::Tensor &bias, double eps); std::tuple fused_add_rms_norm(const torch::Tensor &input, const torch::Tensor &residual, const torch::Tensor &weight, double eps); std::tuple fused_add_rms_norm_backward( const torch::Tensor &output_grad, const torch::Tensor &add_output_grad, const torch::Tensor &input, const torch::Tensor &weight, double eps, bool need_input_grad);