Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -588,7 +588,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
588 |
config_class = ChatNTConfig
|
589 |
|
590 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
591 |
if isinstance(config, dict):
|
|
|
592 |
# If config is a dictionary instead of ChatNTConfig (which can happen
|
593 |
# depending how the config was saved), we convert it to the config
|
594 |
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
|
@@ -596,10 +598,15 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
596 |
)
|
597 |
config["gpt_config"] = GptConfig(**config["gpt_config"])
|
598 |
config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
|
|
|
|
|
599 |
config["perceiver_resampler_config"] = PerceiverResamplerConfig(
|
600 |
**config["perceiver_resampler_config"]
|
601 |
)
|
602 |
config = ChatNTConfig(**config) # type: ignore
|
|
|
|
|
|
|
603 |
|
604 |
super().__init__(config=config)
|
605 |
self.gpt_config = config.gpt_config
|
|
|
588 |
config_class = ChatNTConfig
|
589 |
|
590 |
def __init__(self, config: ChatNTConfig) -> None:
|
591 |
+
print("(debug) Entering in class")
|
592 |
if isinstance(config, dict):
|
593 |
+
print("(debug) going in if condition")
|
594 |
# If config is a dictionary instead of ChatNTConfig (which can happen
|
595 |
# depending how the config was saved), we convert it to the config
|
596 |
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
|
|
|
598 |
)
|
599 |
config["gpt_config"] = GptConfig(**config["gpt_config"])
|
600 |
config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
|
601 |
+
print("(debug) Type esm_config : ", type(config["esm_config"]))
|
602 |
+
print("(debug) esm_config : ", config["esm_config"])
|
603 |
config["perceiver_resampler_config"] = PerceiverResamplerConfig(
|
604 |
**config["perceiver_resampler_config"]
|
605 |
)
|
606 |
config = ChatNTConfig(**config) # type: ignore
|
607 |
+
print("(debug) Type config : ", type(config))
|
608 |
+
|
609 |
+
print("(debug) config : ", config)
|
610 |
|
611 |
super().__init__(config=config)
|
612 |
self.gpt_config = config.gpt_config
|