HemanM commited on
Commit
a9b4cfb
·
verified ·
1 Parent(s): db7c38f

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +33 -29
evo_model.py CHANGED
@@ -1,45 +1,50 @@
1
  import torch
2
  import torch.nn as nn
3
- import math
4
 
5
- class MultiHeadSelfAttention(nn.Module):
6
- def __init__(self, d_model=256, nhead=4):
7
  super().__init__()
8
- self.nhead = nhead
9
- self.d_head = d_model // nhead
10
- self.qkv_proj = nn.Linear(d_model, d_model * 3)
11
  self.out_proj = nn.Linear(d_model, d_model)
 
 
12
 
13
  def forward(self, x):
14
  B, T, C = x.size()
15
- qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
16
  q, k, v = qkv.chunk(3, dim=-1)
17
- q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
18
- scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
 
 
 
 
19
  attn = torch.softmax(scores, dim=-1)
20
- context = attn @ v
21
- context = context.transpose(1, 2).contiguous().view(B, T, C)
22
- return self.out_proj(context)
 
23
 
24
  class FeedForward(nn.Module):
25
- def __init__(self, d_model=256, d_ff=512):
26
  super().__init__()
27
  self.net = nn.Sequential(
28
- nn.Linear(d_model, d_ff),
29
  nn.ReLU(),
30
- nn.Linear(d_ff, d_model)
31
  )
32
 
33
  def forward(self, x):
34
  return self.net(x)
35
 
36
- class DecoderBlock(nn.Module):
37
- def __init__(self, d_model=256, nhead=4, d_ff=512):
38
  super().__init__()
 
39
  self.ln1 = nn.LayerNorm(d_model)
40
- self.attn = MultiHeadSelfAttention(d_model, nhead)
41
  self.ln2 = nn.LayerNorm(d_model)
42
- self.ffn = FeedForward(d_model, d_ff)
43
 
44
  def forward(self, x):
45
  x = x + self.attn(self.ln1(x))
@@ -47,22 +52,21 @@ class DecoderBlock(nn.Module):
47
  return x
48
 
49
  class EvoDecoder(nn.Module):
50
- def __init__(self, vocab_size=50257, d_model=256, nhead=4, num_layers=3, d_ff=512, max_len=512):
51
  super().__init__()
52
  self.token_emb = nn.Embedding(vocab_size, d_model)
53
- self.pos_emb = nn.Embedding(max_len, d_model)
54
- self.blocks = nn.ModuleList([
55
- DecoderBlock(d_model, nhead, d_ff) for _ in range(num_layers)
56
  ])
57
  self.ln_f = nn.LayerNorm(d_model)
58
  self.fc_out = nn.Linear(d_model, vocab_size)
59
 
60
  def forward(self, x):
61
- B, T = x.shape
62
- token = self.token_emb(x)
63
- pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0))
64
- x = token + pos
65
- for block in self.blocks:
66
- x = block(x)
67
  x = self.ln_f(x)
68
  return self.fc_out(x)
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
+ class SelfAttention(nn.Module):
6
+ def __init__(self, d_model, nhead):
7
  super().__init__()
8
+ self.qkv_proj = nn.Linear(d_model, 3 * d_model)
 
 
9
  self.out_proj = nn.Linear(d_model, d_model)
10
+ self.nhead = nhead
11
+ self.d_model = d_model
12
 
13
  def forward(self, x):
14
  B, T, C = x.size()
15
+ qkv = self.qkv_proj(x)
16
  q, k, v = qkv.chunk(3, dim=-1)
17
+
18
+ q = q.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
19
+ k = k.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
20
+ v = v.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
21
+
22
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (C // self.nhead) ** 0.5
23
  attn = torch.softmax(scores, dim=-1)
24
+ out = torch.matmul(attn, v)
25
+
26
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
27
+ return self.out_proj(out)
28
 
29
  class FeedForward(nn.Module):
30
+ def __init__(self, d_model, dim_feedforward):
31
  super().__init__()
32
  self.net = nn.Sequential(
33
+ nn.Linear(d_model, dim_feedforward),
34
  nn.ReLU(),
35
+ nn.Linear(dim_feedforward, d_model)
36
  )
37
 
38
  def forward(self, x):
39
  return self.net(x)
40
 
41
+ class TransformerBlock(nn.Module):
42
+ def __init__(self, d_model, nhead, dim_feedforward):
43
  super().__init__()
44
+ self.attn = SelfAttention(d_model, nhead)
45
  self.ln1 = nn.LayerNorm(d_model)
46
+ self.ffn = FeedForward(d_model, dim_feedforward)
47
  self.ln2 = nn.LayerNorm(d_model)
 
48
 
49
  def forward(self, x):
50
  x = x + self.attn(self.ln1(x))
 
52
  return x
53
 
54
  class EvoDecoder(nn.Module):
55
+ def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024):
56
  super().__init__()
57
  self.token_emb = nn.Embedding(vocab_size, d_model)
58
+ self.pos_emb = nn.Embedding(512, d_model)
59
+ self.blocks = nn.Sequential(*[
60
+ TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)
61
  ])
62
  self.ln_f = nn.LayerNorm(d_model)
63
  self.fc_out = nn.Linear(d_model, vocab_size)
64
 
65
  def forward(self, x):
66
+ B, T = x.size()
67
+ tok = self.token_emb(x)
68
+ pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T))
69
+ x = tok + pos
70
+ x = self.blocks(x)
 
71
  x = self.ln_f(x)
72
  return self.fc_out(x)