Yanisadel commited on
Commit
7ce7cf9
·
1 Parent(s): 7010ff2

Upload model

Browse files
Files changed (1) hide show
  1. chatNT.py +13 -0
chatNT.py CHANGED
@@ -588,6 +588,19 @@ class TorchMultiOmicsModel(PreTrainedModel):
588
  config_class = ChatNTConfig
589
 
590
  def __init__(self, config: ChatNTConfig) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  super().__init__(config=config)
592
  self.gpt_config = config.gpt_config
593
  self.esm_config = config.esm_config
 
588
  config_class = ChatNTConfig
589
 
590
  def __init__(self, config: ChatNTConfig) -> None:
591
+ if isinstance(config, dict):
592
+ # If config is a dictionary instead of ChatNTConfig (which can happen
593
+ # depending how the config was saved), we convert it to the config
594
+ config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
595
+ **config["gpt_config"]["rope_config"]
596
+ )
597
+ config["gpt_config"] = GptConfig(**config["gpt_config"])
598
+ config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
599
+ config["perceiver_resampler_config"] = PerceiverResamplerConfig(
600
+ **config["perceiver_resampler_config"]
601
+ )
602
+ config = ChatNTConfig(**config) # type: ignore
603
+
604
  super().__init__(config=config)
605
  self.gpt_config = config.gpt_config
606
  self.esm_config = config.esm_config