Spaces:
Sleeping
Sleeping
from typing import Optional, Union | |
from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor | |
import torch | |
from torch import Tensor | |
from torch.nn import Linear | |
from torch_scatter import scatter | |
from torch_geometric.nn.conv import MessagePassing | |
import torch.nn as nn | |
import dgl | |
import dgl.function as fn | |
import numpy as np | |
from dgl.nn import EdgeWeightNorm | |
# import torch_cmspepr | |
class GravNetConv(MessagePassing): | |
"""The GravNet operator from the `"Learning Representations of Irregular | |
Particle-detector Geometry with Distance-weighted Graph | |
Networks" <https://arxiv.org/abs/1902.07987>`_ paper, where the graph is | |
dynamically constructed using nearest neighbors. | |
The neighbors are constructed in a learnable low-dimensional projection of | |
the feature space. | |
A second projection of the input feature space is then propagated from the | |
neighbors to each vertex using distance weights that are derived by | |
applying a Gaussian function to the distances. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
space_dimensions (int): The dimensionality of the space used to | |
construct the neighbors; referred to as :math:`S` in the paper. | |
propagate_dimensions (int): The number of features to be propagated | |
between the vertices; referred to as :math:`F_{\textrm{LR}}` in the | |
paper. | |
k (int): The number of nearest neighbors. | |
num_workers (int): Number of workers to use for k-NN computation. | |
Has no effect in case :obj:`batch` is not :obj:`None`, or the input | |
lies on the GPU. (default: :obj:`1`) | |
**kwargs (optional): Additional arguments of | |
:class:`torch_geometric.nn.conv.MessagePassing`. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
space_dimensions: int, | |
propagate_dimensions: int, | |
k: int, | |
num_workers: int = 1, | |
weird_batchnom=False, | |
**kwargs | |
): | |
super(GravNetConv, self).__init__(flow="target_to_source", **kwargs) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.k = k | |
self.num_workers = num_workers | |
# if weird_batchnom: | |
# self.batchnorm_gravconv = WeirdBatchNorm(out_channels) | |
# else: | |
# self.batchnorm_gravconv = nn.BatchNorm1d(out_channels) | |
self.lin_s = Linear(in_channels, space_dimensions, bias=False) | |
# self.lin_s.weight.data.copy_(torch.eye(space_dimensions, in_channels)) | |
# torch.nn.init.xavier_uniform_(self.lin_s.weight, gain=0.001) | |
self.lin_h = Linear(in_channels, propagate_dimensions) | |
self.lin = Linear(in_channels + 2 * propagate_dimensions, out_channels) | |
self.norm = EdgeWeightNorm(norm="right", eps=0.0) | |
# self.reset_parameters() | |
def reset_parameters(self): | |
self.lin_s.reset_parameters() | |
self.lin_h.reset_parameters() | |
# self.lin.reset_parameters() | |
def forward( | |
self, g, x: Tensor, original_coords: Tensor, batch: OptTensor = None | |
) -> Tensor: | |
"""""" | |
assert x.dim() == 2, "Static graphs not supported in `GravNetConv`." | |
b: OptTensor = None | |
if isinstance(batch, Tensor): | |
b = batch | |
h_l: Tensor = self.lin_h(x) #! input_feature_transform | |
# print("weights input_feature_transform", self.lin_h.weight.data) | |
# print("bias input_feature_transform", self.lin_h.bias.data) | |
s_l: Tensor = self.lin_s(x) | |
# print("weights input_spatial_transform", self.lin_s.weight.data) | |
# print("bias input_spatial_transform", self.lin_s.bias.data) | |
# print("coordinates INPUTS TO FIRST LAYER") | |
# print(s_l) | |
graph = knn_per_graph(g, s_l, self.k) | |
graph.ndata["s_l"] = s_l | |
row = graph.edges()[0] | |
col = graph.edges()[1] | |
edge_index = torch.stack([row, col], dim=0) | |
edge_weight = (s_l[edge_index[0]] - s_l[edge_index[1]]).pow(2).sum(-1) | |
# print("distancesq distancesq distancesq") | |
# print(edge_weight) | |
# edge_weight = edge_weight + 1e-5 | |
#! normalized edge weight | |
# print("edge weight", edge_weight) | |
# edge_weight = self.norm(graph, edge_weight) | |
# print("normalized edge weight", edge_weight) | |
# edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread | |
#! AverageDistanceRegularizer | |
# dist = edge_weight | |
# dist = torch.sqrt(dist + 1e-6) | |
# graph.edata["dist"] = dist | |
# graph.ndata["ones"] = torch.ones_like(s_l) | |
# # average dist per node and divide by the number of neighbourgs | |
# graph.update_all(fn.u_mul_e("ones", "dist", "m"), fn.mean("m", "dist")) | |
# avdist = graph.ndata["dist"] | |
# loss_regularizing_neig = 1e-3 * torch.mean(torch.square(avdist - 0.5)) | |
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor) | |
#! LLRegulariseGravNetSpace | |
#! mean distance in original space vs distance in gravnet space between the neigh | |
# ? This code was checked on 12.01.24 and is correct | |
# graph.edata["_edge_w"] = dist | |
# graph.update_all(fn.copy_e("_edge_w", "m"), fn.sum("m", "in_weight")) | |
# degs = graph.dstdata["in_weight"] + 1e-4 | |
# graph.dstdata["_dst_in_w"] = 1 / degs | |
# graph.apply_edges( | |
# lambda e: {"_norm_edge_weights": e.dst["_dst_in_w"] * e.data["_edge_w"]} | |
# ) | |
# dist = graph.edata["_norm_edge_weights"] | |
# original_coord = g.ndata["pos_hits_xyz"] | |
# #! distance in original coordinates | |
# gndist = ( | |
# (original_coord[edge_index[0]] - original_coord[edge_index[1]]) | |
# .pow(2) | |
# .sum(-1) | |
# ) | |
# gndist = torch.sqrt(gndist + 1e-6) | |
# graph.edata["_edge_w_gndist"] = dist | |
# graph.update_all(fn.copy_e("_edge_w_gndist", "m"), fn.sum("m", "in_weight")) | |
# degs = graph.dstdata["in_weight"] + 1e-4 | |
# graph.dstdata["_dst_in_w"] = 1 / degs | |
# graph.apply_edges( | |
# lambda e: {"_norm_edge_weights_gn": e.dst["_dst_in_w"] * e.data["_edge_w"]} | |
# ) | |
# gndist = graph.edata["_norm_edge_weights_gn"] | |
# loss_llregulariser = torch.mean(torch.square(dist - gndist)) | |
# print(torch.square(dist - gndist)) | |
#! this is the output_feature_transform | |
edge_weight = torch.sqrt(edge_weight + 1e-6) | |
edge_weight = torch.exp(-torch.square(edge_weight)) | |
out = self.propagate( | |
edge_index, | |
x=h_l, | |
edge_weight=edge_weight, | |
size=(s_l.size(0), s_l.size(0)), | |
) | |
# print("outfeats", out) | |
#! not sure this cat is exactly the same that is happening in the RaggedGravNet but they also cat | |
out = self.lin(torch.cat([out, x], dim=-1)) | |
# out = self.batchnorm_gravconv(out) | |
return ( | |
out, | |
graph, | |
s_l, | |
0, # loss_regularizing_neig, | |
0, # loss_llregulariser, | |
) | |
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: | |
return x_j * edge_weight.unsqueeze(1) | |
def aggregate( | |
self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None | |
) -> Tensor: | |
out_mean = scatter( | |
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="mean" | |
) | |
out_max = scatter( | |
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="max" | |
) | |
return torch.cat([out_mean, out_max], dim=-1) | |
def __repr__(self): | |
return "{}({}, {}, k={})".format( | |
self.__class__.__name__, self.in_channels, self.out_channels, self.k | |
) | |
def knn_per_graph(g, sl, k): | |
graphs_list = dgl.unbatch(g) | |
node_counter = 0 | |
new_graphs = [] | |
for graph in graphs_list: | |
non = graph.number_of_nodes() | |
sls_graph = sl[node_counter : node_counter + non] | |
# new_graph = dgl.knn_graph(sls_graph, k, exclude_self=True) | |
edge_index = torch_cmspepr.knn_graph(sls_graph, k=k) | |
new_graph = dgl.graph( | |
(edge_index[0], edge_index[1]), num_nodes=sls_graph.shape[0] | |
) | |
new_graph = dgl.remove_self_loop(new_graph) | |
new_graphs.append(new_graph) | |
node_counter = node_counter + non | |
return dgl.batch(new_graphs) | |
class WeirdBatchNorm(nn.Module): | |
def __init__(self, n_neurons, eps=1e-5): | |
super(WeirdBatchNorm, self).__init__() | |
# stores number of neuros | |
self.n_neurons = n_neurons | |
# initinalize batch normalization parameters | |
self.gamma = nn.Parameter(torch.ones(self.n_neurons)) | |
self.beta = nn.Parameter(torch.zeros(self.n_neurons)) | |
print("self beta requires grad", self.beta.requires_grad) | |
self.mean = torch.zeros(self.n_neurons) | |
self.den = torch.ones(self.n_neurons) | |
self.viscosity = 0.5 | |
self.epsilon = eps | |
self.fluidity_decay = 1e-4 | |
self.max_viscosity = 1 | |
def forward(self, input): | |
x = input.detach() | |
mu = x.mean(dim=0) | |
var = x.var(dim=0, unbiased=False) | |
mu_update = self._calc_update(self.mean, mu) | |
self.mean = mu_update | |
var_update = self._calc_update(self.den, var) | |
self.den = var_update | |
# normalization | |
center_input = x - self.mean | |
denominator = self.den + self.epsilon | |
denominator = denominator.sqrt() | |
in_hat = center_input / denominator | |
self._update_viscosity() | |
# scale and shift | |
out = self.gamma * in_hat + self.beta | |
return out | |
def _calc_update(self, old, new): | |
delta = new - old.to(new.device) | |
update = old.to(new.device) + (1 - self.viscosity) * delta.to(new.device) | |
update = update.to(new.device) | |
return update | |
def _update_viscosity(self): | |
if self.fluidity_decay > 0: | |
newvisc = ( | |
self.viscosity | |
+ (self.max_viscosity - self.viscosity) * self.fluidity_decay | |
) | |
self.viscosity = newvisc | |