mlip-playground / src /streamlit_app.py
ManasSharma07's picture
Update src/streamlit_app.py
61998ea verified
raw
history blame
37.4 kB
import streamlit as st
import os
import tempfile
import torch
import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.optimize import BFGS, LBFGS, FIRE
from ase.constraints import FixAtoms
from ase.filters import FrechetCellFilter
from ase.visualize import view
import py3Dmol
from mace.calculators import mace_mp
from fairchem.core import pretrained_mlip, FAIRChemCalculator
import pandas as pd
import yaml # Added for FairChem reference energies
from huggingface_hub import login
# try:
# hf_token = st.secrets["HF_TOKEN"]["token"]
# os.environ["HF_TOKEN"] = hf_token
# login(token=hf_token)
# except Exception as e:
# print("streamlit hf secret not defined/assigned")
try:
hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately
if hf_token:
login(token = hf_token)
else:
print("Hugging Face token not found. Some models might not be accessible.")
except Exception as e:
print(f"hf login error: {e}")
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
# YAML data for FairChem reference energies
ELEMENT_REF_ENERGIES_YAML = """
oc20_elem_refs:
- 0.0
- -0.16141512
- 0.03262098
- -0.04787699
- -0.06299825
- -0.14979306
- -0.11657468
- -0.10862579
- -0.10298174
- -0.03420248
- 0.02673997
- -0.03729558
- 0.00515243
- -0.07535697
- -0.13663351
- -0.12922852
- -0.11796547
- -0.07802946
- -0.00672682
- -0.04089589
- -0.00024177
- -1.74545186
- -1.54220241
- -1.0934019
- -1.16168372
- -1.23073475
- -0.78852824
- -0.71851599
- -0.52465053
- -0.02692092
- -0.00317922
- -0.06266862
- -0.10835274
- -0.12394474
- -0.11351727
- -0.07455817
- -0.00258354
- -0.04111325
- -0.02090265
- -1.89306078
- -1.30591887
- -0.63320009
- -0.26230344
- -0.2633669
- -0.5160055
- -0.95950798
- -1.45589361
- -0.0429969
- -0.00026949
- -0.05925609
- -0.09734631
- -0.12406852
- -0.11427538
- -0.07021442
- 0.01091345
- -0.05305289
- -0.02427209
- -0.19975668
- -1.71692859
- -1.53677781
- -3.89987009
- -10.70940462
- -6.71693816
- -0.28102249
- -8.86944824
- -7.95762687
- -7.13041437
- -6.64620014
- -5.11482482
- -4.42548227
- 0.00848295
- -0.06956227
- -2.6748853
- -2.21153293
- -1.67367741
- -1.07636151
- -0.79009981
- -0.16387243
- -0.18164401
- -0.04122529
- -0.00041833
- -0.05259382
- -0.0934314
- -0.11023834
- -0.10039175
- -0.06069209
- 0.01790437
- -0.04694024
- 0.00334084
- -0.06030621
- -0.58793619
- -1.27821808
- -4.97483577
- -5.66985655
- -8.43154622
- -11.15001317
- -12.95770812
- 0.0
- -14.47602729
- 0.0
odac_elem_refs:
- 0.0
- -1.11737936
- -0.00011835
- -0.2941727
- -0.03868426
- -0.34862832
- -1.31552566
- -3.12457285
- -1.6052078
- -0.49653389
- -0.01137327
- -0.21957281
- -0.0008343
- -0.2750172
- -0.88417265
- -1.887378
- -0.94903558
- -0.31628167
- -0.02014536
- -0.15901053
- -0.00731884
- -1.96521355
- -1.89045209
- -2.53057428
- -5.43600675
- -5.09739336
- -3.03088746
- -1.23786562
- -0.40650749
- -0.2416017
- -0.01139188
- -0.26282496
- -0.82446455
- -1.70237206
- -0.84245376
- -0.28544892
- -0.02239991
- -0.14115912
- -0.02840799
- -2.09540994
- -1.85863996
- -1.12257399
- -4.32965355
- -3.30670045
- -1.19460755
- -1.26257601
- -1.46832888
- -0.19779414
- -0.0144274
- -0.23668767
- -0.70836953
- -1.43186113
- -0.71701186
- -0.24883129
- -0.01118184
- -0.13173447
- -0.0318395
- -0.41195547
- -1.23134873
- -2.03082996
- 0.1375954
- -5.45866275
- -7.59139905
- -5.99965965
- -8.43495767
- -2.6578407
- -7.77349787
- -5.30762201
- -5.15109657
- -4.41466995
- -0.02995219
- -0.2544495
- -3.23821202
- -3.45887214
- -4.53635003
- -4.60979468
- -2.90707964
- -1.28286153
- -0.57716664
- -0.18337108
- -0.01135944
- -0.22045398
- -0.66150479
- -1.32506342
- -0.66500178
- -0.22643927
- -0.00728197
- -0.11208472
- -0.00757856
- -0.21798637
- -0.91078787
- -1.78187161
- -3.89912261
- -3.94192659
- -7.59026042
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omat_elem_refs:
- 0.0
- -1.11700253
- 0.00079886
- -0.29731164
- -0.04129868
- -0.29106192
- -1.27751531
- -3.12342715
- -1.54797136
- -0.43969356
- -0.01250908
- -0.22855413
- -0.00943179
- -0.21707638
- -0.82619133
- -1.88667434
- -0.89093583
- -0.25816211
- -0.02414768
- -0.17662425
- -0.02568319
- -2.13001165
- -2.38688845
- -3.55934233
- -5.44700879
- -5.14749562
- -3.30662847
- -1.42167737
- -0.63181379
- -0.23449167
- -0.01146636
- -0.21291259
- -0.77939897
- -1.70148487
- -0.78386705
- -0.22690657
- -0.02245409
- -0.16092396
- -0.02798717
- -2.25685695
- -2.23690495
- -2.15347771
- -4.60251809
- -3.36416792
- -2.23062607
- -1.15550917
- -1.47553527
- -0.19918102
- -0.01475888
- -0.19767692
- -0.68005773
- -1.43073368
- -0.65790462
- -0.18915279
- -0.01179476
- -0.13507902
- -0.03056979
- -0.36017439
- -0.86279246
- -0.20573327
- -0.2734463
- -0.20046965
- -0.25444338
- -8.37972664
- -9.58424928
- -0.19466184
- -0.24860115
- -0.19531288
- -0.15401392
- -0.14577898
- -0.19655747
- -0.15645898
- -3.49380556
- -3.5317097
- -4.57108006
- -4.63425205
- -2.88247063
- -1.45679675
- -0.50290184
- -0.18521704
- -0.01123956
- -0.17483649
- -0.63132037
- -1.3248562
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- -0.24135757
- -1.04601971
- -2.04574044
- -3.84544799
- -7.28626119
- -7.3136314
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omol_elem_refs:
- 0.0
- -13.44558
- -78.82027
- -203.32564
- -398.94742
- -670.75275
- -1029.85403
- -1485.54188
- -2042.97832
- -2714.24015
- -3508.74317
- -4415.24203
- -5443.89712
- -6594.61834
- -7873.6878
- -9285.6593
- -10832.62132
- -12520.66852
- -14354.278
- -16323.54671
- -18436.47845
- -20696.18244
- -23110.5386
- -25682.99429
- -28418.37804
- -31317.92317
- -34383.42519
- -37623.46835
- -41039.92413
- -44637.38634
- -48417.14864
- -52373.87849
- -56512.76952
- -60836.14871
- -65344.28833
- -70041.24251
- -74929.56277
- -653.64777
- -833.31922
- -1038.0281
- -1273.96788
- -1542.45481
- -1850.74158
- -2193.91654
- -2577.18734
- -3004.13604
- -3477.52796
- -3997.31825
- -4563.75804
- -5171.82293
- -5828.85334
- -6535.61529
- -7291.54792
- -8099.87914
- -8962.17916
- -546.03214
- -690.6089
- -854.11237
- -12923.04096
- -14064.26124
- -15272.68689
- -16550.20551
- -17900.36515
- -19323.23406
- -20829.08848
- -22428.73258
- -24078.68008
- -25794.42097
- -27616.6819
- -29523.5526
- -31526.68012
- -33615.37779
- -1300.17791
- -1544.40924
- -1818.62298
- -2123.14417
- -2461.76028
- -2833.76287
- -3242.79895
- -3690.363
- -4174.99772
- -4691.75674
- -5245.36013
- -5838.12005
- -6469.07296
- -7140.86455
- -7854.60638
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omc_elem_refs:
- 0.0
- -0.02831808
- 4.512e-05
- -0.03227157
- -0.03842519
- -0.05829283
- -0.0845041
- -0.08806738
- -0.09021346
- -0.06669846
- -0.01218631
- -0.03650269
- -0.00059093
- -0.05787736
- -0.08730952
- -0.0975534
- -0.09264199
- -0.07124762
- -0.02374602
- -0.05299112
- -0.02631476
- -1.7772147
- -1.25083444
- -0.79579447
- -0.49099317
- -0.31414986
- -0.20292182
- -0.14011632
- -0.09929659
- -0.03771207
- -0.01117902
- -0.06168715
- -0.08873364
- -0.09512942
- -0.09035978
- -0.06910849
- -0.02244872
- -0.05303651
- -0.02871903
- -1.94805417
- -1.33379896
- -0.69169331
- -0.26184306
- -0.20631599
- -0.48251608
- -0.96911893
- -1.47569462
- -0.03845194
- -0.0142445
- -0.07118991
- -0.09940292
- -0.09235056
- -0.08755943
- -0.06544925
- -0.01246646
- -0.04692937
- -0.03225123
- -0.26086039
- -27.20024339
- -0.08412926
- -0.08225924
- -0.07799715
- -0.07806185
- 0.00043759
- -0.07459766
- 0.0
- -0.06842841
- -0.07758266
- -0.07025152
- -0.08055003
- -0.07118177
- -0.07159568
- -2.69202862
- -2.21926765
- -1.679756
- -1.06135075
- -0.4554231
- -0.14488432
- -0.18377098
- -0.03603118
- -0.01076585
- -0.06381411
- -0.0905623
- -0.10095787
- -0.09501217
- -0.0574478
- -0.00599173
- -0.04134751
- -0.0082683
- -0.08704692
- -0.49656425
- -5.24233138
- -2.32542606
- -4.3376616
- -5.96430676
- 0.0
- 0.0
- -0.03842519
- 0.0
- 0.0
"""
try:
ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML)
except yaml.YAMLError as e:
# st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow
print(f"Error parsing YAML reference energies: {e}")
ELEMENT_REF_ENERGIES = {} # Fallback
# Check if running on Streamlit Cloud vs locally
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud
MAX_ATOMS_CLOUD_UMA = 500
# Set page configuration
st.set_page_config(
page_title="Molecular Structure Analysis",
page_icon="🧪",
layout="wide"
)
# Title and description
st.markdown('## MLIP Playground', unsafe_allow_html=True)
st.write('#### Run, test and compare >17 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials')
st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True)
# Create a directory for sample structures if it doesn't exist
SAMPLE_DIR = "sample_structures"
os.makedirs(SAMPLE_DIR, exist_ok=True)
# Dictionary of sample structures
SAMPLE_STRUCTURES = {
"Water": "H2O.xyz",
"Methane": "CH4.xyz",
"Benzene": "C6H6.xyz",
"Ethane": "C2H6.xyz",
"Caffeine": "caffeine.xyz",
"Ibuprofen": "ibuprofen.xyz"
}
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += "Structure\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view = py3Dmol.view(width=width, height=height)
view.addModel(xyz_str, "xyz")
if style.lower() == 'ball_stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any()
cell = atoms_obj.get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any(): # Ensure cell is not None and not all zeros
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
return view
opt_log = [] # Define globally or pass around if necessary
table_placeholder = st.empty() # Define globally if updated from callback
def streamlit_log(opt):
global opt_log, table_placeholder
try:
energy = opt.atoms.get_potential_energy()
forces = opt.atoms.get_forces()
fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0
opt_log.append({
"Step": opt.nsteps,
"Energy (eV)": round(energy, 6),
"Fmax (eV/Å)": round(fmax_step, 6)
})
df = pd.DataFrame(opt_log)
table_placeholder.dataframe(df)
except Exception as e:
st.warning(f"Error in optimization logger: {e}")
def check_atom_limit(atoms_obj, selected_model):
if atoms_obj is None:
return True
num_atoms = len(atoms_obj)
limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD
if num_atoms > limit:
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.")
return False
return True
MACE_MODELS = {
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
}
FAIRCHEM_MODELS = {
"UMA Small": "uma-sm",
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
}
@st.cache_resource
def get_mace_model(model_path, device, selected_default_dtype):
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
@st.cache_resource
def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict
predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device)
if selected_model_name == "UMA Small":
calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc)
else:
calc = FAIRChemCalculator(predictor)
return calc
st.sidebar.markdown("## Input Options")
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
atoms = None
if input_method == "Upload File":
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_filepath = tmp_file.name
try:
atoms = read(tmp_filepath)
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading file: {str(e)}")
finally:
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
elif input_method == "Select Example":
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
if example_name:
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
try:
atoms = read(file_path)
st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading example: {str(e)}")
elif input_method == "Paste Content":
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
content = st.sidebar.text_area("Paste file content here:", height=200)
if content:
try:
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
suffix = suffix_map.get(file_format, ".xyz")
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(content.encode())
tmp_filepath = tmp_file.name
atoms = read(tmp_filepath)
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error parsing content: {str(e)}")
finally:
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
if atoms is not None:
if not hasattr(atoms, 'info'):
atoms.info = {}
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1)
st.sidebar.markdown("## Model Selection")
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
selected_task_type = None # For FairChem UMA
if model_type == "MACE":
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
model_path = MACE_MODELS[selected_model]
if selected_model == "MACE OMAT Medium":
st.sidebar.warning("Using model under Academic Software License (ASL).")
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
if model_type == "FairChem":
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
model_path = FAIRCHEM_MODELS[selected_model]
if selected_model == "UMA Small":
st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.")
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
if selected_task_type == "omol" and atoms is not None:
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0))
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
atoms.info["charge"] = charge
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
if atoms is not None:
if not check_atom_limit(atoms, selected_model):
st.stop() # Stop execution if limit exceeded
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1)
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
if device == "cpu" and torch.cuda.is_available():
st.sidebar.info("GPU is available but CPU was selected.")
elif device == "cpu" and not torch.cuda.is_available():
st.sidebar.info("No GPU detected. Using CPU.")
st.sidebar.markdown("## Task Selection")
task = st.sidebar.selectbox("Select Calculation Task:",
["Energy Calculation",
"Energy + Forces Calculation",
"Atomization/Cohesive Energy", # New Task Added
"Geometry Optimization",
"Cell + Geometry Optimization"])
if "Optimization" in task:
st.sidebar.markdown("### Optimization Parameters")
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
if atoms is not None:
col1, col2 = st.columns(2)
with col1:
st.markdown('### Structure Visualization', unsafe_allow_html=True)
view_3d = get_structure_viz2(atoms, style='stick', show_unit_cell=True, width=400, height=400)
st.components.v1.html(view_3d._make_html(), width=400, height=400)
st.markdown("### Structure Information")
atoms_info = {
"Number of Atoms": len(atoms),
"Chemical Formula": atoms.get_chemical_formula(),
"Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(),
"Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic",
"Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols()))))
}
for key, value in atoms_info.items():
st.write(f"**{key}:** {value}")
with col2:
st.markdown('## Calculation Setup', unsafe_allow_html=True)
st.markdown("### Selected Model")
st.write(f"**Model Type:** {model_type}")
st.write(f"**Model:** {selected_model}")
if model_type == "FairChem" and selected_model == "UMA Small":
st.write(f"**UMA Task Type:** {selected_task_type}")
st.write(f"**Device:** {device}")
st.markdown("### Selected Task")
st.write(f"**Task:** {task}")
if "Optimization" in task:
st.write(f"**Max Steps:** {max_steps}")
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
st.write(f"**Optimizer:** {optimizer_type}")
run_calculation = st.button("Run Calculation", type="primary")
if run_calculation:
results = {}
#global table_placeholder # Ensure they are accessible
opt_log = [] # Reset log for each run
if "Optimization" in task:
table_placeholder = st.empty() # Recreate placeholder for table
try:
with st.spinner("Running calculation... Please wait."):
calc_atoms = atoms.copy()
if model_type == "MACE":
# st.write("Setting up MACE calculator...")
calc = get_mace_model(model_path, device, selected_default_dtype)
else: # FairChem
# st.write("Setting up FairChem calculator...")
# Workaround for potential dtype issues when switching models
if device == "cpu": # Ensure torch default dtype matches if needed
torch.set_default_dtype(torch.float32)
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
calc_atoms.calc = calc
if task == "Energy Calculation":
energy = calc_atoms.get_potential_energy()
results["Energy"] = f"{energy:.6f} eV"
elif task == "Energy + Forces Calculation":
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Energy"] = f"{energy:.6f} eV"
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
elif task == "Atomization/Cohesive Energy":
st.write("Calculating system energy...")
E_system = calc_atoms.get_potential_energy()
num_atoms = len(calc_atoms)
if num_atoms == 0:
st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.")
results["Error"] = "System has no atoms."
else:
atomic_numbers = calc_atoms.get_atomic_numbers()
E_isolated_atoms_total = 0.0
calculation_possible = True
if model_type == "FairChem":
st.write("Fetching FairChem reference energies for isolated atoms...")
ref_key_suffix = "_elem_refs"
chosen_ref_list_name = None
if selected_model == "UMA Small":
if selected_task_type:
chosen_ref_list_name = selected_task_type + ref_key_suffix
elif "ESEN" in selected_model:
chosen_ref_list_name = "omol" + ref_key_suffix
if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES:
ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name]
missing_Z_refs = []
for Z_val in atomic_numbers:
if Z_val > 0 and Z_val < len(ref_energies):
E_isolated_atoms_total += ref_energies[Z_val]
else:
if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val)
if missing_Z_refs:
st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} "
f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). "
"These atoms are treated as having 0 reference energy.")
else:
st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' "
f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.")
results["Error"] = "Missing FairChem reference energies."
calculation_possible = False
elif model_type == "MACE":
st.write("Calculating isolated atom energies with MACE...")
unique_atomic_numbers = sorted(list(set(atomic_numbers)))
atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers}
progress_text = "Calculating isolated atom energies: 0% complete"
mace_progress_bar = st.progress(0, text=progress_text)
for i, Z_unique in enumerate(unique_atomic_numbers):
isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False)
if not hasattr(isolated_atom, 'info'): isolated_atom.info = {}
isolated_atom.info["charge"] = 0
isolated_atom.info["spin"] = 0
isolated_atom.calc = calc # Use the same MACE calculator
E_isolated_atom_type = isolated_atom.get_potential_energy()
E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique]
progress_val = (i + 1) / len(unique_atomic_numbers)
mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete")
mace_progress_bar.empty()
if calculation_possible:
is_periodic = any(calc_atoms.pbc)
if is_periodic:
cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms
results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom"
else:
atomization_E = E_isolated_atoms_total - E_system
results["Atomization Energy"] = f"{atomization_E:.6f} eV"
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
if optimizer_type == "BFGS":
opt = BFGS(opt_atoms_obj)
elif optimizer_type == "LBFGS":
opt = LBFGS(opt_atoms_obj)
else: # FIRE
opt = FIRE(opt_atoms_obj)
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
opt.attach(lambda: streamlit_log(opt), interval=1)
st.write(f"Running {task.lower()}...")
opt.run(fmax=fmax, steps=max_steps)
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Final Energy"] = f"{energy:.6f} eV"
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
results["Steps Taken"] = opt.get_number_of_steps()
results["Converged"] = "Yes" if opt.converged() else "No"
if task == "Cell + Geometry Optimization":
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist()
st.success("Calculation completed successfully!")
st.markdown("### Results")
for key, value in results.items():
st.write(f"**{key}:** {value}")
if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
st.markdown("### Optimized Structure")
# Need get_structure_viz function that takes atoms obj
def get_structure_viz_simple(atoms_obj_viz):
xyz_str_viz = f"{len(atoms_obj_viz)}\nStructure\n"
for atom_viz in atoms_obj_viz:
xyz_str_viz += f"{atom_viz.symbol} {atom_viz.position[0]:.6f} {atom_viz.position[1]:.6f} {atom_viz.position[2]:.6f}\n"
view_viz = py3Dmol.view(width=400, height=400)
view_viz.addModel(xyz_str_viz, "xyz")
view_viz.setStyle({'stick': {}})
if any(atoms_obj_viz.pbc): # Show cell for optimized periodic structures
cell_viz = atoms_obj_viz.get_cell()
if cell_viz is not None and cell_viz.any():
# Simplified cell drawing for brevity, use get_structure_viz2 if full cell needed
view_viz.addUnitCell({'box': {'lx':cell_viz.lengths()[0],'ly':cell_viz.lengths()[1],'lz':cell_viz.lengths()[2],
'hx':cell_viz.cellpar()[3],'hy':cell_viz.cellpar()[4],'hz':cell_viz.cellpar()[5]}})
view_viz.zoomTo()
view_viz.setBackgroundColor('white')
return view_viz
opt_view = get_structure_viz2(calc_atoms, style='stick', show_unit_cell=True, width=400, height=400)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
write(tmp_file_opt.name, calc_atoms, format="xyz")
tmp_filepath_opt = tmp_file_opt.name
with open(tmp_filepath_opt, 'r') as file_opt:
xyz_content_opt = file_opt.read()
st.download_button(
label="Download Optimized Structure (XYZ)",
data=xyz_content_opt,
file_name="optimized_structure.xyz",
mime="chemical/x-xyz"
)
os.unlink(tmp_filepath_opt)
except Exception as e:
st.error(f"🔴 Calculation error: {str(e)}")
st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
import traceback
st.error(f"Traceback: {traceback.format_exc()}")
else:
st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.")
st.markdown("---")
with st.expander('ℹ️ About This App & Foundational MLIPs'):
st.write("""
**Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).**
This application allows you to perform atomistic simulations using pre-trained foundational MLIPs
from the MACE and FairChem (by Meta AI) libraries.
**Features:**
- Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples.
- Select from various MACE and FairChem models.
- Calculate energies, forces, and perform geometry/cell optimizations.
- **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems).
- Visualize atomic structures in 3D and download results.
**Quick Start:**
1. **Input**: Choose an input method in the sidebar (e.g., "Select Example").
2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials).
3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization").
4. **Run**: Click "Run Calculation" and view the results.
**Atomization/Cohesive Energy Notes:**
- **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules).
- **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems.
- For **MACE models**, isolated atom energies are computed on-the-fly.
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
""")
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️")
st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) ([Fundamental AI Research (FAIR) team, Meta AI](https://ai.meta.com/research/fair/) and [Ananth Govind Rajan Group, IISc Bangalore](https://www.agrgroup.org/))")