jetclustering / src /layers /GravNetConv3.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
10.3 kB
from typing import Optional, Union
from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor
import torch
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
import torch.nn as nn
import dgl
import dgl.function as fn
import numpy as np
from dgl.nn import EdgeWeightNorm
# import torch_cmspepr
class GravNetConv(MessagePassing):
"""The GravNet operator from the `"Learning Representations of Irregular
Particle-detector Geometry with Distance-weighted Graph
Networks" <https://arxiv.org/abs/1902.07987>`_ paper, where the graph is
dynamically constructed using nearest neighbors.
The neighbors are constructed in a learnable low-dimensional projection of
the feature space.
A second projection of the input feature space is then propagated from the
neighbors to each vertex using distance weights that are derived by
applying a Gaussian function to the distances.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
space_dimensions (int): The dimensionality of the space used to
construct the neighbors; referred to as :math:`S` in the paper.
propagate_dimensions (int): The number of features to be propagated
between the vertices; referred to as :math:`F_{\textrm{LR}}` in the
paper.
k (int): The number of nearest neighbors.
num_workers (int): Number of workers to use for k-NN computation.
Has no effect in case :obj:`batch` is not :obj:`None`, or the input
lies on the GPU. (default: :obj:`1`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
space_dimensions: int,
propagate_dimensions: int,
k: int,
num_workers: int = 1,
weird_batchnom=False,
**kwargs
):
super(GravNetConv, self).__init__(flow="target_to_source", **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.k = k
self.num_workers = num_workers
# if weird_batchnom:
# self.batchnorm_gravconv = WeirdBatchNorm(out_channels)
# else:
# self.batchnorm_gravconv = nn.BatchNorm1d(out_channels)
self.lin_s = Linear(in_channels, space_dimensions, bias=False)
# self.lin_s.weight.data.copy_(torch.eye(space_dimensions, in_channels))
# torch.nn.init.xavier_uniform_(self.lin_s.weight, gain=0.001)
self.lin_h = Linear(in_channels, propagate_dimensions)
self.lin = Linear(in_channels + 2 * propagate_dimensions, out_channels)
self.norm = EdgeWeightNorm(norm="right", eps=0.0)
# self.reset_parameters()
def reset_parameters(self):
self.lin_s.reset_parameters()
self.lin_h.reset_parameters()
# self.lin.reset_parameters()
def forward(
self, g, x: Tensor, original_coords: Tensor, batch: OptTensor = None
) -> Tensor:
""""""
assert x.dim() == 2, "Static graphs not supported in `GravNetConv`."
b: OptTensor = None
if isinstance(batch, Tensor):
b = batch
h_l: Tensor = self.lin_h(x) #! input_feature_transform
# print("weights input_feature_transform", self.lin_h.weight.data)
# print("bias input_feature_transform", self.lin_h.bias.data)
s_l: Tensor = self.lin_s(x)
# print("weights input_spatial_transform", self.lin_s.weight.data)
# print("bias input_spatial_transform", self.lin_s.bias.data)
# print("coordinates INPUTS TO FIRST LAYER")
# print(s_l)
graph = knn_per_graph(g, s_l, self.k)
graph.ndata["s_l"] = s_l
row = graph.edges()[0]
col = graph.edges()[1]
edge_index = torch.stack([row, col], dim=0)
edge_weight = (s_l[edge_index[0]] - s_l[edge_index[1]]).pow(2).sum(-1)
# print("distancesq distancesq distancesq")
# print(edge_weight)
# edge_weight = edge_weight + 1e-5
#! normalized edge weight
# print("edge weight", edge_weight)
# edge_weight = self.norm(graph, edge_weight)
# print("normalized edge weight", edge_weight)
# edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread
#! AverageDistanceRegularizer
# dist = edge_weight
# dist = torch.sqrt(dist + 1e-6)
# graph.edata["dist"] = dist
# graph.ndata["ones"] = torch.ones_like(s_l)
# # average dist per node and divide by the number of neighbourgs
# graph.update_all(fn.u_mul_e("ones", "dist", "m"), fn.mean("m", "dist"))
# avdist = graph.ndata["dist"]
# loss_regularizing_neig = 1e-3 * torch.mean(torch.square(avdist - 0.5))
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
#! LLRegulariseGravNetSpace
#! mean distance in original space vs distance in gravnet space between the neigh
# ? This code was checked on 12.01.24 and is correct
# graph.edata["_edge_w"] = dist
# graph.update_all(fn.copy_e("_edge_w", "m"), fn.sum("m", "in_weight"))
# degs = graph.dstdata["in_weight"] + 1e-4
# graph.dstdata["_dst_in_w"] = 1 / degs
# graph.apply_edges(
# lambda e: {"_norm_edge_weights": e.dst["_dst_in_w"] * e.data["_edge_w"]}
# )
# dist = graph.edata["_norm_edge_weights"]
# original_coord = g.ndata["pos_hits_xyz"]
# #! distance in original coordinates
# gndist = (
# (original_coord[edge_index[0]] - original_coord[edge_index[1]])
# .pow(2)
# .sum(-1)
# )
# gndist = torch.sqrt(gndist + 1e-6)
# graph.edata["_edge_w_gndist"] = dist
# graph.update_all(fn.copy_e("_edge_w_gndist", "m"), fn.sum("m", "in_weight"))
# degs = graph.dstdata["in_weight"] + 1e-4
# graph.dstdata["_dst_in_w"] = 1 / degs
# graph.apply_edges(
# lambda e: {"_norm_edge_weights_gn": e.dst["_dst_in_w"] * e.data["_edge_w"]}
# )
# gndist = graph.edata["_norm_edge_weights_gn"]
# loss_llregulariser = torch.mean(torch.square(dist - gndist))
# print(torch.square(dist - gndist))
#! this is the output_feature_transform
edge_weight = torch.sqrt(edge_weight + 1e-6)
edge_weight = torch.exp(-torch.square(edge_weight))
out = self.propagate(
edge_index,
x=h_l,
edge_weight=edge_weight,
size=(s_l.size(0), s_l.size(0)),
)
# print("outfeats", out)
#! not sure this cat is exactly the same that is happening in the RaggedGravNet but they also cat
out = self.lin(torch.cat([out, x], dim=-1))
# out = self.batchnorm_gravconv(out)
return (
out,
graph,
s_l,
0, # loss_regularizing_neig,
0, # loss_llregulariser,
)
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return x_j * edge_weight.unsqueeze(1)
def aggregate(
self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None
) -> Tensor:
out_mean = scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="mean"
)
out_max = scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="max"
)
return torch.cat([out_mean, out_max], dim=-1)
def __repr__(self):
return "{}({}, {}, k={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.k
)
def knn_per_graph(g, sl, k):
graphs_list = dgl.unbatch(g)
node_counter = 0
new_graphs = []
for graph in graphs_list:
non = graph.number_of_nodes()
sls_graph = sl[node_counter : node_counter + non]
# new_graph = dgl.knn_graph(sls_graph, k, exclude_self=True)
edge_index = torch_cmspepr.knn_graph(sls_graph, k=k)
new_graph = dgl.graph(
(edge_index[0], edge_index[1]), num_nodes=sls_graph.shape[0]
)
new_graph = dgl.remove_self_loop(new_graph)
new_graphs.append(new_graph)
node_counter = node_counter + non
return dgl.batch(new_graphs)
class WeirdBatchNorm(nn.Module):
def __init__(self, n_neurons, eps=1e-5):
super(WeirdBatchNorm, self).__init__()
# stores number of neuros
self.n_neurons = n_neurons
# initinalize batch normalization parameters
self.gamma = nn.Parameter(torch.ones(self.n_neurons))
self.beta = nn.Parameter(torch.zeros(self.n_neurons))
print("self beta requires grad", self.beta.requires_grad)
self.mean = torch.zeros(self.n_neurons)
self.den = torch.ones(self.n_neurons)
self.viscosity = 0.5
self.epsilon = eps
self.fluidity_decay = 1e-4
self.max_viscosity = 1
def forward(self, input):
x = input.detach()
mu = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
mu_update = self._calc_update(self.mean, mu)
self.mean = mu_update
var_update = self._calc_update(self.den, var)
self.den = var_update
# normalization
center_input = x - self.mean
denominator = self.den + self.epsilon
denominator = denominator.sqrt()
in_hat = center_input / denominator
self._update_viscosity()
# scale and shift
out = self.gamma * in_hat + self.beta
return out
def _calc_update(self, old, new):
delta = new - old.to(new.device)
update = old.to(new.device) + (1 - self.viscosity) * delta.to(new.device)
update = update.to(new.device)
return update
def _update_viscosity(self):
if self.fluidity_decay > 0:
newvisc = (
self.viscosity
+ (self.max_viscosity - self.viscosity) * self.fluidity_decay
)
self.viscosity = newvisc