HemanM commited on
Commit
c6b6ef5
Β·
verified Β·
1 Parent(s): ca61944

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +36 -33
evo_model.py CHANGED
@@ -1,54 +1,57 @@
1
  import torch
2
- from torch import nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
- from model import EvoTransformer # assumes your core model is in model.py
5
 
6
  class EvoTransformerConfig(PretrainedConfig):
7
- model_type = "evo-transformer"
8
-
9
- def __init__(
10
- self,
11
- vocab_size=30522,
12
- d_model=256,
13
- nhead=4,
14
- dim_feedforward=512,
15
- num_hidden_layers=4,
16
- **kwargs
17
- ):
18
  super().__init__(**kwargs)
19
- self.vocab_size = vocab_size
20
- self.d_model = d_model
21
- self.nhead = nhead
22
- self.dim_feedforward = dim_feedforward
23
- self.num_hidden_layers = num_hidden_layers
24
 
25
  class EvoTransformerForClassification(PreTrainedModel):
26
  config_class = EvoTransformerConfig
27
 
28
  def __init__(self, config):
29
  super().__init__(config)
30
- self.model = EvoTransformer(
31
- vocab_size=config.vocab_size,
32
- d_model=config.d_model,
33
- nhead=config.nhead,
34
- dim_feedforward=config.dim_feedforward,
35
- num_layers=config.num_hidden_layers
 
 
 
 
36
  )
37
- self.classifier = nn.Linear(config.d_model, 2) # 2-way classification
 
 
 
 
 
 
 
 
38
 
39
- def forward(self, input_ids):
40
- hidden = self.model(input_ids) # (batch_size, seq_len, d_model)
41
- pooled = hidden[:, 0, :] # Use the first token as a summary
42
- logits = self.classifier(pooled) # (batch_size, 2)
43
  return logits
44
 
45
  def save_pretrained(self, save_directory):
46
- torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
47
- self.config.save_pretrained(save_directory)
 
 
 
48
 
49
  @classmethod
50
  def from_pretrained(cls, load_directory):
51
- config = EvoTransformerConfig.from_pretrained(load_directory)
 
 
52
  model = cls(config)
53
- model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
54
  return model
 
1
  import torch
2
+ import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
 
4
 
5
  class EvoTransformerConfig(PretrainedConfig):
6
+ def __init__(self, hidden_size=384, num_layers=6, num_labels=2, **kwargs):
 
 
 
 
 
 
 
 
 
 
7
  super().__init__(**kwargs)
8
+ self.hidden_size = hidden_size
9
+ self.num_layers = num_layers
10
+ self.num_labels = num_labels
 
 
11
 
12
  class EvoTransformerForClassification(PreTrainedModel):
13
  config_class = EvoTransformerConfig
14
 
15
  def __init__(self, config):
16
  super().__init__(config)
17
+ self.config = config
18
+ self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
19
+ self.layers = nn.ModuleList([
20
+ nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=6, dim_feedforward=1024)
21
+ for _ in range(config.num_layers)
22
+ ])
23
+ self.classifier = nn.Sequential(
24
+ nn.Linear(config.hidden_size, 256),
25
+ nn.ReLU(),
26
+ nn.Linear(256, config.num_labels)
27
  )
28
+ self.init_weights()
29
+
30
+ def forward(self, input_ids, attention_mask=None, labels=None):
31
+ x = self.embedding(input_ids) # [batch, seq_len, hidden_size]
32
+ x = x.transpose(0, 1) # Transformer expects [seq_len, batch, hidden_size]
33
+ for layer in self.layers:
34
+ x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None)
35
+ x = x.mean(dim=0) # mean pooling over seq_len
36
+ logits = self.classifier(x)
37
 
38
+ if labels is not None:
39
+ loss = nn.functional.cross_entropy(logits, labels)
40
+ return loss, logits
 
41
  return logits
42
 
43
  def save_pretrained(self, save_directory):
44
+ import os, json
45
+ os.makedirs(save_directory, exist_ok=True)
46
+ torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
47
+ with open(f"{save_directory}/config.json", "w") as f:
48
+ f.write(self.config.to_json_string())
49
 
50
  @classmethod
51
  def from_pretrained(cls, load_directory):
52
+ config_path = f"{load_directory}/config.json"
53
+ model_path = f"{load_directory}/pytorch_model.bin"
54
+ config = EvoTransformerConfig.from_json_file(config_path)
55
  model = cls(config)
56
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
57
  return model