Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| from torch.autograd import Function | |
| import voxlib | |
| # Cheatsheet: | |
| # mark_dirty() must be used to mark any input that is modified inplace by the forward function. | |
| # mark_non_differentiable() | |
| class PositionalEncodingFunction(Function): | |
| def forward(ctx, in_feature, pe_degrees, dim, incl_orig): | |
| out_feature = voxlib.positional_encoding(in_feature, pe_degrees, dim, incl_orig) | |
| ctx.save_for_backward(out_feature) | |
| ctx.pe_degrees = pe_degrees | |
| ctx.dim = dim | |
| ctx.incl_orig = incl_orig | |
| return out_feature | |
| def backward(ctx, out_feature_grad): | |
| out_feature, = ctx.saved_tensors | |
| # torch::Tensor positional_encoding_backward(const torch::Tensor& out_feature_grad, | |
| # const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) { | |
| in_feature_grad = voxlib.positional_encoding_backward( | |
| out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig) | |
| return in_feature_grad, None, None, None | |
| def positional_encoding(in_feature, pe_degrees, dim=-1, incl_orig=False): | |
| return PositionalEncodingFunction.apply(in_feature, pe_degrees, dim, incl_orig) | |
| # input: N, C | |
| # output: N, pe_degrees*C | |
| def positional_encoding_pt(pts, pe_degrees, dim=-1, incl_orig=False): | |
| import numpy as np | |
| pe_stor = [] | |
| for i in range(pe_degrees): | |
| pe_stor.append(torch.sin(pts * np.pi * 2 ** i)) | |
| pe_stor.append(torch.cos(pts * np.pi * 2 ** i)) | |
| if incl_orig: | |
| pe_stor.append(pts) | |
| pe = torch.cat(pe_stor, dim=dim) | |
| return pe | |
| if __name__ == '__main__': | |
| x = torch.rand(384, 512, 5, 48).cuda() * 1024 | |
| y = positional_encoding_pt(x, 4, incl_orig=True) | |
| y2 = positional_encoding(x, 4, incl_orig=True) | |
| print(torch.abs(y - y2)) | |
| print(torch.allclose(y, y2, rtol=1e-05, atol=1e-05)) | |