Spaces:
Sleeping
Sleeping
File size: 785 Bytes
b085dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
class IdentityModel(torch.nn.Module):
def __init__(self, n_out_coords=3):
super().__init__()
self.n_out_coords = n_out_coords
def forward(self, data):
# data: instance of EventBatch
inputs_v = data.input_vectors # four-momenta
betas = torch.ones(data.input_vectors.shape[0]).to(inputs_v.device)
norm_inputs_v = torch.norm(inputs_v, dim=1).unsqueeze(1)
#print("inputs_v.shape", inputs_v.shape)
#print("betas.shape", betas.shape)
#print("norm_inputs_v.shape", norm_inputs_v.shape)
#print("betas unsqueezed shape", betas.unsqueeze(1).shape)
x = torch.cat([inputs_v / norm_inputs_v, betas.unsqueeze(1)], dim=1)
return x
def get_model(args):
return IdentityModel()
|