Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.autograd import Function | |
| from pointops._C import ( | |
| attention_relation_step_forward_cuda, | |
| attention_relation_step_backward_cuda, | |
| attention_fusion_step_forward_cuda, | |
| attention_fusion_step_backward_cuda, | |
| ) | |
| class AttentionRelationStep(Function): | |
| def forward(ctx, query, key, weight, index_target, index_refer): | |
| """ | |
| input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention, | |
| index_target: (m), index_refer: (m) | |
| output - relation: (M, g) | |
| """ | |
| assert ( | |
| query.is_contiguous() | |
| and key.is_contiguous() | |
| and index_target.is_contiguous() | |
| and index_refer.is_contiguous() | |
| and weight.is_contiguous() | |
| ) | |
| assert index_target.shape[0] == index_refer.shape[0] | |
| _, g, c = query.shape | |
| m = index_target.shape[0] | |
| output = torch.cuda.FloatTensor(m, g).zero_() | |
| attention_relation_step_forward_cuda( | |
| m, g, c, query, key, weight, index_target.int(), index_refer.int(), output | |
| ) | |
| ctx.save_for_backward(query, key, weight, index_target, index_refer) | |
| return output | |
| def backward(ctx, grad_output): | |
| query, key, weight, index_target, index_refer = ctx.saved_tensors | |
| n, g, c = query.shape | |
| m = index_target.shape[0] | |
| grad_query = torch.cuda.FloatTensor(n, g, c).zero_() | |
| grad_key = torch.cuda.FloatTensor(n, g, c).zero_() | |
| grad_weight = torch.cuda.FloatTensor(c).zero_() | |
| attention_relation_step_backward_cuda( | |
| m, | |
| g, | |
| c, | |
| query, | |
| grad_query, | |
| key, | |
| grad_key, | |
| weight, | |
| grad_weight, | |
| index_target.int(), | |
| index_refer.int(), | |
| grad_output, | |
| ) | |
| return grad_query, grad_key, None, None, None | |
| class AttentionFusionStep(Function): | |
| def forward(ctx, weight, value, index_target, index_refer): | |
| """ | |
| input - weight: (m, g), value: (n, g, c) | |
| index_target: (m), index_value: (m) | |
| output - output: (n, g, c) | |
| """ | |
| assert ( | |
| weight.is_contiguous() | |
| and value.is_contiguous() | |
| and index_target.is_contiguous() | |
| and index_refer.is_contiguous() | |
| and weight.is_contiguous() | |
| ) | |
| assert index_target.shape[0] == index_refer.shape[0] | |
| n, g, c = value.shape | |
| m = index_refer.shape[0] | |
| output = torch.cuda.FloatTensor(n, g, c).zero_() | |
| attention_fusion_step_forward_cuda( | |
| m, g, c, weight, value, index_target.int(), index_refer.int(), output | |
| ) | |
| ctx.save_for_backward(weight, value, index_target, index_refer) | |
| return output | |
| def backward(ctx, grad_output): | |
| """ | |
| input: grad_output: (n, g, c) | |
| output: grad_weight: (m, g), grad_value: (n, g, c), none, none | |
| """ | |
| weight, value, index_target, index_refer = ctx.saved_tensors | |
| n, g, c = value.shape | |
| m = index_target.shape[0] | |
| grad_weight = torch.cuda.FloatTensor(m, g).zero_() | |
| grad_value = torch.cuda.FloatTensor(n, g, c).zero_() | |
| attention_fusion_step_backward_cuda( | |
| m, | |
| g, | |
| c, | |
| weight, | |
| grad_weight, | |
| value, | |
| grad_value, | |
| index_target.int(), | |
| index_refer.int(), | |
| grad_output, | |
| ) | |
| return grad_weight, grad_value, None, None | |
| attention_relation_step = AttentionRelationStep.apply | |
| attention_fusion_step = AttentionFusionStep.apply | |