File size: 670 Bytes
44e9845
 
 
 
e85ecc9
 
f3b99fb
 
 
1
2
3
4
5
6
7
8
9
10
#pragma once

#include <torch/torch.h>

void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);

void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, double eps);
void rms_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);