Spaces:
Runtime error
Runtime error
Shawn Shen
commited on
Commit
·
5aa3fcd
1
Parent(s):
3aa4b4a
minor fixes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +24 -19
- esm/.DS_Store +0 -0
- esm/._.DS_Store +0 -0
- esm/._multihead_attention.py +0 -0
- esm/._pretrained.py +0 -0
- esm/__init__.py +17 -0
- esm/__pycache__/__init__.cpython-36.pyc +0 -0
- esm/__pycache__/__init__.cpython-39.pyc +0 -0
- esm/__pycache__/axial_attention.cpython-36.pyc +0 -0
- esm/__pycache__/axial_attention.cpython-39.pyc +0 -0
- esm/__pycache__/constants.cpython-36.pyc +0 -0
- esm/__pycache__/constants.cpython-39.pyc +0 -0
- esm/__pycache__/data.cpython-36.pyc +0 -0
- esm/__pycache__/data.cpython-39.pyc +0 -0
- esm/__pycache__/data_protein.cpython-36.pyc +0 -0
- esm/__pycache__/model.cpython-36.pyc +0 -0
- esm/__pycache__/model.cpython-39.pyc +0 -0
- esm/__pycache__/modules.cpython-36.pyc +0 -0
- esm/__pycache__/modules.cpython-39.pyc +0 -0
- esm/__pycache__/multihead_attention.cpython-36.pyc +0 -0
- esm/__pycache__/multihead_attention.cpython-39.pyc +0 -0
- esm/__pycache__/pretrained.cpython-36.pyc +0 -0
- esm/__pycache__/pretrained.cpython-39.pyc +0 -0
- esm/__pycache__/rotary_embedding.cpython-36.pyc +0 -0
- esm/__pycache__/rotary_embedding.cpython-39.pyc +0 -0
- esm/__pycache__/version.cpython-36.pyc +0 -0
- esm/__pycache__/version.cpython-39.pyc +0 -0
- esm/axial_attention.py +239 -0
- esm/constants.py +14 -0
- esm/data.py +524 -0
- esm/data_supervised.py +524 -0
- esm/model/._esm2_secondarystructure.py +0 -0
- esm/model/__pycache__/esm1.cpython-36.pyc +0 -0
- esm/model/__pycache__/esm1.cpython-39.pyc +0 -0
- esm/model/__pycache__/esm2.cpython-36.pyc +0 -0
- esm/model/__pycache__/esm2.cpython-39.pyc +0 -0
- esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc +0 -0
- esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc +0 -0
- esm/model/__pycache__/esm2_supervised.cpython-39.pyc +0 -0
- esm/model/__pycache__/msa_transformer.cpython-36.pyc +0 -0
- esm/model/__pycache__/msa_transformer.cpython-39.pyc +0 -0
- esm/model/esm1.py +203 -0
- esm/model/esm2.py +163 -0
- esm/model/esm2_only_secondarystructure.py +179 -0
- esm/model/esm2_secondarystructure.py +179 -0
- esm/model/esm2_supervised.py +174 -0
- esm/model/msa_transformer.py +238 -0
- esm/modules.py +419 -0
- esm/multihead_attention.py +506 -0
- esm/pretrained.py +378 -0
app.py
CHANGED
|
@@ -2,10 +2,13 @@ import streamlit as st
|
|
| 2 |
from Bio import SeqIO
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
from esm.model.esm2 import ESM2 as ESM2_SISS
|
| 6 |
-
from esm import Alphabet, FastaBatchedDataset
|
| 7 |
import pandas as pd
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from io import StringIO
|
| 11 |
|
|
@@ -18,14 +21,14 @@ modelfile = 'model.pkl'
|
|
| 18 |
layers = 6
|
| 19 |
heads = 16
|
| 20 |
embed_dim = 128
|
| 21 |
-
batch_toks =
|
| 22 |
|
| 23 |
inp_len = 50
|
| 24 |
|
| 25 |
device = "cpu"
|
| 26 |
|
| 27 |
-
alphabet = Alphabet(
|
| 28 |
-
alphabet.tok_to_idx
|
| 29 |
|
| 30 |
class CNN_linear(nn.Module):
|
| 31 |
def __init__(self,
|
|
@@ -68,8 +71,8 @@ class CNN_linear(nn.Module):
|
|
| 68 |
|
| 69 |
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
|
| 70 |
|
| 71 |
-
|
| 72 |
-
x = self.esm2(tokens, [layers])
|
| 73 |
|
| 74 |
x = x["representations"][layers][:, 0]
|
| 75 |
x_o = x.unsqueeze(2)
|
|
@@ -81,15 +84,14 @@ class CNN_linear(nn.Module):
|
|
| 81 |
o = self.output(o_dropout)
|
| 82 |
return o
|
| 83 |
|
| 84 |
-
def eval_step(dataloader, model, threshold
|
| 85 |
model.eval()
|
| 86 |
logits_list= []
|
| 87 |
# y_pred_list, y_prob_list = [], []
|
| 88 |
ids_list, strs_list = [], []
|
| 89 |
my_bar = st.progress(0, text="Running UTR_LM")
|
| 90 |
with torch.no_grad():
|
| 91 |
-
|
| 92 |
-
for i, (ids, strs, toks) in enumerate(dataloader):
|
| 93 |
ids_list.extend(ids)
|
| 94 |
strs_list.extend(strs)
|
| 95 |
# toks = toks.to(device)
|
|
@@ -106,6 +108,7 @@ def eval_step(dataloader, model, threshold = 0.5):
|
|
| 106 |
# y_pred_list.extend(y_pred.tolist())
|
| 107 |
|
| 108 |
st.success('Done', icon="✅")
|
|
|
|
| 109 |
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list})
|
| 110 |
return data_pred
|
| 111 |
|
|
@@ -129,8 +132,9 @@ def read_raw(raw_input):
|
|
| 129 |
return ids, sequences
|
| 130 |
|
| 131 |
def generate_dataset_dataloader(ids, seqs):
|
| 132 |
-
|
| 133 |
-
|
|
|
|
| 134 |
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
|
| 135 |
dataloader = torch.utils.data.DataLoader(dataset,
|
| 136 |
collate_fn=alphabet.get_batch_converter(),
|
|
@@ -166,13 +170,14 @@ uploaded = st.file_uploader("Sequence file in FASTA format")
|
|
| 166 |
if st.button("Predict"):
|
| 167 |
if uploaded:
|
| 168 |
result = predict_raw(uploaded.getvalue().decode())
|
| 169 |
-
result_file = result.to_csv(index=False)
|
| 170 |
-
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
|
| 171 |
-
st.dataframe(result)
|
| 172 |
else:
|
| 173 |
result = predict_raw(seq)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
|
| 178 |
|
|
|
|
| 2 |
from Bio import SeqIO
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
+
|
| 7 |
+
import esm
|
| 8 |
+
from esm.data import *
|
| 9 |
+
from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
|
| 10 |
+
|
| 11 |
+
from esm import Alphabet, FastaBatchedDataset
|
| 12 |
|
| 13 |
from io import StringIO
|
| 14 |
|
|
|
|
| 21 |
layers = 6
|
| 22 |
heads = 16
|
| 23 |
embed_dim = 128
|
| 24 |
+
batch_toks = 1024
|
| 25 |
|
| 26 |
inp_len = 50
|
| 27 |
|
| 28 |
device = "cpu"
|
| 29 |
|
| 30 |
+
alphabet = Alphabet(standard_toks = 'AGCT')
|
| 31 |
+
assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
| 32 |
|
| 33 |
class CNN_linear(nn.Module):
|
| 34 |
def __init__(self,
|
|
|
|
| 71 |
|
| 72 |
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
|
| 73 |
|
| 74 |
+
x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
|
| 75 |
+
# x = self.esm2(tokens, [layers])
|
| 76 |
|
| 77 |
x = x["representations"][layers][:, 0]
|
| 78 |
x_o = x.unsqueeze(2)
|
|
|
|
| 84 |
o = self.output(o_dropout)
|
| 85 |
return o
|
| 86 |
|
| 87 |
+
def eval_step(dataloader, model, threshold=0.5):
|
| 88 |
model.eval()
|
| 89 |
logits_list= []
|
| 90 |
# y_pred_list, y_prob_list = [], []
|
| 91 |
ids_list, strs_list = [], []
|
| 92 |
my_bar = st.progress(0, text="Running UTR_LM")
|
| 93 |
with torch.no_grad():
|
| 94 |
+
for i, (ids, strs, _, toks, _, _) in enumerate(dataloader):
|
|
|
|
| 95 |
ids_list.extend(ids)
|
| 96 |
strs_list.extend(strs)
|
| 97 |
# toks = toks.to(device)
|
|
|
|
| 108 |
# y_pred_list.extend(y_pred.tolist())
|
| 109 |
|
| 110 |
st.success('Done', icon="✅")
|
| 111 |
+
# data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list, "prob":y_prob_list, "pred":y_pred_list})
|
| 112 |
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list})
|
| 113 |
return data_pred
|
| 114 |
|
|
|
|
| 132 |
return ids, sequences
|
| 133 |
|
| 134 |
def generate_dataset_dataloader(ids, seqs):
|
| 135 |
+
dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
| 136 |
+
|
| 137 |
+
# dataset = FastaBatchedDataset(ids, seqs)
|
| 138 |
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
|
| 139 |
dataloader = torch.utils.data.DataLoader(dataset,
|
| 140 |
collate_fn=alphabet.get_batch_converter(),
|
|
|
|
| 170 |
if st.button("Predict"):
|
| 171 |
if uploaded:
|
| 172 |
result = predict_raw(uploaded.getvalue().decode())
|
| 173 |
+
# result_file = result.to_csv(index=False)
|
| 174 |
+
# st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
|
| 175 |
+
# st.dataframe(result)
|
| 176 |
else:
|
| 177 |
result = predict_raw(seq)
|
| 178 |
+
|
| 179 |
+
result_file = result.to_csv(index=False)
|
| 180 |
+
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
|
| 181 |
+
st.dataframe(result)
|
| 182 |
|
| 183 |
|
esm/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
esm/._.DS_Store
ADDED
|
Binary file (4.1 kB). View file
|
|
|
esm/._multihead_attention.py
ADDED
|
Binary file (4.1 kB). View file
|
|
|
esm/._pretrained.py
ADDED
|
Binary file (4.1 kB). View file
|
|
|
esm/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
from .version import version as __version__ # noqa
|
| 6 |
+
|
| 7 |
+
from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
|
| 8 |
+
from .model.esm1 import ProteinBertModel # noqa
|
| 9 |
+
from .model.esm2 import ESM2 # noqa
|
| 10 |
+
from .model.msa_transformer import MSATransformer #noqa
|
| 11 |
+
from . import pretrained # noqa
|
| 12 |
+
|
| 13 |
+
# from .version import version as __version__ # noqa
|
| 14 |
+
|
| 15 |
+
# from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
|
| 16 |
+
# from .model import ProteinBertModel, MSATransformer, ESM2 # noqa
|
| 17 |
+
# from . import pretrained # noqa
|
esm/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (480 Bytes). View file
|
|
|
esm/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (458 Bytes). View file
|
|
|
esm/__pycache__/axial_attention.cpython-36.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
esm/__pycache__/axial_attention.cpython-39.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
esm/__pycache__/constants.cpython-36.pyc
ADDED
|
Binary file (355 Bytes). View file
|
|
|
esm/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (307 Bytes). View file
|
|
|
esm/__pycache__/data.cpython-36.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
esm/__pycache__/data.cpython-39.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
esm/__pycache__/data_protein.cpython-36.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
esm/__pycache__/model.cpython-36.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
esm/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (9.63 kB). View file
|
|
|
esm/__pycache__/modules.cpython-36.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
esm/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
esm/__pycache__/multihead_attention.cpython-36.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
esm/__pycache__/multihead_attention.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
esm/__pycache__/pretrained.cpython-36.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
esm/__pycache__/pretrained.cpython-39.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
esm/__pycache__/rotary_embedding.cpython-36.pyc
ADDED
|
Binary file (2.73 kB). View file
|
|
|
esm/__pycache__/rotary_embedding.cpython-39.pyc
ADDED
|
Binary file (2.7 kB). View file
|
|
|
esm/__pycache__/version.cpython-36.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
esm/__pycache__/version.cpython-39.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
esm/axial_attention.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RowSelfAttention(nn.Module):
|
| 12 |
+
"""Compute self-attention over rows of a 2D input."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
embed_dim,
|
| 17 |
+
num_heads,
|
| 18 |
+
dropout=0.0,
|
| 19 |
+
max_tokens_per_msa: int = 2 ** 16,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.num_heads = num_heads
|
| 23 |
+
self.dropout = dropout
|
| 24 |
+
self.head_dim = embed_dim // num_heads
|
| 25 |
+
self.scaling = self.head_dim ** -0.5
|
| 26 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
| 27 |
+
self.attn_shape = "hnij"
|
| 28 |
+
|
| 29 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 30 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 31 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 32 |
+
|
| 33 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 34 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 35 |
+
|
| 36 |
+
def align_scaling(self, q):
|
| 37 |
+
num_rows = q.size(0)
|
| 38 |
+
return self.scaling / math.sqrt(num_rows)
|
| 39 |
+
|
| 40 |
+
def _batched_forward(
|
| 41 |
+
self,
|
| 42 |
+
x,
|
| 43 |
+
self_attn_mask=None,
|
| 44 |
+
self_attn_padding_mask=None,
|
| 45 |
+
):
|
| 46 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 47 |
+
max_rows = max(1, self.max_tokens_per_msa // num_cols)
|
| 48 |
+
attns = 0
|
| 49 |
+
scaling = self.align_scaling(x)
|
| 50 |
+
for start in range(0, num_rows, max_rows):
|
| 51 |
+
attn_weights = self.compute_attention_weights(
|
| 52 |
+
x[start : start + max_rows],
|
| 53 |
+
scaling,
|
| 54 |
+
self_attn_mask=self_attn_mask,
|
| 55 |
+
self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
|
| 56 |
+
if self_attn_padding_mask is not None
|
| 57 |
+
else None,
|
| 58 |
+
)
|
| 59 |
+
attns += attn_weights
|
| 60 |
+
attn_probs = attns.softmax(-1)
|
| 61 |
+
attn_probs = self.dropout_module(attn_probs)
|
| 62 |
+
|
| 63 |
+
outputs = []
|
| 64 |
+
for start in range(0, num_rows, max_rows):
|
| 65 |
+
output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
|
| 66 |
+
outputs.append(output)
|
| 67 |
+
|
| 68 |
+
output = torch.cat(outputs, 0)
|
| 69 |
+
return output, attn_probs
|
| 70 |
+
|
| 71 |
+
def compute_attention_weights(
|
| 72 |
+
self,
|
| 73 |
+
x,
|
| 74 |
+
scaling: float,
|
| 75 |
+
self_attn_mask=None,
|
| 76 |
+
self_attn_padding_mask=None,
|
| 77 |
+
):
|
| 78 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 79 |
+
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 80 |
+
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 81 |
+
q *= scaling
|
| 82 |
+
if self_attn_padding_mask is not None:
|
| 83 |
+
# Zero out any padded aligned positions - this is important since
|
| 84 |
+
# we take a sum across the alignment axis.
|
| 85 |
+
q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
|
| 86 |
+
|
| 87 |
+
attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
|
| 88 |
+
|
| 89 |
+
if self_attn_mask is not None:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
# Mask Size: [B x R x C], Weights Size: [H x B x C x C]
|
| 92 |
+
|
| 93 |
+
if self_attn_padding_mask is not None:
|
| 94 |
+
attn_weights = attn_weights.masked_fill(
|
| 95 |
+
self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
|
| 96 |
+
-10000,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return attn_weights
|
| 100 |
+
|
| 101 |
+
def compute_attention_update(
|
| 102 |
+
self,
|
| 103 |
+
x,
|
| 104 |
+
attn_probs,
|
| 105 |
+
):
|
| 106 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 107 |
+
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 108 |
+
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
|
| 109 |
+
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
|
| 110 |
+
output = self.out_proj(context)
|
| 111 |
+
return output
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
x,
|
| 116 |
+
self_attn_mask=None,
|
| 117 |
+
self_attn_padding_mask=None,
|
| 118 |
+
):
|
| 119 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 120 |
+
if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
|
| 121 |
+
return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
|
| 122 |
+
else:
|
| 123 |
+
scaling = self.align_scaling(x)
|
| 124 |
+
attn_weights = self.compute_attention_weights(
|
| 125 |
+
x, scaling, self_attn_mask, self_attn_padding_mask
|
| 126 |
+
)
|
| 127 |
+
attn_probs = attn_weights.softmax(-1)
|
| 128 |
+
attn_probs = self.dropout_module(attn_probs)
|
| 129 |
+
output = self.compute_attention_update(x, attn_probs)
|
| 130 |
+
return output, attn_probs
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ColumnSelfAttention(nn.Module):
|
| 134 |
+
"""Compute self-attention over columns of a 2D input."""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
embed_dim,
|
| 139 |
+
num_heads,
|
| 140 |
+
dropout=0.0,
|
| 141 |
+
max_tokens_per_msa: int = 2 ** 16,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.num_heads = num_heads
|
| 146 |
+
self.dropout = dropout
|
| 147 |
+
self.head_dim = embed_dim // num_heads
|
| 148 |
+
self.scaling = self.head_dim ** -0.5
|
| 149 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
| 150 |
+
|
| 151 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 152 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 153 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 154 |
+
|
| 155 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 156 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 157 |
+
|
| 158 |
+
def _batched_forward(
|
| 159 |
+
self,
|
| 160 |
+
x,
|
| 161 |
+
self_attn_mask=None,
|
| 162 |
+
self_attn_padding_mask=None,
|
| 163 |
+
):
|
| 164 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 165 |
+
max_cols = max(1, self.max_tokens_per_msa // num_rows)
|
| 166 |
+
outputs = []
|
| 167 |
+
attns = []
|
| 168 |
+
for start in range(0, num_cols, max_cols):
|
| 169 |
+
output, attn = self(
|
| 170 |
+
x[:, start : start + max_cols],
|
| 171 |
+
self_attn_mask=self_attn_mask,
|
| 172 |
+
self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
|
| 173 |
+
if self_attn_padding_mask is not None
|
| 174 |
+
else None,
|
| 175 |
+
)
|
| 176 |
+
outputs.append(output)
|
| 177 |
+
attns.append(attn)
|
| 178 |
+
output = torch.cat(outputs, 1)
|
| 179 |
+
attns = torch.cat(attns, 1)
|
| 180 |
+
return output, attns
|
| 181 |
+
|
| 182 |
+
def compute_attention_update(
|
| 183 |
+
self,
|
| 184 |
+
x,
|
| 185 |
+
self_attn_mask=None,
|
| 186 |
+
self_attn_padding_mask=None,
|
| 187 |
+
):
|
| 188 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 189 |
+
if num_rows == 1:
|
| 190 |
+
# if there is only 1 position, this is equivalent and doesn't break with padding
|
| 191 |
+
attn_probs = torch.ones(
|
| 192 |
+
self.num_heads,
|
| 193 |
+
num_cols,
|
| 194 |
+
batch_size,
|
| 195 |
+
num_rows,
|
| 196 |
+
num_rows,
|
| 197 |
+
device=x.device,
|
| 198 |
+
dtype=x.dtype,
|
| 199 |
+
)
|
| 200 |
+
output = self.out_proj(self.v_proj(x))
|
| 201 |
+
else:
|
| 202 |
+
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 203 |
+
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 204 |
+
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
| 205 |
+
q *= self.scaling
|
| 206 |
+
|
| 207 |
+
attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
|
| 208 |
+
|
| 209 |
+
if self_attn_mask is not None:
|
| 210 |
+
raise NotImplementedError
|
| 211 |
+
if self_attn_padding_mask is not None:
|
| 212 |
+
attn_weights = attn_weights.masked_fill(
|
| 213 |
+
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
|
| 214 |
+
-10000,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
attn_probs = attn_weights.softmax(-1)
|
| 218 |
+
attn_probs = self.dropout_module(attn_probs)
|
| 219 |
+
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
|
| 220 |
+
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
|
| 221 |
+
output = self.out_proj(context)
|
| 222 |
+
return output, attn_probs
|
| 223 |
+
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
x,
|
| 227 |
+
self_attn_mask=None,
|
| 228 |
+
self_attn_padding_mask=None,
|
| 229 |
+
):
|
| 230 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
| 231 |
+
# if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
|
| 232 |
+
if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
|
| 233 |
+
return self._batched_forward(
|
| 234 |
+
x,
|
| 235 |
+
self_attn_mask,
|
| 236 |
+
self_attn_padding_mask,
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
|
esm/constants.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# fmt: off
|
| 7 |
+
proteinseq_toks = {
|
| 8 |
+
'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
rnaseq_toks = {
|
| 12 |
+
'toks': ['A', 'G', 'T', 'C']
|
| 13 |
+
}
|
| 14 |
+
# fmt: on
|
esm/data.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import os
|
| 8 |
+
from typing import Sequence, Tuple, List, Union
|
| 9 |
+
import pickle
|
| 10 |
+
import re
|
| 11 |
+
import shutil
|
| 12 |
+
import torch
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from .constants import proteinseq_toks, rnaseq_toks
|
| 15 |
+
import math
|
| 16 |
+
import random
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
+
RawMSA = Sequence[Tuple[str, str]]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Alphabet(object):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
standard_toks: Sequence[str],
|
| 26 |
+
prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"), # "<null_0>",
|
| 27 |
+
append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), #
|
| 28 |
+
prepend_bos: bool = True,
|
| 29 |
+
append_eos: bool = True,
|
| 30 |
+
use_msa: bool = False,
|
| 31 |
+
mask_prob: float = 0.15, ###---
|
| 32 |
+
):
|
| 33 |
+
self.mask_prob = mask_prob ###---
|
| 34 |
+
self.standard_toks = list(standard_toks)
|
| 35 |
+
self.prepend_toks = list(prepend_toks)
|
| 36 |
+
self.append_toks = list(append_toks)
|
| 37 |
+
self.prepend_bos = prepend_bos
|
| 38 |
+
self.append_eos = append_eos
|
| 39 |
+
self.use_msa = use_msa
|
| 40 |
+
|
| 41 |
+
self.all_toks = list(self.prepend_toks)
|
| 42 |
+
self.all_toks.extend(self.standard_toks)
|
| 43 |
+
# for i in range((8 - (len(self.all_toks) % 8)) % 8):
|
| 44 |
+
# self.all_toks.append(f"<null_{i + 1}>")
|
| 45 |
+
self.all_toks.extend(self.append_toks)
|
| 46 |
+
|
| 47 |
+
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
|
| 48 |
+
# print(self.tok_to_idx)
|
| 49 |
+
self.unk_idx = self.tok_to_idx["<unk>"]
|
| 50 |
+
self.padding_idx = self.get_idx("<pad>")
|
| 51 |
+
self.cls_idx = self.get_idx("<cls>")
|
| 52 |
+
self.mask_idx = self.get_idx("<mask>")
|
| 53 |
+
self.eos_idx = self.get_idx("<eos>")
|
| 54 |
+
self.all_special_tokens = ['<eos>', '<pad>', '<mask>'] # , '<unk>', '<cls>'
|
| 55 |
+
self.unique_no_split_tokens = self.all_toks
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.all_toks)
|
| 59 |
+
|
| 60 |
+
def get_idx(self, tok):
|
| 61 |
+
return self.tok_to_idx.get(tok, self.unk_idx)
|
| 62 |
+
|
| 63 |
+
def get_tok(self, ind):
|
| 64 |
+
return self.all_toks[ind]
|
| 65 |
+
|
| 66 |
+
def to_dict(self):
|
| 67 |
+
return self.tok_to_idx.copy()
|
| 68 |
+
|
| 69 |
+
def get_batch_converter(self):
|
| 70 |
+
if self.use_msa:
|
| 71 |
+
return MSABatchConverter(self)
|
| 72 |
+
else:
|
| 73 |
+
return BatchConverter(self)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_architecture(cls, name: str) -> "Alphabet":
|
| 77 |
+
if name in ("ESM-1", "protein_bert_base"):
|
| 78 |
+
standard_toks = proteinseq_toks["toks"]
|
| 79 |
+
prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
|
| 80 |
+
append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
|
| 81 |
+
prepend_bos = True
|
| 82 |
+
append_eos = False
|
| 83 |
+
use_msa = False
|
| 84 |
+
elif name in ("ESM-1b", "roberta_large"):
|
| 85 |
+
standard_toks = proteinseq_toks["toks"] ###---rnaseq
|
| 86 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
| 87 |
+
append_toks = ("<mask>",)
|
| 88 |
+
prepend_bos = True
|
| 89 |
+
append_eos = True
|
| 90 |
+
use_msa = False
|
| 91 |
+
elif name in ("MSA Transformer", "msa_transformer"):
|
| 92 |
+
standard_toks = proteinseq_toks["toks"]
|
| 93 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
| 94 |
+
append_toks = ("<mask>",)
|
| 95 |
+
prepend_bos = True
|
| 96 |
+
append_eos = False
|
| 97 |
+
use_msa = True
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("Unknown architecture selected")
|
| 100 |
+
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
|
| 101 |
+
|
| 102 |
+
def _tokenize(self, text) -> str:
|
| 103 |
+
return text.split()
|
| 104 |
+
|
| 105 |
+
def tokenize(self, text, **kwargs) -> List[str]:
|
| 106 |
+
"""
|
| 107 |
+
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
|
| 108 |
+
Converts a string in a sequence of tokens, using the tokenizer.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
text (:obj:`str`):
|
| 112 |
+
The sequence to be encoded.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
:obj:`List[str]`: The list of tokens.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def split_on_token(tok, text):
|
| 119 |
+
result = []
|
| 120 |
+
split_text = text.split(tok)
|
| 121 |
+
for i, sub_text in enumerate(split_text):
|
| 122 |
+
# AddedToken can control whitespace stripping around them.
|
| 123 |
+
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
| 124 |
+
# Cf. https://github.com/huggingface/transformers/pull/2778
|
| 125 |
+
# and https://github.com/huggingface/transformers/issues/3788
|
| 126 |
+
# We strip left and right by default
|
| 127 |
+
if i < len(split_text) - 1:
|
| 128 |
+
sub_text = sub_text.rstrip()
|
| 129 |
+
if i > 0:
|
| 130 |
+
sub_text = sub_text.lstrip()
|
| 131 |
+
|
| 132 |
+
if i == 0 and not sub_text:
|
| 133 |
+
result.append(tok)
|
| 134 |
+
elif i == len(split_text) - 1:
|
| 135 |
+
if sub_text:
|
| 136 |
+
result.append(sub_text)
|
| 137 |
+
else:
|
| 138 |
+
pass
|
| 139 |
+
else:
|
| 140 |
+
if sub_text:
|
| 141 |
+
result.append(sub_text)
|
| 142 |
+
result.append(tok)
|
| 143 |
+
return result
|
| 144 |
+
|
| 145 |
+
def split_on_tokens(tok_list, text):
|
| 146 |
+
if not text.strip():
|
| 147 |
+
return []
|
| 148 |
+
|
| 149 |
+
tokenized_text = []
|
| 150 |
+
text_list = [text]
|
| 151 |
+
for tok in tok_list:
|
| 152 |
+
tokenized_text = []
|
| 153 |
+
for sub_text in text_list:
|
| 154 |
+
if sub_text not in self.unique_no_split_tokens:
|
| 155 |
+
tokenized_text.extend(split_on_token(tok, sub_text))
|
| 156 |
+
else:
|
| 157 |
+
tokenized_text.append(sub_text)
|
| 158 |
+
text_list = tokenized_text
|
| 159 |
+
|
| 160 |
+
return list(
|
| 161 |
+
itertools.chain.from_iterable(
|
| 162 |
+
(
|
| 163 |
+
self._tokenize(token)
|
| 164 |
+
if token not in self.unique_no_split_tokens
|
| 165 |
+
else [token]
|
| 166 |
+
for token in tokenized_text
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
no_split_token = self.unique_no_split_tokens
|
| 172 |
+
tokenized_text = split_on_tokens(no_split_token, text)
|
| 173 |
+
return tokenized_text
|
| 174 |
+
|
| 175 |
+
def encode(self, text):
|
| 176 |
+
return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
|
| 177 |
+
|
| 178 |
+
class FastaBatchedDataset(object):
|
| 179 |
+
def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15):
|
| 180 |
+
self.sequence_labels = list(sequence_labels)
|
| 181 |
+
self.sequence_strs = list(sequence_strs)
|
| 182 |
+
self.mask_prob = mask_prob
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def from_file(cls, fasta_file, mask_prob = 0.15):
|
| 186 |
+
sequence_labels, sequence_strs = [], []
|
| 187 |
+
cur_seq_label = None
|
| 188 |
+
buf = []
|
| 189 |
+
|
| 190 |
+
def _flush_current_seq():
|
| 191 |
+
nonlocal cur_seq_label, buf
|
| 192 |
+
if cur_seq_label is None:
|
| 193 |
+
return
|
| 194 |
+
sequence_labels.append(cur_seq_label)
|
| 195 |
+
sequence_strs.append("".join(buf))
|
| 196 |
+
cur_seq_label = None
|
| 197 |
+
buf = []
|
| 198 |
+
|
| 199 |
+
with open(fasta_file, "r") as infile:
|
| 200 |
+
for line_idx, line in enumerate(infile):
|
| 201 |
+
if line.startswith(">"): # label line
|
| 202 |
+
_flush_current_seq()
|
| 203 |
+
line = line[1:].strip()
|
| 204 |
+
if len(line) > 0:
|
| 205 |
+
cur_seq_label = line
|
| 206 |
+
else:
|
| 207 |
+
cur_seq_label = f"seqnum{line_idx:09d}"
|
| 208 |
+
else: # sequence line
|
| 209 |
+
buf.append(line.strip())
|
| 210 |
+
|
| 211 |
+
_flush_current_seq()
|
| 212 |
+
|
| 213 |
+
assert len(set(sequence_strs)) == len(
|
| 214 |
+
sequence_strs
|
| 215 |
+
), "Found duplicate sequence labels"
|
| 216 |
+
|
| 217 |
+
return cls(sequence_labels, sequence_strs, mask_prob)
|
| 218 |
+
|
| 219 |
+
def __len__(self):
|
| 220 |
+
return len(self.sequence_labels)
|
| 221 |
+
|
| 222 |
+
def mask_sequence(self, seq): ###---
|
| 223 |
+
length = len(seq)
|
| 224 |
+
# print(self.mask_prob)
|
| 225 |
+
max_length = math.ceil(length * self.mask_prob)
|
| 226 |
+
rand = random.sample(range(0, length), max_length)
|
| 227 |
+
res = ''.join(['<mask>' if idx in rand else ele for idx, ele in enumerate(seq)])
|
| 228 |
+
#print(seq, rand, res)
|
| 229 |
+
return rand, res
|
| 230 |
+
|
| 231 |
+
def __getitem__(self, idx):
|
| 232 |
+
sequence_str = self.sequence_strs[idx]
|
| 233 |
+
sequence_label = self.sequence_labels[idx]
|
| 234 |
+
masked_indices, masked_sequence_str = self.mask_sequence(sequence_str)
|
| 235 |
+
return sequence_label, sequence_str, masked_sequence_str, masked_indices
|
| 236 |
+
|
| 237 |
+
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
|
| 238 |
+
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
|
| 239 |
+
sizes.sort()
|
| 240 |
+
batches = []
|
| 241 |
+
buf = []
|
| 242 |
+
max_len = 0
|
| 243 |
+
|
| 244 |
+
def _flush_current_buf():
|
| 245 |
+
nonlocal max_len, buf
|
| 246 |
+
if len(buf) == 0:
|
| 247 |
+
return
|
| 248 |
+
batches.append(buf)
|
| 249 |
+
buf = []
|
| 250 |
+
max_len = 0
|
| 251 |
+
|
| 252 |
+
for sz, i in sizes:
|
| 253 |
+
sz += extra_toks_per_seq
|
| 254 |
+
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
|
| 255 |
+
_flush_current_buf()
|
| 256 |
+
max_len = max(max_len, sz)
|
| 257 |
+
buf.append(i)
|
| 258 |
+
|
| 259 |
+
_flush_current_buf()
|
| 260 |
+
return batches
|
| 261 |
+
|
| 262 |
+
class BatchConverter(object):
|
| 263 |
+
"""Callable to convert an unprocessed (labels + strings) batch to a
|
| 264 |
+
processed (labels + tensor) batch.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(self, alphabet):
|
| 268 |
+
self.alphabet = alphabet
|
| 269 |
+
|
| 270 |
+
def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
|
| 271 |
+
# RoBERTa uses an eos token, while ESM-1 does not.
|
| 272 |
+
batch_size = len(raw_batch)
|
| 273 |
+
batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch)
|
| 274 |
+
|
| 275 |
+
masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###---
|
| 276 |
+
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###---
|
| 277 |
+
# print('====', seq_str_list)
|
| 278 |
+
# print('----', masked_seq_str_list)
|
| 279 |
+
# print('++++', masked_seq_encoded_list)
|
| 280 |
+
# print('****', seq_encoded_list)
|
| 281 |
+
|
| 282 |
+
max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list)
|
| 283 |
+
tokens = torch.empty(
|
| 284 |
+
(
|
| 285 |
+
batch_size,
|
| 286 |
+
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
| 287 |
+
),
|
| 288 |
+
dtype=torch.int64,
|
| 289 |
+
)
|
| 290 |
+
tokens.fill_(self.alphabet.padding_idx)
|
| 291 |
+
masked_tokens = deepcopy(tokens)
|
| 292 |
+
|
| 293 |
+
labels = []
|
| 294 |
+
strs, masked_strs = [], []
|
| 295 |
+
masked_indices = []
|
| 296 |
+
# print('=================')
|
| 297 |
+
for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate(
|
| 298 |
+
zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###---
|
| 299 |
+
):
|
| 300 |
+
labels.append(label)
|
| 301 |
+
strs.append(seq_str)
|
| 302 |
+
masked_strs.append(masked_seq_str)
|
| 303 |
+
masked_indices.append(indices_mask)
|
| 304 |
+
|
| 305 |
+
if self.alphabet.prepend_bos:
|
| 306 |
+
tokens[i, 0] = self.alphabet.cls_idx
|
| 307 |
+
masked_tokens[i, 0] = self.alphabet.cls_idx
|
| 308 |
+
|
| 309 |
+
seq = torch.tensor(seq_encoded, dtype=torch.int64)
|
| 310 |
+
masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64)
|
| 311 |
+
# print(tokens, masked_tokens)
|
| 312 |
+
tokens[
|
| 313 |
+
i,
|
| 314 |
+
int(self.alphabet.prepend_bos) : len(seq_encoded)
|
| 315 |
+
+ int(self.alphabet.prepend_bos),
|
| 316 |
+
] = seq
|
| 317 |
+
|
| 318 |
+
masked_tokens[
|
| 319 |
+
i,
|
| 320 |
+
int(self.alphabet.prepend_bos) : len(masked_seq_encoded)
|
| 321 |
+
+ int(self.alphabet.prepend_bos),
|
| 322 |
+
] = masked_seq
|
| 323 |
+
# print(tokens, masked_tokens)
|
| 324 |
+
if self.alphabet.append_eos:
|
| 325 |
+
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
| 326 |
+
masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
| 327 |
+
# print(tokens, masked_tokens)
|
| 328 |
+
return labels, strs, masked_strs, tokens, masked_tokens, masked_indices
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class MSABatchConverter(BatchConverter):
|
| 332 |
+
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
|
| 333 |
+
if isinstance(inputs[0][0], str):
|
| 334 |
+
# Input is a single MSA
|
| 335 |
+
raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
|
| 336 |
+
else:
|
| 337 |
+
raw_batch = inputs # type: ignore
|
| 338 |
+
|
| 339 |
+
batch_size = len(raw_batch)
|
| 340 |
+
max_alignments = max(len(msa) for msa in raw_batch)
|
| 341 |
+
max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
|
| 342 |
+
|
| 343 |
+
tokens = torch.empty(
|
| 344 |
+
(
|
| 345 |
+
batch_size,
|
| 346 |
+
max_alignments,
|
| 347 |
+
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
| 348 |
+
),
|
| 349 |
+
dtype=torch.int64,
|
| 350 |
+
)
|
| 351 |
+
tokens.fill_(self.alphabet.padding_idx)
|
| 352 |
+
labels = []
|
| 353 |
+
strs = []
|
| 354 |
+
|
| 355 |
+
for i, msa in enumerate(raw_batch):
|
| 356 |
+
msa_seqlens = set(len(seq) for _, seq in msa)
|
| 357 |
+
if not len(msa_seqlens) == 1:
|
| 358 |
+
raise RuntimeError(
|
| 359 |
+
"Received unaligned sequences for input to MSA, all sequence "
|
| 360 |
+
"lengths must be equal."
|
| 361 |
+
)
|
| 362 |
+
msa_labels, msa_strs, msa_tokens = super().__call__(msa)
|
| 363 |
+
labels.append(msa_labels)
|
| 364 |
+
strs.append(msa_strs)
|
| 365 |
+
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
|
| 366 |
+
|
| 367 |
+
return labels, strs, tokens
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def read_fasta(
|
| 371 |
+
path,
|
| 372 |
+
keep_gaps=True,
|
| 373 |
+
keep_insertions=True,
|
| 374 |
+
to_upper=False,
|
| 375 |
+
):
|
| 376 |
+
with open(path, "r") as f:
|
| 377 |
+
for result in read_alignment_lines(
|
| 378 |
+
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
|
| 379 |
+
):
|
| 380 |
+
yield result
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def read_alignment_lines(
|
| 384 |
+
lines,
|
| 385 |
+
keep_gaps=True,
|
| 386 |
+
keep_insertions=True,
|
| 387 |
+
to_upper=False,
|
| 388 |
+
):
|
| 389 |
+
seq = desc = None
|
| 390 |
+
|
| 391 |
+
def parse(s):
|
| 392 |
+
if not keep_gaps:
|
| 393 |
+
s = re.sub("-", "", s)
|
| 394 |
+
if not keep_insertions:
|
| 395 |
+
s = re.sub("[a-z]", "", s)
|
| 396 |
+
return s.upper() if to_upper else s
|
| 397 |
+
|
| 398 |
+
for line in lines:
|
| 399 |
+
# Line may be empty if seq % file_line_width == 0
|
| 400 |
+
if len(line) > 0 and line[0] == ">":
|
| 401 |
+
if seq is not None:
|
| 402 |
+
yield desc, parse(seq)
|
| 403 |
+
desc = line.strip()
|
| 404 |
+
seq = ""
|
| 405 |
+
else:
|
| 406 |
+
assert isinstance(seq, str)
|
| 407 |
+
seq += line.strip()
|
| 408 |
+
assert isinstance(seq, str) and isinstance(desc, str)
|
| 409 |
+
yield desc, parse(seq)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class ESMStructuralSplitDataset(torch.utils.data.Dataset):
|
| 413 |
+
"""
|
| 414 |
+
Structural Split Dataset as described in section A.10 of the supplement of our paper.
|
| 415 |
+
https://doi.org/10.1101/622803
|
| 416 |
+
|
| 417 |
+
We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
|
| 418 |
+
generated on January 23, 2020.
|
| 419 |
+
|
| 420 |
+
For each SCOPe domain:
|
| 421 |
+
- We extract the sequence from the corresponding PDB file
|
| 422 |
+
- We extract the 3D coordinates of the Carbon beta atoms, aligning them
|
| 423 |
+
to the sequence. We put NaN where Cb atoms are missing.
|
| 424 |
+
- From the 3D coordinates, we calculate a pairwise distance map, based
|
| 425 |
+
on L2 distance
|
| 426 |
+
- We use DSSP to generate secondary structure labels for the corresponding
|
| 427 |
+
PDB file. This is also aligned to the sequence. We put - where SSP
|
| 428 |
+
labels are missing.
|
| 429 |
+
|
| 430 |
+
For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
|
| 431 |
+
we have split the data into 5 partitions for cross validation. These are provided
|
| 432 |
+
in a downloaded splits folder, in the format:
|
| 433 |
+
splits/{split_level}/{cv_partition}/{train|valid}.txt
|
| 434 |
+
where train is the partition and valid is the concatentation of the remaining 4.
|
| 435 |
+
|
| 436 |
+
For each SCOPe domain, we provide a pkl dump that contains:
|
| 437 |
+
- seq : The domain sequence, stored as an L-length string
|
| 438 |
+
- ssp : The secondary structure labels, stored as an L-length string
|
| 439 |
+
- dist : The distance map, stored as an LxL numpy array
|
| 440 |
+
- coords : The 3D coordinates, stored as an Lx3 numpy array
|
| 441 |
+
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
base_folder = "structural-data"
|
| 445 |
+
file_list = [
|
| 446 |
+
# url tar filename filename MD5 Hash
|
| 447 |
+
(
|
| 448 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
|
| 449 |
+
"splits.tar.gz",
|
| 450 |
+
"splits",
|
| 451 |
+
"456fe1c7f22c9d3d8dfe9735da52411d",
|
| 452 |
+
),
|
| 453 |
+
(
|
| 454 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
|
| 455 |
+
"pkl.tar.gz",
|
| 456 |
+
"pkl",
|
| 457 |
+
"644ea91e56066c750cd50101d390f5db",
|
| 458 |
+
),
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
split_level,
|
| 464 |
+
cv_partition,
|
| 465 |
+
split,
|
| 466 |
+
root_path=os.path.expanduser("~/.cache/torch/data/esm"),
|
| 467 |
+
download=False,
|
| 468 |
+
):
|
| 469 |
+
super().__init__()
|
| 470 |
+
assert split in [
|
| 471 |
+
"train",
|
| 472 |
+
"valid",
|
| 473 |
+
], "train_valid must be 'train' or 'valid'"
|
| 474 |
+
self.root_path = root_path
|
| 475 |
+
self.base_path = os.path.join(self.root_path, self.base_folder)
|
| 476 |
+
|
| 477 |
+
# check if root path has what you need or else download it
|
| 478 |
+
if download:
|
| 479 |
+
self.download()
|
| 480 |
+
|
| 481 |
+
self.split_file = os.path.join(
|
| 482 |
+
self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
|
| 483 |
+
)
|
| 484 |
+
self.pkl_dir = os.path.join(self.base_path, "pkl")
|
| 485 |
+
self.names = []
|
| 486 |
+
with open(self.split_file) as f:
|
| 487 |
+
self.names = f.read().splitlines()
|
| 488 |
+
|
| 489 |
+
def __len__(self):
|
| 490 |
+
return len(self.names)
|
| 491 |
+
|
| 492 |
+
def _check_exists(self) -> bool:
|
| 493 |
+
for (_, _, filename, _) in self.file_list:
|
| 494 |
+
fpath = os.path.join(self.base_path, filename)
|
| 495 |
+
if not os.path.exists(fpath) or not os.path.isdir(fpath):
|
| 496 |
+
return False
|
| 497 |
+
return True
|
| 498 |
+
|
| 499 |
+
def download(self):
|
| 500 |
+
|
| 501 |
+
if self._check_exists():
|
| 502 |
+
print("Files already downloaded and verified")
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
+
from torchvision.datasets.utils import download_url
|
| 506 |
+
|
| 507 |
+
for url, tar_filename, filename, md5_hash in self.file_list:
|
| 508 |
+
download_path = os.path.join(self.base_path, tar_filename)
|
| 509 |
+
download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
|
| 510 |
+
shutil.unpack_archive(download_path, self.base_path)
|
| 511 |
+
|
| 512 |
+
def __getitem__(self, idx):
|
| 513 |
+
"""
|
| 514 |
+
Returns a dict with the following entires
|
| 515 |
+
- seq : Str (domain sequence)
|
| 516 |
+
- ssp : Str (SSP labels)
|
| 517 |
+
- dist : np.array (distance map)
|
| 518 |
+
- coords : np.array (3D coordinates)
|
| 519 |
+
"""
|
| 520 |
+
name = self.names[idx]
|
| 521 |
+
pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
|
| 522 |
+
with open(pkl_fname, "rb") as f:
|
| 523 |
+
obj = pickle.load(f)
|
| 524 |
+
return obj
|
esm/data_supervised.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import os
|
| 8 |
+
from typing import Sequence, Tuple, List, Union
|
| 9 |
+
import pickle
|
| 10 |
+
import re
|
| 11 |
+
import shutil
|
| 12 |
+
import torch
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from .constants import proteinseq_toks, rnaseq_toks
|
| 15 |
+
import math
|
| 16 |
+
import random
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
+
RawMSA = Sequence[Tuple[str, str]]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Alphabet(object):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
standard_toks: Sequence[str],
|
| 26 |
+
prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"), # "<null_0>",
|
| 27 |
+
append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), #
|
| 28 |
+
prepend_bos: bool = True,
|
| 29 |
+
append_eos: bool = True,
|
| 30 |
+
use_msa: bool = False,
|
| 31 |
+
mask_prob: float = 0.15, ###---
|
| 32 |
+
):
|
| 33 |
+
self.mask_prob = mask_prob ###---
|
| 34 |
+
self.standard_toks = list(standard_toks)
|
| 35 |
+
self.prepend_toks = list(prepend_toks)
|
| 36 |
+
self.append_toks = list(append_toks)
|
| 37 |
+
self.prepend_bos = prepend_bos
|
| 38 |
+
self.append_eos = append_eos
|
| 39 |
+
self.use_msa = use_msa
|
| 40 |
+
|
| 41 |
+
self.all_toks = list(self.prepend_toks)
|
| 42 |
+
self.all_toks.extend(self.standard_toks)
|
| 43 |
+
# for i in range((8 - (len(self.all_toks) % 8)) % 8):
|
| 44 |
+
# self.all_toks.append(f"<null_{i + 1}>")
|
| 45 |
+
self.all_toks.extend(self.append_toks)
|
| 46 |
+
|
| 47 |
+
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
|
| 48 |
+
# print(self.tok_to_idx)
|
| 49 |
+
self.unk_idx = self.tok_to_idx["<unk>"]
|
| 50 |
+
self.padding_idx = self.get_idx("<pad>")
|
| 51 |
+
self.cls_idx = self.get_idx("<cls>")
|
| 52 |
+
self.mask_idx = self.get_idx("<mask>")
|
| 53 |
+
self.eos_idx = self.get_idx("<eos>")
|
| 54 |
+
self.all_special_tokens = ['<eos>', '<pad>', '<mask>'] # , '<unk>', '<cls>'
|
| 55 |
+
self.unique_no_split_tokens = self.all_toks
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.all_toks)
|
| 59 |
+
|
| 60 |
+
def get_idx(self, tok):
|
| 61 |
+
return self.tok_to_idx.get(tok, self.unk_idx)
|
| 62 |
+
|
| 63 |
+
def get_tok(self, ind):
|
| 64 |
+
return self.all_toks[ind]
|
| 65 |
+
|
| 66 |
+
def to_dict(self):
|
| 67 |
+
return self.tok_to_idx.copy()
|
| 68 |
+
|
| 69 |
+
def get_batch_converter(self):
|
| 70 |
+
if self.use_msa:
|
| 71 |
+
return MSABatchConverter(self)
|
| 72 |
+
else:
|
| 73 |
+
return BatchConverter(self)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_architecture(cls, name: str) -> "Alphabet":
|
| 77 |
+
if name in ("ESM-1", "protein_bert_base"):
|
| 78 |
+
standard_toks = proteinseq_toks["toks"]
|
| 79 |
+
prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
|
| 80 |
+
append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
|
| 81 |
+
prepend_bos = True
|
| 82 |
+
append_eos = False
|
| 83 |
+
use_msa = False
|
| 84 |
+
elif name in ("ESM-1b", "roberta_large"):
|
| 85 |
+
standard_toks = proteinseq_toks["toks"] ###---rnaseq
|
| 86 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
| 87 |
+
append_toks = ("<mask>",)
|
| 88 |
+
prepend_bos = True
|
| 89 |
+
append_eos = True
|
| 90 |
+
use_msa = False
|
| 91 |
+
elif name in ("MSA Transformer", "msa_transformer"):
|
| 92 |
+
standard_toks = proteinseq_toks["toks"]
|
| 93 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
| 94 |
+
append_toks = ("<mask>",)
|
| 95 |
+
prepend_bos = True
|
| 96 |
+
append_eos = False
|
| 97 |
+
use_msa = True
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("Unknown architecture selected")
|
| 100 |
+
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
|
| 101 |
+
|
| 102 |
+
def _tokenize(self, text) -> str:
|
| 103 |
+
return text.split()
|
| 104 |
+
|
| 105 |
+
def tokenize(self, text, **kwargs) -> List[str]:
|
| 106 |
+
"""
|
| 107 |
+
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
|
| 108 |
+
Converts a string in a sequence of tokens, using the tokenizer.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
text (:obj:`str`):
|
| 112 |
+
The sequence to be encoded.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
:obj:`List[str]`: The list of tokens.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def split_on_token(tok, text):
|
| 119 |
+
result = []
|
| 120 |
+
split_text = text.split(tok)
|
| 121 |
+
for i, sub_text in enumerate(split_text):
|
| 122 |
+
# AddedToken can control whitespace stripping around them.
|
| 123 |
+
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
| 124 |
+
# Cf. https://github.com/huggingface/transformers/pull/2778
|
| 125 |
+
# and https://github.com/huggingface/transformers/issues/3788
|
| 126 |
+
# We strip left and right by default
|
| 127 |
+
if i < len(split_text) - 1:
|
| 128 |
+
sub_text = sub_text.rstrip()
|
| 129 |
+
if i > 0:
|
| 130 |
+
sub_text = sub_text.lstrip()
|
| 131 |
+
|
| 132 |
+
if i == 0 and not sub_text:
|
| 133 |
+
result.append(tok)
|
| 134 |
+
elif i == len(split_text) - 1:
|
| 135 |
+
if sub_text:
|
| 136 |
+
result.append(sub_text)
|
| 137 |
+
else:
|
| 138 |
+
pass
|
| 139 |
+
else:
|
| 140 |
+
if sub_text:
|
| 141 |
+
result.append(sub_text)
|
| 142 |
+
result.append(tok)
|
| 143 |
+
return result
|
| 144 |
+
|
| 145 |
+
def split_on_tokens(tok_list, text):
|
| 146 |
+
if not text.strip():
|
| 147 |
+
return []
|
| 148 |
+
|
| 149 |
+
tokenized_text = []
|
| 150 |
+
text_list = [text]
|
| 151 |
+
for tok in tok_list:
|
| 152 |
+
tokenized_text = []
|
| 153 |
+
for sub_text in text_list:
|
| 154 |
+
if sub_text not in self.unique_no_split_tokens:
|
| 155 |
+
tokenized_text.extend(split_on_token(tok, sub_text))
|
| 156 |
+
else:
|
| 157 |
+
tokenized_text.append(sub_text)
|
| 158 |
+
text_list = tokenized_text
|
| 159 |
+
|
| 160 |
+
return list(
|
| 161 |
+
itertools.chain.from_iterable(
|
| 162 |
+
(
|
| 163 |
+
self._tokenize(token)
|
| 164 |
+
if token not in self.unique_no_split_tokens
|
| 165 |
+
else [token]
|
| 166 |
+
for token in tokenized_text
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
no_split_token = self.unique_no_split_tokens
|
| 172 |
+
tokenized_text = split_on_tokens(no_split_token, text)
|
| 173 |
+
return tokenized_text
|
| 174 |
+
|
| 175 |
+
def encode(self, text):
|
| 176 |
+
return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
|
| 177 |
+
|
| 178 |
+
class FastaBatchedDataset(object):
|
| 179 |
+
def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15):
|
| 180 |
+
self.sequence_labels = list(sequence_labels)
|
| 181 |
+
self.sequence_strs = list(sequence_strs)
|
| 182 |
+
self.mask_prob = mask_prob
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def from_file(cls, fasta_file, mask_prob = 0.15):
|
| 186 |
+
sequence_labels, sequence_strs = [], []
|
| 187 |
+
cur_seq_label = None
|
| 188 |
+
buf = []
|
| 189 |
+
|
| 190 |
+
def _flush_current_seq():
|
| 191 |
+
nonlocal cur_seq_label, buf
|
| 192 |
+
if cur_seq_label is None:
|
| 193 |
+
return
|
| 194 |
+
sequence_labels.append(cur_seq_label)
|
| 195 |
+
sequence_strs.append("".join(buf))
|
| 196 |
+
cur_seq_label = None
|
| 197 |
+
buf = []
|
| 198 |
+
|
| 199 |
+
with open(fasta_file, "r") as infile:
|
| 200 |
+
for line_idx, line in enumerate(infile):
|
| 201 |
+
if line.startswith(">"): # label line
|
| 202 |
+
_flush_current_seq()
|
| 203 |
+
line = line[1:].strip()
|
| 204 |
+
if len(line) > 0:
|
| 205 |
+
cur_seq_label = line
|
| 206 |
+
else:
|
| 207 |
+
cur_seq_label = f"seqnum{line_idx:09d}"
|
| 208 |
+
else: # sequence line
|
| 209 |
+
buf.append(line.strip())
|
| 210 |
+
|
| 211 |
+
_flush_current_seq()
|
| 212 |
+
|
| 213 |
+
assert len(set(sequence_labels)) == len(
|
| 214 |
+
sequence_labels
|
| 215 |
+
), "Found duplicate sequence labels"
|
| 216 |
+
|
| 217 |
+
return cls(sequence_labels, sequence_strs, mask_prob)
|
| 218 |
+
|
| 219 |
+
def __len__(self):
|
| 220 |
+
return len(self.sequence_labels)
|
| 221 |
+
|
| 222 |
+
def mask_sequence(self, seq): ###---
|
| 223 |
+
length = len(seq)
|
| 224 |
+
# print(self.mask_prob)
|
| 225 |
+
max_length = math.ceil(length * self.mask_prob)
|
| 226 |
+
rand = random.sample(range(0, length), max_length)
|
| 227 |
+
res = ''.join(['<mask>' if idx in rand else ele for idx, ele in enumerate(seq)])
|
| 228 |
+
#print(seq, rand, res)
|
| 229 |
+
return rand, res
|
| 230 |
+
|
| 231 |
+
def __getitem__(self, idx):
|
| 232 |
+
sequence_str = self.sequence_strs[idx]
|
| 233 |
+
sequence_label = self.sequence_labels[idx]
|
| 234 |
+
masked_indices, masked_sequence_str = self.mask_sequence(sequence_str)
|
| 235 |
+
return sequence_label, sequence_str, masked_sequence_str, masked_indices
|
| 236 |
+
|
| 237 |
+
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
|
| 238 |
+
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
|
| 239 |
+
sizes.sort()
|
| 240 |
+
batches = []
|
| 241 |
+
buf = []
|
| 242 |
+
max_len = 0
|
| 243 |
+
|
| 244 |
+
def _flush_current_buf():
|
| 245 |
+
nonlocal max_len, buf
|
| 246 |
+
if len(buf) == 0:
|
| 247 |
+
return
|
| 248 |
+
batches.append(buf)
|
| 249 |
+
buf = []
|
| 250 |
+
max_len = 0
|
| 251 |
+
|
| 252 |
+
for sz, i in sizes:
|
| 253 |
+
sz += extra_toks_per_seq
|
| 254 |
+
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
|
| 255 |
+
_flush_current_buf()
|
| 256 |
+
max_len = max(max_len, sz)
|
| 257 |
+
buf.append(i)
|
| 258 |
+
|
| 259 |
+
_flush_current_buf()
|
| 260 |
+
return batches
|
| 261 |
+
|
| 262 |
+
class BatchConverter(object):
|
| 263 |
+
"""Callable to convert an unprocessed (labels + strings) batch to a
|
| 264 |
+
processed (labels + tensor) batch.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(self, alphabet):
|
| 268 |
+
self.alphabet = alphabet
|
| 269 |
+
|
| 270 |
+
def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
|
| 271 |
+
# RoBERTa uses an eos token, while ESM-1 does not.
|
| 272 |
+
batch_size = len(raw_batch)
|
| 273 |
+
batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch)
|
| 274 |
+
|
| 275 |
+
masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###---
|
| 276 |
+
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###---
|
| 277 |
+
# print('====', seq_str_list)
|
| 278 |
+
# print('----', masked_seq_str_list)
|
| 279 |
+
# print('++++', masked_seq_encoded_list)
|
| 280 |
+
# print('****', seq_encoded_list)
|
| 281 |
+
|
| 282 |
+
max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list)
|
| 283 |
+
tokens = torch.empty(
|
| 284 |
+
(
|
| 285 |
+
batch_size,
|
| 286 |
+
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
| 287 |
+
),
|
| 288 |
+
dtype=torch.int64,
|
| 289 |
+
)
|
| 290 |
+
tokens.fill_(self.alphabet.padding_idx)
|
| 291 |
+
masked_tokens = deepcopy(tokens)
|
| 292 |
+
|
| 293 |
+
labels = []
|
| 294 |
+
strs, masked_strs = [], []
|
| 295 |
+
masked_indices = []
|
| 296 |
+
# print('=================')
|
| 297 |
+
for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate(
|
| 298 |
+
zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###---
|
| 299 |
+
):
|
| 300 |
+
labels.append(label)
|
| 301 |
+
strs.append(seq_str)
|
| 302 |
+
masked_strs.append(masked_seq_str)
|
| 303 |
+
masked_indices.append(indices_mask)
|
| 304 |
+
|
| 305 |
+
if self.alphabet.prepend_bos:
|
| 306 |
+
tokens[i, 0] = self.alphabet.cls_idx
|
| 307 |
+
masked_tokens[i, 0] = self.alphabet.cls_idx
|
| 308 |
+
|
| 309 |
+
seq = torch.tensor(seq_encoded, dtype=torch.int64)
|
| 310 |
+
masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64)
|
| 311 |
+
# print(tokens, masked_tokens)
|
| 312 |
+
tokens[
|
| 313 |
+
i,
|
| 314 |
+
int(self.alphabet.prepend_bos) : len(seq_encoded)
|
| 315 |
+
+ int(self.alphabet.prepend_bos),
|
| 316 |
+
] = seq
|
| 317 |
+
|
| 318 |
+
masked_tokens[
|
| 319 |
+
i,
|
| 320 |
+
int(self.alphabet.prepend_bos) : len(masked_seq_encoded)
|
| 321 |
+
+ int(self.alphabet.prepend_bos),
|
| 322 |
+
] = masked_seq
|
| 323 |
+
# print(tokens, masked_tokens)
|
| 324 |
+
if self.alphabet.append_eos:
|
| 325 |
+
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
| 326 |
+
masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
| 327 |
+
# print(tokens, masked_tokens)
|
| 328 |
+
return labels, strs, masked_strs, tokens, masked_tokens, masked_indices
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class MSABatchConverter(BatchConverter):
|
| 332 |
+
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
|
| 333 |
+
if isinstance(inputs[0][0], str):
|
| 334 |
+
# Input is a single MSA
|
| 335 |
+
raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
|
| 336 |
+
else:
|
| 337 |
+
raw_batch = inputs # type: ignore
|
| 338 |
+
|
| 339 |
+
batch_size = len(raw_batch)
|
| 340 |
+
max_alignments = max(len(msa) for msa in raw_batch)
|
| 341 |
+
max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
|
| 342 |
+
|
| 343 |
+
tokens = torch.empty(
|
| 344 |
+
(
|
| 345 |
+
batch_size,
|
| 346 |
+
max_alignments,
|
| 347 |
+
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
| 348 |
+
),
|
| 349 |
+
dtype=torch.int64,
|
| 350 |
+
)
|
| 351 |
+
tokens.fill_(self.alphabet.padding_idx)
|
| 352 |
+
labels = []
|
| 353 |
+
strs = []
|
| 354 |
+
|
| 355 |
+
for i, msa in enumerate(raw_batch):
|
| 356 |
+
msa_seqlens = set(len(seq) for _, seq in msa)
|
| 357 |
+
if not len(msa_seqlens) == 1:
|
| 358 |
+
raise RuntimeError(
|
| 359 |
+
"Received unaligned sequences for input to MSA, all sequence "
|
| 360 |
+
"lengths must be equal."
|
| 361 |
+
)
|
| 362 |
+
msa_labels, msa_strs, msa_tokens = super().__call__(msa)
|
| 363 |
+
labels.append(msa_labels)
|
| 364 |
+
strs.append(msa_strs)
|
| 365 |
+
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
|
| 366 |
+
|
| 367 |
+
return labels, strs, tokens
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def read_fasta(
|
| 371 |
+
path,
|
| 372 |
+
keep_gaps=True,
|
| 373 |
+
keep_insertions=True,
|
| 374 |
+
to_upper=False,
|
| 375 |
+
):
|
| 376 |
+
with open(path, "r") as f:
|
| 377 |
+
for result in read_alignment_lines(
|
| 378 |
+
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
|
| 379 |
+
):
|
| 380 |
+
yield result
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def read_alignment_lines(
|
| 384 |
+
lines,
|
| 385 |
+
keep_gaps=True,
|
| 386 |
+
keep_insertions=True,
|
| 387 |
+
to_upper=False,
|
| 388 |
+
):
|
| 389 |
+
seq = desc = None
|
| 390 |
+
|
| 391 |
+
def parse(s):
|
| 392 |
+
if not keep_gaps:
|
| 393 |
+
s = re.sub("-", "", s)
|
| 394 |
+
if not keep_insertions:
|
| 395 |
+
s = re.sub("[a-z]", "", s)
|
| 396 |
+
return s.upper() if to_upper else s
|
| 397 |
+
|
| 398 |
+
for line in lines:
|
| 399 |
+
# Line may be empty if seq % file_line_width == 0
|
| 400 |
+
if len(line) > 0 and line[0] == ">":
|
| 401 |
+
if seq is not None:
|
| 402 |
+
yield desc, parse(seq)
|
| 403 |
+
desc = line.strip()
|
| 404 |
+
seq = ""
|
| 405 |
+
else:
|
| 406 |
+
assert isinstance(seq, str)
|
| 407 |
+
seq += line.strip()
|
| 408 |
+
assert isinstance(seq, str) and isinstance(desc, str)
|
| 409 |
+
yield desc, parse(seq)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class ESMStructuralSplitDataset(torch.utils.data.Dataset):
|
| 413 |
+
"""
|
| 414 |
+
Structural Split Dataset as described in section A.10 of the supplement of our paper.
|
| 415 |
+
https://doi.org/10.1101/622803
|
| 416 |
+
|
| 417 |
+
We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
|
| 418 |
+
generated on January 23, 2020.
|
| 419 |
+
|
| 420 |
+
For each SCOPe domain:
|
| 421 |
+
- We extract the sequence from the corresponding PDB file
|
| 422 |
+
- We extract the 3D coordinates of the Carbon beta atoms, aligning them
|
| 423 |
+
to the sequence. We put NaN where Cb atoms are missing.
|
| 424 |
+
- From the 3D coordinates, we calculate a pairwise distance map, based
|
| 425 |
+
on L2 distance
|
| 426 |
+
- We use DSSP to generate secondary structure labels for the corresponding
|
| 427 |
+
PDB file. This is also aligned to the sequence. We put - where SSP
|
| 428 |
+
labels are missing.
|
| 429 |
+
|
| 430 |
+
For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
|
| 431 |
+
we have split the data into 5 partitions for cross validation. These are provided
|
| 432 |
+
in a downloaded splits folder, in the format:
|
| 433 |
+
splits/{split_level}/{cv_partition}/{train|valid}.txt
|
| 434 |
+
where train is the partition and valid is the concatentation of the remaining 4.
|
| 435 |
+
|
| 436 |
+
For each SCOPe domain, we provide a pkl dump that contains:
|
| 437 |
+
- seq : The domain sequence, stored as an L-length string
|
| 438 |
+
- ssp : The secondary structure labels, stored as an L-length string
|
| 439 |
+
- dist : The distance map, stored as an LxL numpy array
|
| 440 |
+
- coords : The 3D coordinates, stored as an Lx3 numpy array
|
| 441 |
+
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
base_folder = "structural-data"
|
| 445 |
+
file_list = [
|
| 446 |
+
# url tar filename filename MD5 Hash
|
| 447 |
+
(
|
| 448 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
|
| 449 |
+
"splits.tar.gz",
|
| 450 |
+
"splits",
|
| 451 |
+
"456fe1c7f22c9d3d8dfe9735da52411d",
|
| 452 |
+
),
|
| 453 |
+
(
|
| 454 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
|
| 455 |
+
"pkl.tar.gz",
|
| 456 |
+
"pkl",
|
| 457 |
+
"644ea91e56066c750cd50101d390f5db",
|
| 458 |
+
),
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
split_level,
|
| 464 |
+
cv_partition,
|
| 465 |
+
split,
|
| 466 |
+
root_path=os.path.expanduser("~/.cache/torch/data/esm"),
|
| 467 |
+
download=False,
|
| 468 |
+
):
|
| 469 |
+
super().__init__()
|
| 470 |
+
assert split in [
|
| 471 |
+
"train",
|
| 472 |
+
"valid",
|
| 473 |
+
], "train_valid must be 'train' or 'valid'"
|
| 474 |
+
self.root_path = root_path
|
| 475 |
+
self.base_path = os.path.join(self.root_path, self.base_folder)
|
| 476 |
+
|
| 477 |
+
# check if root path has what you need or else download it
|
| 478 |
+
if download:
|
| 479 |
+
self.download()
|
| 480 |
+
|
| 481 |
+
self.split_file = os.path.join(
|
| 482 |
+
self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
|
| 483 |
+
)
|
| 484 |
+
self.pkl_dir = os.path.join(self.base_path, "pkl")
|
| 485 |
+
self.names = []
|
| 486 |
+
with open(self.split_file) as f:
|
| 487 |
+
self.names = f.read().splitlines()
|
| 488 |
+
|
| 489 |
+
def __len__(self):
|
| 490 |
+
return len(self.names)
|
| 491 |
+
|
| 492 |
+
def _check_exists(self) -> bool:
|
| 493 |
+
for (_, _, filename, _) in self.file_list:
|
| 494 |
+
fpath = os.path.join(self.base_path, filename)
|
| 495 |
+
if not os.path.exists(fpath) or not os.path.isdir(fpath):
|
| 496 |
+
return False
|
| 497 |
+
return True
|
| 498 |
+
|
| 499 |
+
def download(self):
|
| 500 |
+
|
| 501 |
+
if self._check_exists():
|
| 502 |
+
print("Files already downloaded and verified")
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
+
from torchvision.datasets.utils import download_url
|
| 506 |
+
|
| 507 |
+
for url, tar_filename, filename, md5_hash in self.file_list:
|
| 508 |
+
download_path = os.path.join(self.base_path, tar_filename)
|
| 509 |
+
download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
|
| 510 |
+
shutil.unpack_archive(download_path, self.base_path)
|
| 511 |
+
|
| 512 |
+
def __getitem__(self, idx):
|
| 513 |
+
"""
|
| 514 |
+
Returns a dict with the following entires
|
| 515 |
+
- seq : Str (domain sequence)
|
| 516 |
+
- ssp : Str (SSP labels)
|
| 517 |
+
- dist : np.array (distance map)
|
| 518 |
+
- coords : np.array (3D coordinates)
|
| 519 |
+
"""
|
| 520 |
+
name = self.names[idx]
|
| 521 |
+
pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
|
| 522 |
+
with open(pkl_fname, "rb") as f:
|
| 523 |
+
obj = pickle.load(f)
|
| 524 |
+
return obj
|
esm/model/._esm2_secondarystructure.py
ADDED
|
Binary file (4.1 kB). View file
|
|
|
esm/model/__pycache__/esm1.cpython-36.pyc
ADDED
|
Binary file (5.18 kB). View file
|
|
|
esm/model/__pycache__/esm1.cpython-39.pyc
ADDED
|
Binary file (5.17 kB). View file
|
|
|
esm/model/__pycache__/esm2.cpython-36.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
esm/model/__pycache__/esm2.cpython-39.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc
ADDED
|
Binary file (4.79 kB). View file
|
|
|
esm/model/__pycache__/esm2_supervised.cpython-39.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
esm/model/__pycache__/msa_transformer.cpython-36.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
esm/model/__pycache__/msa_transformer.cpython-39.pyc
ADDED
|
Binary file (5.5 kB). View file
|
|
|
esm/model/esm1.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from ..modules import (
|
| 13 |
+
TransformerLayer,
|
| 14 |
+
LearnedPositionalEmbedding,
|
| 15 |
+
SinusoidalPositionalEmbedding,
|
| 16 |
+
RobertaLMHead,
|
| 17 |
+
ESM1bLayerNorm,
|
| 18 |
+
ContactPredictionHead,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ProteinBertModel(nn.Module):
|
| 23 |
+
@classmethod
|
| 24 |
+
def add_args(cls, parser):
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--num_layers", default=36, type=int, metavar="N", help="number of layers"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--logit_bias", action="store_true", help="whether to apply bias to logits"
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--ffn_embed_dim",
|
| 36 |
+
default=5120,
|
| 37 |
+
type=int,
|
| 38 |
+
metavar="N",
|
| 39 |
+
help="embedding dimension for FFN",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--attention_heads",
|
| 43 |
+
default=20,
|
| 44 |
+
type=int,
|
| 45 |
+
metavar="N",
|
| 46 |
+
help="number of attention heads",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def __init__(self, args, alphabet):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.args = args
|
| 52 |
+
self.alphabet_size = len(alphabet)
|
| 53 |
+
self.padding_idx = alphabet.padding_idx
|
| 54 |
+
self.mask_idx = alphabet.mask_idx
|
| 55 |
+
self.cls_idx = alphabet.cls_idx
|
| 56 |
+
self.eos_idx = alphabet.eos_idx
|
| 57 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 58 |
+
self.append_eos = alphabet.append_eos
|
| 59 |
+
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
|
| 60 |
+
if self.args.arch == "roberta_large":
|
| 61 |
+
self.model_version = "ESM-1b"
|
| 62 |
+
self._init_submodules_esm1b()
|
| 63 |
+
else:
|
| 64 |
+
self.model_version = "ESM-1"
|
| 65 |
+
self._init_submodules_esm1()
|
| 66 |
+
|
| 67 |
+
def _init_submodules_common(self):
|
| 68 |
+
self.embed_tokens = nn.Embedding(
|
| 69 |
+
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
|
| 70 |
+
)
|
| 71 |
+
self.layers = nn.ModuleList(
|
| 72 |
+
[
|
| 73 |
+
TransformerLayer(
|
| 74 |
+
self.args.embed_dim,
|
| 75 |
+
self.args.ffn_embed_dim,
|
| 76 |
+
self.args.attention_heads,
|
| 77 |
+
add_bias_kv=(self.model_version != "ESM-1b"),
|
| 78 |
+
use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
|
| 79 |
+
)
|
| 80 |
+
for _ in range(self.args.layers)
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.contact_head = ContactPredictionHead(
|
| 85 |
+
self.args.layers * self.args.attention_heads,
|
| 86 |
+
self.prepend_bos,
|
| 87 |
+
self.append_eos,
|
| 88 |
+
eos_idx=self.eos_idx,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _init_submodules_esm1b(self):
|
| 92 |
+
self._init_submodules_common()
|
| 93 |
+
self.embed_scale = 1
|
| 94 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
| 95 |
+
self.args.max_positions, self.args.embed_dim, self.padding_idx
|
| 96 |
+
)
|
| 97 |
+
self.emb_layer_norm_before = (
|
| 98 |
+
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
|
| 99 |
+
)
|
| 100 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
|
| 101 |
+
self.lm_head = RobertaLMHead(
|
| 102 |
+
embed_dim=self.args.embed_dim,
|
| 103 |
+
output_dim=self.alphabet_size,
|
| 104 |
+
weight=self.embed_tokens.weight,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _init_submodules_esm1(self):
|
| 108 |
+
self._init_submodules_common()
|
| 109 |
+
self.embed_scale = math.sqrt(self.args.embed_dim)
|
| 110 |
+
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
|
| 111 |
+
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
|
| 112 |
+
self.embed_out_bias = None
|
| 113 |
+
if self.args.final_bias:
|
| 114 |
+
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
|
| 115 |
+
|
| 116 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False):
|
| 117 |
+
if return_contacts:
|
| 118 |
+
need_head_weights = True
|
| 119 |
+
|
| 120 |
+
assert tokens.ndim == 2
|
| 121 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
| 122 |
+
|
| 123 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
| 124 |
+
|
| 125 |
+
if getattr(self.args, "token_dropout", False):
|
| 126 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
| 127 |
+
# x: B x T x C
|
| 128 |
+
mask_ratio_train = 0.15 * 0.8
|
| 129 |
+
src_lengths = (~padding_mask).sum(-1)
|
| 130 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
|
| 131 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
| 132 |
+
|
| 133 |
+
x = x + self.embed_positions(tokens)
|
| 134 |
+
|
| 135 |
+
if self.model_version == "ESM-1b":
|
| 136 |
+
if self.emb_layer_norm_before:
|
| 137 |
+
x = self.emb_layer_norm_before(x)
|
| 138 |
+
if padding_mask is not None:
|
| 139 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 140 |
+
|
| 141 |
+
repr_layers = set(repr_layers)
|
| 142 |
+
hidden_representations = {}
|
| 143 |
+
if 0 in repr_layers:
|
| 144 |
+
hidden_representations[0] = x
|
| 145 |
+
|
| 146 |
+
if need_head_weights:
|
| 147 |
+
attn_weights = []
|
| 148 |
+
|
| 149 |
+
# (B, T, E) => (T, B, E)
|
| 150 |
+
x = x.transpose(0, 1)
|
| 151 |
+
|
| 152 |
+
if not padding_mask.any():
|
| 153 |
+
padding_mask = None
|
| 154 |
+
|
| 155 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 156 |
+
x, attn = layer(
|
| 157 |
+
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
|
| 158 |
+
)
|
| 159 |
+
if (layer_idx + 1) in repr_layers:
|
| 160 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
| 161 |
+
if need_head_weights:
|
| 162 |
+
# (H, B, T, T) => (B, H, T, T)
|
| 163 |
+
attn_weights.append(attn.transpose(1, 0))
|
| 164 |
+
|
| 165 |
+
if self.model_version == "ESM-1b":
|
| 166 |
+
x = self.emb_layer_norm_after(x)
|
| 167 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 168 |
+
|
| 169 |
+
# last hidden representation should have layer norm applied
|
| 170 |
+
if (layer_idx + 1) in repr_layers:
|
| 171 |
+
hidden_representations[layer_idx + 1] = x
|
| 172 |
+
x = self.lm_head(x)
|
| 173 |
+
else:
|
| 174 |
+
x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
|
| 175 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 176 |
+
|
| 177 |
+
if return_representation:
|
| 178 |
+
result = {"logits": x, "representations": hidden_representations}
|
| 179 |
+
else:
|
| 180 |
+
result = {"logits": x}
|
| 181 |
+
if need_head_weights:
|
| 182 |
+
# attentions: B x L x H x T x T
|
| 183 |
+
attentions = torch.stack(attn_weights, 1)
|
| 184 |
+
if self.model_version == "ESM-1":
|
| 185 |
+
# ESM-1 models have an additional null-token for attention, which we remove
|
| 186 |
+
attentions = attentions[..., :-1]
|
| 187 |
+
if padding_mask is not None:
|
| 188 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
| 189 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
| 190 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
| 191 |
+
result["attentions"] = attentions
|
| 192 |
+
if return_contacts:
|
| 193 |
+
contacts = self.contact_head(tokens, attentions)
|
| 194 |
+
result["contacts"] = contacts
|
| 195 |
+
|
| 196 |
+
return result
|
| 197 |
+
|
| 198 |
+
def predict_contacts(self, tokens):
|
| 199 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def num_layers(self):
|
| 203 |
+
return self.args.layers
|
esm/model/esm2.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import esm
|
| 11 |
+
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ESM2(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_layers: int = 33,
|
| 18 |
+
embed_dim: int = 1280,
|
| 19 |
+
attention_heads: int = 20,
|
| 20 |
+
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
|
| 21 |
+
token_dropout: bool = True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
self.embed_dim = embed_dim
|
| 26 |
+
self.attention_heads = attention_heads
|
| 27 |
+
if not isinstance(alphabet, esm.data.Alphabet):
|
| 28 |
+
alphabet = esm.data.Alphabet.from_architecture(alphabet)
|
| 29 |
+
self.alphabet = alphabet
|
| 30 |
+
self.alphabet_size = len(alphabet)
|
| 31 |
+
self.padding_idx = alphabet.padding_idx
|
| 32 |
+
self.mask_idx = alphabet.mask_idx
|
| 33 |
+
self.cls_idx = alphabet.cls_idx
|
| 34 |
+
self.eos_idx = alphabet.eos_idx
|
| 35 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 36 |
+
self.append_eos = alphabet.append_eos
|
| 37 |
+
self.token_dropout = token_dropout
|
| 38 |
+
|
| 39 |
+
self._init_submodules()
|
| 40 |
+
|
| 41 |
+
def _init_submodules(self):
|
| 42 |
+
self.embed_scale = 1
|
| 43 |
+
self.embed_tokens = nn.Embedding(
|
| 44 |
+
self.alphabet_size,
|
| 45 |
+
self.embed_dim,
|
| 46 |
+
padding_idx=self.padding_idx,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.layers = nn.ModuleList(
|
| 50 |
+
[
|
| 51 |
+
TransformerLayer(
|
| 52 |
+
self.embed_dim,
|
| 53 |
+
4 * self.embed_dim,
|
| 54 |
+
self.attention_heads,
|
| 55 |
+
add_bias_kv=False,
|
| 56 |
+
use_esm1b_layer_norm=True,
|
| 57 |
+
use_rotary_embeddings=True,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(self.num_layers)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.contact_head = ContactPredictionHead(
|
| 64 |
+
self.num_layers * self.attention_heads,
|
| 65 |
+
self.prepend_bos,
|
| 66 |
+
self.append_eos,
|
| 67 |
+
eos_idx=self.eos_idx,
|
| 68 |
+
)
|
| 69 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
|
| 70 |
+
|
| 71 |
+
self.lm_head = RobertaLMHead(
|
| 72 |
+
embed_dim=self.embed_dim,
|
| 73 |
+
output_dim=self.alphabet_size,
|
| 74 |
+
weight=self.embed_tokens.weight,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False):
|
| 78 |
+
if return_contacts:
|
| 79 |
+
need_head_weights = True
|
| 80 |
+
|
| 81 |
+
assert tokens.ndim == 2
|
| 82 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
| 83 |
+
|
| 84 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
| 85 |
+
|
| 86 |
+
if self.token_dropout:
|
| 87 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
| 88 |
+
# x: B x T x C
|
| 89 |
+
mask_ratio_train = 0.15 * 0.8
|
| 90 |
+
src_lengths = (~padding_mask).sum(-1)
|
| 91 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
| 92 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
| 93 |
+
|
| 94 |
+
if padding_mask is not None:
|
| 95 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 96 |
+
|
| 97 |
+
repr_layers = set(repr_layers)
|
| 98 |
+
hidden_representations = {}
|
| 99 |
+
if 0 in repr_layers:
|
| 100 |
+
hidden_representations[0] = x
|
| 101 |
+
|
| 102 |
+
if need_head_weights:
|
| 103 |
+
attn_weights = []
|
| 104 |
+
|
| 105 |
+
# (B, T, E) => (T, B, E)
|
| 106 |
+
x = x.transpose(0, 1)
|
| 107 |
+
|
| 108 |
+
if not padding_mask.any():
|
| 109 |
+
padding_mask = None
|
| 110 |
+
|
| 111 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 112 |
+
x, attn = layer(
|
| 113 |
+
x,
|
| 114 |
+
self_attn_padding_mask=padding_mask,
|
| 115 |
+
need_head_weights=need_head_weights,
|
| 116 |
+
)
|
| 117 |
+
if (layer_idx + 1) in repr_layers:
|
| 118 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
| 119 |
+
if need_head_weights:
|
| 120 |
+
# (H, B, T, T) => (B, H, T, T)
|
| 121 |
+
attn_weights.append(attn.transpose(1, 0))
|
| 122 |
+
# print(x.shape) # 73, 2, 1280
|
| 123 |
+
x = self.emb_layer_norm_after(x)
|
| 124 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 125 |
+
|
| 126 |
+
# last hidden representation should have layer norm applied
|
| 127 |
+
if (layer_idx + 1) in repr_layers:
|
| 128 |
+
hidden_representations[layer_idx + 1] = x
|
| 129 |
+
x = self.lm_head(x)
|
| 130 |
+
|
| 131 |
+
if return_representation:
|
| 132 |
+
result = {"logits": x, "representations": hidden_representations}
|
| 133 |
+
else:
|
| 134 |
+
result = {"logits": x}
|
| 135 |
+
if need_head_weights:
|
| 136 |
+
# attentions: B x L x H x T x T
|
| 137 |
+
attentions = torch.stack(attn_weights, 1)
|
| 138 |
+
if padding_mask is not None:
|
| 139 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
| 140 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
| 141 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
| 142 |
+
result["attentions"] = attentions
|
| 143 |
+
if return_contacts:
|
| 144 |
+
attentions_symm, contacts = self.contact_head(tokens, attentions)
|
| 145 |
+
result["contacts"] = contacts
|
| 146 |
+
result["attentions_symm"] = attentions_symm
|
| 147 |
+
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
def predict_contacts(self, tokens):
|
| 151 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 152 |
+
|
| 153 |
+
def predict_symmetric_attentions(self, tokens):
|
| 154 |
+
return self(tokens, return_contacts=True)["attentions_symm"]
|
| 155 |
+
|
| 156 |
+
def predict_attentions(self, tokens):
|
| 157 |
+
return self(tokens, need_head_weights=True)["attentions"]
|
| 158 |
+
|
| 159 |
+
def predict_representations(self, tokens):
|
| 160 |
+
return self(tokens, return_representation=True)['representations']
|
| 161 |
+
|
| 162 |
+
def predict_logits(self, tokens):
|
| 163 |
+
return self(tokens)['logits']
|
esm/model/esm2_only_secondarystructure.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import esm
|
| 11 |
+
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ESM2(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_layers: int = 33,
|
| 18 |
+
embed_dim: int = 1280,
|
| 19 |
+
attention_heads: int = 20,
|
| 20 |
+
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
|
| 21 |
+
token_dropout: bool = True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
self.embed_dim = embed_dim
|
| 26 |
+
self.attention_heads = attention_heads
|
| 27 |
+
if not isinstance(alphabet, esm.data.Alphabet):
|
| 28 |
+
alphabet = esm.data.Alphabet.from_architecture(alphabet)
|
| 29 |
+
self.alphabet = alphabet
|
| 30 |
+
self.alphabet_size = len(alphabet)
|
| 31 |
+
self.padding_idx = alphabet.padding_idx
|
| 32 |
+
self.mask_idx = alphabet.mask_idx
|
| 33 |
+
self.cls_idx = alphabet.cls_idx
|
| 34 |
+
self.eos_idx = alphabet.eos_idx
|
| 35 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 36 |
+
self.append_eos = alphabet.append_eos
|
| 37 |
+
self.token_dropout = token_dropout
|
| 38 |
+
|
| 39 |
+
self._init_submodules()
|
| 40 |
+
|
| 41 |
+
def _init_submodules(self):
|
| 42 |
+
self.embed_scale = 1
|
| 43 |
+
self.embed_tokens = nn.Embedding(
|
| 44 |
+
self.alphabet_size,
|
| 45 |
+
self.embed_dim,
|
| 46 |
+
padding_idx=self.padding_idx,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.layers = nn.ModuleList(
|
| 50 |
+
[
|
| 51 |
+
TransformerLayer(
|
| 52 |
+
self.embed_dim,
|
| 53 |
+
4 * self.embed_dim,
|
| 54 |
+
self.attention_heads,
|
| 55 |
+
add_bias_kv=False,
|
| 56 |
+
use_esm1b_layer_norm=True,
|
| 57 |
+
use_rotary_embeddings=True,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(self.num_layers)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.contact_head = ContactPredictionHead(
|
| 64 |
+
self.num_layers * self.attention_heads,
|
| 65 |
+
self.prepend_bos,
|
| 66 |
+
self.append_eos,
|
| 67 |
+
eos_idx=self.eos_idx,
|
| 68 |
+
)
|
| 69 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
|
| 70 |
+
|
| 71 |
+
self.lm_head = RobertaLMHead(
|
| 72 |
+
embed_dim=self.embed_dim,
|
| 73 |
+
output_dim=self.alphabet_size,
|
| 74 |
+
weight=self.embed_tokens.weight,
|
| 75 |
+
)
|
| 76 |
+
# self.supervised_linear = nn.Linear(self.embed_dim, 1)
|
| 77 |
+
self.structure_linear = nn.Linear(self.embed_dim, 3)
|
| 78 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
|
| 79 |
+
if return_contacts:
|
| 80 |
+
need_head_weights = True
|
| 81 |
+
|
| 82 |
+
assert tokens.ndim == 2
|
| 83 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
| 84 |
+
|
| 85 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
| 86 |
+
|
| 87 |
+
if self.token_dropout:
|
| 88 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
| 89 |
+
#print(f'tokens = {tokens}')
|
| 90 |
+
#print(f'self.mask_idx = {self.mask_idx}')
|
| 91 |
+
#print('x.shape = ', x.shape)
|
| 92 |
+
# x: B x T x C
|
| 93 |
+
mask_ratio_train = 0.15 * 0.8
|
| 94 |
+
src_lengths = (~padding_mask).sum(-1)
|
| 95 |
+
#print(f'mask_ratio_train = {mask_ratio_train}')
|
| 96 |
+
#print(f'padding_mask = {padding_mask}')
|
| 97 |
+
#print(f'src_lengths = {src_lengths}')
|
| 98 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
| 99 |
+
#print('mask_ratio_observed = ',mask_ratio_observed)
|
| 100 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
| 101 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 102 |
+
if padding_mask is not None:
|
| 103 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 104 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 105 |
+
repr_layers = set(repr_layers)
|
| 106 |
+
hidden_representations = {}
|
| 107 |
+
if 0 in repr_layers:
|
| 108 |
+
hidden_representations[0] = x
|
| 109 |
+
|
| 110 |
+
if need_head_weights:
|
| 111 |
+
attn_weights = []
|
| 112 |
+
|
| 113 |
+
# (B, T, E) => (T, B, E)
|
| 114 |
+
x = x.transpose(0, 1)
|
| 115 |
+
|
| 116 |
+
if not padding_mask.any():
|
| 117 |
+
padding_mask = None
|
| 118 |
+
|
| 119 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 120 |
+
x, attn = layer(
|
| 121 |
+
x,
|
| 122 |
+
self_attn_padding_mask=padding_mask,
|
| 123 |
+
need_head_weights=need_head_weights,
|
| 124 |
+
)
|
| 125 |
+
if (layer_idx + 1) in repr_layers:
|
| 126 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
| 127 |
+
if need_head_weights:
|
| 128 |
+
# (H, B, T, T) => (B, H, T, T)
|
| 129 |
+
attn_weights.append(attn.transpose(1, 0))
|
| 130 |
+
# print(x.shape) # 73, 2, 1280
|
| 131 |
+
x = self.emb_layer_norm_after(x)
|
| 132 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 133 |
+
|
| 134 |
+
# last hidden representation should have layer norm applied
|
| 135 |
+
if (layer_idx + 1) in repr_layers:
|
| 136 |
+
hidden_representations[layer_idx + 1] = x
|
| 137 |
+
# x_supervised = self.supervised_linear(x[:,0,:])
|
| 138 |
+
x_structure = self.structure_linear(x)
|
| 139 |
+
x = self.lm_head(x)
|
| 140 |
+
|
| 141 |
+
if return_representation:
|
| 142 |
+
result = {"logits": x, "logits_structure": x_structure, "representations": hidden_representations}
|
| 143 |
+
else:
|
| 144 |
+
result = {"logits": x, "logits_structure": x_structure}
|
| 145 |
+
if need_head_weights:
|
| 146 |
+
# attentions: B x L x H x T x T
|
| 147 |
+
attentions = torch.stack(attn_weights, 1)
|
| 148 |
+
if padding_mask is not None:
|
| 149 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
| 150 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
| 151 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
| 152 |
+
if return_attentions: result["attentions"] = attentions
|
| 153 |
+
if return_contacts:
|
| 154 |
+
attentions_symm, contacts = self.contact_head(tokens, attentions)
|
| 155 |
+
result["contacts"] = contacts
|
| 156 |
+
if return_attentions_symm: result["attentions_symm"] = attentions_symm
|
| 157 |
+
|
| 158 |
+
return result
|
| 159 |
+
|
| 160 |
+
def predict_contacts(self, tokens):
|
| 161 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 162 |
+
|
| 163 |
+
def predict_symmetric_attentions(self, tokens):
|
| 164 |
+
return self(tokens, return_contacts=True)["attentions_symm"]
|
| 165 |
+
|
| 166 |
+
def predict_attentions(self, tokens):
|
| 167 |
+
return self(tokens, need_head_weights=True)["attentions"]
|
| 168 |
+
|
| 169 |
+
def predict_representations(self, tokens):
|
| 170 |
+
return self(tokens, return_representation=True)['representations']
|
| 171 |
+
|
| 172 |
+
def predict_logits(self, tokens):
|
| 173 |
+
return self(tokens)['logits']
|
| 174 |
+
|
| 175 |
+
# def predict_logits_supervised(self, tokens):
|
| 176 |
+
# return self(tokens)['logits_supervised']
|
| 177 |
+
|
| 178 |
+
def predict_logits_structure(self, tokens):
|
| 179 |
+
return self(tokens)['logits_structure']
|
esm/model/esm2_secondarystructure.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import esm
|
| 11 |
+
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
|
| 12 |
+
# ```该代码定义了一个名为 ESM2 的 PyTorch 模型,继承自 nn.Module。在 __init__ 方法中,定义了一些超参数,例如 num_layers、embed_dim、attention_heads 等等。同时,它还初始化了一些子模块,例如 Embedding 层 embed_tokens、一系列 Transformer 层 layers、预测接触的 ContactPredictionHead 层 contact_head,以及一些线性层 lm_head、supervised_linear、structure_linear 等。该模型的前向传播在 forward 方法中定义,接收一个表示序列的 token 序列 tokens,返回预测的标签和其他附加信息。```
|
| 13 |
+
|
| 14 |
+
class ESM2(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_layers: int = 33,
|
| 18 |
+
embed_dim: int = 1280,
|
| 19 |
+
attention_heads: int = 20,
|
| 20 |
+
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
|
| 21 |
+
token_dropout: bool = True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
self.embed_dim = embed_dim
|
| 26 |
+
self.attention_heads = attention_heads
|
| 27 |
+
if not isinstance(alphabet, esm.data.Alphabet):
|
| 28 |
+
alphabet = esm.data.Alphabet.from_architecture(alphabet)
|
| 29 |
+
self.alphabet = alphabet
|
| 30 |
+
self.alphabet_size = len(alphabet)
|
| 31 |
+
self.padding_idx = alphabet.padding_idx
|
| 32 |
+
self.mask_idx = alphabet.mask_idx
|
| 33 |
+
self.cls_idx = alphabet.cls_idx
|
| 34 |
+
self.eos_idx = alphabet.eos_idx
|
| 35 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 36 |
+
self.append_eos = alphabet.append_eos
|
| 37 |
+
self.token_dropout = token_dropout
|
| 38 |
+
|
| 39 |
+
self._init_submodules()
|
| 40 |
+
|
| 41 |
+
def _init_submodules(self):
|
| 42 |
+
self.embed_scale = 1
|
| 43 |
+
self.embed_tokens = nn.Embedding(
|
| 44 |
+
self.alphabet_size,
|
| 45 |
+
self.embed_dim,
|
| 46 |
+
padding_idx=self.padding_idx,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.layers = nn.ModuleList(
|
| 50 |
+
[
|
| 51 |
+
TransformerLayer(
|
| 52 |
+
self.embed_dim,
|
| 53 |
+
4 * self.embed_dim,
|
| 54 |
+
self.attention_heads,
|
| 55 |
+
add_bias_kv=False,
|
| 56 |
+
use_esm1b_layer_norm=True,
|
| 57 |
+
use_rotary_embeddings=True,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(self.num_layers)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.contact_head = ContactPredictionHead(
|
| 64 |
+
self.num_layers * self.attention_heads,
|
| 65 |
+
self.prepend_bos,
|
| 66 |
+
self.append_eos,
|
| 67 |
+
eos_idx=self.eos_idx,
|
| 68 |
+
)
|
| 69 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
|
| 70 |
+
|
| 71 |
+
self.lm_head = RobertaLMHead(
|
| 72 |
+
embed_dim=self.embed_dim,
|
| 73 |
+
output_dim=self.alphabet_size,
|
| 74 |
+
weight=self.embed_tokens.weight,
|
| 75 |
+
)
|
| 76 |
+
self.supervised_linear = nn.Linear(self.embed_dim, 1)
|
| 77 |
+
self.structure_linear = nn.Linear(self.embed_dim, 3)
|
| 78 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
|
| 79 |
+
if return_contacts:
|
| 80 |
+
need_head_weights = True
|
| 81 |
+
|
| 82 |
+
assert tokens.ndim == 2
|
| 83 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
| 84 |
+
|
| 85 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
| 86 |
+
|
| 87 |
+
if self.token_dropout:
|
| 88 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
| 89 |
+
#print(f'tokens = {tokens}')
|
| 90 |
+
#print(f'self.mask_idx = {self.mask_idx}')
|
| 91 |
+
#print('x.shape = ', x.shape)
|
| 92 |
+
# x: B x T x C
|
| 93 |
+
mask_ratio_train = 0.15 * 0.8
|
| 94 |
+
src_lengths = (~padding_mask).sum(-1)
|
| 95 |
+
#print(f'mask_ratio_train = {mask_ratio_train}')
|
| 96 |
+
#print(f'padding_mask = {padding_mask}')
|
| 97 |
+
#print(f'src_lengths = {src_lengths}')
|
| 98 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
| 99 |
+
#print('mask_ratio_observed = ',mask_ratio_observed)
|
| 100 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
| 101 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 102 |
+
if padding_mask is not None:
|
| 103 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 104 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 105 |
+
repr_layers = set(repr_layers)
|
| 106 |
+
hidden_representations = {}
|
| 107 |
+
if 0 in repr_layers:
|
| 108 |
+
hidden_representations[0] = x
|
| 109 |
+
|
| 110 |
+
if need_head_weights:
|
| 111 |
+
attn_weights = []
|
| 112 |
+
|
| 113 |
+
# (B, T, E) => (T, B, E)
|
| 114 |
+
x = x.transpose(0, 1)
|
| 115 |
+
|
| 116 |
+
if not padding_mask.any():
|
| 117 |
+
padding_mask = None
|
| 118 |
+
|
| 119 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 120 |
+
x, attn = layer(
|
| 121 |
+
x,
|
| 122 |
+
self_attn_padding_mask=padding_mask,
|
| 123 |
+
need_head_weights=need_head_weights,
|
| 124 |
+
)
|
| 125 |
+
if (layer_idx + 1) in repr_layers:
|
| 126 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
| 127 |
+
if need_head_weights:
|
| 128 |
+
# (H, B, T, T) => (B, H, T, T)
|
| 129 |
+
attn_weights.append(attn.transpose(1, 0))
|
| 130 |
+
# print(x.shape) # 73, 2, 1280
|
| 131 |
+
x = self.emb_layer_norm_after(x)
|
| 132 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 133 |
+
|
| 134 |
+
# last hidden representation should have layer norm applied
|
| 135 |
+
if (layer_idx + 1) in repr_layers:
|
| 136 |
+
hidden_representations[layer_idx + 1] = x
|
| 137 |
+
x_supervised = self.supervised_linear(x[:,0,:])
|
| 138 |
+
x_structure = self.structure_linear(x)
|
| 139 |
+
x = self.lm_head(x)
|
| 140 |
+
|
| 141 |
+
if return_representation:
|
| 142 |
+
result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure, "representations": hidden_representations}
|
| 143 |
+
else:
|
| 144 |
+
result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure}
|
| 145 |
+
if need_head_weights:
|
| 146 |
+
# attentions: B x L x H x T x T
|
| 147 |
+
attentions = torch.stack(attn_weights, 1)
|
| 148 |
+
if padding_mask is not None:
|
| 149 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
| 150 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
| 151 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
| 152 |
+
if return_attentions: result["attentions"] = attentions
|
| 153 |
+
if return_contacts:
|
| 154 |
+
attentions_symm, contacts = self.contact_head(tokens, attentions)
|
| 155 |
+
result["contacts"] = contacts
|
| 156 |
+
if return_attentions_symm: result["attentions_symm"] = attentions_symm
|
| 157 |
+
|
| 158 |
+
return result
|
| 159 |
+
|
| 160 |
+
def predict_contacts(self, tokens):
|
| 161 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 162 |
+
|
| 163 |
+
def predict_symmetric_attentions(self, tokens):
|
| 164 |
+
return self(tokens, return_contacts=True)["attentions_symm"]
|
| 165 |
+
|
| 166 |
+
def predict_attentions(self, tokens):
|
| 167 |
+
return self(tokens, need_head_weights=True)["attentions"]
|
| 168 |
+
|
| 169 |
+
def predict_representations(self, tokens):
|
| 170 |
+
return self(tokens, return_representation=True)['representations']
|
| 171 |
+
|
| 172 |
+
def predict_logits(self, tokens):
|
| 173 |
+
return self(tokens)['logits']
|
| 174 |
+
|
| 175 |
+
def predict_logits_supervised(self, tokens):
|
| 176 |
+
return self(tokens)['logits_supervised']
|
| 177 |
+
|
| 178 |
+
def predict_logits_structure(self, tokens):
|
| 179 |
+
return self(tokens)['logits_structure']
|
esm/model/esm2_supervised.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import esm
|
| 11 |
+
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ESM2(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_layers: int = 33,
|
| 18 |
+
embed_dim: int = 1280,
|
| 19 |
+
attention_heads: int = 20,
|
| 20 |
+
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
|
| 21 |
+
token_dropout: bool = True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
self.embed_dim = embed_dim
|
| 26 |
+
self.attention_heads = attention_heads
|
| 27 |
+
if not isinstance(alphabet, esm.data.Alphabet):
|
| 28 |
+
alphabet = esm.data.Alphabet.from_architecture(alphabet)
|
| 29 |
+
self.alphabet = alphabet
|
| 30 |
+
self.alphabet_size = len(alphabet)
|
| 31 |
+
self.padding_idx = alphabet.padding_idx
|
| 32 |
+
self.mask_idx = alphabet.mask_idx
|
| 33 |
+
self.cls_idx = alphabet.cls_idx
|
| 34 |
+
self.eos_idx = alphabet.eos_idx
|
| 35 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 36 |
+
self.append_eos = alphabet.append_eos
|
| 37 |
+
self.token_dropout = token_dropout
|
| 38 |
+
|
| 39 |
+
self._init_submodules()
|
| 40 |
+
|
| 41 |
+
def _init_submodules(self):
|
| 42 |
+
self.embed_scale = 1
|
| 43 |
+
self.embed_tokens = nn.Embedding(
|
| 44 |
+
self.alphabet_size,
|
| 45 |
+
self.embed_dim,
|
| 46 |
+
padding_idx=self.padding_idx,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.layers = nn.ModuleList(
|
| 50 |
+
[
|
| 51 |
+
TransformerLayer(
|
| 52 |
+
self.embed_dim,
|
| 53 |
+
4 * self.embed_dim,
|
| 54 |
+
self.attention_heads,
|
| 55 |
+
add_bias_kv=False,
|
| 56 |
+
use_esm1b_layer_norm=True,
|
| 57 |
+
use_rotary_embeddings=True,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(self.num_layers)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.contact_head = ContactPredictionHead(
|
| 64 |
+
self.num_layers * self.attention_heads,
|
| 65 |
+
self.prepend_bos,
|
| 66 |
+
self.append_eos,
|
| 67 |
+
eos_idx=self.eos_idx,
|
| 68 |
+
)
|
| 69 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
|
| 70 |
+
|
| 71 |
+
self.lm_head = RobertaLMHead(
|
| 72 |
+
embed_dim=self.embed_dim,
|
| 73 |
+
output_dim=self.alphabet_size,
|
| 74 |
+
weight=self.embed_tokens.weight,
|
| 75 |
+
)
|
| 76 |
+
self.supervised_linear = nn.Linear(self.embed_dim, 1)
|
| 77 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
|
| 78 |
+
if return_contacts:
|
| 79 |
+
need_head_weights = True
|
| 80 |
+
|
| 81 |
+
assert tokens.ndim == 2
|
| 82 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
| 83 |
+
|
| 84 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
| 85 |
+
|
| 86 |
+
if self.token_dropout:
|
| 87 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
| 88 |
+
#print(f'tokens = {tokens}')
|
| 89 |
+
#print(f'self.mask_idx = {self.mask_idx}')
|
| 90 |
+
#print('x.shape = ', x.shape)
|
| 91 |
+
# x: B x T x C
|
| 92 |
+
mask_ratio_train = 0.15 * 0.8
|
| 93 |
+
src_lengths = (~padding_mask).sum(-1)
|
| 94 |
+
#print(f'mask_ratio_train = {mask_ratio_train}')
|
| 95 |
+
#print(f'padding_mask = {padding_mask}')
|
| 96 |
+
#print(f'src_lengths = {src_lengths}')
|
| 97 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
| 98 |
+
#print('mask_ratio_observed = ',mask_ratio_observed)
|
| 99 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
| 100 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 101 |
+
if padding_mask is not None:
|
| 102 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 103 |
+
#print(f'x.shape = {x.shape}:\n', x)
|
| 104 |
+
repr_layers = set(repr_layers)
|
| 105 |
+
hidden_representations = {}
|
| 106 |
+
if 0 in repr_layers:
|
| 107 |
+
hidden_representations[0] = x
|
| 108 |
+
|
| 109 |
+
if need_head_weights:
|
| 110 |
+
attn_weights = []
|
| 111 |
+
|
| 112 |
+
# (B, T, E) => (T, B, E)
|
| 113 |
+
x = x.transpose(0, 1)
|
| 114 |
+
|
| 115 |
+
if not padding_mask.any():
|
| 116 |
+
padding_mask = None
|
| 117 |
+
|
| 118 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 119 |
+
x, attn = layer(
|
| 120 |
+
x,
|
| 121 |
+
self_attn_padding_mask=padding_mask,
|
| 122 |
+
need_head_weights=need_head_weights,
|
| 123 |
+
)
|
| 124 |
+
if (layer_idx + 1) in repr_layers:
|
| 125 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
| 126 |
+
if need_head_weights:
|
| 127 |
+
# (H, B, T, T) => (B, H, T, T)
|
| 128 |
+
attn_weights.append(attn.transpose(1, 0))
|
| 129 |
+
# print(x.shape) # 73, 2, 1280
|
| 130 |
+
x = self.emb_layer_norm_after(x)
|
| 131 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
| 132 |
+
|
| 133 |
+
# last hidden representation should have layer norm applied
|
| 134 |
+
if (layer_idx + 1) in repr_layers:
|
| 135 |
+
hidden_representations[layer_idx + 1] = x
|
| 136 |
+
x_supervised = self.supervised_linear(x[:,0,:])
|
| 137 |
+
x = self.lm_head(x)
|
| 138 |
+
|
| 139 |
+
if return_representation:
|
| 140 |
+
result = {"logits": x, "logits_supervised": x_supervised, "representations": hidden_representations}
|
| 141 |
+
else:
|
| 142 |
+
result = {"logits": x, "logits_supervised": x_supervised}
|
| 143 |
+
if need_head_weights:
|
| 144 |
+
# attentions: B x L x H x T x T
|
| 145 |
+
attentions = torch.stack(attn_weights, 1)
|
| 146 |
+
if padding_mask is not None:
|
| 147 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
| 148 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
| 149 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
| 150 |
+
if return_attentions: result["attentions"] = attentions
|
| 151 |
+
if return_contacts:
|
| 152 |
+
attentions_symm, contacts = self.contact_head(tokens, attentions)
|
| 153 |
+
result["contacts"] = contacts
|
| 154 |
+
if return_attentions_symm: result["attentions_symm"] = attentions_symm
|
| 155 |
+
|
| 156 |
+
return result
|
| 157 |
+
|
| 158 |
+
def predict_contacts(self, tokens):
|
| 159 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 160 |
+
|
| 161 |
+
def predict_symmetric_attentions(self, tokens):
|
| 162 |
+
return self(tokens, return_contacts=True)["attentions_symm"]
|
| 163 |
+
|
| 164 |
+
def predict_attentions(self, tokens):
|
| 165 |
+
return self(tokens, need_head_weights=True)["attentions"]
|
| 166 |
+
|
| 167 |
+
def predict_representations(self, tokens):
|
| 168 |
+
return self(tokens, return_representation=True)['representations']
|
| 169 |
+
|
| 170 |
+
def predict_logits(self, tokens):
|
| 171 |
+
return self(tokens)['logits']
|
| 172 |
+
|
| 173 |
+
def predict_logits_supervised(self, tokens):
|
| 174 |
+
return self(tokens)['logits_supervised']
|
esm/model/msa_transformer.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from ..modules import (
|
| 10 |
+
AxialTransformerLayer,
|
| 11 |
+
LearnedPositionalEmbedding,
|
| 12 |
+
RobertaLMHead,
|
| 13 |
+
ESM1bLayerNorm,
|
| 14 |
+
ContactPredictionHead,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from ..axial_attention import RowSelfAttention, ColumnSelfAttention
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MSATransformer(nn.Module):
|
| 22 |
+
@classmethod
|
| 23 |
+
def add_args(cls, parser):
|
| 24 |
+
# fmt: off
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--num_layers",
|
| 27 |
+
default=12,
|
| 28 |
+
type=int,
|
| 29 |
+
metavar="N",
|
| 30 |
+
help="number of layers"
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--embed_dim",
|
| 34 |
+
default=768,
|
| 35 |
+
type=int,
|
| 36 |
+
metavar="N",
|
| 37 |
+
help="embedding dimension"
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--logit_bias",
|
| 41 |
+
action="store_true",
|
| 42 |
+
help="whether to apply bias to logits"
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--ffn_embed_dim",
|
| 46 |
+
default=3072,
|
| 47 |
+
type=int,
|
| 48 |
+
metavar="N",
|
| 49 |
+
help="embedding dimension for FFN",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--attention_heads",
|
| 53 |
+
default=12,
|
| 54 |
+
type=int,
|
| 55 |
+
metavar="N",
|
| 56 |
+
help="number of attention heads",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--dropout",
|
| 60 |
+
default=0.1,
|
| 61 |
+
type=float,
|
| 62 |
+
help="Dropout to apply."
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--attention_dropout",
|
| 66 |
+
default=0.1,
|
| 67 |
+
type=float,
|
| 68 |
+
help="Dropout to apply."
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--activation_dropout",
|
| 72 |
+
default=0.1,
|
| 73 |
+
type=float,
|
| 74 |
+
help="Dropout to apply."
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--max_tokens_per_msa",
|
| 78 |
+
default=2 ** 14,
|
| 79 |
+
type=int,
|
| 80 |
+
help=(
|
| 81 |
+
"Used during inference to batch attention computations in a single "
|
| 82 |
+
"forward pass. This allows increased input sizes with less memory."
|
| 83 |
+
),
|
| 84 |
+
)
|
| 85 |
+
# fmt: on
|
| 86 |
+
|
| 87 |
+
def __init__(self, args, alphabet):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.args = args
|
| 90 |
+
self.alphabet_size = len(alphabet)
|
| 91 |
+
self.padding_idx = alphabet.padding_idx
|
| 92 |
+
self.mask_idx = alphabet.mask_idx
|
| 93 |
+
self.cls_idx = alphabet.cls_idx
|
| 94 |
+
self.eos_idx = alphabet.eos_idx
|
| 95 |
+
self.prepend_bos = alphabet.prepend_bos
|
| 96 |
+
self.append_eos = alphabet.append_eos
|
| 97 |
+
|
| 98 |
+
self.embed_tokens = nn.Embedding(
|
| 99 |
+
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if getattr(self.args, "embed_positions_msa", False):
|
| 103 |
+
emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
|
| 104 |
+
self.msa_position_embedding = nn.Parameter(
|
| 105 |
+
0.01 * torch.randn(1, 1024, 1, emb_dim),
|
| 106 |
+
requires_grad=True,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
self.register_parameter("msa_position_embedding", None)
|
| 110 |
+
|
| 111 |
+
self.dropout_module = nn.Dropout(self.args.dropout)
|
| 112 |
+
self.layers = nn.ModuleList(
|
| 113 |
+
[
|
| 114 |
+
AxialTransformerLayer(
|
| 115 |
+
self.args.embed_dim,
|
| 116 |
+
self.args.ffn_embed_dim,
|
| 117 |
+
self.args.attention_heads,
|
| 118 |
+
self.args.dropout,
|
| 119 |
+
self.args.attention_dropout,
|
| 120 |
+
self.args.activation_dropout,
|
| 121 |
+
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
|
| 122 |
+
)
|
| 123 |
+
for _ in range(self.args.layers)
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self.contact_head = ContactPredictionHead(
|
| 128 |
+
self.args.layers * self.args.attention_heads,
|
| 129 |
+
self.prepend_bos,
|
| 130 |
+
self.append_eos,
|
| 131 |
+
eos_idx=self.eos_idx,
|
| 132 |
+
)
|
| 133 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
| 134 |
+
self.args.max_positions,
|
| 135 |
+
self.args.embed_dim,
|
| 136 |
+
self.padding_idx,
|
| 137 |
+
)
|
| 138 |
+
self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
|
| 139 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
|
| 140 |
+
self.lm_head = RobertaLMHead(
|
| 141 |
+
embed_dim=self.args.embed_dim,
|
| 142 |
+
output_dim=self.alphabet_size,
|
| 143 |
+
weight=self.embed_tokens.weight,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
|
| 147 |
+
if return_contacts:
|
| 148 |
+
need_head_weights = True
|
| 149 |
+
|
| 150 |
+
assert tokens.ndim == 3
|
| 151 |
+
batch_size, num_alignments, seqlen = tokens.size()
|
| 152 |
+
padding_mask = tokens.eq(self.padding_idx) # B, R, C
|
| 153 |
+
if not padding_mask.any():
|
| 154 |
+
padding_mask = None
|
| 155 |
+
|
| 156 |
+
x = self.embed_tokens(tokens)
|
| 157 |
+
x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
|
| 158 |
+
if self.msa_position_embedding is not None:
|
| 159 |
+
if x.size(1) > 1024:
|
| 160 |
+
raise RuntimeError(
|
| 161 |
+
"Using model with MSA position embedding trained on maximum MSA "
|
| 162 |
+
f"depth of 1024, but received {x.size(1)} alignments."
|
| 163 |
+
)
|
| 164 |
+
x += self.msa_position_embedding[:, :num_alignments]
|
| 165 |
+
|
| 166 |
+
x = self.emb_layer_norm_before(x)
|
| 167 |
+
|
| 168 |
+
x = self.dropout_module(x)
|
| 169 |
+
|
| 170 |
+
if padding_mask is not None:
|
| 171 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
| 172 |
+
|
| 173 |
+
repr_layers = set(repr_layers)
|
| 174 |
+
hidden_representations = {}
|
| 175 |
+
if 0 in repr_layers:
|
| 176 |
+
hidden_representations[0] = x
|
| 177 |
+
|
| 178 |
+
if need_head_weights:
|
| 179 |
+
row_attn_weights = []
|
| 180 |
+
col_attn_weights = []
|
| 181 |
+
|
| 182 |
+
# B x R x C x D -> R x C x B x D
|
| 183 |
+
x = x.permute(1, 2, 0, 3)
|
| 184 |
+
|
| 185 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 186 |
+
x = layer(
|
| 187 |
+
x,
|
| 188 |
+
self_attn_padding_mask=padding_mask,
|
| 189 |
+
need_head_weights=need_head_weights,
|
| 190 |
+
)
|
| 191 |
+
if need_head_weights:
|
| 192 |
+
x, col_attn, row_attn = x
|
| 193 |
+
# H x C x B x R x R -> B x H x C x R x R
|
| 194 |
+
col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
|
| 195 |
+
# H x B x C x C -> B x H x C x C
|
| 196 |
+
row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
|
| 197 |
+
if (layer_idx + 1) in repr_layers:
|
| 198 |
+
hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
|
| 199 |
+
|
| 200 |
+
x = self.emb_layer_norm_after(x)
|
| 201 |
+
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
|
| 202 |
+
|
| 203 |
+
# last hidden representation should have layer norm applied
|
| 204 |
+
if (layer_idx + 1) in repr_layers:
|
| 205 |
+
hidden_representations[layer_idx + 1] = x
|
| 206 |
+
x = self.lm_head(x)
|
| 207 |
+
|
| 208 |
+
result = {"logits": x, "representations": hidden_representations}
|
| 209 |
+
if need_head_weights:
|
| 210 |
+
# col_attentions: B x L x H x C x R x R
|
| 211 |
+
col_attentions = torch.stack(col_attn_weights, 1)
|
| 212 |
+
# row_attentions: B x L x H x C x C
|
| 213 |
+
row_attentions = torch.stack(row_attn_weights, 1)
|
| 214 |
+
result["col_attentions"] = col_attentions
|
| 215 |
+
result["row_attentions"] = row_attentions
|
| 216 |
+
if return_contacts:
|
| 217 |
+
contacts = self.contact_head(tokens, row_attentions)
|
| 218 |
+
result["contacts"] = contacts
|
| 219 |
+
|
| 220 |
+
return result
|
| 221 |
+
|
| 222 |
+
def predict_contacts(self, tokens):
|
| 223 |
+
return self(tokens, return_contacts=True)["contacts"]
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def num_layers(self):
|
| 227 |
+
return self.args.layers
|
| 228 |
+
|
| 229 |
+
def max_tokens_per_msa_(self, value: int) -> None:
|
| 230 |
+
"""The MSA Transformer automatically batches attention computations when
|
| 231 |
+
gradients are disabled to allow you to pass in larger MSAs at test time than
|
| 232 |
+
you can fit in GPU memory. By default this occurs when more than 2^14 tokens
|
| 233 |
+
are passed in the input MSA. You can set this value to infinity to disable
|
| 234 |
+
this behavior.
|
| 235 |
+
"""
|
| 236 |
+
for module in self.modules():
|
| 237 |
+
if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
|
| 238 |
+
module.max_tokens_per_msa = value
|
esm/modules.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from .multihead_attention import MultiheadAttention # noqa
|
| 14 |
+
from .axial_attention import ColumnSelfAttention, RowSelfAttention
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gelu(x):
|
| 18 |
+
"""Implementation of the gelu activation function.
|
| 19 |
+
For information: OpenAI GPT's gelu is slightly different
|
| 20 |
+
(and gives slightly different results):
|
| 21 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 22 |
+
"""
|
| 23 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def symmetrize(x):
|
| 27 |
+
"Make layer symmetric in final two dimensions, used for contact prediction."
|
| 28 |
+
return x + x.transpose(-1, -2)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apc(x):
|
| 32 |
+
"Perform average product correct, used for contact prediction."
|
| 33 |
+
a1 = x.sum(-1, keepdims=True)
|
| 34 |
+
a2 = x.sum(-2, keepdims=True)
|
| 35 |
+
a12 = x.sum((-1, -2), keepdims=True)
|
| 36 |
+
|
| 37 |
+
avg = a1 * a2
|
| 38 |
+
avg.div_(a12) # in-place to reduce memory
|
| 39 |
+
normalized = x - avg
|
| 40 |
+
return normalized
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ESM1LayerNorm(nn.Module):
|
| 44 |
+
def __init__(self, hidden_size, eps=1e-12, affine=True):
|
| 45 |
+
"""Construct a layernorm layer in the TF style (eps inside the sqrt)."""
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
|
| 48 |
+
self.eps = eps
|
| 49 |
+
self.affine = bool(affine)
|
| 50 |
+
if self.affine:
|
| 51 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 52 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
| 53 |
+
else:
|
| 54 |
+
self.weight, self.bias = None, None
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
|
| 58 |
+
means = x.mean(dims, keepdim=True)
|
| 59 |
+
x_zeromean = x - means
|
| 60 |
+
variances = x_zeromean.pow(2).mean(dims, keepdim=True)
|
| 61 |
+
x = x_zeromean / torch.sqrt(variances + self.eps)
|
| 62 |
+
if self.affine:
|
| 63 |
+
x = (self.weight * x) + self.bias
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
| 69 |
+
|
| 70 |
+
class ESM1bLayerNorm(_FusedLayerNorm):
|
| 71 |
+
@torch.jit.unused
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
if not x.is_cuda:
|
| 74 |
+
return super().forward(x)
|
| 75 |
+
else:
|
| 76 |
+
with torch.cuda.device(x.device):
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
except ImportError:
|
| 80 |
+
from torch.nn import LayerNorm as ESM1bLayerNorm
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TransformerLayer(nn.Module):
|
| 84 |
+
"""Transformer layer block."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
embed_dim,
|
| 89 |
+
ffn_embed_dim,
|
| 90 |
+
attention_heads,
|
| 91 |
+
add_bias_kv=True,
|
| 92 |
+
use_esm1b_layer_norm=False,
|
| 93 |
+
use_rotary_embeddings: bool = False,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.embed_dim = embed_dim
|
| 97 |
+
self.ffn_embed_dim = ffn_embed_dim
|
| 98 |
+
self.attention_heads = attention_heads
|
| 99 |
+
self.use_rotary_embeddings = use_rotary_embeddings
|
| 100 |
+
self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
|
| 101 |
+
|
| 102 |
+
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
|
| 103 |
+
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
|
| 104 |
+
|
| 105 |
+
self.self_attn = MultiheadAttention(
|
| 106 |
+
self.embed_dim,
|
| 107 |
+
self.attention_heads,
|
| 108 |
+
add_bias_kv=add_bias_kv,
|
| 109 |
+
add_zero_attn=False,
|
| 110 |
+
use_rotary_embeddings=self.use_rotary_embeddings,
|
| 111 |
+
)
|
| 112 |
+
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
|
| 113 |
+
|
| 114 |
+
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
|
| 115 |
+
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
|
| 116 |
+
|
| 117 |
+
self.final_layer_norm = BertLayerNorm(self.embed_dim)
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
|
| 121 |
+
):
|
| 122 |
+
residual = x
|
| 123 |
+
x = self.self_attn_layer_norm(x)
|
| 124 |
+
x, attn = self.self_attn(
|
| 125 |
+
query=x,
|
| 126 |
+
key=x,
|
| 127 |
+
value=x,
|
| 128 |
+
key_padding_mask=self_attn_padding_mask,
|
| 129 |
+
need_weights=True,
|
| 130 |
+
need_head_weights=need_head_weights,
|
| 131 |
+
attn_mask=self_attn_mask,
|
| 132 |
+
)
|
| 133 |
+
x = residual + x
|
| 134 |
+
|
| 135 |
+
residual = x
|
| 136 |
+
x = self.final_layer_norm(x)
|
| 137 |
+
x = gelu(self.fc1(x))
|
| 138 |
+
x = self.fc2(x)
|
| 139 |
+
x = residual + x
|
| 140 |
+
#print(f'------{attn.half().dtype}-----')
|
| 141 |
+
|
| 142 |
+
return x, attn#.half() ###
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class AxialTransformerLayer(nn.Module):
|
| 146 |
+
"""Implements an Axial MSA Transformer block."""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
embedding_dim: int = 768,
|
| 151 |
+
ffn_embedding_dim: int = 3072,
|
| 152 |
+
num_attention_heads: int = 8,
|
| 153 |
+
dropout: float = 0.1,
|
| 154 |
+
attention_dropout: float = 0.1,
|
| 155 |
+
activation_dropout: float = 0.1,
|
| 156 |
+
max_tokens_per_msa: int = 2**14,
|
| 157 |
+
) -> None:
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
# Initialize parameters
|
| 161 |
+
self.embedding_dim = embedding_dim
|
| 162 |
+
self.dropout_prob = dropout
|
| 163 |
+
|
| 164 |
+
row_self_attention = RowSelfAttention(
|
| 165 |
+
embedding_dim,
|
| 166 |
+
num_attention_heads,
|
| 167 |
+
dropout=dropout,
|
| 168 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
column_self_attention = ColumnSelfAttention(
|
| 172 |
+
embedding_dim,
|
| 173 |
+
num_attention_heads,
|
| 174 |
+
dropout=dropout,
|
| 175 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
feed_forward_layer = FeedForwardNetwork(
|
| 179 |
+
embedding_dim,
|
| 180 |
+
ffn_embedding_dim,
|
| 181 |
+
activation_dropout=activation_dropout,
|
| 182 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.row_self_attention = self.build_residual(row_self_attention)
|
| 186 |
+
self.column_self_attention = self.build_residual(column_self_attention)
|
| 187 |
+
self.feed_forward_layer = self.build_residual(feed_forward_layer)
|
| 188 |
+
|
| 189 |
+
def build_residual(self, layer: nn.Module):
|
| 190 |
+
return NormalizedResidualBlock(
|
| 191 |
+
layer,
|
| 192 |
+
self.embedding_dim,
|
| 193 |
+
self.dropout_prob,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(
|
| 197 |
+
self,
|
| 198 |
+
x: torch.Tensor,
|
| 199 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
| 200 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
| 201 |
+
need_head_weights: bool = False,
|
| 202 |
+
):
|
| 203 |
+
"""
|
| 204 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
| 205 |
+
modules similar to the original Transformer implementation.
|
| 206 |
+
"""
|
| 207 |
+
x, row_attn = self.row_self_attention(
|
| 208 |
+
x,
|
| 209 |
+
self_attn_mask=self_attn_mask,
|
| 210 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
| 211 |
+
)
|
| 212 |
+
x, column_attn = self.column_self_attention(
|
| 213 |
+
x,
|
| 214 |
+
self_attn_mask=self_attn_mask,
|
| 215 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
| 216 |
+
)
|
| 217 |
+
x = self.feed_forward_layer(x)
|
| 218 |
+
if need_head_weights:
|
| 219 |
+
return x, column_attn, row_attn
|
| 220 |
+
else:
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class LearnedPositionalEmbedding(nn.Embedding):
|
| 225 |
+
"""
|
| 226 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 227 |
+
Padding ids are ignored by either offsetting based on padding_idx
|
| 228 |
+
or by setting padding_idx to None and ensuring that the appropriate
|
| 229 |
+
position ids are passed to the forward function.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
| 233 |
+
if padding_idx is not None:
|
| 234 |
+
num_embeddings_ = num_embeddings + padding_idx + 1
|
| 235 |
+
else:
|
| 236 |
+
num_embeddings_ = num_embeddings
|
| 237 |
+
super().__init__(num_embeddings_, embedding_dim, padding_idx)
|
| 238 |
+
self.max_positions = num_embeddings
|
| 239 |
+
|
| 240 |
+
def forward(self, input: torch.Tensor):
|
| 241 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
| 242 |
+
if input.size(1) > self.max_positions:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"Sequence length {input.size(1)} above maximum "
|
| 245 |
+
f" sequence length of {self.max_positions}"
|
| 246 |
+
)
|
| 247 |
+
mask = input.ne(self.padding_idx).int()
|
| 248 |
+
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
|
| 249 |
+
return F.embedding(
|
| 250 |
+
positions,
|
| 251 |
+
self.weight,
|
| 252 |
+
self.padding_idx,
|
| 253 |
+
self.max_norm,
|
| 254 |
+
self.norm_type,
|
| 255 |
+
self.scale_grad_by_freq,
|
| 256 |
+
self.sparse,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 261 |
+
def __init__(self, embed_dim, padding_idx, learned=False):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.embed_dim = embed_dim
|
| 264 |
+
self.padding_idx = padding_idx
|
| 265 |
+
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
| 266 |
+
self.weights = None
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
bsz, seq_len = x.shape
|
| 270 |
+
max_pos = self.padding_idx + 1 + seq_len
|
| 271 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
| 272 |
+
self.weights = self.get_embedding(max_pos)
|
| 273 |
+
self.weights = self.weights.type_as(self._float_tensor)
|
| 274 |
+
|
| 275 |
+
positions = self.make_positions(x)
|
| 276 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
| 277 |
+
|
| 278 |
+
def make_positions(self, x):
|
| 279 |
+
mask = x.ne(self.padding_idx)
|
| 280 |
+
range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
|
| 281 |
+
positions = range_buf.expand_as(x)
|
| 282 |
+
return positions * mask.long() + self.padding_idx * (1 - mask.long())
|
| 283 |
+
|
| 284 |
+
def get_embedding(self, num_embeddings):
|
| 285 |
+
half_dim = self.embed_dim // 2
|
| 286 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 287 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
| 288 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
| 289 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
| 290 |
+
if self.embed_dim % 2 == 1:
|
| 291 |
+
# zero pad
|
| 292 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
| 293 |
+
if self.padding_idx is not None:
|
| 294 |
+
emb[self.padding_idx, :] = 0
|
| 295 |
+
return emb
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class RobertaLMHead(nn.Module):
|
| 299 |
+
"""Head for masked language modeling."""
|
| 300 |
+
|
| 301 |
+
def __init__(self, embed_dim, output_dim, weight):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
| 304 |
+
self.layer_norm = ESM1bLayerNorm(embed_dim)
|
| 305 |
+
self.weight = weight
|
| 306 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
| 307 |
+
|
| 308 |
+
def forward(self, features):
|
| 309 |
+
x = self.dense(features)
|
| 310 |
+
x = gelu(x)
|
| 311 |
+
x = self.layer_norm(x)
|
| 312 |
+
# project back to size of vocabulary with bias
|
| 313 |
+
x = F.linear(x, self.weight) + self.bias
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class ContactPredictionHead(nn.Module):
|
| 318 |
+
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
in_features: int,
|
| 323 |
+
prepend_bos: bool,
|
| 324 |
+
append_eos: bool,
|
| 325 |
+
bias=True,
|
| 326 |
+
eos_idx: Optional[int] = None,
|
| 327 |
+
):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.in_features = in_features
|
| 330 |
+
self.prepend_bos = prepend_bos
|
| 331 |
+
self.append_eos = append_eos
|
| 332 |
+
if append_eos and eos_idx is None:
|
| 333 |
+
raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
|
| 334 |
+
self.eos_idx = eos_idx
|
| 335 |
+
self.regression = nn.Linear(in_features, 1, bias)
|
| 336 |
+
self.activation = nn.Sigmoid()
|
| 337 |
+
|
| 338 |
+
def forward(self, tokens, attentions):
|
| 339 |
+
# remove eos token attentions
|
| 340 |
+
if self.append_eos:
|
| 341 |
+
eos_mask = tokens.ne(self.eos_idx).to(attentions)
|
| 342 |
+
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
|
| 343 |
+
attentions = attentions * eos_mask[:, None, None, :, :]
|
| 344 |
+
attentions = attentions[..., :-1, :-1]
|
| 345 |
+
# remove cls token attentions
|
| 346 |
+
if self.prepend_bos:
|
| 347 |
+
attentions = attentions[..., 1:, 1:]
|
| 348 |
+
batch_size, layers, heads, seqlen, _ = attentions.size()
|
| 349 |
+
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
|
| 350 |
+
|
| 351 |
+
# features: B x C x T x T
|
| 352 |
+
attentions = attentions.to(
|
| 353 |
+
self.regression.weight.device
|
| 354 |
+
) # attentions always float32, may need to convert to float16
|
| 355 |
+
attentions = apc(symmetrize(attentions))
|
| 356 |
+
attentions = attentions.permute(0, 2, 3, 1)
|
| 357 |
+
#print(f'----------{attentions.dtype, attentions.float().dtype}----')
|
| 358 |
+
return attentions.sum(dim=-1), self.activation(self.regression(attentions).squeeze(3))#float().to(self.regression.weight.device)).squeeze(3))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class NormalizedResidualBlock(nn.Module):
|
| 362 |
+
def __init__(
|
| 363 |
+
self,
|
| 364 |
+
layer: nn.Module,
|
| 365 |
+
embedding_dim: int,
|
| 366 |
+
dropout: float = 0.1,
|
| 367 |
+
):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.embedding_dim = embedding_dim
|
| 370 |
+
|
| 371 |
+
self.layer = layer
|
| 372 |
+
self.dropout_module = nn.Dropout(
|
| 373 |
+
dropout,
|
| 374 |
+
)
|
| 375 |
+
self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
|
| 376 |
+
|
| 377 |
+
def forward(self, x, *args, **kwargs):
|
| 378 |
+
residual = x
|
| 379 |
+
x = self.layer_norm(x)
|
| 380 |
+
outputs = self.layer(x, *args, **kwargs)
|
| 381 |
+
if isinstance(outputs, tuple):
|
| 382 |
+
x, *out = outputs
|
| 383 |
+
else:
|
| 384 |
+
x = outputs
|
| 385 |
+
out = None
|
| 386 |
+
|
| 387 |
+
x = self.dropout_module(x)
|
| 388 |
+
x = residual + x
|
| 389 |
+
|
| 390 |
+
if out is not None:
|
| 391 |
+
return (x,) + tuple(out)
|
| 392 |
+
else:
|
| 393 |
+
return x
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class FeedForwardNetwork(nn.Module):
|
| 397 |
+
def __init__(
|
| 398 |
+
self,
|
| 399 |
+
embedding_dim: int,
|
| 400 |
+
ffn_embedding_dim: int,
|
| 401 |
+
activation_dropout: float = 0.1,
|
| 402 |
+
max_tokens_per_msa: int = 2**14,
|
| 403 |
+
):
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.embedding_dim = embedding_dim
|
| 406 |
+
self.ffn_embedding_dim = ffn_embedding_dim
|
| 407 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
| 408 |
+
self.activation_fn = nn.GELU()
|
| 409 |
+
self.activation_dropout_module = nn.Dropout(
|
| 410 |
+
activation_dropout,
|
| 411 |
+
)
|
| 412 |
+
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
|
| 413 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
|
| 414 |
+
|
| 415 |
+
def forward(self, x):
|
| 416 |
+
x = self.activation_fn(self.fc1(x))
|
| 417 |
+
x = self.activation_dropout_module(x)
|
| 418 |
+
x = self.fc2(x)
|
| 419 |
+
return x
|
esm/multihead_attention.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from torch.nn import Parameter
|
| 13 |
+
from esm.rotary_embedding import RotaryEmbedding
|
| 14 |
+
|
| 15 |
+
import uuid
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def utils_softmax(x, dim: int, onnx_trace: bool = False):
|
| 19 |
+
if onnx_trace:
|
| 20 |
+
return F.softmax(x.float(), dim=dim)
|
| 21 |
+
else:
|
| 22 |
+
return F.softmax(x, dim=dim, dtype=torch.float32)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FairseqIncrementalState(object):
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
self.init_incremental_state()
|
| 29 |
+
|
| 30 |
+
def init_incremental_state(self):
|
| 31 |
+
self._incremental_state_id = str(uuid.uuid4())
|
| 32 |
+
|
| 33 |
+
def _get_full_incremental_state_key(self, key: str) -> str:
|
| 34 |
+
return "{}.{}".format(self._incremental_state_id, key)
|
| 35 |
+
|
| 36 |
+
def get_incremental_state(
|
| 37 |
+
self,
|
| 38 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
| 39 |
+
key: str,
|
| 40 |
+
) -> Optional[Dict[str, Optional[Tensor]]]:
|
| 41 |
+
"""Helper for getting incremental state for an nn.Module."""
|
| 42 |
+
full_key = self._get_full_incremental_state_key(key)
|
| 43 |
+
if incremental_state is None or full_key not in incremental_state:
|
| 44 |
+
return None
|
| 45 |
+
return incremental_state[full_key]
|
| 46 |
+
|
| 47 |
+
def set_incremental_state(
|
| 48 |
+
self,
|
| 49 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
| 50 |
+
key: str,
|
| 51 |
+
value: Dict[str, Optional[Tensor]],
|
| 52 |
+
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
| 53 |
+
"""Helper for setting incremental state for an nn.Module."""
|
| 54 |
+
if incremental_state is not None:
|
| 55 |
+
full_key = self._get_full_incremental_state_key(key)
|
| 56 |
+
incremental_state[full_key] = value
|
| 57 |
+
return incremental_state
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def with_incremental_state(cls):
|
| 61 |
+
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
| 62 |
+
b for b in cls.__bases__ if b != FairseqIncrementalState
|
| 63 |
+
)
|
| 64 |
+
return cls
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@with_incremental_state
|
| 68 |
+
class MultiheadAttention(nn.Module):
|
| 69 |
+
"""Multi-headed attention.
|
| 70 |
+
See "Attention Is All You Need" for more details.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
embed_dim,
|
| 76 |
+
num_heads,
|
| 77 |
+
kdim=None,
|
| 78 |
+
vdim=None,
|
| 79 |
+
dropout=0.0,
|
| 80 |
+
bias=True,
|
| 81 |
+
add_bias_kv: bool = False,
|
| 82 |
+
add_zero_attn: bool = False,
|
| 83 |
+
self_attention: bool = False,
|
| 84 |
+
encoder_decoder_attention: bool = False,
|
| 85 |
+
use_rotary_embeddings: bool = False,
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.embed_dim = embed_dim
|
| 89 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 90 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 91 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 92 |
+
|
| 93 |
+
self.num_heads = num_heads
|
| 94 |
+
self.dropout = dropout
|
| 95 |
+
self.head_dim = embed_dim // num_heads
|
| 96 |
+
assert (
|
| 97 |
+
self.head_dim * num_heads == self.embed_dim
|
| 98 |
+
), "embed_dim must be divisible by num_heads"
|
| 99 |
+
self.scaling = self.head_dim**-0.5
|
| 100 |
+
|
| 101 |
+
self.self_attention = self_attention
|
| 102 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 103 |
+
|
| 104 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 105 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
| 109 |
+
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
| 110 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 111 |
+
|
| 112 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 113 |
+
|
| 114 |
+
if add_bias_kv:
|
| 115 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 116 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 117 |
+
else:
|
| 118 |
+
self.bias_k = self.bias_v = None
|
| 119 |
+
|
| 120 |
+
self.add_zero_attn = add_zero_attn
|
| 121 |
+
|
| 122 |
+
self.reset_parameters()
|
| 123 |
+
|
| 124 |
+
self.onnx_trace = False
|
| 125 |
+
self.rot_emb = None
|
| 126 |
+
if use_rotary_embeddings:
|
| 127 |
+
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
|
| 128 |
+
|
| 129 |
+
self.enable_torch_version = False
|
| 130 |
+
if hasattr(F, "multi_head_attention_forward"):
|
| 131 |
+
self.enable_torch_version = True
|
| 132 |
+
else:
|
| 133 |
+
self.enable_torch_version = False
|
| 134 |
+
|
| 135 |
+
def prepare_for_onnx_export_(self):
|
| 136 |
+
self.onnx_trace = True
|
| 137 |
+
|
| 138 |
+
def reset_parameters(self):
|
| 139 |
+
if self.qkv_same_dim:
|
| 140 |
+
# Empirically observed the convergence to be much better with
|
| 141 |
+
# the scaled initialization
|
| 142 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 143 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 144 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 145 |
+
else:
|
| 146 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 147 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 148 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 149 |
+
|
| 150 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 151 |
+
if self.out_proj.bias is not None:
|
| 152 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 153 |
+
if self.bias_k is not None:
|
| 154 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 155 |
+
if self.bias_v is not None:
|
| 156 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
query,
|
| 161 |
+
key: Optional[Tensor],
|
| 162 |
+
value: Optional[Tensor],
|
| 163 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 164 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 165 |
+
need_weights: bool = True,
|
| 166 |
+
static_kv: bool = False,
|
| 167 |
+
attn_mask: Optional[Tensor] = None,
|
| 168 |
+
before_softmax: bool = False,
|
| 169 |
+
need_head_weights: bool = False,
|
| 170 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 171 |
+
"""Input shape: Time x Batch x Channel
|
| 172 |
+
Args:
|
| 173 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 174 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 175 |
+
padding elements are indicated by 1s.
|
| 176 |
+
need_weights (bool, optional): return the attention weights,
|
| 177 |
+
averaged over heads (default: False).
|
| 178 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 179 |
+
implement causal attention, where the mask prevents the
|
| 180 |
+
attention from looking forward in time (default: None).
|
| 181 |
+
before_softmax (bool, optional): return the raw attention
|
| 182 |
+
weights and values before the attention softmax.
|
| 183 |
+
need_head_weights (bool, optional): return the attention
|
| 184 |
+
weights for each head. Implies *need_weights*. Default:
|
| 185 |
+
return the average attention weights over all heads.
|
| 186 |
+
"""
|
| 187 |
+
if need_head_weights:
|
| 188 |
+
need_weights = True
|
| 189 |
+
|
| 190 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 191 |
+
assert embed_dim == self.embed_dim
|
| 192 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 193 |
+
|
| 194 |
+
if (
|
| 195 |
+
not self.rot_emb
|
| 196 |
+
and self.enable_torch_version
|
| 197 |
+
and not self.onnx_trace
|
| 198 |
+
and incremental_state is None
|
| 199 |
+
and not static_kv
|
| 200 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
| 201 |
+
# treats bias in linear module as method.
|
| 202 |
+
and not torch.jit.is_scripting()
|
| 203 |
+
and not need_head_weights
|
| 204 |
+
):
|
| 205 |
+
assert key is not None and value is not None
|
| 206 |
+
return F.multi_head_attention_forward(
|
| 207 |
+
query,
|
| 208 |
+
key,
|
| 209 |
+
value,
|
| 210 |
+
self.embed_dim,
|
| 211 |
+
self.num_heads,
|
| 212 |
+
torch.empty([0]),
|
| 213 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
| 214 |
+
self.bias_k,
|
| 215 |
+
self.bias_v,
|
| 216 |
+
self.add_zero_attn,
|
| 217 |
+
self.dropout,
|
| 218 |
+
self.out_proj.weight,
|
| 219 |
+
self.out_proj.bias,
|
| 220 |
+
self.training,
|
| 221 |
+
key_padding_mask,
|
| 222 |
+
need_weights,
|
| 223 |
+
attn_mask,
|
| 224 |
+
use_separate_proj_weight=True,
|
| 225 |
+
q_proj_weight=self.q_proj.weight,
|
| 226 |
+
k_proj_weight=self.k_proj.weight,
|
| 227 |
+
v_proj_weight=self.v_proj.weight,
|
| 228 |
+
)
|
| 229 |
+
if incremental_state is not None:
|
| 230 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 231 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 232 |
+
# previous time steps are cached - no need to recompute
|
| 233 |
+
# key and value if they are static
|
| 234 |
+
if static_kv:
|
| 235 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 236 |
+
key = value = None
|
| 237 |
+
else:
|
| 238 |
+
saved_state = None
|
| 239 |
+
|
| 240 |
+
if self.self_attention:
|
| 241 |
+
q = self.q_proj(query)
|
| 242 |
+
k = self.k_proj(query)
|
| 243 |
+
v = self.v_proj(query)
|
| 244 |
+
elif self.encoder_decoder_attention:
|
| 245 |
+
# encoder-decoder attention
|
| 246 |
+
q = self.q_proj(query)
|
| 247 |
+
if key is None:
|
| 248 |
+
assert value is None
|
| 249 |
+
k = v = None
|
| 250 |
+
else:
|
| 251 |
+
k = self.k_proj(key)
|
| 252 |
+
v = self.v_proj(key)
|
| 253 |
+
|
| 254 |
+
else:
|
| 255 |
+
assert key is not None and value is not None
|
| 256 |
+
q = self.q_proj(query)
|
| 257 |
+
k = self.k_proj(key)
|
| 258 |
+
v = self.v_proj(value)
|
| 259 |
+
q *= self.scaling
|
| 260 |
+
|
| 261 |
+
if self.bias_k is not None:
|
| 262 |
+
assert self.bias_v is not None
|
| 263 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 264 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 265 |
+
if attn_mask is not None:
|
| 266 |
+
attn_mask = torch.cat(
|
| 267 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 268 |
+
)
|
| 269 |
+
if key_padding_mask is not None:
|
| 270 |
+
key_padding_mask = torch.cat(
|
| 271 |
+
[
|
| 272 |
+
key_padding_mask,
|
| 273 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 274 |
+
],
|
| 275 |
+
dim=1,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 279 |
+
if k is not None:
|
| 280 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 281 |
+
if v is not None:
|
| 282 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 283 |
+
|
| 284 |
+
if saved_state is not None:
|
| 285 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 286 |
+
if "prev_key" in saved_state:
|
| 287 |
+
_prev_key = saved_state["prev_key"]
|
| 288 |
+
assert _prev_key is not None
|
| 289 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 290 |
+
if static_kv:
|
| 291 |
+
k = prev_key
|
| 292 |
+
else:
|
| 293 |
+
assert k is not None
|
| 294 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 295 |
+
if "prev_value" in saved_state:
|
| 296 |
+
_prev_value = saved_state["prev_value"]
|
| 297 |
+
assert _prev_value is not None
|
| 298 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 299 |
+
if static_kv:
|
| 300 |
+
v = prev_value
|
| 301 |
+
else:
|
| 302 |
+
assert v is not None
|
| 303 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 304 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 305 |
+
if "prev_key_padding_mask" in saved_state:
|
| 306 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 307 |
+
assert k is not None and v is not None
|
| 308 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 309 |
+
key_padding_mask=key_padding_mask,
|
| 310 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 311 |
+
batch_size=bsz,
|
| 312 |
+
src_len=k.size(1),
|
| 313 |
+
static_kv=static_kv,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 317 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 318 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 319 |
+
# In this branch incremental_state is never None
|
| 320 |
+
assert incremental_state is not None
|
| 321 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 322 |
+
assert k is not None
|
| 323 |
+
src_len = k.size(1)
|
| 324 |
+
|
| 325 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 326 |
+
# not supporting Optional types.
|
| 327 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 328 |
+
key_padding_mask = None
|
| 329 |
+
|
| 330 |
+
if key_padding_mask is not None:
|
| 331 |
+
assert key_padding_mask.size(0) == bsz
|
| 332 |
+
assert key_padding_mask.size(1) == src_len
|
| 333 |
+
|
| 334 |
+
if self.add_zero_attn:
|
| 335 |
+
assert v is not None
|
| 336 |
+
src_len += 1
|
| 337 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 338 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 339 |
+
if attn_mask is not None:
|
| 340 |
+
attn_mask = torch.cat(
|
| 341 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 342 |
+
)
|
| 343 |
+
if key_padding_mask is not None:
|
| 344 |
+
key_padding_mask = torch.cat(
|
| 345 |
+
[
|
| 346 |
+
key_padding_mask,
|
| 347 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
| 348 |
+
],
|
| 349 |
+
dim=1,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if self.rot_emb:
|
| 353 |
+
q, k = self.rot_emb(q, k)
|
| 354 |
+
|
| 355 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 356 |
+
attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 357 |
+
|
| 358 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 359 |
+
|
| 360 |
+
if attn_mask is not None:
|
| 361 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 362 |
+
if self.onnx_trace:
|
| 363 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 364 |
+
attn_weights += attn_mask
|
| 365 |
+
|
| 366 |
+
if key_padding_mask is not None:
|
| 367 |
+
# don't attend to padding symbols
|
| 368 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 369 |
+
attn_weights = attn_weights.masked_fill(
|
| 370 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
|
| 371 |
+
)
|
| 372 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 373 |
+
|
| 374 |
+
if before_softmax:
|
| 375 |
+
return attn_weights, v
|
| 376 |
+
|
| 377 |
+
attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
| 378 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 379 |
+
attn_probs = F.dropout(
|
| 380 |
+
attn_weights_float.type_as(attn_weights),
|
| 381 |
+
p=self.dropout,
|
| 382 |
+
training=self.training,
|
| 383 |
+
)
|
| 384 |
+
assert v is not None
|
| 385 |
+
attn = torch.bmm(attn_probs, v)
|
| 386 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 387 |
+
if self.onnx_trace and attn.size(1) == 1:
|
| 388 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
| 389 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
| 390 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
| 391 |
+
else:
|
| 392 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 393 |
+
attn = self.out_proj(attn)
|
| 394 |
+
attn_weights: Optional[Tensor] = None
|
| 395 |
+
if need_weights:
|
| 396 |
+
attn_weights = attn_weights_float.view(
|
| 397 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 398 |
+
).type_as(attn).transpose(1, 0)
|
| 399 |
+
if not need_head_weights:
|
| 400 |
+
# average attention weights over heads
|
| 401 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 402 |
+
|
| 403 |
+
return attn, attn_weights
|
| 404 |
+
|
| 405 |
+
@staticmethod
|
| 406 |
+
def _append_prev_key_padding_mask(
|
| 407 |
+
key_padding_mask: Optional[Tensor],
|
| 408 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 409 |
+
batch_size: int,
|
| 410 |
+
src_len: int,
|
| 411 |
+
static_kv: bool,
|
| 412 |
+
) -> Optional[Tensor]:
|
| 413 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 414 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 415 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 416 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 417 |
+
new_key_padding_mask = torch.cat(
|
| 418 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 419 |
+
)
|
| 420 |
+
# During incremental decoding, as the padding token enters and
|
| 421 |
+
# leaves the frame, there will be a time when prev or current
|
| 422 |
+
# is None
|
| 423 |
+
elif prev_key_padding_mask is not None:
|
| 424 |
+
filler = torch.zeros(
|
| 425 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 426 |
+
device=prev_key_padding_mask.device,
|
| 427 |
+
)
|
| 428 |
+
new_key_padding_mask = torch.cat(
|
| 429 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 430 |
+
)
|
| 431 |
+
elif key_padding_mask is not None:
|
| 432 |
+
filler = torch.zeros(
|
| 433 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 434 |
+
device=key_padding_mask.device,
|
| 435 |
+
)
|
| 436 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
| 437 |
+
else:
|
| 438 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 439 |
+
return new_key_padding_mask
|
| 440 |
+
|
| 441 |
+
@torch.jit.export
|
| 442 |
+
def reorder_incremental_state(
|
| 443 |
+
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
| 444 |
+
):
|
| 445 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
| 446 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 447 |
+
if input_buffer is not None:
|
| 448 |
+
for k in input_buffer.keys():
|
| 449 |
+
input_buffer_k = input_buffer[k]
|
| 450 |
+
if input_buffer_k is not None:
|
| 451 |
+
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
|
| 452 |
+
0
|
| 453 |
+
):
|
| 454 |
+
break
|
| 455 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 456 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 457 |
+
return incremental_state
|
| 458 |
+
|
| 459 |
+
def _get_input_buffer(
|
| 460 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 461 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 462 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 463 |
+
if result is not None:
|
| 464 |
+
return result
|
| 465 |
+
else:
|
| 466 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 467 |
+
return empty_result
|
| 468 |
+
|
| 469 |
+
def _set_input_buffer(
|
| 470 |
+
self,
|
| 471 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 472 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 473 |
+
):
|
| 474 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 475 |
+
|
| 476 |
+
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 477 |
+
return attn_weights
|
| 478 |
+
|
| 479 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 480 |
+
prefix = name + "." if name != "" else ""
|
| 481 |
+
items_to_add = {}
|
| 482 |
+
keys_to_remove = []
|
| 483 |
+
for k in state_dict.keys():
|
| 484 |
+
if k.endswith(prefix + "in_proj_weight"):
|
| 485 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
| 486 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 487 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 488 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 489 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 490 |
+
|
| 491 |
+
keys_to_remove.append(k)
|
| 492 |
+
|
| 493 |
+
k_bias = prefix + "in_proj_bias"
|
| 494 |
+
if k_bias in state_dict.keys():
|
| 495 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 496 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 497 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
|
| 498 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 499 |
+
|
| 500 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
| 501 |
+
|
| 502 |
+
for k in keys_to_remove:
|
| 503 |
+
del state_dict[k]
|
| 504 |
+
|
| 505 |
+
for key, value in items_to_add.items():
|
| 506 |
+
state_dict[key] = value
|
esm/pretrained.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import urllib
|
| 8 |
+
import warnings
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
import esm
|
| 15 |
+
from esm.model.esm2 import ESM2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _has_regression_weights(model_name):
|
| 19 |
+
"""Return whether we expect / require regression weights;
|
| 20 |
+
Right now that is all models except ESM-1v and ESM-IF"""
|
| 21 |
+
return not ("esm1v" in model_name or "esm_if" in model_name)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_model_and_alphabet(model_name):
|
| 25 |
+
if model_name.endswith(".pt"): # treat as filepath
|
| 26 |
+
return load_model_and_alphabet_local(model_name)
|
| 27 |
+
else:
|
| 28 |
+
return load_model_and_alphabet_hub(model_name)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_hub_workaround(url):
|
| 32 |
+
try:
|
| 33 |
+
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
|
| 34 |
+
except RuntimeError:
|
| 35 |
+
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
|
| 36 |
+
fn = Path(url).name
|
| 37 |
+
data = torch.load(
|
| 38 |
+
f"{torch.hub.get_dir()}/checkpoints/{fn}",
|
| 39 |
+
map_location="cpu",
|
| 40 |
+
)
|
| 41 |
+
except urllib.error.HTTPError as e:
|
| 42 |
+
raise Exception(f"Could not load {url}, check if you specified a correct model name?")
|
| 43 |
+
return data
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_regression_hub(model_name):
|
| 47 |
+
url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"
|
| 48 |
+
regression_data = load_hub_workaround(url)
|
| 49 |
+
return regression_data
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _download_model_and_regression_data(model_name):
|
| 53 |
+
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
|
| 54 |
+
model_data = load_hub_workaround(url)
|
| 55 |
+
if _has_regression_weights(model_name):
|
| 56 |
+
regression_data = load_regression_hub(model_name)
|
| 57 |
+
else:
|
| 58 |
+
regression_data = None
|
| 59 |
+
return model_data, regression_data
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_model_and_alphabet_hub(model_name):
|
| 63 |
+
model_data, regression_data = _download_model_and_regression_data(model_name)
|
| 64 |
+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_model_and_alphabet_local(model_location):
|
| 68 |
+
"""Load from local path. The regression weights need to be co-located"""
|
| 69 |
+
model_location = Path(model_location)
|
| 70 |
+
model_data = torch.load(str(model_location), map_location="cpu")
|
| 71 |
+
model_name = model_location.stem
|
| 72 |
+
if _has_regression_weights(model_name):
|
| 73 |
+
regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt"
|
| 74 |
+
regression_data = torch.load(regression_location, map_location="cpu")
|
| 75 |
+
else:
|
| 76 |
+
regression_data = None
|
| 77 |
+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def has_emb_layer_norm_before(model_state):
|
| 81 |
+
"""Determine whether layer norm needs to be applied before the encoder"""
|
| 82 |
+
return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items())
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _load_model_and_alphabet_core_v1(model_data):
|
| 86 |
+
import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here
|
| 87 |
+
|
| 88 |
+
alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
|
| 89 |
+
|
| 90 |
+
if model_data["args"].arch == "roberta_large":
|
| 91 |
+
# upgrade state dict
|
| 92 |
+
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
|
| 93 |
+
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
|
| 94 |
+
prs2 = lambda s: "".join(
|
| 95 |
+
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
|
| 96 |
+
)
|
| 97 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
| 98 |
+
model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()}
|
| 99 |
+
model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop
|
| 100 |
+
model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state)
|
| 101 |
+
model_type = esm.ProteinBertModel
|
| 102 |
+
|
| 103 |
+
elif model_data["args"].arch == "protein_bert_base":
|
| 104 |
+
|
| 105 |
+
# upgrade state dict
|
| 106 |
+
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
|
| 107 |
+
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
|
| 108 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
| 109 |
+
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}
|
| 110 |
+
model_type = esm.ProteinBertModel
|
| 111 |
+
elif model_data["args"].arch == "msa_transformer":
|
| 112 |
+
|
| 113 |
+
# upgrade state dict
|
| 114 |
+
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
|
| 115 |
+
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
|
| 116 |
+
prs2 = lambda s: "".join(
|
| 117 |
+
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
|
| 118 |
+
)
|
| 119 |
+
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
|
| 120 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
| 121 |
+
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
|
| 122 |
+
if model_args.get("embed_positions_msa", False):
|
| 123 |
+
emb_dim = model_state["msa_position_embedding"].size(-1)
|
| 124 |
+
model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
|
| 125 |
+
|
| 126 |
+
model_type = esm.MSATransformer
|
| 127 |
+
|
| 128 |
+
elif "invariant_gvp" in model_data["args"].arch:
|
| 129 |
+
import esm.inverse_folding
|
| 130 |
+
|
| 131 |
+
model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel
|
| 132 |
+
model_args = vars(model_data["args"]) # convert Namespace -> dict
|
| 133 |
+
|
| 134 |
+
def update_name(s):
|
| 135 |
+
# Map the module names in checkpoints trained with internal code to
|
| 136 |
+
# the updated module names in open source code
|
| 137 |
+
s = s.replace("W_v", "embed_graph.embed_node")
|
| 138 |
+
s = s.replace("W_e", "embed_graph.embed_edge")
|
| 139 |
+
s = s.replace("embed_scores.0", "embed_confidence")
|
| 140 |
+
s = s.replace("embed_score.", "embed_graph.embed_confidence.")
|
| 141 |
+
s = s.replace("seq_logits_projection.", "")
|
| 142 |
+
s = s.replace("embed_ingraham_features", "embed_dihedrals")
|
| 143 |
+
s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output")
|
| 144 |
+
s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features")
|
| 145 |
+
return s
|
| 146 |
+
|
| 147 |
+
model_state = {
|
| 148 |
+
update_name(sname): svalue
|
| 149 |
+
for sname, svalue in model_data["model"].items()
|
| 150 |
+
if "version" not in sname
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError("Unknown architecture selected")
|
| 155 |
+
|
| 156 |
+
model = model_type(
|
| 157 |
+
Namespace(**model_args),
|
| 158 |
+
alphabet,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return model, alphabet, model_state
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _load_model_and_alphabet_core_v2(model_data):
|
| 165 |
+
def upgrade_state_dict(state_dict):
|
| 166 |
+
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
|
| 167 |
+
prefixes = ["encoder.sentence_encoder.", "encoder."]
|
| 168 |
+
pattern = re.compile("^" + "|".join(prefixes))
|
| 169 |
+
state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
|
| 170 |
+
return state_dict
|
| 171 |
+
|
| 172 |
+
cfg = model_data["cfg"]["model"]
|
| 173 |
+
state_dict = model_data["model"]
|
| 174 |
+
state_dict = upgrade_state_dict(state_dict)
|
| 175 |
+
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
|
| 176 |
+
model = ESM2(
|
| 177 |
+
num_layers=cfg.encoder_layers,
|
| 178 |
+
embed_dim=cfg.encoder_embed_dim,
|
| 179 |
+
attention_heads=cfg.encoder_attention_heads,
|
| 180 |
+
alphabet=alphabet,
|
| 181 |
+
token_dropout=cfg.token_dropout,
|
| 182 |
+
)
|
| 183 |
+
return model, alphabet, state_dict
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def load_model_and_alphabet_core(model_name, model_data, regression_data=None):
|
| 187 |
+
if regression_data is not None:
|
| 188 |
+
model_data["model"].update(regression_data["model"])
|
| 189 |
+
|
| 190 |
+
if model_name.startswith("esm2"):
|
| 191 |
+
model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data)
|
| 192 |
+
else:
|
| 193 |
+
model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data)
|
| 194 |
+
|
| 195 |
+
expected_keys = set(model.state_dict().keys())
|
| 196 |
+
found_keys = set(model_state.keys())
|
| 197 |
+
|
| 198 |
+
if regression_data is None:
|
| 199 |
+
expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"}
|
| 200 |
+
error_msgs = []
|
| 201 |
+
missing = (expected_keys - found_keys) - expected_missing
|
| 202 |
+
if missing:
|
| 203 |
+
error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
|
| 204 |
+
unexpected = found_keys - expected_keys
|
| 205 |
+
if unexpected:
|
| 206 |
+
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.")
|
| 207 |
+
|
| 208 |
+
if error_msgs:
|
| 209 |
+
raise RuntimeError(
|
| 210 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
| 211 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
| 212 |
+
)
|
| 213 |
+
)
|
| 214 |
+
if expected_missing - found_keys:
|
| 215 |
+
warnings.warn(
|
| 216 |
+
"Regression weights not found, predicting contacts will not produce correct results."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
model.load_state_dict(model_state, strict=regression_data is not None)
|
| 220 |
+
|
| 221 |
+
return model, alphabet
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def esm1_t34_670M_UR50S():
|
| 225 |
+
"""34 layer transformer model with 670M params, trained on Uniref50 Sparse.
|
| 226 |
+
Returns a tuple of (Model, Alphabet).
|
| 227 |
+
"""
|
| 228 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR50S")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def esm1_t34_670M_UR50D():
|
| 232 |
+
"""34 layer transformer model with 670M params, trained on Uniref50 Dense.
|
| 233 |
+
Returns a tuple of (Model, Alphabet).
|
| 234 |
+
"""
|
| 235 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR50D")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def esm1_t34_670M_UR100():
|
| 239 |
+
"""34 layer transformer model with 670M params, trained on Uniref100.
|
| 240 |
+
Returns a tuple of (Model, Alphabet).
|
| 241 |
+
"""
|
| 242 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR100")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def esm1_t12_85M_UR50S():
|
| 246 |
+
"""12 layer transformer model with 85M params, trained on Uniref50 Sparse.
|
| 247 |
+
Returns a tuple of (Model, Alphabet).
|
| 248 |
+
"""
|
| 249 |
+
return load_model_and_alphabet_hub("esm1_t12_85M_UR50S")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def esm1_t6_43M_UR50S():
|
| 253 |
+
"""6 layer transformer model with 43M params, trained on Uniref50 Sparse.
|
| 254 |
+
Returns a tuple of (Model, Alphabet).
|
| 255 |
+
"""
|
| 256 |
+
return load_model_and_alphabet_hub("esm1_t6_43M_UR50S")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def esm1b_t33_650M_UR50S():
|
| 260 |
+
"""33 layer transformer model with 650M params, trained on Uniref50 Sparse.
|
| 261 |
+
This is our best performing model, which will be described in a future publication.
|
| 262 |
+
Returns a tuple of (Model, Alphabet).
|
| 263 |
+
"""
|
| 264 |
+
return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def esm_msa1_t12_100M_UR50S():
|
| 268 |
+
warnings.warn(
|
| 269 |
+
"This model had a minor bug in the positional embeddings, "
|
| 270 |
+
"please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()",
|
| 271 |
+
)
|
| 272 |
+
return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def esm_msa1b_t12_100M_UR50S():
|
| 276 |
+
return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def esm1v_t33_650M_UR90S():
|
| 280 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 281 |
+
This is model 1 of a 5 model ensemble.
|
| 282 |
+
Returns a tuple of (Model, Alphabet).
|
| 283 |
+
"""
|
| 284 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def esm1v_t33_650M_UR90S_1():
|
| 288 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 289 |
+
This is model 1 of a 5 model ensemble.
|
| 290 |
+
Returns a tuple of (Model, Alphabet).
|
| 291 |
+
"""
|
| 292 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def esm1v_t33_650M_UR90S_2():
|
| 296 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 297 |
+
This is model 2 of a 5 model ensemble.
|
| 298 |
+
Returns a tuple of (Model, Alphabet).
|
| 299 |
+
"""
|
| 300 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def esm1v_t33_650M_UR90S_3():
|
| 304 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 305 |
+
This is model 3 of a 5 model ensemble.
|
| 306 |
+
Returns a tuple of (Model, Alphabet).
|
| 307 |
+
"""
|
| 308 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def esm1v_t33_650M_UR90S_4():
|
| 312 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 313 |
+
This is model 4 of a 5 model ensemble.
|
| 314 |
+
Returns a tuple of (Model, Alphabet).
|
| 315 |
+
"""
|
| 316 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def esm1v_t33_650M_UR90S_5():
|
| 320 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
| 321 |
+
This is model 5 of a 5 model ensemble.
|
| 322 |
+
Returns a tuple of (Model, Alphabet).
|
| 323 |
+
"""
|
| 324 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def esm_if1_gvp4_t16_142M_UR50():
|
| 328 |
+
"""Inverse folding model with 142M params, with 4 GVP-GNN layers, 8
|
| 329 |
+
Transformer encoder layers, and 8 Transformer decoder layers, trained on
|
| 330 |
+
CATH structures and 12 million alphafold2 predicted structures from UniRef50
|
| 331 |
+
sequences.
|
| 332 |
+
Returns a tuple of (Model, Alphabet).
|
| 333 |
+
"""
|
| 334 |
+
return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def esm2_t6_8M_UR50D():
|
| 338 |
+
"""6 layer ESM-2 model with 8M params, trained on UniRef50.
|
| 339 |
+
Returns a tuple of (Model, Alphabet).
|
| 340 |
+
"""
|
| 341 |
+
return load_model_and_alphabet_hub("esm2_t6_8M_UR50D")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def esm2_t12_35M_UR50D():
|
| 345 |
+
"""12 layer ESM-2 model with 35M params, trained on UniRef50.
|
| 346 |
+
Returns a tuple of (Model, Alphabet).
|
| 347 |
+
"""
|
| 348 |
+
return load_model_and_alphabet_hub("esm2_t12_35M_UR50D")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def esm2_t30_150M_UR50D():
|
| 352 |
+
"""30 layer ESM-2 model with 150M params, trained on UniRef50.
|
| 353 |
+
Returns a tuple of (Model, Alphabet).
|
| 354 |
+
"""
|
| 355 |
+
return load_model_and_alphabet_hub("esm2_t30_150M_UR50D")
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def esm2_t33_650M_UR50D():
|
| 359 |
+
"""33 layer ESM-2 model with 650M params, trained on UniRef50.
|
| 360 |
+
Returns a tuple of (Model, Alphabet).
|
| 361 |
+
"""
|
| 362 |
+
return load_model_and_alphabet_hub("esm2_t33_650M_UR50D")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def esm2_t36_3B_UR50D():
|
| 366 |
+
"""36 layer ESM-2 model with 3B params, trained on UniRef50.
|
| 367 |
+
Returns a tuple of (Model, Alphabet).
|
| 368 |
+
"""
|
| 369 |
+
return load_model_and_alphabet_hub("esm2_t36_3B_UR50D")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def esm2_t48_15B_UR50D():
|
| 373 |
+
"""48 layer ESM-2 model with 15B params, trained on UniRef50.
|
| 374 |
+
If you have OOM while loading this model, please refer to README
|
| 375 |
+
on how to employ FSDP and ZeRO CPU offloading
|
| 376 |
+
Returns a tuple of (Model, Alphabet).
|
| 377 |
+
"""
|
| 378 |
+
return load_model_and_alphabet_hub("esm2_t48_15B_UR50D")
|