File size: 1,496 Bytes
09f0cd3
 
2f2edb0
09f0cd3
2f2edb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7987693
2f2edb0
 
 
 
 
1e2845c
2f2edb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7987693
2f2edb0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config["vocab_size"], config["d_model"])
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config["d_model"],
            nhead=config["nhead"],
            dim_feedforward=config["ff_dim"],
            dropout=0.1,
            activation="gelu",
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config["num_layers"])
        self.memory_token = nn.Parameter(torch.randn(1, 1, config["d_model"]))
        self.memory_proj = nn.Linear(config["d_model"], config["d_model"])

    def forward(self, x):
        x = self.embedding(x)
        B, T, D = x.shape
        memory = self.memory_token.repeat(B, 1, 1)
        x = torch.cat([memory, x], dim=1)
        x = self.transformer(x)
        memory_out = x[:, 0]
        return self.memory_proj(memory_out)

class EvoTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        config = {
            "vocab_size": 30522,
            "d_model": 384,
            "nhead": 6,
            "ff_dim": 1024,
            "num_layers": 6,
        }
        self.encoder = TransformerEncoder(config)
        self.classifier = nn.Linear(config["d_model"], 2)

    def forward(self, x):
        x = self.encoder(x)
        return self.classifier(x)