Spaces:
Starting
on
T4
Starting
on
T4
File size: 1,284 Bytes
f8402f9 87c0dbc f8402f9 87c0dbc f8402f9 87c0dbc f8402f9 87c0dbc f8402f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from io import StringIO
from urllib import request
import torch
from Bio.PDB import PDBParser, Structure
from transformers import T5EncoderModel, T5Tokenizer
def get_structure(pdb_code: str) -> Structure:
"""
Get structure from PDB
"""
pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb"
pdb_data = request.urlopen(pdb_url).read().decode("utf-8")
file = StringIO(pdb_data)
parser = PDBParser()
structure = parser.get_structure(pdb_code, file)
return structure
def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
device
)
model.full() if device == "cpu" else model.half()
return tokenizer, model
def get_attention(
pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
):
"""
Get attention from T5
"""
# fetch structure
structure = get_structure(pdb_code)
# get model
tokenizer, model = get_protT5()
# call model
# get attention
# extract attention
|