Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +12 -0
src/streamlit_app.py
CHANGED
@@ -14,6 +14,7 @@ from mace.calculators import mace_mp
|
|
14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
15 |
from orb_models.forcefield import pretrained
|
16 |
from orb_models.forcefield.calculator import ORBCalculator
|
|
|
17 |
import pandas as pd
|
18 |
import yaml # Added for FairChem reference energies
|
19 |
|
@@ -686,6 +687,11 @@ ORB_MODELS = {
|
|
686 |
"V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
|
687 |
"V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
|
688 |
}
|
|
|
|
|
|
|
|
|
|
|
689 |
@st.cache_resource
|
690 |
def get_mace_model(model_path, device, selected_default_dtype):
|
691 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
@@ -780,6 +786,9 @@ if model_type == "ORB":
|
|
780 |
# if "omat" in selected_model:
|
781 |
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
782 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
|
|
|
|
|
|
783 |
if atoms is not None:
|
784 |
if not check_atom_limit(atoms, selected_model):
|
785 |
st.stop() # Stop execution if limit exceeded
|
@@ -870,6 +879,9 @@ if atoms is not None:
|
|
870 |
st.write("Setting up ORB calculator...")
|
871 |
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
872 |
calc = ORBCalculator(orbff, device=device)
|
|
|
|
|
|
|
873 |
calc_atoms.calc = calc
|
874 |
|
875 |
if task == "Energy Calculation":
|
|
|
14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
15 |
from orb_models.forcefield import pretrained
|
16 |
from orb_models.forcefield.calculator import ORBCalculator
|
17 |
+
from mattersim.forcefield import MatterSimCalculator
|
18 |
import pandas as pd
|
19 |
import yaml # Added for FairChem reference energies
|
20 |
|
|
|
687 |
"V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
|
688 |
"V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
|
689 |
}
|
690 |
+
# Define the available MatterSim models
|
691 |
+
MATTERSIM_MODELS = {
|
692 |
+
"V1 SMALL: MatterSim-v1.0.0-1M.pth",
|
693 |
+
"V1 LARGE: MatterSim-v1.0.0-5M.pth"
|
694 |
+
}
|
695 |
@st.cache_resource
|
696 |
def get_mace_model(model_path, device, selected_default_dtype):
|
697 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
|
786 |
# if "omat" in selected_model:
|
787 |
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
788 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
789 |
+
if model_type == "MatterSim":
|
790 |
+
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
|
791 |
+
model_path = MATTERSIM_MODELS[selected_model]
|
792 |
if atoms is not None:
|
793 |
if not check_atom_limit(atoms, selected_model):
|
794 |
st.stop() # Stop execution if limit exceeded
|
|
|
879 |
st.write("Setting up ORB calculator...")
|
880 |
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
881 |
calc = ORBCalculator(orbff, device=device)
|
882 |
+
elif model_type == "MatterSim":
|
883 |
+
st.write("Setting up MatterSim calculator...")
|
884 |
+
calc = MatterSimCalculator(load_path=model_path, device=device)
|
885 |
calc_atoms.calc = calc
|
886 |
|
887 |
if task == "Energy Calculation":
|