HemanM commited on
Commit
0bd71c9
·
verified ·
1 Parent(s): e5b7ec9

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +49 -10
evo_model.py CHANGED
@@ -1,14 +1,53 @@
 
1
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  class EvoDecoderModel(nn.Module):
4
- def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
5
  super(EvoDecoderModel, self).__init__()
6
- self.embedding = nn.Embedding(vocab_size, d_model)
7
- decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
8
- self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
9
- self.output_layer = nn.Linear(d_model, vocab_size)
10
-
11
- def forward(self, tgt, memory):
12
- embedded = self.embedding(tgt)
13
- output = self.transformer_decoder(embedded, memory)
14
- return self.output_layer(output)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class EvoDecoderBlock(nn.Module):
6
+ def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
7
+ super(EvoDecoderBlock, self).__init__()
8
+ self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
9
+ self.qkv_proj = nn.Linear(d_model, d_model * 3)
10
+ self.out_proj = nn.Linear(d_model, d_model)
11
+ self.ffn = nn.Sequential(
12
+ nn.Linear(d_model, dim_feedforward),
13
+ nn.ReLU(),
14
+ nn.Dropout(dropout),
15
+ nn.Linear(dim_feedforward, d_model),
16
+ )
17
+ self.ln1 = nn.LayerNorm(d_model)
18
+ self.ln2 = nn.LayerNorm(d_model)
19
+
20
+ def forward(self, x):
21
+ # Self-attention with skip connection
22
+ qkv = self.qkv_proj(x)
23
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
24
+ attn_output, _ = self.attn(q, k, v)
25
+ x = self.ln1(x + self.out_proj(attn_output))
26
+
27
+ # Feedforward with skip connection
28
+ x = self.ln2(x + self.ffn(x))
29
+ return x
30
 
31
  class EvoDecoderModel(nn.Module):
32
+ def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=512):
33
  super(EvoDecoderModel, self).__init__()
34
+ self.token_emb = nn.Embedding(vocab_size, d_model)
35
+ self.pos_emb = nn.Embedding(max_len, d_model)
36
+ self.blocks = nn.ModuleList([
37
+ EvoDecoderBlock(d_model, nhead, dim_feedforward, dropout)
38
+ for _ in range(num_layers)
39
+ ])
40
+ self.ln_f = nn.LayerNorm(d_model)
41
+ self.fc_out = nn.Linear(d_model, vocab_size)
42
+
43
+ def forward(self, x):
44
+ device = x.device
45
+ seq_len = x.size(1)
46
+ pos = torch.arange(0, seq_len, device=device).unsqueeze(0)
47
+ x = self.token_emb(x) + self.pos_emb(pos)
48
+
49
+ for block in self.blocks:
50
+ x = block(x)
51
+
52
+ x = self.ln_f(x)
53
+ return self.fc_out(x)