HemanM commited on
Commit
0aced0a
·
verified ·
1 Parent(s): 00ea8bb

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +6 -12
evo_model.py CHANGED
@@ -4,20 +4,21 @@ import torch.nn as nn
4
  import math
5
 
6
  class PositionalEncoding(nn.Module):
7
- def __init__(self, d_model, max_len=512): # ⬅ increased to 512
8
  super().__init__()
9
  pe = torch.zeros(max_len, d_model)
10
  position = torch.arange(0, max_len).unsqueeze(1)
11
  div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
12
  pe[:, 0::2] = torch.sin(position * div_term)
13
  pe[:, 1::2] = torch.cos(position * div_term)
14
- pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
15
  self.register_buffer('pe', pe)
 
16
 
17
  def forward(self, x):
18
  seq_len = x.size(1)
19
- if seq_len > self.pe.size(1):
20
- raise ValueError(f"Input length {seq_len} exceeds max_len {self.pe.size(1)}")
21
  return x + self.pe[:, :seq_len]
22
 
23
  class EvoDecoderModel(nn.Module):
@@ -25,18 +26,11 @@ class EvoDecoderModel(nn.Module):
25
  super().__init__()
26
  self.token_embed = nn.Embedding(vocab_size, d_model)
27
  self.pos_encoder = PositionalEncoding(d_model)
28
- decoder_layer = nn.TransformerDecoderLayer(
29
- d_model=d_model,
30
- nhead=nhead,
31
- dim_feedforward=dim_feedforward,
32
- dropout=dropout,
33
- batch_first=True
34
- )
35
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
36
  self.lm_head = nn.Linear(d_model, vocab_size)
37
 
38
  def forward(self, input_ids):
39
- input_ids = input_ids[:, :128] # ⬅ clip input to match saved model's trained length
40
  x = self.token_embed(input_ids)
41
  x = self.pos_encoder(x)
42
  seq_len = x.size(1)
 
4
  import math
5
 
6
  class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model, max_len=128):
8
  super().__init__()
9
  pe = torch.zeros(max_len, d_model)
10
  position = torch.arange(0, max_len).unsqueeze(1)
11
  div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
12
  pe[:, 0::2] = torch.sin(position * div_term)
13
  pe[:, 1::2] = torch.cos(position * div_term)
14
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
15
  self.register_buffer('pe', pe)
16
+ self.max_len = max_len
17
 
18
  def forward(self, x):
19
  seq_len = x.size(1)
20
+ if seq_len > self.max_len:
21
+ raise ValueError(f"Input length {seq_len} exceeds max_len {self.max_len}")
22
  return x + self.pe[:, :seq_len]
23
 
24
  class EvoDecoderModel(nn.Module):
 
26
  super().__init__()
27
  self.token_embed = nn.Embedding(vocab_size, d_model)
28
  self.pos_encoder = PositionalEncoding(d_model)
29
+ decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
 
 
 
 
 
 
30
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
31
  self.lm_head = nn.Linear(d_model, vocab_size)
32
 
33
  def forward(self, input_ids):
 
34
  x = self.token_embed(input_ids)
35
  x = self.pos_encoder(x)
36
  seq_len = x.size(1)