lhallee commited on
Commit
9862082
·
verified ·
1 Parent(s): 5f37df1

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +139 -119
README.md CHANGED
@@ -1,120 +1,140 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # FastESM
7
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
-
9
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
-
11
- Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
- Various other optimizations also make the base implementation slightly different than the one in transformers.
13
-
14
- ## Use with 🤗 transformers
15
-
16
- ### Supported models
17
- ```python
18
- model_dict = {
19
- # Synthyra/ESM2-8M
20
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
- # Synthyra/ESM2-35M
22
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
- # Synthyra/ESM2-150M
24
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
- # Synthyra/ESM2-650M
26
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
- # Synthyra/ESM2-3B
28
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
- }
30
- ```
31
-
32
- ### For working with embeddings
33
- ```python
34
- import torch
35
- from transformers import AutoModel, AutoTokenizer
36
-
37
- model_path = 'Synthyra/ESM2-8M'
38
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
- tokenizer = model.tokenizer
40
-
41
- sequences = ['MPRTEIN', 'MSEQWENCE']
42
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
43
- with torch.no_grad():
44
- embeddings = model(**tokenized).last_hidden_state
45
-
46
- print(embeddings.shape) # (2, 11, 1280)
47
- ```
48
-
49
- ### For working with sequence logits
50
- ```python
51
- import torch
52
- from transformers import AutoModelForMaskedLM, AutoTokenizer
53
-
54
- model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
55
- with torch.no_grad():
56
- logits = model(**tokenized).logits
57
-
58
- print(logits.shape) # (2, 11, 33)
59
- ```
60
-
61
- ### For working with attention maps
62
- ```python
63
- import torch
64
- from transformers import AutoModel, AutoTokenizer
65
-
66
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
67
- with torch.no_grad():
68
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
69
-
70
- print(attentions[-1].shape) # (2, 20, 11, 11)
71
- ```
72
-
73
- ### Contact prediction
74
- Because we can output attentions using the naive attention implementation, the contact prediction is also supported
75
- ```python
76
- with torch.no_grad():
77
- contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
78
- ```
79
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
80
-
81
- ## Embed entire datasets with no new code
82
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
83
- ```python
84
- embeddings = model.embed_dataset(
85
- sequences=sequences, # list of protein strings
86
- batch_size=16, # embedding batch size
87
- max_len=2048, # truncate to max_len
88
- full_embeddings=True, # return residue-wise embeddings
89
- full_precision=False, # store as float32
90
- pooling_type='mean', # use mean pooling if protein-wise embeddings
91
- num_workers=0, # data loading num workers
92
- sql=False, # return dictionary of sequences and embeddings
93
- )
94
-
95
- _ = model.embed_dataset(
96
- sequences=sequences, # list of protein strings
97
- batch_size=16, # embedding batch size
98
- max_len=2048, # truncate to max_len
99
- full_embeddings=True, # return residue-wise embeddings
100
- full_precision=False, # store as float32
101
- pooling_type='mean', # use mean pooling if protein-wise embeddings
102
- num_workers=0, # data loading num workers
103
- sql=True, # store sequences in local SQL database
104
- sql_db_path='embeddings.db', # path to .db file of choice
105
- )
106
- ```
107
-
108
-
109
- ### Citation
110
- If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
111
- ```
112
- @misc {FastESM2,
113
- author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
114
- title = { FastESM2 },
115
- year = 2024,
116
- url = { https://huggingface.co/Synthyra/FastESM2_650 },
117
- doi = { 10.57967/hf/3729 },
118
- publisher = { Hugging Face }
119
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # FastESM
7
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
+
9
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
+
11
+ Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
13
+
14
+ ## Use with 🤗 transformers
15
+
16
+ ### Supported models
17
+ ```python
18
+ model_dict = {
19
+ # Synthyra/ESM2-8M
20
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
+ # Synthyra/ESM2-35M
22
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
+ # Synthyra/ESM2-150M
24
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
+ # Synthyra/ESM2-650M
26
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
+ # Synthyra/ESM2-3B
28
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
+ }
30
+ ```
31
+
32
+ ### For working with embeddings
33
+ ```python
34
+ import torch
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ model_path = 'Synthyra/ESM2-8M'
38
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
+ tokenizer = model.tokenizer
40
+
41
+ sequences = ['MPRTEIN', 'MSEQWENCE']
42
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
43
+ with torch.no_grad():
44
+ embeddings = model(**tokenized).last_hidden_state
45
+
46
+ print(embeddings.shape) # (2, 11, 1280)
47
+ ```
48
+
49
+ ### For working with sequence logits
50
+ ```python
51
+ import torch
52
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
53
+
54
+ model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
55
+ with torch.no_grad():
56
+ logits = model(**tokenized).logits
57
+
58
+ print(logits.shape) # (2, 11, 33)
59
+ ```
60
+
61
+ ### For working with attention maps
62
+ ```python
63
+ import torch
64
+ from transformers import AutoModel, AutoTokenizer
65
+
66
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
67
+ with torch.no_grad():
68
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
69
+
70
+ print(attentions[-1].shape) # (2, 20, 11, 11)
71
+ ```
72
+
73
+ ### Contact prediction
74
+ Because we can output attentions using the naive attention implementation, the contact prediction is also supported
75
+ ```python
76
+ with torch.no_grad():
77
+ contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
78
+ ```
79
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
80
+
81
+ ## Embed entire datasets with no new code
82
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
83
+
84
+ Example:
85
+ ```python
86
+ embedding_dict = model.embed_dataset(
87
+ sequences=[
88
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
89
+ ],
90
+ batch_size=2, # adjust for your GPU memory
91
+ max_len=512, # adjust for your needs
92
+ full_embeddings=False, # if True, no pooling is performed
93
+ embed_dtype=torch.float32, # cast to what dtype you want
94
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
95
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
96
+ sql=False, # if True, embeddings will be stored in SQLite database
97
+ sql_db_path='embeddings.db',
98
+ save=True, # if True, embeddings will be saved as a .pth file
99
+ save_path='embeddings.pth',
100
+ )
101
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
102
+ ```
103
+
104
+ ```
105
+ model.embed_dataset()
106
+ Args:
107
+ sequences: List of protein sequences
108
+ batch_size: Batch size for processing
109
+ max_len: Maximum sequence length
110
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
111
+ pooling_type: Type of pooling ('mean' or 'cls')
112
+ num_workers: Number of workers for data loading, 0 for the main process
113
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
114
+ sql_db_path: Path to SQLite database
115
+
116
+ Returns:
117
+ Dictionary mapping sequences to embeddings, or None if sql=True
118
+
119
+ Note:
120
+ - If sql=True, embeddings can only be stored in float32
121
+ - sql is ideal if you need to stream a very large dataset for training in real-time
122
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
123
+ - sql will be used if it is True and save is True or False
124
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
125
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
126
+ ```
127
+
128
+
129
+ ### Citation
130
+ If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
131
+ ```
132
+ @misc {FastESM2,
133
+ author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
134
+ title = { FastESM2 },
135
+ year = 2024,
136
+ url = { https://huggingface.co/Synthyra/FastESM2_650 },
137
+ doi = { 10.57967/hf/3729 },
138
+ publisher = { Hugging Face }
139
+ }
140
  ```