MoraxCheng commited on
Commit
9957ccd
·
1 Parent(s): 77cfeaa

Update model selection for HF Spaces to use Large model

Browse files

- Remove hardcoded fallback from Large to Medium model
- Simplify download_model_from_hf function to support all model sizes
- Add try-catch for model loading with automatic fallback to Medium if needed
- All Tranception models (Small, Medium, Large) are now available on HF Hub

Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -42,19 +42,9 @@ def download_model_from_hf(model_name):
42
  """Download model from Hugging Face Hub if not present locally"""
43
  model_path = f"./{model_name}"
44
  if not os.path.exists(model_path):
45
- print(f"Downloading {model_name} model...")
46
- try:
47
- # For Small and Medium models, they are available on HF Hub
48
- if model_name in ["Tranception_Small", "Tranception_Medium"]:
49
- return f"PascalNotin/{model_name}"
50
- else:
51
- # For Large model, we need to download from the original source
52
- print("Note: Large model needs to be downloaded from the original source.")
53
- print("Using Medium model as fallback...")
54
- return "PascalNotin/Tranception_Medium"
55
- except Exception as e:
56
- print(f"Error downloading {model_name}: {e}")
57
- return None
58
  return model_path
59
 
60
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
@@ -239,15 +229,21 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
239
  assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
240
 
241
  # Load model with HF Space compatibility
242
- if model_type=="Small":
243
- model_path = download_model_from_hf("Tranception_Small")
244
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
245
- elif model_type=="Medium":
 
 
 
 
 
 
 
 
 
246
  model_path = download_model_from_hf("Tranception_Medium")
247
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
248
- elif model_type=="Large":
249
- model_path = download_model_from_hf("Tranception_Large")
250
- model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
251
 
252
  # Device selection - for HF Spaces, typically CPU
253
  if torch.cuda.is_available():
 
42
  """Download model from Hugging Face Hub if not present locally"""
43
  model_path = f"./{model_name}"
44
  if not os.path.exists(model_path):
45
+ print(f"Loading {model_name} model from Hugging Face Hub...")
46
+ # All models are available on HF Hub
47
+ return f"PascalNotin/{model_name}"
 
 
 
 
 
 
 
 
 
 
48
  return model_path
49
 
50
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
 
229
  assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
230
 
231
  # Load model with HF Space compatibility
232
+ try:
233
+ if model_type=="Small":
234
+ model_path = download_model_from_hf("Tranception_Small")
235
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
236
+ elif model_type=="Medium":
237
+ model_path = download_model_from_hf("Tranception_Medium")
238
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
239
+ elif model_type=="Large":
240
+ model_path = download_model_from_hf("Tranception_Large")
241
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
242
+ except Exception as e:
243
+ print(f"Error loading {model_type} model: {e}")
244
+ print("Falling back to Medium model...")
245
  model_path = download_model_from_hf("Tranception_Medium")
246
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
 
 
 
247
 
248
  # Device selection - for HF Spaces, typically CPU
249
  if torch.cuda.is_available():