Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a0e970b
1
Parent(s):
86d5c5f
Implement model caching and preload functionality to optimize loading in Zero GPU spaces
Browse files
app.py
CHANGED
@@ -4,6 +4,11 @@ Tranception Design App - Hugging Face Spaces Version (Zero GPU Fixed)
|
|
4 |
"""
|
5 |
import os
|
6 |
import sys
|
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import transformers
|
9 |
from transformers import PreTrainedTokenizerFast
|
@@ -63,15 +68,49 @@ if not os.path.exists("tranception"):
|
|
63 |
import tranception
|
64 |
from tranception import config, model_pytorch
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
77 |
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
|
@@ -265,22 +304,11 @@ def score_and_create_matrix_all_singles_impl(sequence,mutation_range_start=None,
|
|
265 |
assert mutation_range_start <= mutation_range_end, "mutation range is invalid"
|
266 |
assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
|
267 |
|
268 |
-
# Load model with
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
elif model_type=="Medium":
|
274 |
-
model_path = download_model_from_hf("Tranception_Medium")
|
275 |
-
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
276 |
-
elif model_type=="Large":
|
277 |
-
model_path = download_model_from_hf("Tranception_Large")
|
278 |
-
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
279 |
-
except Exception as e:
|
280 |
-
print(f"Error loading {model_type} model: {e}")
|
281 |
-
print("Falling back to Medium model...")
|
282 |
-
model_path = download_model_from_hf("Tranception_Medium")
|
283 |
-
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
284 |
|
285 |
# Device selection - Zero GPU will provide CUDA when decorated with @spaces.GPU
|
286 |
print(f"GPU Available: {torch.cuda.is_available()}")
|
@@ -347,12 +375,13 @@ def score_and_create_matrix_all_singles_impl(sequence,mutation_range_start=None,
|
|
347 |
return score_heatmaps, suggest_mutations(scores), csv_files
|
348 |
|
349 |
finally:
|
350 |
-
#
|
|
|
351 |
if 'model' in locals():
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
|
357 |
# Apply Zero GPU decorator if available
|
358 |
if SPACES_AVAILABLE:
|
@@ -497,7 +526,25 @@ with tranception_design:
|
|
497 |
gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>")
|
498 |
gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a> <a href='https://igem.org/teams/5247' target='_blank'>BASIS-China iGEM Team</a>")
|
499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
if __name__ == "__main__":
|
|
|
|
|
|
|
501 |
# Simple launch without queue to avoid Zero GPU conflicts
|
502 |
tranception_design.launch(
|
503 |
server_name="0.0.0.0",
|
|
|
4 |
"""
|
5 |
import os
|
6 |
import sys
|
7 |
+
|
8 |
+
# Set up caching to avoid re-downloading models
|
9 |
+
os.environ['HF_HOME'] = '/tmp/huggingface'
|
10 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface/transformers'
|
11 |
+
os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
|
12 |
import torch
|
13 |
import transformers
|
14 |
from transformers import PreTrainedTokenizerFast
|
|
|
68 |
import tranception
|
69 |
from tranception import config, model_pytorch
|
70 |
|
71 |
+
# Model loading configuration
|
72 |
+
MODEL_CACHE = {}
|
73 |
+
|
74 |
+
def get_model_path(model_name):
|
75 |
+
"""Get model path - always use HF Hub for Zero GPU spaces"""
|
76 |
+
# In HF Spaces, models are cached automatically by the transformers library
|
77 |
+
# Always return the HF Hub path to leverage this caching
|
78 |
+
return f"PascalNotin/{model_name}"
|
79 |
+
|
80 |
+
def load_model_cached(model_type):
|
81 |
+
"""Load model with caching to avoid re-downloading"""
|
82 |
+
global MODEL_CACHE
|
83 |
+
|
84 |
+
# Check if model is already in cache
|
85 |
+
if model_type in MODEL_CACHE:
|
86 |
+
print(f"Using cached {model_type} model")
|
87 |
+
return MODEL_CACHE[model_type]
|
88 |
+
|
89 |
+
print(f"Loading {model_type} model...")
|
90 |
+
model_name = f"Tranception_{model_type}"
|
91 |
+
model_path = get_model_path(model_name)
|
92 |
+
|
93 |
+
try:
|
94 |
+
# Create cache directory if it doesn't exist
|
95 |
+
cache_dir = "/tmp/huggingface/transformers"
|
96 |
+
os.makedirs(cache_dir, exist_ok=True)
|
97 |
+
|
98 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
99 |
+
pretrained_model_name_or_path=model_path,
|
100 |
+
cache_dir=cache_dir,
|
101 |
+
local_files_only=False, # Allow downloading if not cached
|
102 |
+
resume_download=True # Resume incomplete downloads
|
103 |
+
)
|
104 |
+
MODEL_CACHE[model_type] = model
|
105 |
+
print(f"{model_type} model loaded and cached")
|
106 |
+
return model
|
107 |
+
except Exception as e:
|
108 |
+
print(f"Error loading {model_type} model: {e}")
|
109 |
+
# Fallback to Medium if requested model fails
|
110 |
+
if model_type != "Medium":
|
111 |
+
print("Falling back to Medium model...")
|
112 |
+
return load_model_cached("Medium")
|
113 |
+
raise
|
114 |
|
115 |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
116 |
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
|
|
|
304 |
assert mutation_range_start <= mutation_range_end, "mutation range is invalid"
|
305 |
assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
|
306 |
|
307 |
+
# Load model with caching
|
308 |
+
model = load_model_cached(model_type)
|
309 |
+
|
310 |
+
# Move model to appropriate device INSIDE the GPU decorated function
|
311 |
+
# This is crucial for Zero GPU - the model must be moved to GPU inside the decorated function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
# Device selection - Zero GPU will provide CUDA when decorated with @spaces.GPU
|
314 |
print(f"GPU Available: {torch.cuda.is_available()}")
|
|
|
375 |
return score_heatmaps, suggest_mutations(scores), csv_files
|
376 |
|
377 |
finally:
|
378 |
+
# Clean up GPU memory but keep model in cache
|
379 |
+
# Move model back to CPU to free GPU memory
|
380 |
if 'model' in locals():
|
381 |
+
model.cpu()
|
382 |
+
if torch.cuda.is_available():
|
383 |
+
torch.cuda.empty_cache()
|
384 |
+
gc.collect()
|
385 |
|
386 |
# Apply Zero GPU decorator if available
|
387 |
if SPACES_AVAILABLE:
|
|
|
526 |
gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>")
|
527 |
gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a> <a href='https://igem.org/teams/5247' target='_blank'>BASIS-China iGEM Team</a>")
|
528 |
|
529 |
+
# Preload models function
|
530 |
+
def preload_models():
|
531 |
+
"""Preload models at startup to avoid downloading during inference"""
|
532 |
+
print("Preloading models at startup...")
|
533 |
+
try:
|
534 |
+
# Try to load Small model (fastest)
|
535 |
+
load_model_cached("Small")
|
536 |
+
print("Small model preloaded successfully")
|
537 |
+
except Exception as e:
|
538 |
+
print(f"Could not preload Small model: {e}")
|
539 |
+
|
540 |
+
# Optionally preload other models
|
541 |
+
# load_model_cached("Medium")
|
542 |
+
# load_model_cached("Large")
|
543 |
+
|
544 |
if __name__ == "__main__":
|
545 |
+
# Preload models before launching
|
546 |
+
preload_models()
|
547 |
+
|
548 |
# Simple launch without queue to avoid Zero GPU conflicts
|
549 |
tranception_design.launch(
|
550 |
server_name="0.0.0.0",
|