from moe import MoE import torch # instantiate the MoE layer model = MoE(input_size=1000, output_size=20, num_experts=10,hidden_size=66, k= 4, noisy_gating=True) X = torch.rand(32, 1000) #train model.train() # forward y_hat, aux_loss = model(X) # evaluation model.eval() y_hat, aux_loss = model(X)