HemanM commited on
Commit
6604c50
·
verified ·
1 Parent(s): 5876a92

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +8 -53
evo_model.py CHANGED
@@ -1,9 +1,7 @@
1
  import torch
2
  from torch import nn
3
- from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
4
  from model import EvoTransformer # assumes your core model is in model.py
5
- from torch.utils.data import DataLoader, Dataset
6
- import torch.optim as optim
7
 
8
  class EvoTransformerConfig(PretrainedConfig):
9
  model_type = "evo-transformer"
@@ -36,64 +34,21 @@ class EvoTransformerForClassification(PreTrainedModel):
36
  dim_feedforward=config.dim_feedforward,
37
  num_layers=config.num_hidden_layers
38
  )
39
- self.classifier = nn.Linear(config.d_model, 2)
40
 
41
- def forward(self, input_ids, attention_mask=None):
42
- x = self.model(input_ids) # (batch_size, seq_len, hidden_size)
43
- pooled = x[:, 0, :] # Take [CLS]-like first token
44
- logits = self.classifier(pooled)
45
  return logits
46
 
47
  def save_pretrained(self, save_directory):
48
- torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
49
  self.config.save_pretrained(save_directory)
50
 
51
  @classmethod
52
  def from_pretrained(cls, load_directory):
53
  config = EvoTransformerConfig.from_pretrained(load_directory)
54
  model = cls(config)
55
- model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
56
  return model
57
-
58
- # ✅ Add this retraining logic
59
- def train_evo_transformer(df, epochs=1):
60
- class EvoDataset(Dataset):
61
- def __init__(self, dataframe, tokenizer):
62
- self.df = dataframe
63
- self.tokenizer = tokenizer
64
-
65
- def __len__(self):
66
- return len(self.df)
67
-
68
- def __getitem__(self, idx):
69
- row = self.df.iloc[idx]
70
- text = f"{row['goal']} [SEP] {row['sol1']} [SEP] {row['sol2']}"
71
- encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=64, return_tensors='pt')
72
- input_ids = encoding['input_ids'].squeeze(0)
73
- attention_mask = encoding['attention_mask'].squeeze(0)
74
- label = torch.tensor(0 if row['correct'] == 'Solution 1' else 1)
75
- return input_ids, attention_mask, label
76
-
77
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
78
- config = EvoTransformerConfig()
79
- model = EvoTransformerForClassification(config)
80
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
- model.to(device)
82
- model.train()
83
-
84
- dataset = EvoDataset(df, tokenizer)
85
- loader = DataLoader(dataset, batch_size=8, shuffle=True)
86
- optimizer = optim.Adam(model.parameters(), lr=2e-5)
87
- criterion = nn.CrossEntropyLoss()
88
-
89
- for epoch in range(epochs):
90
- for input_ids, attention_mask, labels in loader:
91
- input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
92
- logits = model(input_ids, attention_mask)
93
- loss = criterion(logits, labels)
94
- optimizer.zero_grad()
95
- loss.backward()
96
- optimizer.step()
97
-
98
- torch.save(model.state_dict(), "trained_model.pt")
99
- return True
 
1
  import torch
2
  from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
  from model import EvoTransformer # assumes your core model is in model.py
 
 
5
 
6
  class EvoTransformerConfig(PretrainedConfig):
7
  model_type = "evo-transformer"
 
34
  dim_feedforward=config.dim_feedforward,
35
  num_layers=config.num_hidden_layers
36
  )
37
+ self.classifier = nn.Linear(config.d_model, 2) # 2-way classification
38
 
39
+ def forward(self, input_ids):
40
+ hidden = self.model(input_ids) # (batch_size, seq_len, d_model)
41
+ pooled = hidden[:, 0, :] # Use the first token as a summary
42
+ logits = self.classifier(pooled) # (batch_size, 2)
43
  return logits
44
 
45
  def save_pretrained(self, save_directory):
46
+ torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
47
  self.config.save_pretrained(save_directory)
48
 
49
  @classmethod
50
  def from_pretrained(cls, load_directory):
51
  config = EvoTransformerConfig.from_pretrained(load_directory)
52
  model = cls(config)
53
+ model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
54
  return model