HemanM commited on
Commit
08624da
·
verified ·
1 Parent(s): 1cb13b6

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +47 -46
evo_model.py CHANGED
@@ -1,68 +1,69 @@
1
- import torch.nn as nn
2
  import torch
 
 
3
 
4
- class FeedForward(nn.Module):
5
- def __init__(self, dim, hidden_dim):
6
  super().__init__()
7
- self.net = nn.Sequential(
8
- nn.Linear(dim, hidden_dim),
9
- nn.GELU(),
10
- nn.Dropout(0.1),
11
- nn.Linear(hidden_dim, dim),
12
- nn.Dropout(0.1),
13
- )
14
 
15
  def forward(self, x):
16
- return self.net(x)
 
 
 
17
 
18
- class Attention(nn.Module):
19
- def __init__(self, dim, heads=4):
20
- super().__init__()
21
- self.heads = heads
22
- self.scale = dim ** -0.5
23
 
24
- self.qkv_proj = nn.Linear(dim, dim * 3)
25
- self.out_proj = nn.Linear(dim, dim)
 
 
 
 
 
 
26
 
27
  def forward(self, x):
28
- B, T, C = x.shape
29
- qkv = self.qkv_proj(x).reshape(B, T, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
30
- q, k, v = qkv[0], qkv[1], qkv[2]
31
- attn_scores = (q @ k.transpose(-2, -1)) * self.scale
32
- attn_weights = attn_scores.softmax(dim=-1)
33
- attn_output = attn_weights @ v
34
- attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
35
- return self.out_proj(attn_output)
36
 
37
- class TransformerBlock(nn.Module):
38
- def __init__(self, dim, heads, hidden_dim):
39
  super().__init__()
40
- self.attn = Attention(dim, heads)
41
- self.ffn = FeedForward(dim, hidden_dim)
42
- self.ln1 = nn.LayerNorm(dim)
43
- self.ln2 = nn.LayerNorm(dim)
44
 
45
  def forward(self, x):
46
  x = x + self.attn(self.ln1(x))
47
  x = x + self.ffn(self.ln2(x))
48
  return x
49
 
50
- class EvoDecoderModel(nn.Module):
51
- def __init__(self, vocab_size, dim=256, depth=3, heads=4, hidden_dim=512):
52
  super().__init__()
53
- self.token_emb = nn.Embedding(vocab_size, dim)
54
- self.pos_emb = nn.Embedding(512, dim)
55
- self.blocks = nn.Sequential(*[TransformerBlock(dim, heads, hidden_dim) for _ in range(depth)])
56
- self.ln_f = nn.LayerNorm(dim)
57
- self.fc_out = nn.Linear(dim, vocab_size)
 
 
58
 
59
  def forward(self, x):
60
  B, T = x.shape
61
- pos = torch.arange(0, T, device=x.device).unsqueeze(0)
62
- tok = self.token_emb(x)
63
- pos = self.pos_emb(pos)
64
- x = tok + pos
65
- x = self.blocks(x)
66
  x = self.ln_f(x)
67
- logits = self.fc_out(x)
68
- return logits
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import math
4
 
5
+ class MultiHeadSelfAttention(nn.Module):
6
+ def __init__(self, d_model, nhead):
7
  super().__init__()
8
+ self.nhead = nhead
9
+ self.d_head = d_model // nhead
10
+ self.qkv_proj = nn.Linear(d_model, d_model * 3)
11
+ self.out_proj = nn.Linear(d_model, d_model)
 
 
 
12
 
13
  def forward(self, x):
14
+ B, T, C = x.size()
15
+ qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
16
+ q, k, v = qkv.chunk(3, dim=-1)
17
+ q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
18
 
19
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
20
+ attn = torch.softmax(scores, dim=-1)
21
+ context = attn @ v
22
+ context = context.transpose(1, 2).contiguous().view(B, T, C)
23
+ return self.out_proj(context)
24
 
25
+ class FeedForward(nn.Module):
26
+ def __init__(self, d_model, d_ff):
27
+ super().__init__()
28
+ self.net = nn.Sequential(
29
+ nn.Linear(d_model, d_ff),
30
+ nn.ReLU(),
31
+ nn.Linear(d_ff, d_model)
32
+ )
33
 
34
  def forward(self, x):
35
+ return self.net(x)
 
 
 
 
 
 
 
36
 
37
+ class DecoderBlock(nn.Module):
38
+ def __init__(self, d_model, nhead, d_ff):
39
  super().__init__()
40
+ self.ln1 = nn.LayerNorm(d_model)
41
+ self.attn = MultiHeadSelfAttention(d_model, nhead)
42
+ self.ln2 = nn.LayerNorm(d_model)
43
+ self.ffn = FeedForward(d_model, d_ff)
44
 
45
  def forward(self, x):
46
  x = x + self.attn(self.ln1(x))
47
  x = x + self.ffn(self.ln2(x))
48
  return x
49
 
50
+ class EvoDecoder(nn.Module):
51
+ def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, d_ff=2048, max_len=512):
52
  super().__init__()
53
+ self.token_emb = nn.Embedding(vocab_size, d_model)
54
+ self.pos_emb = nn.Embedding(max_len, d_model)
55
+ self.blocks = nn.ModuleList([
56
+ DecoderBlock(d_model, nhead, d_ff) for _ in range(num_layers)
57
+ ])
58
+ self.ln_f = nn.LayerNorm(d_model)
59
+ self.fc_out = nn.Linear(d_model, vocab_size)
60
 
61
  def forward(self, x):
62
  B, T = x.shape
63
+ token = self.token_emb(x)
64
+ pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0))
65
+ x = token + pos
66
+ for block in self.blocks:
67
+ x = block(x)
68
  x = self.ln_f(x)
69
+ return self.fc_out(x)