namespace megablocks { | |
// Forward pass: replicate values from x according to bin sizes | |
void replicate_forward(torch::Tensor x, | |
torch::Tensor bins, | |
torch::Tensor out); | |
// Backward pass: reduce gradients back to bins using segmented reduction | |
void replicate_backward(torch::Tensor grad, | |
torch::Tensor bins, | |
torch::Tensor out); | |
} // namespace megablocks |