ManasSharma07 commited on
Commit
8f100da
·
verified ·
1 Parent(s): 4bf2972

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -0
src/streamlit_app.py CHANGED
@@ -14,6 +14,7 @@ from mace.calculators import mace_mp
14
  from fairchem.core import pretrained_mlip, FAIRChemCalculator
15
  from orb_models.forcefield import pretrained
16
  from orb_models.forcefield.calculator import ORBCalculator
 
17
  import pandas as pd
18
  import yaml # Added for FairChem reference energies
19
 
@@ -686,6 +687,11 @@ ORB_MODELS = {
686
  "V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
687
  "V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
688
  }
 
 
 
 
 
689
  @st.cache_resource
690
  def get_mace_model(model_path, device, selected_default_dtype):
691
  return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
@@ -780,6 +786,9 @@ if model_type == "ORB":
780
  # if "omat" in selected_model:
781
  # st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
782
  selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
 
 
 
783
  if atoms is not None:
784
  if not check_atom_limit(atoms, selected_model):
785
  st.stop() # Stop execution if limit exceeded
@@ -870,6 +879,9 @@ if atoms is not None:
870
  st.write("Setting up ORB calculator...")
871
  orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
872
  calc = ORBCalculator(orbff, device=device)
 
 
 
873
  calc_atoms.calc = calc
874
 
875
  if task == "Energy Calculation":
 
14
  from fairchem.core import pretrained_mlip, FAIRChemCalculator
15
  from orb_models.forcefield import pretrained
16
  from orb_models.forcefield.calculator import ORBCalculator
17
+ from mattersim.forcefield import MatterSimCalculator
18
  import pandas as pd
19
  import yaml # Added for FairChem reference energies
20
 
 
687
  "V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
688
  "V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
689
  }
690
+ # Define the available MatterSim models
691
+ MATTERSIM_MODELS = {
692
+ "V1 SMALL: MatterSim-v1.0.0-1M.pth",
693
+ "V1 LARGE: MatterSim-v1.0.0-5M.pth"
694
+ }
695
  @st.cache_resource
696
  def get_mace_model(model_path, device, selected_default_dtype):
697
  return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
 
786
  # if "omat" in selected_model:
787
  # st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
788
  selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
789
+ if model_type == "MatterSim":
790
+ selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
791
+ model_path = MATTERSIM_MODELS[selected_model]
792
  if atoms is not None:
793
  if not check_atom_limit(atoms, selected_model):
794
  st.stop() # Stop execution if limit exceeded
 
879
  st.write("Setting up ORB calculator...")
880
  orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
881
  calc = ORBCalculator(orbff, device=device)
882
+ elif model_type == "MatterSim":
883
+ st.write("Setting up MatterSim calculator...")
884
+ calc = MatterSimCalculator(load_path=model_path, device=device)
885
  calc_atoms.calc = calc
886
 
887
  if task == "Energy Calculation":