Spaces:
Sleeping
Sleeping
import torch | |
class IdentityModel(torch.nn.Module): | |
def __init__(self, n_out_coords=3): | |
super().__init__() | |
self.n_out_coords = n_out_coords | |
def forward(self, data): | |
# data: instance of EventBatch | |
inputs_v = data.input_vectors # four-momenta | |
betas = torch.ones(data.input_vectors.shape[0]).to(inputs_v.device) | |
norm_inputs_v = torch.norm(inputs_v, dim=1).unsqueeze(1) | |
#print("inputs_v.shape", inputs_v.shape) | |
#print("betas.shape", betas.shape) | |
#print("norm_inputs_v.shape", norm_inputs_v.shape) | |
#print("betas unsqueezed shape", betas.unsqueeze(1).shape) | |
x = torch.cat([inputs_v / norm_inputs_v, betas.unsqueeze(1)], dim=1) | |
return x | |
def get_model(args): | |
return IdentityModel() | |