Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +18 -5
src/streamlit_app.py
CHANGED
@@ -12,6 +12,8 @@ from ase.visualize import view
|
|
12 |
import py3Dmol
|
13 |
from mace.calculators import mace_mp
|
14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
|
|
|
|
15 |
import pandas as pd
|
16 |
import yaml # Added for FairChem reference energies
|
17 |
|
@@ -673,7 +675,10 @@ FAIRCHEM_MODELS = {
|
|
673 |
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
|
674 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
675 |
}
|
676 |
-
|
|
|
|
|
|
|
677 |
@st.cache_resource
|
678 |
def get_mace_model(model_path, device, selected_default_dtype):
|
679 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
@@ -742,7 +747,7 @@ if atoms is not None:
|
|
742 |
|
743 |
|
744 |
st.sidebar.markdown("## Model Selection")
|
745 |
-
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
|
746 |
|
747 |
selected_task_type = None # For FairChem UMA
|
748 |
if model_type == "MACE":
|
@@ -762,7 +767,12 @@ if model_type == "FairChem":
|
|
762 |
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
|
763 |
atoms.info["charge"] = charge
|
764 |
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
|
765 |
-
|
|
|
|
|
|
|
|
|
|
|
766 |
if atoms is not None:
|
767 |
if not check_atom_limit(atoms, selected_model):
|
768 |
st.stop() # Stop execution if limit exceeded
|
@@ -842,14 +852,17 @@ if atoms is not None:
|
|
842 |
if model_type == "MACE":
|
843 |
# st.write("Setting up MACE calculator...")
|
844 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
845 |
-
|
846 |
# st.write("Setting up FairChem calculator...")
|
847 |
# Workaround for potential dtype issues when switching models
|
848 |
if device == "cpu": # Ensure torch default dtype matches if needed
|
849 |
torch.set_default_dtype(torch.float32)
|
850 |
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
|
851 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
852 |
-
|
|
|
|
|
|
|
853 |
calc_atoms.calc = calc
|
854 |
|
855 |
if task == "Energy Calculation":
|
|
|
12 |
import py3Dmol
|
13 |
from mace.calculators import mace_mp
|
14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
15 |
+
from orb_models.forcefield import pretrained
|
16 |
+
from orb_models.forcefield.calculator import ORBCalculator
|
17 |
import pandas as pd
|
18 |
import yaml # Added for FairChem reference energies
|
19 |
|
|
|
675 |
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
|
676 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
677 |
}
|
678 |
+
# Define the available ORB models
|
679 |
+
ORB_MODELS = {
|
680 |
+
"V3 OMAT Conserving": "orb_v3_conservative_inf_omat",
|
681 |
+
}
|
682 |
@st.cache_resource
|
683 |
def get_mace_model(model_path, device, selected_default_dtype):
|
684 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
|
747 |
|
748 |
|
749 |
st.sidebar.markdown("## Model Selection")
|
750 |
+
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB"])
|
751 |
|
752 |
selected_task_type = None # For FairChem UMA
|
753 |
if model_type == "MACE":
|
|
|
767 |
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
|
768 |
atoms.info["charge"] = charge
|
769 |
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
|
770 |
+
if model_type == "ORB":
|
771 |
+
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
|
772 |
+
model_path = ORB_MODELS[selected_model]
|
773 |
+
# if "omat" in selected_model:
|
774 |
+
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
775 |
+
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
776 |
if atoms is not None:
|
777 |
if not check_atom_limit(atoms, selected_model):
|
778 |
st.stop() # Stop execution if limit exceeded
|
|
|
852 |
if model_type == "MACE":
|
853 |
# st.write("Setting up MACE calculator...")
|
854 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
855 |
+
elif model_type == "FairChem": # FairChem
|
856 |
# st.write("Setting up FairChem calculator...")
|
857 |
# Workaround for potential dtype issues when switching models
|
858 |
if device == "cpu": # Ensure torch default dtype matches if needed
|
859 |
torch.set_default_dtype(torch.float32)
|
860 |
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
|
861 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
862 |
+
elif model_type == "ORB":
|
863 |
+
st.write("Setting up ORB calculator...")
|
864 |
+
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
865 |
+
calc = ORBCalculator(orbff, device=device)
|
866 |
calc_atoms.calc = calc
|
867 |
|
868 |
if task == "Energy Calculation":
|