File size: 565 Bytes
3d8c93e
 
2c48bcc
 
 
3d8c93e
2c48bcc
3d8c93e
2c48bcc
 
 
 
 
3d8c93e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn

class EvoTransformerArabic(nn.Module):
    def __init__(self, d_model=768, hidden_dim=1024, n_classes=2, dropout=0.1):
        super(EvoTransformerArabic, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, n_classes)
        )

    def forward(self, x):
        return self.classifier(x)