MoraxCheng commited on
Commit
a0e970b
·
1 Parent(s): 86d5c5f

Implement model caching and preload functionality to optimize loading in Zero GPU spaces

Browse files
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -4,6 +4,11 @@ Tranception Design App - Hugging Face Spaces Version (Zero GPU Fixed)
4
  """
5
  import os
6
  import sys
 
 
 
 
 
7
  import torch
8
  import transformers
9
  from transformers import PreTrainedTokenizerFast
@@ -63,15 +68,49 @@ if not os.path.exists("tranception"):
63
  import tranception
64
  from tranception import config, model_pytorch
65
 
66
- # Download model checkpoints if not present
67
- def download_model_from_hf(model_name):
68
- """Download model from Hugging Face Hub if not present locally"""
69
- model_path = f"./{model_name}"
70
- if not os.path.exists(model_path):
71
- print(f"Loading {model_name} model from Hugging Face Hub...")
72
- # All models are available on HF Hub
73
- return f"PascalNotin/{model_name}"
74
- return model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
77
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
@@ -265,22 +304,11 @@ def score_and_create_matrix_all_singles_impl(sequence,mutation_range_start=None,
265
  assert mutation_range_start <= mutation_range_end, "mutation range is invalid"
266
  assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
267
 
268
- # Load model with HF Space compatibility
269
- try:
270
- if model_type=="Small":
271
- model_path = download_model_from_hf("Tranception_Small")
272
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
273
- elif model_type=="Medium":
274
- model_path = download_model_from_hf("Tranception_Medium")
275
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
276
- elif model_type=="Large":
277
- model_path = download_model_from_hf("Tranception_Large")
278
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
279
- except Exception as e:
280
- print(f"Error loading {model_type} model: {e}")
281
- print("Falling back to Medium model...")
282
- model_path = download_model_from_hf("Tranception_Medium")
283
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
284
 
285
  # Device selection - Zero GPU will provide CUDA when decorated with @spaces.GPU
286
  print(f"GPU Available: {torch.cuda.is_available()}")
@@ -347,12 +375,13 @@ def score_and_create_matrix_all_singles_impl(sequence,mutation_range_start=None,
347
  return score_heatmaps, suggest_mutations(scores), csv_files
348
 
349
  finally:
350
- # Always clean up model from memory
 
351
  if 'model' in locals():
352
- del model
353
- gc.collect()
354
- if torch.cuda.is_available():
355
- torch.cuda.empty_cache()
356
 
357
  # Apply Zero GPU decorator if available
358
  if SPACES_AVAILABLE:
@@ -497,7 +526,25 @@ with tranception_design:
497
  gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>")
498
  gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a> <a href='https://igem.org/teams/5247' target='_blank'>BASIS-China iGEM Team</a>")
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  if __name__ == "__main__":
 
 
 
501
  # Simple launch without queue to avoid Zero GPU conflicts
502
  tranception_design.launch(
503
  server_name="0.0.0.0",
 
4
  """
5
  import os
6
  import sys
7
+
8
+ # Set up caching to avoid re-downloading models
9
+ os.environ['HF_HOME'] = '/tmp/huggingface'
10
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface/transformers'
11
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
12
  import torch
13
  import transformers
14
  from transformers import PreTrainedTokenizerFast
 
68
  import tranception
69
  from tranception import config, model_pytorch
70
 
71
+ # Model loading configuration
72
+ MODEL_CACHE = {}
73
+
74
+ def get_model_path(model_name):
75
+ """Get model path - always use HF Hub for Zero GPU spaces"""
76
+ # In HF Spaces, models are cached automatically by the transformers library
77
+ # Always return the HF Hub path to leverage this caching
78
+ return f"PascalNotin/{model_name}"
79
+
80
+ def load_model_cached(model_type):
81
+ """Load model with caching to avoid re-downloading"""
82
+ global MODEL_CACHE
83
+
84
+ # Check if model is already in cache
85
+ if model_type in MODEL_CACHE:
86
+ print(f"Using cached {model_type} model")
87
+ return MODEL_CACHE[model_type]
88
+
89
+ print(f"Loading {model_type} model...")
90
+ model_name = f"Tranception_{model_type}"
91
+ model_path = get_model_path(model_name)
92
+
93
+ try:
94
+ # Create cache directory if it doesn't exist
95
+ cache_dir = "/tmp/huggingface/transformers"
96
+ os.makedirs(cache_dir, exist_ok=True)
97
+
98
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
99
+ pretrained_model_name_or_path=model_path,
100
+ cache_dir=cache_dir,
101
+ local_files_only=False, # Allow downloading if not cached
102
+ resume_download=True # Resume incomplete downloads
103
+ )
104
+ MODEL_CACHE[model_type] = model
105
+ print(f"{model_type} model loaded and cached")
106
+ return model
107
+ except Exception as e:
108
+ print(f"Error loading {model_type} model: {e}")
109
+ # Fallback to Medium if requested model fails
110
+ if model_type != "Medium":
111
+ print("Falling back to Medium model...")
112
+ return load_model_cached("Medium")
113
+ raise
114
 
115
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
116
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
 
304
  assert mutation_range_start <= mutation_range_end, "mutation range is invalid"
305
  assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
306
 
307
+ # Load model with caching
308
+ model = load_model_cached(model_type)
309
+
310
+ # Move model to appropriate device INSIDE the GPU decorated function
311
+ # This is crucial for Zero GPU - the model must be moved to GPU inside the decorated function
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  # Device selection - Zero GPU will provide CUDA when decorated with @spaces.GPU
314
  print(f"GPU Available: {torch.cuda.is_available()}")
 
375
  return score_heatmaps, suggest_mutations(scores), csv_files
376
 
377
  finally:
378
+ # Clean up GPU memory but keep model in cache
379
+ # Move model back to CPU to free GPU memory
380
  if 'model' in locals():
381
+ model.cpu()
382
+ if torch.cuda.is_available():
383
+ torch.cuda.empty_cache()
384
+ gc.collect()
385
 
386
  # Apply Zero GPU decorator if available
387
  if SPACES_AVAILABLE:
 
526
  gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>")
527
  gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a> <a href='https://igem.org/teams/5247' target='_blank'>BASIS-China iGEM Team</a>")
528
 
529
+ # Preload models function
530
+ def preload_models():
531
+ """Preload models at startup to avoid downloading during inference"""
532
+ print("Preloading models at startup...")
533
+ try:
534
+ # Try to load Small model (fastest)
535
+ load_model_cached("Small")
536
+ print("Small model preloaded successfully")
537
+ except Exception as e:
538
+ print(f"Could not preload Small model: {e}")
539
+
540
+ # Optionally preload other models
541
+ # load_model_cached("Medium")
542
+ # load_model_cached("Large")
543
+
544
  if __name__ == "__main__":
545
+ # Preload models before launching
546
+ preload_models()
547
+
548
  # Simple launch without queue to avoid Zero GPU conflicts
549
  tranception_design.launch(
550
  server_name="0.0.0.0",