HemanM commited on
Commit
e897bf0
·
verified ·
1 Parent(s): f997552

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +55 -8
evo_model.py CHANGED
@@ -1,9 +1,9 @@
1
- # ✅ evo_model.py – HF-compatible wrapper for EvoTransformer
2
-
3
  import torch
4
  from torch import nn
5
- from transformers import PreTrainedModel, PretrainedConfig
6
  from model import EvoTransformer # assumes your core model is in model.py
 
 
7
 
8
  class EvoTransformerConfig(PretrainedConfig):
9
  model_type = "evo-transformer"
@@ -36,17 +36,64 @@ class EvoTransformerForClassification(PreTrainedModel):
36
  dim_feedforward=config.dim_feedforward,
37
  num_layers=config.num_hidden_layers
38
  )
 
39
 
40
- def forward(self, input_ids):
41
- return self.model(input_ids)
 
 
 
42
 
43
  def save_pretrained(self, save_directory):
44
- torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
45
  self.config.save_pretrained(save_directory)
46
 
47
  @classmethod
48
  def from_pretrained(cls, load_directory):
49
  config = EvoTransformerConfig.from_pretrained(load_directory)
50
  model = cls(config)
51
- model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
52
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
4
  from model import EvoTransformer # assumes your core model is in model.py
5
+ from torch.utils.data import DataLoader, Dataset
6
+ import torch.optim as optim
7
 
8
  class EvoTransformerConfig(PretrainedConfig):
9
  model_type = "evo-transformer"
 
36
  dim_feedforward=config.dim_feedforward,
37
  num_layers=config.num_hidden_layers
38
  )
39
+ self.classifier = nn.Linear(config.d_model, 2)
40
 
41
+ def forward(self, input_ids, attention_mask=None):
42
+ x = self.model(input_ids) # (batch_size, seq_len, hidden_size)
43
+ pooled = x[:, 0, :] # Take [CLS]-like first token
44
+ logits = self.classifier(pooled)
45
+ return logits
46
 
47
  def save_pretrained(self, save_directory):
48
+ torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
49
  self.config.save_pretrained(save_directory)
50
 
51
  @classmethod
52
  def from_pretrained(cls, load_directory):
53
  config = EvoTransformerConfig.from_pretrained(load_directory)
54
  model = cls(config)
55
+ model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
56
+ return model
57
+
58
+ # ✅ Add this retraining logic
59
+ def train_evo_transformer(df, epochs=1):
60
+ class EvoDataset(Dataset):
61
+ def __init__(self, dataframe, tokenizer):
62
+ self.df = dataframe
63
+ self.tokenizer = tokenizer
64
+
65
+ def __len__(self):
66
+ return len(self.df)
67
+
68
+ def __getitem__(self, idx):
69
+ row = self.df.iloc[idx]
70
+ text = f"{row['goal']} [SEP] {row['sol1']} [SEP] {row['sol2']}"
71
+ encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=64, return_tensors='pt')
72
+ input_ids = encoding['input_ids'].squeeze(0)
73
+ attention_mask = encoding['attention_mask'].squeeze(0)
74
+ label = torch.tensor(0 if row['correct'] == 'Solution 1' else 1)
75
+ return input_ids, attention_mask, label
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
78
+ config = EvoTransformerConfig()
79
+ model = EvoTransformerForClassification(config)
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ model.to(device)
82
+ model.train()
83
+
84
+ dataset = EvoDataset(df, tokenizer)
85
+ loader = DataLoader(dataset, batch_size=8, shuffle=True)
86
+ optimizer = optim.Adam(model.parameters(), lr=2e-5)
87
+ criterion = nn.CrossEntropyLoss()
88
+
89
+ for epoch in range(epochs):
90
+ for input_ids, attention_mask, labels in loader:
91
+ input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
92
+ logits = model(input_ids, attention_mask)
93
+ loss = criterion(logits, labels)
94
+ optimizer.zero_grad()
95
+ loss.backward()
96
+ optimizer.step()
97
+
98
+ torch.save(model.state_dict(), "trained_model.pt")
99
+ return True