Update evo_model.py
Browse files- evo_model.py +4 -2
evo_model.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
|
2 |
import torch
|
|
|
3 |
|
4 |
class EvoDecoderModel(nn.Module):
|
5 |
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=128):
|
6 |
super().__init__()
|
7 |
self.token_embed = nn.Embedding(vocab_size, d_model)
|
8 |
-
self.pos_encoder = nn.Parameter(torch.zeros(1, max_len, d_model)) #
|
|
|
9 |
decoder_layer = nn.TransformerDecoderLayer(
|
10 |
d_model=d_model,
|
11 |
nhead=nhead,
|
|
|
1 |
+
# evo_model.py — EvoDecoderModel matching saved checkpoint
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
|
5 |
class EvoDecoderModel(nn.Module):
|
6 |
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=128):
|
7 |
super().__init__()
|
8 |
self.token_embed = nn.Embedding(vocab_size, d_model)
|
9 |
+
self.pos_encoder = nn.Parameter(torch.zeros(1, max_len, d_model)) # matches saved shape [1, 128, 512]
|
10 |
+
|
11 |
decoder_layer = nn.TransformerDecoderLayer(
|
12 |
d_model=d_model,
|
13 |
nhead=nhead,
|