File size: 302 Bytes
e84e6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from moe import MoE
import torch

# instantiate the MoE layer
model = MoE(input_size=1000, output_size=20, num_experts=10,hidden_size=66, k= 4, noisy_gating=True)

X = torch.rand(32, 1000)

#train
model.train()
# forward
y_hat, aux_loss = model(X)

# evaluation

model.eval()
y_hat, aux_loss = model(X)