Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9957ccd
1
Parent(s):
77cfeaa
Update model selection for HF Spaces to use Large model
Browse files- Remove hardcoded fallback from Large to Medium model
- Simplify download_model_from_hf function to support all model sizes
- Add try-catch for model loading with automatic fallback to Medium if needed
- All Tranception models (Small, Medium, Large) are now available on HF Hub
app.py
CHANGED
@@ -42,19 +42,9 @@ def download_model_from_hf(model_name):
|
|
42 |
"""Download model from Hugging Face Hub if not present locally"""
|
43 |
model_path = f"./{model_name}"
|
44 |
if not os.path.exists(model_path):
|
45 |
-
print(f"
|
46 |
-
|
47 |
-
|
48 |
-
if model_name in ["Tranception_Small", "Tranception_Medium"]:
|
49 |
-
return f"PascalNotin/{model_name}"
|
50 |
-
else:
|
51 |
-
# For Large model, we need to download from the original source
|
52 |
-
print("Note: Large model needs to be downloaded from the original source.")
|
53 |
-
print("Using Medium model as fallback...")
|
54 |
-
return "PascalNotin/Tranception_Medium"
|
55 |
-
except Exception as e:
|
56 |
-
print(f"Error downloading {model_name}: {e}")
|
57 |
-
return None
|
58 |
return model_path
|
59 |
|
60 |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
@@ -239,15 +229,21 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
|
|
239 |
assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
|
240 |
|
241 |
# Load model with HF Space compatibility
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
model_path = download_model_from_hf("Tranception_Medium")
|
247 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
248 |
-
elif model_type=="Large":
|
249 |
-
model_path = download_model_from_hf("Tranception_Large")
|
250 |
-
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
251 |
|
252 |
# Device selection - for HF Spaces, typically CPU
|
253 |
if torch.cuda.is_available():
|
|
|
42 |
"""Download model from Hugging Face Hub if not present locally"""
|
43 |
model_path = f"./{model_name}"
|
44 |
if not os.path.exists(model_path):
|
45 |
+
print(f"Loading {model_name} model from Hugging Face Hub...")
|
46 |
+
# All models are available on HF Hub
|
47 |
+
return f"PascalNotin/{model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
return model_path
|
49 |
|
50 |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
|
|
229 |
assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
|
230 |
|
231 |
# Load model with HF Space compatibility
|
232 |
+
try:
|
233 |
+
if model_type=="Small":
|
234 |
+
model_path = download_model_from_hf("Tranception_Small")
|
235 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
236 |
+
elif model_type=="Medium":
|
237 |
+
model_path = download_model_from_hf("Tranception_Medium")
|
238 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
239 |
+
elif model_type=="Large":
|
240 |
+
model_path = download_model_from_hf("Tranception_Large")
|
241 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
242 |
+
except Exception as e:
|
243 |
+
print(f"Error loading {model_type} model: {e}")
|
244 |
+
print("Falling back to Medium model...")
|
245 |
model_path = download_model_from_hf("Tranception_Medium")
|
246 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
|
|
|
|
|
|
|
247 |
|
248 |
# Device selection - for HF Spaces, typically CPU
|
249 |
if torch.cuda.is_available():
|