HemanM commited on
Commit
558d45c
·
verified ·
1 Parent(s): 944fe21

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +11 -10
evo_model.py CHANGED
@@ -4,14 +4,14 @@ import torch.nn as nn
4
  import math
5
 
6
  class PositionalEncoding(nn.Module):
7
- def __init__(self, d_model, max_len=128): # Match saved model: [1, 128, 512]
8
  super().__init__()
9
  pe = torch.zeros(max_len, d_model)
10
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
11
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
12
  pe[:, 0::2] = torch.sin(position * div_term)
13
  pe[:, 1::2] = torch.cos(position * div_term)
14
- pe = pe.unsqueeze(0) # (1, max_len, d_model)
15
  self.register_buffer('pe', pe)
16
 
17
  def forward(self, x):
@@ -23,8 +23,8 @@ class PositionalEncoding(nn.Module):
23
  class EvoDecoderModel(nn.Module):
24
  def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
25
  super().__init__()
26
- self.token_embed = nn.Embedding(vocab_size, d_model) # ✅ matches saved key: token_embed.weight
27
- self.pos_encoder = PositionalEncoding(d_model) # ✅ fixed dimension and safe slicing
28
  decoder_layer = nn.TransformerDecoderLayer(
29
  d_model=d_model,
30
  nhead=nhead,
@@ -33,12 +33,13 @@ class EvoDecoderModel(nn.Module):
33
  batch_first=True
34
  )
35
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
36
- self.lm_head = nn.Linear(d_model, vocab_size) # ✅ matches saved key: lm_head.weight
37
 
38
  def forward(self, input_ids):
39
- x = self.token_embed(input_ids) # (B, T, D)
40
- x = self.pos_encoder(x) # (B, T, D)
 
41
  seq_len = x.size(1)
42
  mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
43
  x = self.decoder(x, x, tgt_mask=mask)
44
- return self.lm_head(x) # (B, T, V)
 
4
  import math
5
 
6
  class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model, max_len=512): # increased to 512
8
  super().__init__()
9
  pe = torch.zeros(max_len, d_model)
10
+ position = torch.arange(0, max_len).unsqueeze(1)
11
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
12
  pe[:, 0::2] = torch.sin(position * div_term)
13
  pe[:, 1::2] = torch.cos(position * div_term)
14
+ pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
15
  self.register_buffer('pe', pe)
16
 
17
  def forward(self, x):
 
23
  class EvoDecoderModel(nn.Module):
24
  def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
25
  super().__init__()
26
+ self.token_embed = nn.Embedding(vocab_size, d_model)
27
+ self.pos_encoder = PositionalEncoding(d_model)
28
  decoder_layer = nn.TransformerDecoderLayer(
29
  d_model=d_model,
30
  nhead=nhead,
 
33
  batch_first=True
34
  )
35
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
36
+ self.lm_head = nn.Linear(d_model, vocab_size)
37
 
38
  def forward(self, input_ids):
39
+ input_ids = input_ids[:, :128] # clip input to match saved model's trained length
40
+ x = self.token_embed(input_ids)
41
+ x = self.pos_encoder(x)
42
  seq_len = x.size(1)
43
  mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
44
  x = self.decoder(x, x, tgt_mask=mask)
45
+ return self.lm_head(x)