HemanM commited on
Commit
e7984f7
·
verified ·
1 Parent(s): 9bad7da

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +14 -45
evo_model.py CHANGED
@@ -1,53 +1,22 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class EvoDecoderBlock(nn.Module):
6
- def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
7
- super(EvoDecoderBlock, self).__init__()
8
- self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
9
- self.qkv_proj = nn.Linear(d_model, d_model * 3)
10
- self.out_proj = nn.Linear(d_model, d_model)
11
- self.ffn = nn.Sequential(
12
- nn.Linear(d_model, dim_feedforward),
13
- nn.ReLU(),
14
- nn.Dropout(dropout),
15
- nn.Linear(dim_feedforward, d_model),
16
- )
17
- self.ln1 = nn.LayerNorm(d_model)
18
- self.ln2 = nn.LayerNorm(d_model)
19
-
20
- def forward(self, x):
21
- # Self-attention with skip connection
22
- qkv = self.qkv_proj(x)
23
- q, k, v = torch.chunk(qkv, 3, dim=-1)
24
- attn_output, _ = self.attn(q, k, v)
25
- x = self.ln1(x + self.out_proj(attn_output))
26
-
27
- # Feedforward with skip connection
28
- x = self.ln2(x + self.ffn(x))
29
- return x
30
 
31
  class EvoDecoderModel(nn.Module):
32
- def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=512):
33
  super(EvoDecoderModel, self).__init__()
34
- self.token_emb = nn.Embedding(vocab_size, d_model)
35
- self.pos_emb = nn.Embedding(max_len, d_model)
36
- self.blocks = nn.ModuleList([
37
- EvoDecoderBlock(d_model, nhead, dim_feedforward, dropout)
38
- for _ in range(num_layers)
39
- ])
40
- self.ln_f = nn.LayerNorm(d_model)
41
- self.fc_out = nn.Linear(d_model, vocab_size)
42
 
43
- def forward(self, x):
44
- device = x.device
45
- seq_len = x.size(1)
46
- pos = torch.arange(0, seq_len, device=device).unsqueeze(0)
47
- x = self.token_emb(x) + self.pos_emb(pos)
48
 
49
- for block in self.blocks:
50
- x = block(x)
 
51
 
52
- x = self.ln_f(x)
53
- return self.fc_out(x)
 
1
  import torch
2
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class EvoDecoderModel(nn.Module):
5
+ def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1):
6
  super(EvoDecoderModel, self).__init__()
7
+ self.embedding = nn.Embedding(vocab_size, d_model)
8
+ self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model)) # max length 512
9
+ decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
10
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
11
+ self.output_layer = nn.Linear(d_model, vocab_size)
 
 
 
12
 
13
+ def forward(self, tgt, memory=None):
14
+ seq_len = tgt.size(1)
15
+ embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :]
 
 
16
 
17
+ # If no memory is provided, use dummy memory filled with zeros
18
+ if memory is None:
19
+ memory = torch.zeros_like(embedded)
20
 
21
+ output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1))
22
+ return self.output_layer(output.transpose(0, 1))