Spaces:
Runtime error
Runtime error
import torch | |
from torch.autograd import Function | |
from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda | |
class Subtraction(Function): | |
def forward(ctx, input1, input2, idx): | |
""" | |
input: input1: (n, c), input2: (n, c), idx: (n, nsample) | |
output: (n, nsample, c) | |
""" | |
assert input1.is_contiguous() and input2.is_contiguous() | |
n, c = input1.shape | |
nsample = idx.shape[-1] | |
output = torch.cuda.FloatTensor(n, nsample, c).zero_() | |
subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) | |
ctx.save_for_backward(idx) | |
return output | |
def backward(ctx, grad_output): | |
""" | |
input: grad_out: (n, nsample, c) | |
output: grad_input1: (n, c), grad_input2: (n, c) | |
""" | |
(idx,) = ctx.saved_tensors | |
n, nsample, c = grad_output.shape | |
grad_input1 = torch.cuda.FloatTensor(n, c).zero_() | |
grad_input2 = torch.cuda.FloatTensor(n, c).zero_() | |
subtraction_backward_cuda( | |
n, nsample, c, idx, grad_output, grad_input1, grad_input2 | |
) | |
return grad_input1, grad_input2, None | |
subtraction = Subtraction.apply | |