HemanM commited on
Commit
646f772
Β·
verified Β·
1 Parent(s): 6ea260a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -52
model.py CHANGED
@@ -1,57 +1,13 @@
1
- import torch
2
  import torch.nn as nn
3
- from transformers import PreTrainedModel, PretrainedConfig
4
 
5
- class EvoTransformerConfig(PretrainedConfig):
6
- def __init__(self, hidden_size=384, num_layers=6, num_labels=2, **kwargs):
7
- super().__init__(**kwargs)
8
- self.hidden_size = hidden_size
9
- self.num_layers = num_layers
10
- self.num_labels = num_labels
11
-
12
- class EvoTransformerForClassification(PreTrainedModel):
13
- config_class = EvoTransformerConfig
14
-
15
- def __init__(self, config):
16
- super().__init__(config)
17
- self.config = config
18
- self.embedding = nn.Embedding(30522, config.hidden_size)
19
- self.layers = nn.ModuleList([
20
- nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=6, dim_feedforward=1024)
21
- for _ in range(config.num_layers)
22
- ])
23
- self.classifier = nn.Sequential(
24
- nn.Linear(config.hidden_size, 256),
25
  nn.ReLU(),
26
- nn.Linear(256, config.num_labels)
27
  )
28
- self.init_weights()
29
-
30
- def forward(self, input_ids, attention_mask=None, labels=None):
31
- x = self.embedding(input_ids)
32
- x = x.transpose(0, 1)
33
- for layer in self.layers:
34
- x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None)
35
- x = x.mean(dim=0)
36
- logits = self.classifier(x)
37
-
38
- if labels is not None:
39
- loss = nn.functional.cross_entropy(logits, labels)
40
- return {"loss": loss, "logits": logits}
41
- return {"logits": logits}
42
-
43
- def save_pretrained(self, save_directory):
44
- import os
45
- os.makedirs(save_directory, exist_ok=True)
46
- torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
47
- with open(f"{save_directory}/config.json", "w") as f:
48
- f.write(self.config.to_json_string())
49
 
50
- @classmethod
51
- def from_pretrained(cls, load_directory):
52
- config_path = f"{load_directory}/config.json"
53
- model_path = f"{load_directory}/pytorch_model.bin"
54
- config = EvoTransformerConfig.from_json_file(config_path)
55
- model = cls(config)
56
- model.load_state_dict(torch.load(model_path, map_location="cpu"))
57
- return model
 
 
1
  import torch.nn as nn
 
2
 
3
+ class SimpleEvoModel(nn.Module):
4
+ def __init__(self, input_dim=768, hidden_dim=256, output_dim=2):
5
+ super().__init__()
6
+ self.model = nn.Sequential(
7
+ nn.Linear(input_dim, hidden_dim),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  nn.ReLU(),
9
+ nn.Linear(hidden_dim, output_dim)
10
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def forward(self, x):
13
+ return self.model(x)