lhallee commited on
Commit
5c163ed
·
verified ·
1 Parent(s): c2c6f85

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +6 -6
modeling_esm_plusplus.py CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
- def __init__(self, config: ESMplusplusConfig):
471
- super().__init__(config)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -642,8 +642,8 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
- def __init__(self, config: ESMplusplusConfig):
646
- super().__init__(config)
647
  self.config = config
648
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
649
  # Large intermediate projections help with sequence classification tasks (*4)
@@ -714,8 +714,8 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
714
 
715
  Extends the base ESM++ model with a token classification head.
716
  """
717
- def __init__(self, config: ESMplusplusConfig):
718
- super().__init__(config)
719
  self.config = config
720
  self.num_labels = config.num_labels
721
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
 
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
471
+ super().__init__(config, **kwargs)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
646
+ super().__init__(config, **kwargs)
647
  self.config = config
648
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
649
  # Large intermediate projections help with sequence classification tasks (*4)
 
714
 
715
  Extends the base ESM++ model with a token classification head.
716
  """
717
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
718
+ super().__init__(config, **kwargs)
719
  self.config = config
720
  self.num_labels = config.num_labels
721
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)