HemanM commited on
Commit
32221da
·
verified ·
1 Parent(s): fdfa3ca

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +65 -19
evo_model.py CHANGED
@@ -1,22 +1,68 @@
1
- import torch
2
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class EvoDecoderModel(nn.Module):
5
- def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1):
6
- super(EvoDecoderModel, self).__init__()
7
- self.embedding = nn.Embedding(vocab_size, d_model)
8
- self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model)) # max length 512
9
- decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
10
- self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
11
- self.output_layer = nn.Linear(d_model, vocab_size)
12
-
13
- def forward(self, tgt, memory=None):
14
- seq_len = tgt.size(1)
15
- embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :]
16
-
17
- # If no memory is provided, use dummy memory filled with zeros
18
- if memory is None:
19
- memory = torch.zeros_like(embedded)
20
-
21
- output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1))
22
- return self.output_layer(output.transpose(0, 1))
 
 
1
  import torch.nn as nn
2
+ import torch
3
+
4
+ class FeedForward(nn.Module):
5
+ def __init__(self, dim, hidden_dim):
6
+ super().__init__()
7
+ self.net = nn.Sequential(
8
+ nn.Linear(dim, hidden_dim),
9
+ nn.GELU(),
10
+ nn.Dropout(0.1),
11
+ nn.Linear(hidden_dim, dim),
12
+ nn.Dropout(0.1),
13
+ )
14
+
15
+ def forward(self, x):
16
+ return self.net(x)
17
+
18
+ class Attention(nn.Module):
19
+ def __init__(self, dim, heads=4):
20
+ super().__init__()
21
+ self.heads = heads
22
+ self.scale = dim ** -0.5
23
+
24
+ self.qkv_proj = nn.Linear(dim, dim * 3)
25
+ self.out_proj = nn.Linear(dim, dim)
26
+
27
+ def forward(self, x):
28
+ B, T, C = x.shape
29
+ qkv = self.qkv_proj(x).reshape(B, T, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
30
+ q, k, v = qkv[0], qkv[1], qkv[2]
31
+ attn_scores = (q @ k.transpose(-2, -1)) * self.scale
32
+ attn_weights = attn_scores.softmax(dim=-1)
33
+ attn_output = attn_weights @ v
34
+ attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
35
+ return self.out_proj(attn_output)
36
+
37
+ class TransformerBlock(nn.Module):
38
+ def __init__(self, dim, heads, hidden_dim):
39
+ super().__init__()
40
+ self.attn = Attention(dim, heads)
41
+ self.ffn = FeedForward(dim, hidden_dim)
42
+ self.ln1 = nn.LayerNorm(dim)
43
+ self.ln2 = nn.LayerNorm(dim)
44
+
45
+ def forward(self, x):
46
+ x = x + self.attn(self.ln1(x))
47
+ x = x + self.ffn(self.ln2(x))
48
+ return x
49
 
50
  class EvoDecoderModel(nn.Module):
51
+ def __init__(self, vocab_size, dim=256, depth=3, heads=4, hidden_dim=512):
52
+ super().__init__()
53
+ self.token_emb = nn.Embedding(vocab_size, dim)
54
+ self.pos_emb = nn.Embedding(512, dim)
55
+ self.blocks = nn.Sequential(*[TransformerBlock(dim, heads, hidden_dim) for _ in range(depth)])
56
+ self.ln_f = nn.LayerNorm(dim)
57
+ self.fc_out = nn.Linear(dim, vocab_size)
58
+
59
+ def forward(self, x):
60
+ B, T = x.shape
61
+ pos = torch.arange(0, T, device=x.device).unsqueeze(0)
62
+ tok = self.token_emb(x)
63
+ pos = self.pos_emb(pos)
64
+ x = tok + pos
65
+ x = self.blocks(x)
66
+ x = self.ln_f(x)
67
+ logits = self.fc_out(x)
68
+ return logits