#pragma once #include namespace megablocks { // Forward pass: replicate values from x according to bin sizes void replicate_forward(torch::Tensor x, torch::Tensor bins, torch::Tensor out); // Backward pass: reduce gradients back to bins using segmented reduction void replicate_backward(torch::Tensor grad, torch::Tensor bins, torch::Tensor out); } // namespace megablocks