HemanM commited on
Commit
3d8c93e
·
verified ·
1 Parent(s): b663f38

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -0
model.py CHANGED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class EvoTransformer(nn.Module):
4
+ def __init__(self, d_model=768, n_classes=2):
5
+ super(EvoTransformer, self).__init__()
6
+ self.classifier = nn.Sequential(
7
+ nn.Linear(d_model, 384),
8
+ nn.ReLU(),
9
+ nn.Linear(384, n_classes)
10
+ )
11
+
12
+ def forward(self, x):
13
+ return self.classifier(x)