Yanisadel commited on
Commit
7974b5e
·
1 Parent(s): 76813c2

Delete chatNT_config.py

Browse files
Files changed (1) hide show
  1. chatNT_config.py +0 -50
chatNT_config.py DELETED
@@ -1,50 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- from transformers import PretrainedConfig
4
-
5
- from genomics_research.biobrain_p1.porting_to_pytorch.configs.esm_config import (
6
- ESMTransformerConfig,
7
- )
8
- from genomics_research.biobrain_p1.porting_to_pytorch.configs.gpt_config import (
9
- GptConfig,
10
- )
11
- from genomics_research.biobrain_p1.porting_to_pytorch.configs.perceiver_resampler_config import ( # noqa
12
- PerceiverResamplerConfig,
13
- )
14
-
15
-
16
- @dataclass
17
- class ChatNTConfig(PretrainedConfig):
18
- model_type = "ChatNT"
19
-
20
- def __init__(self, **kwargs): # type: ignore
21
- self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
22
- self.esm_config: ESMTransformerConfig = kwargs.get(
23
- "esm_config", ESMTransformerConfig(4000, 1, 4)
24
- )
25
- self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
26
- "perceiver_resampler_config", PerceiverResamplerConfig()
27
- )
28
- self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
29
- self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
30
- self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
31
- super().__init__(**kwargs)
32
-
33
- def to_dict(self): # type: ignore
34
- print("(debug) Going into ChatNTConfig to_dict")
35
- output = super().to_dict()
36
-
37
- def serialize(obj): # type: ignore
38
- return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
39
-
40
- print("(debug) Before serialize gpt_config to_dict")
41
- output["gpt_config"] = serialize(self.gpt_config) # type: ignore
42
- print("(debug) Before serialize esm_config to_dict")
43
- output["esm_config"] = serialize(self.esm_config) # type: ignore
44
- print("(debug) Before serialize perceiver_resampler_config to_dict")
45
- output["perceiver_resampler_config"] = serialize( # type: ignore
46
- self.perceiver_resampler_config
47
- )
48
- print("(debug) after serializing all ")
49
- print("(debug) output : ", output)
50
- return output