aksell commited on
Commit
ff8eaa8
·
1 Parent(s): cfba77f

Cache get_attention and get_structure

Browse files
Files changed (1) hide show
  1. hexviz/attention.py +4 -3
hexviz/attention.py CHANGED
@@ -12,9 +12,9 @@ from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
12
 
13
 
14
  class ModelType(str, Enum):
15
- TAPE_BERT = "bert-base"
16
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
- ZymCTRL = "zymctrl"
18
 
19
 
20
  class Model:
@@ -23,7 +23,7 @@ class Model:
23
  self.layers: int = layers
24
  self.heads: int = heads
25
 
26
-
27
  def get_structure(pdb_code: str) -> Structure:
28
  """
29
  Get structure from PDB
@@ -77,6 +77,7 @@ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
77
  model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
78
  return tokenizer, model
79
 
 
80
  def get_attention(
81
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
82
  ):
 
12
 
13
 
14
  class ModelType(str, Enum):
15
+ TAPE_BERT = "TAPE-BERT"
16
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
+ ZymCTRL = "ZymCTRL"
18
 
19
 
20
  class Model:
 
23
  self.layers: int = layers
24
  self.heads: int = heads
25
 
26
+ @st.cache
27
  def get_structure(pdb_code: str) -> Structure:
28
  """
29
  Get structure from PDB
 
77
  model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
78
  return tokenizer, model
79
 
80
+ @st.cache
81
  def get_attention(
82
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
83
  ):