from typing import Optional, Union from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor import torch from torch import Tensor from torch.nn import Linear from torch_scatter import scatter from torch_geometric.nn.conv import MessagePassing import dgl class GravNetConv(MessagePassing): r"""The GravNet operator from the `"Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks" `_ paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. space_dimensions (int): The dimensionality of the space used to construct the neighbors; referred to as :math:`S` in the paper. propagate_dimensions (int): The number of features to be propagated between the vertices; referred to as :math:`F_{\textrm{LR}}` in the paper. k (int): The number of nearest neighbors. num_workers (int): Number of workers to use for k-NN computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: int = 1, **kwargs ): super(GravNetConv, self).__init__(flow="target_to_source", **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.k = k self.num_workers = num_workers self.lin_s = Linear(in_channels, space_dimensions, bias=False) self.lin_s.weight.data.copy_(torch.eye(space_dimensions, in_channels)) self.lin_h = Linear(in_channels, propagate_dimensions) self.lin = Linear(in_channels + 2 * propagate_dimensions, out_channels) # self.reset_parameters() def reset_parameters(self): self.lin_s.reset_parameters() self.lin_h.reset_parameters() self.lin.reset_parameters() def forward(self, g, x: Tensor, batch: OptTensor = None) -> Tensor: """""" assert x.dim() == 2, "Static graphs not supported in `GravNetConv`." b: OptTensor = None if isinstance(batch, Tensor): b = batch h_l: Tensor = self.lin_h(x) s_l: Tensor = self.lin_s(x) graph = knn_per_graph(g, s_l, self.k) graph.ndata["s_l"] = s_l row = graph.edges()[0] col = graph.edges()[1] edge_index = torch.stack([row, col], dim=0) edge_weight = (s_l[edge_index[0]] - s_l[edge_index[1]]).pow(2).sum(-1) edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) #! this is the output_feature_transform out = self.propagate( edge_index, x=[h_l, None], edge_weight=edge_weight, size=(s_l.size(0), s_l.size(0)), ) #! not sure this cat is exactly the same that is happening in the RaggedGravNet but they also cat return self.lin(torch.cat([out, x], dim=-1)), graph, s_l def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return x_j * edge_weight.unsqueeze(1) def aggregate( self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None ) -> Tensor: out_mean = scatter( inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="mean" ) out_max = scatter( inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="max" ) return torch.cat([out_mean, out_max], dim=-1) def __repr__(self): return "{}({}, {}, k={})".format( self.__class__.__name__, self.in_channels, self.out_channels, self.k ) def knn_per_graph(g, sl, k): graphs_list = dgl.unbatch(g) node_counter = 0 new_graphs = [] for graph in graphs_list: non = graph.number_of_nodes() sls_graph = sl[node_counter : node_counter + non] new_graph = dgl.knn_graph(sls_graph, k, exclude_self=True) new_graphs.append(new_graph) node_counter = node_counter + non return dgl.batch(new_graphs)