Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.autograd import Function | |
| from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda | |
| class Subtraction(Function): | |
| def forward(ctx, input1, input2, idx): | |
| """ | |
| input: input1: (n, c), input2: (n, c), idx: (n, nsample) | |
| output: (n, nsample, c) | |
| """ | |
| assert input1.is_contiguous() and input2.is_contiguous() | |
| n, c = input1.shape | |
| nsample = idx.shape[-1] | |
| output = torch.cuda.FloatTensor(n, nsample, c).zero_() | |
| subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) | |
| ctx.save_for_backward(idx) | |
| return output | |
| def backward(ctx, grad_output): | |
| """ | |
| input: grad_out: (n, nsample, c) | |
| output: grad_input1: (n, c), grad_input2: (n, c) | |
| """ | |
| (idx,) = ctx.saved_tensors | |
| n, nsample, c = grad_output.shape | |
| grad_input1 = torch.cuda.FloatTensor(n, c).zero_() | |
| grad_input2 = torch.cuda.FloatTensor(n, c).zero_() | |
| subtraction_backward_cuda( | |
| n, nsample, c, idx, grad_output, grad_input1, grad_input2 | |
| ) | |
| return grad_input1, grad_input2, None | |
| subtraction = Subtraction.apply | |