ManasSharma07 commited on
Commit
963acc4
·
verified ·
1 Parent(s): 8230d62

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -5
src/streamlit_app.py CHANGED
@@ -12,6 +12,8 @@ from ase.visualize import view
12
  import py3Dmol
13
  from mace.calculators import mace_mp
14
  from fairchem.core import pretrained_mlip, FAIRChemCalculator
 
 
15
  import pandas as pd
16
  import yaml # Added for FairChem reference energies
17
 
@@ -673,7 +675,10 @@ FAIRCHEM_MODELS = {
673
  "ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
674
  "ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
675
  }
676
-
 
 
 
677
  @st.cache_resource
678
  def get_mace_model(model_path, device, selected_default_dtype):
679
  return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
@@ -742,7 +747,7 @@ if atoms is not None:
742
 
743
 
744
  st.sidebar.markdown("## Model Selection")
745
- model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
746
 
747
  selected_task_type = None # For FairChem UMA
748
  if model_type == "MACE":
@@ -762,7 +767,12 @@ if model_type == "FairChem":
762
  spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
763
  atoms.info["charge"] = charge
764
  atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
765
-
 
 
 
 
 
766
  if atoms is not None:
767
  if not check_atom_limit(atoms, selected_model):
768
  st.stop() # Stop execution if limit exceeded
@@ -842,14 +852,17 @@ if atoms is not None:
842
  if model_type == "MACE":
843
  # st.write("Setting up MACE calculator...")
844
  calc = get_mace_model(model_path, device, selected_default_dtype)
845
- else: # FairChem
846
  # st.write("Setting up FairChem calculator...")
847
  # Workaround for potential dtype issues when switching models
848
  if device == "cpu": # Ensure torch default dtype matches if needed
849
  torch.set_default_dtype(torch.float32)
850
  _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
851
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
852
-
 
 
 
853
  calc_atoms.calc = calc
854
 
855
  if task == "Energy Calculation":
 
12
  import py3Dmol
13
  from mace.calculators import mace_mp
14
  from fairchem.core import pretrained_mlip, FAIRChemCalculator
15
+ from orb_models.forcefield import pretrained
16
+ from orb_models.forcefield.calculator import ORBCalculator
17
  import pandas as pd
18
  import yaml # Added for FairChem reference energies
19
 
 
675
  "ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
676
  "ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
677
  }
678
+ # Define the available ORB models
679
+ ORB_MODELS = {
680
+ "V3 OMAT Conserving": "orb_v3_conservative_inf_omat",
681
+ }
682
  @st.cache_resource
683
  def get_mace_model(model_path, device, selected_default_dtype):
684
  return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
 
747
 
748
 
749
  st.sidebar.markdown("## Model Selection")
750
+ model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB"])
751
 
752
  selected_task_type = None # For FairChem UMA
753
  if model_type == "MACE":
 
767
  spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
768
  atoms.info["charge"] = charge
769
  atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
770
+ if model_type == "ORB":
771
+ selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
772
+ model_path = ORB_MODELS[selected_model]
773
+ # if "omat" in selected_model:
774
+ # st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
775
+ selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
776
  if atoms is not None:
777
  if not check_atom_limit(atoms, selected_model):
778
  st.stop() # Stop execution if limit exceeded
 
852
  if model_type == "MACE":
853
  # st.write("Setting up MACE calculator...")
854
  calc = get_mace_model(model_path, device, selected_default_dtype)
855
+ elif model_type == "FairChem": # FairChem
856
  # st.write("Setting up FairChem calculator...")
857
  # Workaround for potential dtype issues when switching models
858
  if device == "cpu": # Ensure torch default dtype matches if needed
859
  torch.set_default_dtype(torch.float32)
860
  _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
861
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
862
+ elif model_type == "ORB":
863
+ st.write("Setting up ORB calculator...")
864
+ orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
865
+ calc = ORBCalculator(orbff, device=device)
866
  calc_atoms.calc = calc
867
 
868
  if task == "Energy Calculation":