venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from torch.autograd import Function
import voxlib
"""
It takes world coordinate as input instead of block-local coordinate. Corner IDs are looked up on-the-fly to
save memory.
"""
class SparseTrilinearWorldCoordFunction(Function):
@staticmethod
def forward(ctx, in_feature, corner_lut_t, in_worldcoord, ign_zero):
out_feature = voxlib.sp_trilinear_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero, -1)
ctx.ign_zero = ign_zero
ctx.save_for_backward(in_feature, corner_lut_t, in_worldcoord)
return out_feature
@staticmethod
def backward(ctx, out_feature_grad):
in_feature, corner_lut_t, in_worldcoord = ctx.saved_tensors
assert ctx.needs_input_grad[2] is False
in_feature_grad, = voxlib.sp_trilinear_worldcoord_backward(
out_feature_grad, in_feature, corner_lut_t, in_worldcoord, ctx.ign_zero, False)
return in_feature_grad, None, None, None, None
def sparse_trilinear_interp_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero=False):
return SparseTrilinearWorldCoordFunction.apply(in_feature, corner_lut_t, in_worldcoord, ign_zero)