Spaces:
Running
Running
import streamlit as st | |
import os | |
import tempfile | |
import torch | |
import numpy as np | |
from ase import Atoms | |
from ase.io import read, write | |
from ase.optimize import BFGS, LBFGS, FIRE | |
from ase.constraints import FixAtoms | |
from ase.filters import FrechetCellFilter | |
from ase.visualize import view | |
import py3Dmol | |
from mace.calculators import mace_mp | |
from fairchem.core import pretrained_mlip, FAIRChemCalculator | |
import pandas as pd | |
import yaml # Added for FairChem reference energies | |
from huggingface_hub import login | |
# try: | |
# hf_token = st.secrets["HF_TOKEN"]["token"] | |
# os.environ["HF_TOKEN"] = hf_token | |
# login(token=hf_token) | |
# except Exception as e: | |
# print("streamlit hf secret not defined/assigned") | |
try: | |
hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately | |
if hf_token: | |
login(token = hf_token) | |
else: | |
print("Hugging Face token not found. Some models might not be accessible.") | |
except Exception as e: | |
print(f"hf login error: {e}") | |
os.environ["STREAMLIT_WATCHER_TYPE"] = "none" | |
# YAML data for FairChem reference energies | |
ELEMENT_REF_ENERGIES_YAML = """ | |
oc20_elem_refs: | |
- 0.0 | |
- -0.16141512 | |
- 0.03262098 | |
- -0.04787699 | |
- -0.06299825 | |
- -0.14979306 | |
- -0.11657468 | |
- -0.10862579 | |
- -0.10298174 | |
- -0.03420248 | |
- 0.02673997 | |
- -0.03729558 | |
- 0.00515243 | |
- -0.07535697 | |
- -0.13663351 | |
- -0.12922852 | |
- -0.11796547 | |
- -0.07802946 | |
- -0.00672682 | |
- -0.04089589 | |
- -0.00024177 | |
- -1.74545186 | |
- -1.54220241 | |
- -1.0934019 | |
- -1.16168372 | |
- -1.23073475 | |
- -0.78852824 | |
- -0.71851599 | |
- -0.52465053 | |
- -0.02692092 | |
- -0.00317922 | |
- -0.06266862 | |
- -0.10835274 | |
- -0.12394474 | |
- -0.11351727 | |
- -0.07455817 | |
- -0.00258354 | |
- -0.04111325 | |
- -0.02090265 | |
- -1.89306078 | |
- -1.30591887 | |
- -0.63320009 | |
- -0.26230344 | |
- -0.2633669 | |
- -0.5160055 | |
- -0.95950798 | |
- -1.45589361 | |
- -0.0429969 | |
- -0.00026949 | |
- -0.05925609 | |
- -0.09734631 | |
- -0.12406852 | |
- -0.11427538 | |
- -0.07021442 | |
- 0.01091345 | |
- -0.05305289 | |
- -0.02427209 | |
- -0.19975668 | |
- -1.71692859 | |
- -1.53677781 | |
- -3.89987009 | |
- -10.70940462 | |
- -6.71693816 | |
- -0.28102249 | |
- -8.86944824 | |
- -7.95762687 | |
- -7.13041437 | |
- -6.64620014 | |
- -5.11482482 | |
- -4.42548227 | |
- 0.00848295 | |
- -0.06956227 | |
- -2.6748853 | |
- -2.21153293 | |
- -1.67367741 | |
- -1.07636151 | |
- -0.79009981 | |
- -0.16387243 | |
- -0.18164401 | |
- -0.04122529 | |
- -0.00041833 | |
- -0.05259382 | |
- -0.0934314 | |
- -0.11023834 | |
- -0.10039175 | |
- -0.06069209 | |
- 0.01790437 | |
- -0.04694024 | |
- 0.00334084 | |
- -0.06030621 | |
- -0.58793619 | |
- -1.27821808 | |
- -4.97483577 | |
- -5.66985655 | |
- -8.43154622 | |
- -11.15001317 | |
- -12.95770812 | |
- 0.0 | |
- -14.47602729 | |
- 0.0 | |
odac_elem_refs: | |
- 0.0 | |
- -1.11737936 | |
- -0.00011835 | |
- -0.2941727 | |
- -0.03868426 | |
- -0.34862832 | |
- -1.31552566 | |
- -3.12457285 | |
- -1.6052078 | |
- -0.49653389 | |
- -0.01137327 | |
- -0.21957281 | |
- -0.0008343 | |
- -0.2750172 | |
- -0.88417265 | |
- -1.887378 | |
- -0.94903558 | |
- -0.31628167 | |
- -0.02014536 | |
- -0.15901053 | |
- -0.00731884 | |
- -1.96521355 | |
- -1.89045209 | |
- -2.53057428 | |
- -5.43600675 | |
- -5.09739336 | |
- -3.03088746 | |
- -1.23786562 | |
- -0.40650749 | |
- -0.2416017 | |
- -0.01139188 | |
- -0.26282496 | |
- -0.82446455 | |
- -1.70237206 | |
- -0.84245376 | |
- -0.28544892 | |
- -0.02239991 | |
- -0.14115912 | |
- -0.02840799 | |
- -2.09540994 | |
- -1.85863996 | |
- -1.12257399 | |
- -4.32965355 | |
- -3.30670045 | |
- -1.19460755 | |
- -1.26257601 | |
- -1.46832888 | |
- -0.19779414 | |
- -0.0144274 | |
- -0.23668767 | |
- -0.70836953 | |
- -1.43186113 | |
- -0.71701186 | |
- -0.24883129 | |
- -0.01118184 | |
- -0.13173447 | |
- -0.0318395 | |
- -0.41195547 | |
- -1.23134873 | |
- -2.03082996 | |
- 0.1375954 | |
- -5.45866275 | |
- -7.59139905 | |
- -5.99965965 | |
- -8.43495767 | |
- -2.6578407 | |
- -7.77349787 | |
- -5.30762201 | |
- -5.15109657 | |
- -4.41466995 | |
- -0.02995219 | |
- -0.2544495 | |
- -3.23821202 | |
- -3.45887214 | |
- -4.53635003 | |
- -4.60979468 | |
- -2.90707964 | |
- -1.28286153 | |
- -0.57716664 | |
- -0.18337108 | |
- -0.01135944 | |
- -0.22045398 | |
- -0.66150479 | |
- -1.32506342 | |
- -0.66500178 | |
- -0.22643927 | |
- -0.00728197 | |
- -0.11208472 | |
- -0.00757856 | |
- -0.21798637 | |
- -0.91078787 | |
- -1.78187161 | |
- -3.89912261 | |
- -3.94192659 | |
- -7.59026042 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
omat_elem_refs: | |
- 0.0 | |
- -1.11700253 | |
- 0.00079886 | |
- -0.29731164 | |
- -0.04129868 | |
- -0.29106192 | |
- -1.27751531 | |
- -3.12342715 | |
- -1.54797136 | |
- -0.43969356 | |
- -0.01250908 | |
- -0.22855413 | |
- -0.00943179 | |
- -0.21707638 | |
- -0.82619133 | |
- -1.88667434 | |
- -0.89093583 | |
- -0.25816211 | |
- -0.02414768 | |
- -0.17662425 | |
- -0.02568319 | |
- -2.13001165 | |
- -2.38688845 | |
- -3.55934233 | |
- -5.44700879 | |
- -5.14749562 | |
- -3.30662847 | |
- -1.42167737 | |
- -0.63181379 | |
- -0.23449167 | |
- -0.01146636 | |
- -0.21291259 | |
- -0.77939897 | |
- -1.70148487 | |
- -0.78386705 | |
- -0.22690657 | |
- -0.02245409 | |
- -0.16092396 | |
- -0.02798717 | |
- -2.25685695 | |
- -2.23690495 | |
- -2.15347771 | |
- -4.60251809 | |
- -3.36416792 | |
- -2.23062607 | |
- -1.15550917 | |
- -1.47553527 | |
- -0.19918102 | |
- -0.01475888 | |
- -0.19767692 | |
- -0.68005773 | |
- -1.43073368 | |
- -0.65790462 | |
- -0.18915279 | |
- -0.01179476 | |
- -0.13507902 | |
- -0.03056979 | |
- -0.36017439 | |
- -0.86279246 | |
- -0.20573327 | |
- -0.2734463 | |
- -0.20046965 | |
- -0.25444338 | |
- -8.37972664 | |
- -9.58424928 | |
- -0.19466184 | |
- -0.24860115 | |
- -0.19531288 | |
- -0.15401392 | |
- -0.14577898 | |
- -0.19655747 | |
- -0.15645898 | |
- -3.49380556 | |
- -3.5317097 | |
- -4.57108006 | |
- -4.63425205 | |
- -2.88247063 | |
- -1.45679675 | |
- -0.50290184 | |
- -0.18521704 | |
- -0.01123956 | |
- -0.17483649 | |
- -0.63132037 | |
- -1.3248562 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- -0.24135757 | |
- -1.04601971 | |
- -2.04574044 | |
- -3.84544799 | |
- -7.28626119 | |
- -7.3136314 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
omol_elem_refs: | |
- 0.0 | |
- -13.44558 | |
- -78.82027 | |
- -203.32564 | |
- -398.94742 | |
- -670.75275 | |
- -1029.85403 | |
- -1485.54188 | |
- -2042.97832 | |
- -2714.24015 | |
- -3508.74317 | |
- -4415.24203 | |
- -5443.89712 | |
- -6594.61834 | |
- -7873.6878 | |
- -9285.6593 | |
- -10832.62132 | |
- -12520.66852 | |
- -14354.278 | |
- -16323.54671 | |
- -18436.47845 | |
- -20696.18244 | |
- -23110.5386 | |
- -25682.99429 | |
- -28418.37804 | |
- -31317.92317 | |
- -34383.42519 | |
- -37623.46835 | |
- -41039.92413 | |
- -44637.38634 | |
- -48417.14864 | |
- -52373.87849 | |
- -56512.76952 | |
- -60836.14871 | |
- -65344.28833 | |
- -70041.24251 | |
- -74929.56277 | |
- -653.64777 | |
- -833.31922 | |
- -1038.0281 | |
- -1273.96788 | |
- -1542.45481 | |
- -1850.74158 | |
- -2193.91654 | |
- -2577.18734 | |
- -3004.13604 | |
- -3477.52796 | |
- -3997.31825 | |
- -4563.75804 | |
- -5171.82293 | |
- -5828.85334 | |
- -6535.61529 | |
- -7291.54792 | |
- -8099.87914 | |
- -8962.17916 | |
- -546.03214 | |
- -690.6089 | |
- -854.11237 | |
- -12923.04096 | |
- -14064.26124 | |
- -15272.68689 | |
- -16550.20551 | |
- -17900.36515 | |
- -19323.23406 | |
- -20829.08848 | |
- -22428.73258 | |
- -24078.68008 | |
- -25794.42097 | |
- -27616.6819 | |
- -29523.5526 | |
- -31526.68012 | |
- -33615.37779 | |
- -1300.17791 | |
- -1544.40924 | |
- -1818.62298 | |
- -2123.14417 | |
- -2461.76028 | |
- -2833.76287 | |
- -3242.79895 | |
- -3690.363 | |
- -4174.99772 | |
- -4691.75674 | |
- -5245.36013 | |
- -5838.12005 | |
- -6469.07296 | |
- -7140.86455 | |
- -7854.60638 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
- 0.0 | |
omc_elem_refs: | |
- 0.0 | |
- -0.02831808 | |
- 4.512e-05 | |
- -0.03227157 | |
- -0.03842519 | |
- -0.05829283 | |
- -0.0845041 | |
- -0.08806738 | |
- -0.09021346 | |
- -0.06669846 | |
- -0.01218631 | |
- -0.03650269 | |
- -0.00059093 | |
- -0.05787736 | |
- -0.08730952 | |
- -0.0975534 | |
- -0.09264199 | |
- -0.07124762 | |
- -0.02374602 | |
- -0.05299112 | |
- -0.02631476 | |
- -1.7772147 | |
- -1.25083444 | |
- -0.79579447 | |
- -0.49099317 | |
- -0.31414986 | |
- -0.20292182 | |
- -0.14011632 | |
- -0.09929659 | |
- -0.03771207 | |
- -0.01117902 | |
- -0.06168715 | |
- -0.08873364 | |
- -0.09512942 | |
- -0.09035978 | |
- -0.06910849 | |
- -0.02244872 | |
- -0.05303651 | |
- -0.02871903 | |
- -1.94805417 | |
- -1.33379896 | |
- -0.69169331 | |
- -0.26184306 | |
- -0.20631599 | |
- -0.48251608 | |
- -0.96911893 | |
- -1.47569462 | |
- -0.03845194 | |
- -0.0142445 | |
- -0.07118991 | |
- -0.09940292 | |
- -0.09235056 | |
- -0.08755943 | |
- -0.06544925 | |
- -0.01246646 | |
- -0.04692937 | |
- -0.03225123 | |
- -0.26086039 | |
- -27.20024339 | |
- -0.08412926 | |
- -0.08225924 | |
- -0.07799715 | |
- -0.07806185 | |
- 0.00043759 | |
- -0.07459766 | |
- 0.0 | |
- -0.06842841 | |
- -0.07758266 | |
- -0.07025152 | |
- -0.08055003 | |
- -0.07118177 | |
- -0.07159568 | |
- -2.69202862 | |
- -2.21926765 | |
- -1.679756 | |
- -1.06135075 | |
- -0.4554231 | |
- -0.14488432 | |
- -0.18377098 | |
- -0.03603118 | |
- -0.01076585 | |
- -0.06381411 | |
- -0.0905623 | |
- -0.10095787 | |
- -0.09501217 | |
- -0.0574478 | |
- -0.00599173 | |
- -0.04134751 | |
- -0.0082683 | |
- -0.08704692 | |
- -0.49656425 | |
- -5.24233138 | |
- -2.32542606 | |
- -4.3376616 | |
- -5.96430676 | |
- 0.0 | |
- 0.0 | |
- -0.03842519 | |
- 0.0 | |
- 0.0 | |
""" | |
try: | |
ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML) | |
except yaml.YAMLError as e: | |
# st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow | |
print(f"Error parsing YAML reference energies: {e}") | |
ELEMENT_REF_ENERGIES = {} # Fallback | |
# Check if running on Streamlit Cloud vs locally | |
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud' | |
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud | |
MAX_ATOMS_CLOUD_UMA = 500 | |
# Set page configuration | |
st.set_page_config( | |
page_title="Molecular Structure Analysis", | |
page_icon="🧪", | |
layout="wide" | |
) | |
# Title and description | |
st.markdown('## MLIP Playground', unsafe_allow_html=True) | |
st.write('#### Run, test and compare >17 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials') | |
st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True) | |
# Create a directory for sample structures if it doesn't exist | |
SAMPLE_DIR = "sample_structures" | |
os.makedirs(SAMPLE_DIR, exist_ok=True) | |
# Dictionary of sample structures | |
SAMPLE_STRUCTURES = { | |
"Water": "H2O.xyz", | |
"Methane": "CH4.xyz", | |
"Benzene": "C6H6.xyz", | |
"Ethane": "C2H6.xyz", | |
"Caffeine": "caffeine.xyz", | |
"Ibuprofen": "ibuprofen.xyz" | |
} | |
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400): | |
xyz_str = "" | |
xyz_str += f"{len(atoms_obj)}\n" | |
xyz_str += "Structure\n" | |
for atom in atoms_obj: | |
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
view = py3Dmol.view(width=width, height=height) | |
view.addModel(xyz_str, "xyz") | |
if style.lower() == 'ball_stick': | |
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) | |
elif style.lower() == 'stick': | |
view.setStyle({'stick': {}}) | |
elif style.lower() == 'ball': | |
view.setStyle({'sphere': {'scale': 0.4}}) | |
else: | |
view.setStyle({'stick': {'radius': 0.15}}) | |
if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any() | |
cell = atoms_obj.get_cell() | |
origin = np.array([0.0, 0.0, 0.0]) | |
if cell is not None and cell.any(): # Ensure cell is not None and not all zeros | |
edges = [ | |
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), | |
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), | |
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), | |
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), | |
(cell[0] + cell[1], cell[0] + cell[1] + cell[2]) | |
] | |
for start, end in edges: | |
view.addCylinder({ | |
'start': {'x': start[0], 'y': start[1], 'z': start[2]}, | |
'end': {'x': end[0], 'y': end[1], 'z': end[2]}, | |
'radius': 0.05, 'color': 'black', 'alpha': 0.7 | |
}) | |
view.zoomTo() | |
view.setBackgroundColor('white') | |
return view | |
opt_log = [] # Define globally or pass around if necessary | |
table_placeholder = st.empty() # Define globally if updated from callback | |
def streamlit_log(opt): | |
global opt_log, table_placeholder | |
try: | |
energy = opt.atoms.get_potential_energy() | |
forces = opt.atoms.get_forces() | |
fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0 | |
opt_log.append({ | |
"Step": opt.nsteps, | |
"Energy (eV)": round(energy, 6), | |
"Fmax (eV/Å)": round(fmax_step, 6) | |
}) | |
df = pd.DataFrame(opt_log) | |
table_placeholder.dataframe(df) | |
except Exception as e: | |
st.warning(f"Error in optimization logger: {e}") | |
def check_atom_limit(atoms_obj, selected_model): | |
if atoms_obj is None: | |
return True | |
num_atoms = len(atoms_obj) | |
limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD | |
if num_atoms > limit: | |
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.") | |
return False | |
return True | |
MACE_MODELS = { | |
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", | |
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", | |
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", | |
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", | |
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", | |
"MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", | |
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model", | |
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model", | |
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model", | |
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code | |
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", | |
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", | |
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", | |
} | |
FAIRCHEM_MODELS = { | |
"UMA Small": "uma-sm", | |
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol", | |
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol", | |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol" | |
} | |
def get_mace_model(model_path, device, selected_default_dtype): | |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype) | |
def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict | |
predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device) | |
if selected_model_name == "UMA Small": | |
calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc) | |
else: | |
calc = FAIRChemCalculator(predictor) | |
return calc | |
st.sidebar.markdown("## Input Options") | |
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"]) | |
atoms = None | |
if input_method == "Upload File": | |
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"]) | |
if uploaded_file: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
tmp_filepath = tmp_file.name | |
try: | |
atoms = read(tmp_filepath) | |
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!") | |
except Exception as e: | |
st.sidebar.error(f"Error loading file: {str(e)}") | |
finally: | |
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
os.unlink(tmp_filepath) | |
elif input_method == "Select Example": | |
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys())) | |
if example_name: | |
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name]) | |
try: | |
atoms = read(file_path) | |
st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!") | |
except Exception as e: | |
st.sidebar.error(f"Error loading example: {str(e)}") | |
elif input_method == "Paste Content": | |
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"]) | |
content = st.sidebar.text_area("Paste file content here:", height=200) | |
if content: | |
try: | |
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"} | |
suffix = suffix_map.get(file_format, ".xyz") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: | |
tmp_file.write(content.encode()) | |
tmp_filepath = tmp_file.name | |
atoms = read(tmp_filepath) | |
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!") | |
except Exception as e: | |
st.sidebar.error(f"Error parsing content: {str(e)}") | |
finally: | |
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
os.unlink(tmp_filepath) | |
if atoms is not None: | |
if not hasattr(atoms, 'info'): | |
atoms.info = {} | |
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge | |
atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1) | |
st.sidebar.markdown("## Model Selection") | |
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"]) | |
selected_task_type = None # For FairChem UMA | |
if model_type == "MACE": | |
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys())) | |
model_path = MACE_MODELS[selected_model] | |
if selected_model == "MACE OMAT Medium": | |
st.sidebar.warning("Using model under Academic Software License (ASL).") | |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64']) | |
if model_type == "FairChem": | |
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys())) | |
model_path = FAIRCHEM_MODELS[selected_model] | |
if selected_model == "UMA Small": | |
st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.") | |
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"]) | |
if selected_task_type == "omol" and atoms is not None: | |
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0)) | |
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S | |
atoms.info["charge"] = charge | |
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity | |
if atoms is not None: | |
if not check_atom_limit(atoms, selected_model): | |
st.stop() # Stop execution if limit exceeded | |
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1) | |
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu" | |
if device == "cpu" and torch.cuda.is_available(): | |
st.sidebar.info("GPU is available but CPU was selected.") | |
elif device == "cpu" and not torch.cuda.is_available(): | |
st.sidebar.info("No GPU detected. Using CPU.") | |
st.sidebar.markdown("## Task Selection") | |
task = st.sidebar.selectbox("Select Calculation Task:", | |
["Energy Calculation", | |
"Energy + Forces Calculation", | |
"Atomization/Cohesive Energy", # New Task Added | |
"Geometry Optimization", | |
"Cell + Geometry Optimization"]) | |
if "Optimization" in task: | |
st.sidebar.markdown("### Optimization Parameters") | |
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps | |
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax | |
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type | |
if atoms is not None: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown('### Structure Visualization', unsafe_allow_html=True) | |
view_3d = get_structure_viz2(atoms, style='stick', show_unit_cell=True, width=400, height=400) | |
st.components.v1.html(view_3d._make_html(), width=400, height=400) | |
st.markdown("### Structure Information") | |
atoms_info = { | |
"Number of Atoms": len(atoms), | |
"Chemical Formula": atoms.get_chemical_formula(), | |
"Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(), | |
"Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic", | |
"Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols())))) | |
} | |
for key, value in atoms_info.items(): | |
st.write(f"**{key}:** {value}") | |
with col2: | |
st.markdown('## Calculation Setup', unsafe_allow_html=True) | |
st.markdown("### Selected Model") | |
st.write(f"**Model Type:** {model_type}") | |
st.write(f"**Model:** {selected_model}") | |
if model_type == "FairChem" and selected_model == "UMA Small": | |
st.write(f"**UMA Task Type:** {selected_task_type}") | |
st.write(f"**Device:** {device}") | |
st.markdown("### Selected Task") | |
st.write(f"**Task:** {task}") | |
if "Optimization" in task: | |
st.write(f"**Max Steps:** {max_steps}") | |
st.write(f"**Convergence Threshold:** {fmax} eV/Å") | |
st.write(f"**Optimizer:** {optimizer_type}") | |
run_calculation = st.button("Run Calculation", type="primary") | |
if run_calculation: | |
results = {} | |
#global table_placeholder # Ensure they are accessible | |
opt_log = [] # Reset log for each run | |
if "Optimization" in task: | |
table_placeholder = st.empty() # Recreate placeholder for table | |
try: | |
with st.spinner("Running calculation... Please wait."): | |
calc_atoms = atoms.copy() | |
if model_type == "MACE": | |
# st.write("Setting up MACE calculator...") | |
calc = get_mace_model(model_path, device, selected_default_dtype) | |
else: # FairChem | |
# st.write("Setting up FairChem calculator...") | |
# Workaround for potential dtype issues when switching models | |
if device == "cpu": # Ensure torch default dtype matches if needed | |
torch.set_default_dtype(torch.float32) | |
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call | |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type) | |
calc_atoms.calc = calc | |
if task == "Energy Calculation": | |
energy = calc_atoms.get_potential_energy() | |
results["Energy"] = f"{energy:.6f} eV" | |
elif task == "Energy + Forces Calculation": | |
energy = calc_atoms.get_potential_energy() | |
forces = calc_atoms.get_forces() | |
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
results["Energy"] = f"{energy:.6f} eV" | |
results["Maximum Force"] = f"{max_force:.6f} eV/Å" | |
elif task == "Atomization/Cohesive Energy": | |
st.write("Calculating system energy...") | |
E_system = calc_atoms.get_potential_energy() | |
num_atoms = len(calc_atoms) | |
if num_atoms == 0: | |
st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.") | |
results["Error"] = "System has no atoms." | |
else: | |
atomic_numbers = calc_atoms.get_atomic_numbers() | |
E_isolated_atoms_total = 0.0 | |
calculation_possible = True | |
if model_type == "FairChem": | |
st.write("Fetching FairChem reference energies for isolated atoms...") | |
ref_key_suffix = "_elem_refs" | |
chosen_ref_list_name = None | |
if selected_model == "UMA Small": | |
if selected_task_type: | |
chosen_ref_list_name = selected_task_type + ref_key_suffix | |
elif "ESEN" in selected_model: | |
chosen_ref_list_name = "omol" + ref_key_suffix | |
if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES: | |
ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name] | |
missing_Z_refs = [] | |
for Z_val in atomic_numbers: | |
if Z_val > 0 and Z_val < len(ref_energies): | |
E_isolated_atoms_total += ref_energies[Z_val] | |
else: | |
if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val) | |
if missing_Z_refs: | |
st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} " | |
f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). " | |
"These atoms are treated as having 0 reference energy.") | |
else: | |
st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' " | |
f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.") | |
results["Error"] = "Missing FairChem reference energies." | |
calculation_possible = False | |
elif model_type == "MACE": | |
st.write("Calculating isolated atom energies with MACE...") | |
unique_atomic_numbers = sorted(list(set(atomic_numbers))) | |
atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers} | |
progress_text = "Calculating isolated atom energies: 0% complete" | |
mace_progress_bar = st.progress(0, text=progress_text) | |
for i, Z_unique in enumerate(unique_atomic_numbers): | |
isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False) | |
if not hasattr(isolated_atom, 'info'): isolated_atom.info = {} | |
isolated_atom.info["charge"] = 0 | |
isolated_atom.info["spin"] = 0 | |
isolated_atom.calc = calc # Use the same MACE calculator | |
E_isolated_atom_type = isolated_atom.get_potential_energy() | |
E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique] | |
progress_val = (i + 1) / len(unique_atomic_numbers) | |
mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete") | |
mace_progress_bar.empty() | |
if calculation_possible: | |
is_periodic = any(calc_atoms.pbc) | |
if is_periodic: | |
cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms | |
results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom" | |
else: | |
atomization_E = E_isolated_atoms_total - E_system | |
results["Atomization Energy"] = f"{atomization_E:.6f} eV" | |
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV" | |
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV" | |
elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt | |
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms | |
if optimizer_type == "BFGS": | |
opt = BFGS(opt_atoms_obj) | |
elif optimizer_type == "LBFGS": | |
opt = LBFGS(opt_atoms_obj) | |
else: # FIRE | |
opt = FIRE(opt_atoms_obj) | |
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly | |
opt.attach(lambda: streamlit_log(opt), interval=1) | |
st.write(f"Running {task.lower()}...") | |
opt.run(fmax=fmax, steps=max_steps) | |
energy = calc_atoms.get_potential_energy() | |
forces = calc_atoms.get_forces() | |
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
results["Final Energy"] = f"{energy:.6f} eV" | |
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å" | |
results["Steps Taken"] = opt.get_number_of_steps() | |
results["Converged"] = "Yes" if opt.converged() else "No" | |
if task == "Cell + Geometry Optimization": | |
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist() | |
st.success("Calculation completed successfully!") | |
st.markdown("### Results") | |
for key, value in results.items(): | |
st.write(f"**{key}:** {value}") | |
if "Optimization" in task and "Final Energy" in results: # Check if opt was successful | |
st.markdown("### Optimized Structure") | |
# Need get_structure_viz function that takes atoms obj | |
def get_structure_viz_simple(atoms_obj_viz): | |
xyz_str_viz = f"{len(atoms_obj_viz)}\nStructure\n" | |
for atom_viz in atoms_obj_viz: | |
xyz_str_viz += f"{atom_viz.symbol} {atom_viz.position[0]:.6f} {atom_viz.position[1]:.6f} {atom_viz.position[2]:.6f}\n" | |
view_viz = py3Dmol.view(width=400, height=400) | |
view_viz.addModel(xyz_str_viz, "xyz") | |
view_viz.setStyle({'stick': {}}) | |
if any(atoms_obj_viz.pbc): # Show cell for optimized periodic structures | |
cell_viz = atoms_obj_viz.get_cell() | |
if cell_viz is not None and cell_viz.any(): | |
# Simplified cell drawing for brevity, use get_structure_viz2 if full cell needed | |
view_viz.addUnitCell({'box': {'lx':cell_viz.lengths()[0],'ly':cell_viz.lengths()[1],'lz':cell_viz.lengths()[2], | |
'hx':cell_viz.cellpar()[3],'hy':cell_viz.cellpar()[4],'hz':cell_viz.cellpar()[5]}}) | |
view_viz.zoomTo() | |
view_viz.setBackgroundColor('white') | |
return view_viz | |
opt_view = get_structure_viz2(calc_atoms, style='stick', show_unit_cell=True, width=400, height=400) | |
st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt: | |
write(tmp_file_opt.name, calc_atoms, format="xyz") | |
tmp_filepath_opt = tmp_file_opt.name | |
with open(tmp_filepath_opt, 'r') as file_opt: | |
xyz_content_opt = file_opt.read() | |
st.download_button( | |
label="Download Optimized Structure (XYZ)", | |
data=xyz_content_opt, | |
file_name="optimized_structure.xyz", | |
mime="chemical/x-xyz" | |
) | |
os.unlink(tmp_filepath_opt) | |
except Exception as e: | |
st.error(f"🔴 Calculation error: {str(e)}") | |
st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).") | |
import traceback | |
st.error(f"Traceback: {traceback.format_exc()}") | |
else: | |
st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.") | |
st.markdown("---") | |
with st.expander('ℹ️ About This App & Foundational MLIPs'): | |
st.write(""" | |
**Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).** | |
This application allows you to perform atomistic simulations using pre-trained foundational MLIPs | |
from the MACE and FairChem (by Meta AI) libraries. | |
**Features:** | |
- Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples. | |
- Select from various MACE and FairChem models. | |
- Calculate energies, forces, and perform geometry/cell optimizations. | |
- **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems). | |
- Visualize atomic structures in 3D and download results. | |
**Quick Start:** | |
1. **Input**: Choose an input method in the sidebar (e.g., "Select Example"). | |
2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials). | |
3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization"). | |
4. **Run**: Click "Run Calculation" and view the results. | |
**Atomization/Cohesive Energy Notes:** | |
- **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules). | |
- **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems. | |
- For **MACE models**, isolated atom energies are computed on-the-fly. | |
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references. | |
""") | |
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️") | |
st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) ([Fundamental AI Research (FAIR) team, Meta AI](https://ai.meta.com/research/fair/) and [Ananth Govind Rajan Group, IISc Bangalore](https://www.agrgroup.org/))") | |