Update evo_model.py
Browse files- evo_model.py +4 -5
evo_model.py
CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
3 |
import math
|
4 |
|
5 |
class MultiHeadSelfAttention(nn.Module):
|
6 |
-
def __init__(self, d_model, nhead):
|
7 |
super().__init__()
|
8 |
self.nhead = nhead
|
9 |
self.d_head = d_model // nhead
|
@@ -15,7 +15,6 @@ class MultiHeadSelfAttention(nn.Module):
|
|
15 |
qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
|
16 |
q, k, v = qkv.chunk(3, dim=-1)
|
17 |
q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
|
18 |
-
|
19 |
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
|
20 |
attn = torch.softmax(scores, dim=-1)
|
21 |
context = attn @ v
|
@@ -23,7 +22,7 @@ class MultiHeadSelfAttention(nn.Module):
|
|
23 |
return self.out_proj(context)
|
24 |
|
25 |
class FeedForward(nn.Module):
|
26 |
-
def __init__(self, d_model, d_ff):
|
27 |
super().__init__()
|
28 |
self.net = nn.Sequential(
|
29 |
nn.Linear(d_model, d_ff),
|
@@ -35,7 +34,7 @@ class FeedForward(nn.Module):
|
|
35 |
return self.net(x)
|
36 |
|
37 |
class DecoderBlock(nn.Module):
|
38 |
-
def __init__(self, d_model, nhead, d_ff):
|
39 |
super().__init__()
|
40 |
self.ln1 = nn.LayerNorm(d_model)
|
41 |
self.attn = MultiHeadSelfAttention(d_model, nhead)
|
@@ -48,7 +47,7 @@ class DecoderBlock(nn.Module):
|
|
48 |
return x
|
49 |
|
50 |
class EvoDecoder(nn.Module):
|
51 |
-
def __init__(self, vocab_size, d_model=
|
52 |
super().__init__()
|
53 |
self.token_emb = nn.Embedding(vocab_size, d_model)
|
54 |
self.pos_emb = nn.Embedding(max_len, d_model)
|
|
|
3 |
import math
|
4 |
|
5 |
class MultiHeadSelfAttention(nn.Module):
|
6 |
+
def __init__(self, d_model=256, nhead=4):
|
7 |
super().__init__()
|
8 |
self.nhead = nhead
|
9 |
self.d_head = d_model // nhead
|
|
|
15 |
qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
|
16 |
q, k, v = qkv.chunk(3, dim=-1)
|
17 |
q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
|
|
|
18 |
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
|
19 |
attn = torch.softmax(scores, dim=-1)
|
20 |
context = attn @ v
|
|
|
22 |
return self.out_proj(context)
|
23 |
|
24 |
class FeedForward(nn.Module):
|
25 |
+
def __init__(self, d_model=256, d_ff=512):
|
26 |
super().__init__()
|
27 |
self.net = nn.Sequential(
|
28 |
nn.Linear(d_model, d_ff),
|
|
|
34 |
return self.net(x)
|
35 |
|
36 |
class DecoderBlock(nn.Module):
|
37 |
+
def __init__(self, d_model=256, nhead=4, d_ff=512):
|
38 |
super().__init__()
|
39 |
self.ln1 = nn.LayerNorm(d_model)
|
40 |
self.attn = MultiHeadSelfAttention(d_model, nhead)
|
|
|
47 |
return x
|
48 |
|
49 |
class EvoDecoder(nn.Module):
|
50 |
+
def __init__(self, vocab_size=50257, d_model=256, nhead=4, num_layers=3, d_ff=512, max_len=512):
|
51 |
super().__init__()
|
52 |
self.token_emb = nn.Embedding(vocab_size, d_model)
|
53 |
self.pos_emb = nn.Embedding(max_len, d_model)
|