Spaces:
Sleeping
Sleeping
File size: 4,842 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Optional, Union
from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor
import torch
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
import dgl
class GravNetConv(MessagePassing):
r"""The GravNet operator from the `"Learning Representations of Irregular
Particle-detector Geometry with Distance-weighted Graph
Networks" <https://arxiv.org/abs/1902.07987>`_ paper, where the graph is
dynamically constructed using nearest neighbors.
The neighbors are constructed in a learnable low-dimensional projection of
the feature space.
A second projection of the input feature space is then propagated from the
neighbors to each vertex using distance weights that are derived by
applying a Gaussian function to the distances.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
space_dimensions (int): The dimensionality of the space used to
construct the neighbors; referred to as :math:`S` in the paper.
propagate_dimensions (int): The number of features to be propagated
between the vertices; referred to as :math:`F_{\textrm{LR}}` in the
paper.
k (int): The number of nearest neighbors.
num_workers (int): Number of workers to use for k-NN computation.
Has no effect in case :obj:`batch` is not :obj:`None`, or the input
lies on the GPU. (default: :obj:`1`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
space_dimensions: int,
propagate_dimensions: int,
k: int,
num_workers: int = 1,
**kwargs
):
super(GravNetConv, self).__init__(flow="target_to_source", **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.k = k
self.num_workers = num_workers
self.lin_s = Linear(in_channels, space_dimensions, bias=False)
self.lin_s.weight.data.copy_(torch.eye(space_dimensions, in_channels))
self.lin_h = Linear(in_channels, propagate_dimensions)
self.lin = Linear(in_channels + 2 * propagate_dimensions, out_channels)
# self.reset_parameters()
def reset_parameters(self):
self.lin_s.reset_parameters()
self.lin_h.reset_parameters()
self.lin.reset_parameters()
def forward(self, g, x: Tensor, batch: OptTensor = None) -> Tensor:
""""""
assert x.dim() == 2, "Static graphs not supported in `GravNetConv`."
b: OptTensor = None
if isinstance(batch, Tensor):
b = batch
h_l: Tensor = self.lin_h(x)
s_l: Tensor = self.lin_s(x)
graph = knn_per_graph(g, s_l, self.k)
graph.ndata["s_l"] = s_l
row = graph.edges()[0]
col = graph.edges()[1]
edge_index = torch.stack([row, col], dim=0)
edge_weight = (s_l[edge_index[0]] - s_l[edge_index[1]]).pow(2).sum(-1)
edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
#! this is the output_feature_transform
out = self.propagate(
edge_index,
x=[h_l, None],
edge_weight=edge_weight,
size=(s_l.size(0), s_l.size(0)),
)
#! not sure this cat is exactly the same that is happening in the RaggedGravNet but they also cat
return self.lin(torch.cat([out, x], dim=-1)), graph, s_l
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return x_j * edge_weight.unsqueeze(1)
def aggregate(
self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None
) -> Tensor:
out_mean = scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="mean"
)
out_max = scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="max"
)
return torch.cat([out_mean, out_max], dim=-1)
def __repr__(self):
return "{}({}, {}, k={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.k
)
def knn_per_graph(g, sl, k):
graphs_list = dgl.unbatch(g)
node_counter = 0
new_graphs = []
for graph in graphs_list:
non = graph.number_of_nodes()
sls_graph = sl[node_counter : node_counter + non]
new_graph = dgl.knn_graph(sls_graph, k, exclude_self=True)
new_graphs.append(new_graph)
node_counter = node_counter + non
return dgl.batch(new_graphs)
|