ligeti commited on
Commit
74d3fa9
verified
1 Parent(s): b4a9b81

Upload ProkBertForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +13 -4
  2. model.safetensors +3 -0
  3. models.py +285 -0
config.json CHANGED
@@ -1,23 +1,32 @@
1
  {
2
- "_name_or_path": "neuralbioinfo/prokbert-mini",
3
  "architectures": [
4
- "MegatronBertForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
 
 
 
7
  "hidden_act": "gelu",
8
  "hidden_dropout_prob": 0.1,
9
  "hidden_size": 384,
10
  "initializer_range": 0.02,
11
  "intermediate_size": 3072,
 
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 1024,
14
- "model_type": "megatron-bert",
15
  "num_attention_heads": 6,
 
16
  "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "relative_key_query",
 
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.39.3",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 4101
 
1
  {
 
2
  "architectures": [
3
+ "ProkBertForSequenceClassification"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "models.ProkBertConfig",
8
+ "AutoModel": "neuralbioinfo/prokbert-mini--models.ProkBertModel",
9
+ "AutoModelForMaskedLM": "neuralbioinfo/prokbert-mini--models.ProkBertForMaskedLM",
10
+ "AutoModelForSequenceClassification": "models.ProkBertForSequenceClassification"
11
+ },
12
+ "classification_dropout_rate": 0.1,
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
  "hidden_size": 384,
16
  "initializer_range": 0.02,
17
  "intermediate_size": 3072,
18
+ "kmer": 6,
19
  "layer_norm_eps": 1e-12,
20
  "max_position_embeddings": 1024,
21
+ "model_type": "prokbert",
22
  "num_attention_heads": 6,
23
+ "num_class_labels": 2,
24
  "num_hidden_layers": 6,
25
  "pad_token_id": 0,
26
  "position_embedding_type": "relative_key_query",
27
+ "shift": 1,
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.53.0.dev0",
30
  "type_vocab_size": 2,
31
  "use_cache": true,
32
  "vocab_size": 4101
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cb733aa7286052134e9149c9d7713982467eb226eb421950a81c185f97fc6cb
3
+ size 82566500
models.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import warnings
3
+ import logging
4
+ from typing import Optional, Tuple, Union
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
10
+ from transformers.modeling_outputs import SequenceClassifierOutput
11
+ from transformers.utils.hub import cached_file
12
+ #from prokbert.training_utils import compute_metrics_eval_prediction
13
+
14
+ class BertForBinaryClassificationWithPooling(nn.Module):
15
+ """
16
+ ProkBERT model for binary classification with custom pooling.
17
+
18
+ This model extends a pre-trained `MegatronBertModel` by adding a weighting layer
19
+ to compute a weighted sum over the sequence outputs, followed by a classifier.
20
+
21
+ Attributes:
22
+ base_model (MegatronBertModel): The base BERT model.
23
+ weighting_layer (nn.Linear): Linear layer to compute weights for each token.
24
+ dropout (nn.Dropout): Dropout layer.
25
+ classifier (nn.Linear): Linear layer for classification.
26
+ """
27
+ def __init__(self, base_model: MegatronBertModel):
28
+ """
29
+ Initialize the BertForBinaryClassificationWithPooling model.
30
+
31
+ Args:
32
+ base_model (MegatronBertModel): A pre-trained `MegatronBertModel` instance.
33
+ """
34
+
35
+ super(BertForBinaryClassificationWithPooling, self).__init__()
36
+ self.base_model = base_model
37
+ self.base_model_config_dict = base_model.config.to_dict()
38
+ self.hidden_size = self.base_model_config_dict['hidden_size']
39
+ self.dropout_rate = self.base_model_config_dict['hidden_dropout_prob']
40
+
41
+ self.weighting_layer = nn.Linear(self.hidden_size, 1)
42
+ self.dropout = nn.Dropout(self.dropout_rate)
43
+ self.classifier = nn.Linear(self.hidden_size, 2)
44
+
45
+ def forward(self, input_ids, attention_mask=None, labels=None, output_hidden_states=False, output_pooled_output=False):
46
+ # Modified call to base model to include output_hidden_states
47
+ outputs = self.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states)
48
+ sequence_output = outputs[0]
49
+
50
+ # Compute weights for each position in the sequence
51
+ weights = self.weighting_layer(sequence_output)
52
+ weights = torch.nn.functional.softmax(weights, dim=1)
53
+
54
+ # Compute weighted sum
55
+ pooled_output = torch.sum(weights * sequence_output, dim=1)
56
+
57
+ # Classification head
58
+ pooled_output = self.dropout(pooled_output)
59
+ logits = self.classifier(pooled_output)
60
+
61
+ # Prepare the output as a dictionary
62
+ output = {"logits": logits}
63
+
64
+ # Include hidden states in output if requested
65
+ if output_hidden_states:
66
+ output["hidden_states"] = outputs.hidden_states
67
+ if output_pooled_output:
68
+ output["pooled_output"] = pooled_output
69
+
70
+ # If labels are provided, compute the loss
71
+ if labels is not None:
72
+ loss_fct = torch.nn.CrossEntropyLoss()
73
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
74
+ output["loss"] = loss
75
+
76
+ return output
77
+
78
+ def save_pretrained(self, save_directory):
79
+ """
80
+ Save the model weights and configuration in a directory.
81
+
82
+ Args:
83
+ save_directory (str): Directory where the model and configuration can be saved.
84
+ """
85
+ print('The save pretrained is called!')
86
+ if not os.path.exists(save_directory):
87
+ os.makedirs(save_directory)
88
+
89
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
90
+ torch.save(self.state_dict(), model_path)
91
+ print(f'The save directory is: {save_directory}')
92
+ self.base_model.config.save_pretrained(save_directory)
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
96
+ """
97
+ Load the model weights and configuration from a local directory or Hugging Face Hub.
98
+
99
+ Args:
100
+ pretrained_model_name_or_path (str): Directory path where the model and configuration were saved, or name of the model in Hugging Face Hub.
101
+
102
+ Returns:
103
+ model: Instance of BertForBinaryClassificationWithPooling.
104
+ """
105
+ # Determine if the path is local or from Hugging Face Hub
106
+ if os.path.exists(pretrained_model_name_or_path):
107
+ # Path is local
108
+ if 'config' in kwargs:
109
+ print('Config is in the parameters')
110
+ config = kwargs['config']
111
+
112
+ else:
113
+ config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
114
+ base_model = MegatronBertModel(config=config)
115
+ model = cls(base_model=base_model)
116
+ model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
117
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
118
+ else:
119
+ # Path is from Hugging Face Hub
120
+ config = kwargs.pop('config', None)
121
+ if config is None:
122
+ config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
123
+
124
+ base_model = MegatronBertModel(config=config)
125
+ model = cls(base_model=base_model)
126
+ model_file = cached_file(pretrained_model_name_or_path, "pytorch_model.bin")
127
+ model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'), weights_only=True))
128
+
129
+ return model
130
+
131
+
132
+
133
+
134
+ class ProkBertConfig(MegatronBertConfig):
135
+ model_type = "prokbert"
136
+
137
+ def __init__(
138
+ self,
139
+ kmer: int = 6,
140
+ shift: int = 1,
141
+ num_class_labels: int = 2,
142
+ classification_dropout_rate: float = 0.1,
143
+ **kwargs,
144
+ ):
145
+ super().__init__(**kwargs)
146
+ self.kmer = kmer
147
+ self.shift = shift
148
+ self.num_class_labels = num_class_labels
149
+ self.classification_dropout_rate = classification_dropout_rate
150
+
151
+
152
+
153
+
154
+ class ProkBertClassificationConfig(ProkBertConfig):
155
+ model_type = "prokbert"
156
+ def __init__(
157
+ self,
158
+ num_labels: int = 2,
159
+ classification_dropout_rate: float = 0.1,
160
+ **kwargs,
161
+ ):
162
+ super().__init__(**kwargs)
163
+ # Ide j枚n majd n茅mi extra l茅p茅s, egyel艖re csak pr贸b谩lkozunk a sima configgal.
164
+ self.num_labels = num_labels
165
+ self.classification_dropout_rate = classification_dropout_rate
166
+
167
+ class ProkBertPreTrainedModel(PreTrainedModel):
168
+ """
169
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
170
+ models.
171
+ """
172
+
173
+ config_class = ProkBertConfig
174
+ base_model_prefix = "bert"
175
+ supports_gradient_checkpointing = True
176
+
177
+ def _init_weights(self, module):
178
+ """Initialize the weights"""
179
+ if isinstance(module, (nn.Linear, nn.Embedding)):
180
+ # Slightly different from the TF version which uses truncated_normal for initialization
181
+ # cf https://github.com/pytorch/pytorch/pull/5617
182
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
183
+ elif isinstance(module, nn.LayerNorm):
184
+ module.bias.data.zero_()
185
+ module.weight.data.fill_(1.0)
186
+ if isinstance(module, nn.Linear) and module.bias is not None:
187
+ module.bias.data.zero_()
188
+
189
+
190
+
191
+
192
+ class ProkBertModel(MegatronBertModel):
193
+ config_class = ProkBertConfig
194
+
195
+ def __init__(self, config: ProkBertConfig, **kwargs):
196
+ if not isinstance(config, ProkBertConfig):
197
+ raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
198
+
199
+ super().__init__(config, **kwargs)
200
+ self.config = config
201
+ # One should check if it is a prper prokbert config, if not crafting one.
202
+
203
+
204
+ class ProkBertForMaskedLM(MegatronBertForMaskedLM):
205
+ config_class = ProkBertConfig
206
+
207
+ def __init__(self, config: ProkBertConfig, **kwargs):
208
+ if not isinstance(config, ProkBertConfig):
209
+ raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
210
+
211
+ super().__init__(config, **kwargs)
212
+ self.config = config
213
+ # One should check if it is a prper prokbert config, if not crafting one.
214
+
215
+
216
+ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
217
+ config_class = ProkBertConfig
218
+ base_model_prefix = "bert"
219
+
220
+ def __init__(self, config):
221
+
222
+ super().__init__(config)
223
+ self.config = config
224
+ self.bert = ProkBertModel(config)
225
+ self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
226
+ self.dropout = nn.Dropout(self.config.classification_dropout_rate)
227
+ self.classifier = nn.Linear(self.config.hidden_size, self.config.num_class_labels)
228
+ self.loss_fct = torch.nn.CrossEntropyLoss()
229
+
230
+ self.post_init()
231
+
232
+ def forward(
233
+ self,
234
+ input_ids: Optional[torch.LongTensor] = None,
235
+ attention_mask: Optional[torch.FloatTensor] = None,
236
+ token_type_ids: Optional[torch.LongTensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ head_mask: Optional[torch.FloatTensor] = None,
239
+ inputs_embeds: Optional[torch.FloatTensor] = None,
240
+ labels: Optional[torch.LongTensor] = None,
241
+ output_attentions: Optional[bool] = None,
242
+ output_hidden_states: Optional[bool] = None,
243
+ return_dict: Optional[bool] = None,
244
+ ) -> Union[Tuple, SequenceClassifierOutput]:
245
+ r"""
246
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
247
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
248
+ config.num_labels - 1]`. If `config.num_class_labels == 1` a regression loss is computed (Mean-Square loss), If
249
+ `config.num_class_labels > 1` a classification loss is computed (Cross-Entropy).
250
+ """
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
+
253
+ outputs = self.bert(
254
+ input_ids,
255
+ attention_mask=attention_mask,
256
+ token_type_ids=token_type_ids,
257
+ position_ids=position_ids,
258
+ head_mask=head_mask,
259
+ inputs_embeds=inputs_embeds,
260
+ output_attentions=output_attentions,
261
+ output_hidden_states=output_hidden_states,
262
+ return_dict=return_dict,
263
+ )
264
+ sequence_output = outputs[0]
265
+
266
+ # Compute weights for each position in the sequence
267
+ weights = self.weighting_layer(sequence_output)
268
+ weights = torch.nn.functional.softmax(weights, dim=1)
269
+ # Compute weighted sum
270
+ pooled_output = torch.sum(weights * sequence_output, dim=1)
271
+ # Classification head
272
+ pooled_output = self.dropout(pooled_output)
273
+ logits = self.classifier(pooled_output)
274
+ loss = None
275
+ if labels is not None:
276
+ loss = self.loss_fct(logits.view(-1, self.config.num_class_labels), labels.view(-1))
277
+
278
+ classification_output = SequenceClassifierOutput(
279
+ loss=loss,
280
+ logits=logits,
281
+ hidden_states=outputs.hidden_states,
282
+ attentions=outputs.attentions,
283
+ )
284
+ return classification_output
285
+