File size: 785 Bytes
b085dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch

class IdentityModel(torch.nn.Module):
    def __init__(self, n_out_coords=3):
        super().__init__()
        self.n_out_coords = n_out_coords

    def forward(self, data):
        # data: instance of EventBatch
        inputs_v = data.input_vectors # four-momenta
        betas = torch.ones(data.input_vectors.shape[0]).to(inputs_v.device)
        norm_inputs_v = torch.norm(inputs_v, dim=1).unsqueeze(1)
        #print("inputs_v.shape", inputs_v.shape)
        #print("betas.shape", betas.shape)
        #print("norm_inputs_v.shape", norm_inputs_v.shape)
        #print("betas unsqueezed shape", betas.unsqueeze(1).shape)
        x = torch.cat([inputs_v / norm_inputs_v, betas.unsqueeze(1)], dim=1)
        return x


def get_model(args):
    return IdentityModel()