Delete chatNT_config.py
Browse files- chatNT_config.py +0 -50
chatNT_config.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
-
from transformers import PretrainedConfig
|
4 |
-
|
5 |
-
from genomics_research.biobrain_p1.porting_to_pytorch.configs.esm_config import (
|
6 |
-
ESMTransformerConfig,
|
7 |
-
)
|
8 |
-
from genomics_research.biobrain_p1.porting_to_pytorch.configs.gpt_config import (
|
9 |
-
GptConfig,
|
10 |
-
)
|
11 |
-
from genomics_research.biobrain_p1.porting_to_pytorch.configs.perceiver_resampler_config import ( # noqa
|
12 |
-
PerceiverResamplerConfig,
|
13 |
-
)
|
14 |
-
|
15 |
-
|
16 |
-
@dataclass
|
17 |
-
class ChatNTConfig(PretrainedConfig):
|
18 |
-
model_type = "ChatNT"
|
19 |
-
|
20 |
-
def __init__(self, **kwargs): # type: ignore
|
21 |
-
self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
|
22 |
-
self.esm_config: ESMTransformerConfig = kwargs.get(
|
23 |
-
"esm_config", ESMTransformerConfig(4000, 1, 4)
|
24 |
-
)
|
25 |
-
self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
|
26 |
-
"perceiver_resampler_config", PerceiverResamplerConfig()
|
27 |
-
)
|
28 |
-
self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
|
29 |
-
self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
|
30 |
-
self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
|
31 |
-
super().__init__(**kwargs)
|
32 |
-
|
33 |
-
def to_dict(self): # type: ignore
|
34 |
-
print("(debug) Going into ChatNTConfig to_dict")
|
35 |
-
output = super().to_dict()
|
36 |
-
|
37 |
-
def serialize(obj): # type: ignore
|
38 |
-
return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
|
39 |
-
|
40 |
-
print("(debug) Before serialize gpt_config to_dict")
|
41 |
-
output["gpt_config"] = serialize(self.gpt_config) # type: ignore
|
42 |
-
print("(debug) Before serialize esm_config to_dict")
|
43 |
-
output["esm_config"] = serialize(self.esm_config) # type: ignore
|
44 |
-
print("(debug) Before serialize perceiver_resampler_config to_dict")
|
45 |
-
output["perceiver_resampler_config"] = serialize( # type: ignore
|
46 |
-
self.perceiver_resampler_config
|
47 |
-
)
|
48 |
-
print("(debug) after serializing all ")
|
49 |
-
print("(debug) output : ", output)
|
50 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|