File size: 3,654 Bytes
daeebb8
 
e897bf0
daeebb8
e897bf0
 
daeebb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e897bf0
daeebb8
e897bf0
 
 
 
 
daeebb8
 
e897bf0
daeebb8
 
 
 
 
 
e897bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
from model import EvoTransformer  # assumes your core model is in model.py
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

class EvoTransformerConfig(PretrainedConfig):
    model_type = "evo-transformer"

    def __init__(
        self,
        vocab_size=30522,
        d_model=256,
        nhead=4,
        dim_feedforward=512,
        num_hidden_layers=4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.num_hidden_layers = num_hidden_layers

class EvoTransformerForClassification(PreTrainedModel):
    config_class = EvoTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = EvoTransformer(
            vocab_size=config.vocab_size,
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            num_layers=config.num_hidden_layers
        )
        self.classifier = nn.Linear(config.d_model, 2)

    def forward(self, input_ids, attention_mask=None):
        x = self.model(input_ids)  # (batch_size, seq_len, hidden_size)
        pooled = x[:, 0, :]  # Take [CLS]-like first token
        logits = self.classifier(pooled)
        return logits

    def save_pretrained(self, save_directory):
        torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, load_directory):
        config = EvoTransformerConfig.from_pretrained(load_directory)
        model = cls(config)
        model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
        return model

# ✅ Add this retraining logic
def train_evo_transformer(df, epochs=1):
    class EvoDataset(Dataset):
        def __init__(self, dataframe, tokenizer):
            self.df = dataframe
            self.tokenizer = tokenizer

        def __len__(self):
            return len(self.df)

        def __getitem__(self, idx):
            row = self.df.iloc[idx]
            text = f"{row['goal']} [SEP] {row['sol1']} [SEP] {row['sol2']}"
            encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=64, return_tensors='pt')
            input_ids = encoding['input_ids'].squeeze(0)
            attention_mask = encoding['attention_mask'].squeeze(0)
            label = torch.tensor(0 if row['correct'] == 'Solution 1' else 1)
            return input_ids, attention_mask, label

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    config = EvoTransformerConfig()
    model = EvoTransformerForClassification(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    dataset = EvoDataset(df, tokenizer)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=2e-5)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for input_ids, attention_mask, labels in loader:
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    torch.save(model.state_dict(), "trained_model.pt")
    return True