awacke1's picture
Create app.py
e84e6e5 verified
raw
history blame contribute delete
302 Bytes
from moe import MoE
import torch
# instantiate the MoE layer
model = MoE(input_size=1000, output_size=20, num_experts=10,hidden_size=66, k= 4, noisy_gating=True)
X = torch.rand(32, 1000)
#train
model.train()
# forward
y_hat, aux_loss = model(X)
# evaluation
model.eval()
y_hat, aux_loss = model(X)