import torch.nn as nn class SimpleEvoModel(nn.Module): def __init__(self, input_dim=768, hidden_dim=256, output_dim=2): super().__init__() self.model = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): return self.model(x)