HemanM commited on
Commit
db7c38f
·
verified ·
1 Parent(s): 08624da

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +4 -5
evo_model.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import math
4
 
5
  class MultiHeadSelfAttention(nn.Module):
6
- def __init__(self, d_model, nhead):
7
  super().__init__()
8
  self.nhead = nhead
9
  self.d_head = d_model // nhead
@@ -15,7 +15,6 @@ class MultiHeadSelfAttention(nn.Module):
15
  qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
16
  q, k, v = qkv.chunk(3, dim=-1)
17
  q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
18
-
19
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
20
  attn = torch.softmax(scores, dim=-1)
21
  context = attn @ v
@@ -23,7 +22,7 @@ class MultiHeadSelfAttention(nn.Module):
23
  return self.out_proj(context)
24
 
25
  class FeedForward(nn.Module):
26
- def __init__(self, d_model, d_ff):
27
  super().__init__()
28
  self.net = nn.Sequential(
29
  nn.Linear(d_model, d_ff),
@@ -35,7 +34,7 @@ class FeedForward(nn.Module):
35
  return self.net(x)
36
 
37
  class DecoderBlock(nn.Module):
38
- def __init__(self, d_model, nhead, d_ff):
39
  super().__init__()
40
  self.ln1 = nn.LayerNorm(d_model)
41
  self.attn = MultiHeadSelfAttention(d_model, nhead)
@@ -48,7 +47,7 @@ class DecoderBlock(nn.Module):
48
  return x
49
 
50
  class EvoDecoder(nn.Module):
51
- def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, d_ff=2048, max_len=512):
52
  super().__init__()
53
  self.token_emb = nn.Embedding(vocab_size, d_model)
54
  self.pos_emb = nn.Embedding(max_len, d_model)
 
3
  import math
4
 
5
  class MultiHeadSelfAttention(nn.Module):
6
+ def __init__(self, d_model=256, nhead=4):
7
  super().__init__()
8
  self.nhead = nhead
9
  self.d_head = d_model // nhead
 
15
  qkv = self.qkv_proj(x).view(B, T, self.nhead, 3 * self.d_head)
16
  q, k, v = qkv.chunk(3, dim=-1)
17
  q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
 
18
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
19
  attn = torch.softmax(scores, dim=-1)
20
  context = attn @ v
 
22
  return self.out_proj(context)
23
 
24
  class FeedForward(nn.Module):
25
+ def __init__(self, d_model=256, d_ff=512):
26
  super().__init__()
27
  self.net = nn.Sequential(
28
  nn.Linear(d_model, d_ff),
 
34
  return self.net(x)
35
 
36
  class DecoderBlock(nn.Module):
37
+ def __init__(self, d_model=256, nhead=4, d_ff=512):
38
  super().__init__()
39
  self.ln1 = nn.LayerNorm(d_model)
40
  self.attn = MultiHeadSelfAttention(d_model, nhead)
 
47
  return x
48
 
49
  class EvoDecoder(nn.Module):
50
+ def __init__(self, vocab_size=50257, d_model=256, nhead=4, num_layers=3, d_ff=512, max_len=512):
51
  super().__init__()
52
  self.token_emb = nn.Embedding(vocab_size, d_model)
53
  self.pos_emb = nn.Embedding(max_len, d_model)