Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +495 -37
src/streamlit_app.py
CHANGED
@@ -1,40 +1,498 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import pandas as pd
|
4 |
-
import streamlit as st
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
"
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import torch
|
5 |
import numpy as np
|
6 |
+
from ase import Atoms
|
7 |
+
from ase.io import read, write
|
8 |
+
from ase.optimize import BFGS, LBFGS, FIRE
|
9 |
+
from ase.constraints import FixAtoms
|
10 |
+
from ase.filters import FrechetCellFilter
|
11 |
+
from ase.visualize import view
|
12 |
+
import py3Dmol
|
13 |
+
from mace.calculators import mace_mp
|
14 |
+
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
15 |
import pandas as pd
|
|
|
16 |
|
17 |
+
from huggingface_hub import login
|
18 |
+
|
19 |
+
try:
|
20 |
+
hf_token = st.secrets["HF_TOKEN"]["token"]
|
21 |
+
os.environ["HF_TOKEN"] = hf_token
|
22 |
+
login(token=hf_token)
|
23 |
+
except Exception as e:
|
24 |
+
print("streamlit hf secret not defined/assigned")
|
25 |
+
|
26 |
+
import os
|
27 |
+
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
|
28 |
+
|
29 |
+
# Check if running on Streamlit Cloud vs locally
|
30 |
+
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
|
31 |
+
MAX_ATOMS_CLOUD = 50 # Maximum atoms allowed on Streamlit Cloud
|
32 |
+
MAX_ATOMS_CLOUD_UMA = 15
|
33 |
+
|
34 |
+
# Set page configuration
|
35 |
+
st.set_page_config(
|
36 |
+
page_title="Molecular Structure Analysis",
|
37 |
+
page_icon="🧪",
|
38 |
+
layout="wide"
|
39 |
+
)
|
40 |
+
|
41 |
+
# Add CSS for better formatting
|
42 |
+
# st.markdown("""
|
43 |
+
# <style>
|
44 |
+
# .stApp {
|
45 |
+
# max-width: 1200px;
|
46 |
+
# margin: 0 auto;
|
47 |
+
# }
|
48 |
+
# .main-header {
|
49 |
+
# font-size: 2.5rem;
|
50 |
+
# font-weight: bold;
|
51 |
+
# margin-bottom: 1rem;
|
52 |
+
# }
|
53 |
+
# .section-header {
|
54 |
+
# font-size: 1.5rem;
|
55 |
+
# font-weight: bold;
|
56 |
+
# margin-top: 1.5rem;
|
57 |
+
# margin-bottom: 1rem;
|
58 |
+
# }
|
59 |
+
# .info-text {
|
60 |
+
# font-size: 1rem;
|
61 |
+
# color: #555;
|
62 |
+
# }
|
63 |
+
# </style>
|
64 |
+
# """, unsafe_allow_html=True)
|
65 |
+
|
66 |
+
# Title and description
|
67 |
+
st.markdown('## MLIP Playground', unsafe_allow_html=True)
|
68 |
+
st.write('#### Run atomistic simulations with state-of-the-art universal machine learning interatomic potentials (MLIPs) for molecules and materials')
|
69 |
+
st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True)
|
70 |
+
|
71 |
+
# Create a directory for sample structures if it doesn't exist
|
72 |
+
SAMPLE_DIR = "sample_structures"
|
73 |
+
os.makedirs(SAMPLE_DIR, exist_ok=True)
|
74 |
+
|
75 |
+
# Dictionary of sample structures
|
76 |
+
SAMPLE_STRUCTURES = {
|
77 |
+
"Water": "H2O.xyz",
|
78 |
+
"Methane": "CH4.xyz",
|
79 |
+
"Benzene": "C6H6.xyz",
|
80 |
+
"Ethane": "C2H6.xyz",
|
81 |
+
"Caffeine": "caffeine.xyz",
|
82 |
+
"Ibuprofen": "ibuprofen.xyz"
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
# Custom logger that updates the table
|
88 |
+
def streamlit_log(opt):
|
89 |
+
energy = opt.atoms.get_potential_energy()
|
90 |
+
forces = opt.atoms.get_forces()
|
91 |
+
fmax_step = np.max(np.linalg.norm(forces, axis=1))
|
92 |
+
opt_log.append({
|
93 |
+
"Step": opt.nsteps,
|
94 |
+
"Energy (eV)": round(energy, 6),
|
95 |
+
"Fmax (eV/Å)": round(fmax_step, 6)
|
96 |
+
})
|
97 |
+
df = pd.DataFrame(opt_log)
|
98 |
+
table_placeholder.dataframe(df)
|
99 |
+
|
100 |
+
# Function to check atom count limits
|
101 |
+
def check_atom_limit(atoms_obj, selected_model):
|
102 |
+
if atoms_obj is None:
|
103 |
+
return True
|
104 |
+
|
105 |
+
num_atoms = len(atoms_obj)
|
106 |
+
if ('UMA' in selected_model or 'ESEN MD' in selected_model) and num_atoms > MAX_ATOMS_CLOUD_UMA:
|
107 |
+
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, which exceeds the {MAX_ATOMS_CLOUD_UMA} atom limit for Streamlit Cloud deployments for large sized FairChem models. For larger systems, please download the repository from GitHub and run it locally on your machine where no atom limit applies.")
|
108 |
+
st.info("💡 Running locally allows you to process much larger structures and use your own computational resources more efficiently.")
|
109 |
+
return False
|
110 |
+
if num_atoms > MAX_ATOMS_CLOUD:
|
111 |
+
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, which exceeds the {MAX_ATOMS_CLOUD} atom limit for Streamlit Cloud deployments. For larger systems, please download the repository from GitHub and run it locally on your machine where no atom limit applies.")
|
112 |
+
st.info("💡 Running locally allows you to process much larger structures and use your own computational resources more efficiently.")
|
113 |
+
return False
|
114 |
+
return True
|
115 |
+
|
116 |
+
|
117 |
+
# Define the available MACE models
|
118 |
+
MACE_MODELS = {
|
119 |
+
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
|
120 |
+
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
|
121 |
+
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
|
122 |
+
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
|
123 |
+
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
|
124 |
+
"MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
|
125 |
+
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
|
126 |
+
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
|
127 |
+
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
|
128 |
+
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
|
129 |
+
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
|
130 |
+
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
|
131 |
+
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
|
132 |
+
}
|
133 |
+
|
134 |
+
# Define the available FairChem models
|
135 |
+
FAIRCHEM_MODELS = {
|
136 |
+
"UMA Small": "uma-sm",
|
137 |
+
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
|
138 |
+
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
|
139 |
+
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
140 |
+
}
|
141 |
+
|
142 |
+
@st.cache_resource
|
143 |
+
def get_mace_model(model_path, device, selected_default_dtype):
|
144 |
+
# Create a model of the specified type.
|
145 |
+
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
146 |
+
|
147 |
+
@st.cache_resource
|
148 |
+
def get_fairchem_model(selected_model, model_path, device, selected_task_type):
|
149 |
+
predictor = pretrained_mlip.get_predict_unit(model_path, device=device)
|
150 |
+
if selected_model == "UMA Small":
|
151 |
+
calc = FAIRChemCalculator(predictor, task_name=selected_task_type)
|
152 |
+
else:
|
153 |
+
calc = FAIRChemCalculator(predictor)
|
154 |
+
return calc
|
155 |
+
|
156 |
+
# Sidebar for file input and parameters
|
157 |
+
st.sidebar.markdown("## Input Options")
|
158 |
+
|
159 |
+
# Input method selection
|
160 |
+
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
|
161 |
+
|
162 |
+
# Initialize atoms variable
|
163 |
+
atoms = None
|
164 |
+
|
165 |
+
# File upload option
|
166 |
+
if input_method == "Upload File":
|
167 |
+
uploaded_file = st.sidebar.file_uploader("Upload structure file",
|
168 |
+
type=["xyz", "cif", "POSCAR", "mol", "tmol"])
|
169 |
+
|
170 |
+
if uploaded_file is not None:
|
171 |
+
# Create a temporary file to save the uploaded content
|
172 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
173 |
+
tmp_file.write(uploaded_file.getvalue())
|
174 |
+
tmp_filepath = tmp_file.name
|
175 |
+
|
176 |
+
try:
|
177 |
+
# Read the structure using ASE
|
178 |
+
atoms = read(tmp_filepath)
|
179 |
+
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
|
180 |
+
except Exception as e:
|
181 |
+
st.sidebar.error(f"Error loading file: {str(e)}")
|
182 |
+
|
183 |
+
# Clean up the temporary file
|
184 |
+
os.unlink(tmp_filepath)
|
185 |
+
|
186 |
+
# Example structure selection
|
187 |
+
elif input_method == "Select Example":
|
188 |
+
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
189 |
+
|
190 |
+
if example_name:
|
191 |
+
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
192 |
+
try:
|
193 |
+
atoms = read(file_path)
|
194 |
+
st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
|
195 |
+
except Exception as e:
|
196 |
+
st.sidebar.error(f"Error loading example: {str(e)}")
|
197 |
+
|
198 |
+
# Paste content option
|
199 |
+
elif input_method == "Paste Content":
|
200 |
+
file_format = st.sidebar.selectbox("File Format:",
|
201 |
+
["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
202 |
+
|
203 |
+
content = st.sidebar.text_area("Paste file content here:", height=200)
|
204 |
+
|
205 |
+
if content and st.sidebar.button("Parse Content"):
|
206 |
+
try:
|
207 |
+
# Create a temporary file with the pasted content
|
208 |
+
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz",
|
209 |
+
"POSCAR (VASP)": ".POSCAR", "Turbomole": ".tmol", "MOL": ".mol"}
|
210 |
+
|
211 |
+
suffix = suffix_map.get(file_format, ".xyz")
|
212 |
+
|
213 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
214 |
+
tmp_file.write(content.encode())
|
215 |
+
tmp_filepath = tmp_file.name
|
216 |
+
|
217 |
+
# Read the structure using ASE
|
218 |
+
atoms = read(tmp_filepath)
|
219 |
+
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
220 |
+
|
221 |
+
# Clean up the temporary file
|
222 |
+
os.unlink(tmp_filepath)
|
223 |
+
except Exception as e:
|
224 |
+
st.sidebar.error(f"Error parsing content: {str(e)}")
|
225 |
+
|
226 |
+
|
227 |
+
# Model selection
|
228 |
+
st.sidebar.markdown("## Model Selection")
|
229 |
+
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
|
230 |
+
|
231 |
+
selected_task_type = None
|
232 |
+
if model_type == "MACE":
|
233 |
+
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
|
234 |
+
model_path = MACE_MODELS[selected_model]
|
235 |
+
if selected_model == "MACE OMAT Medium":
|
236 |
+
st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
237 |
+
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
|
238 |
+
if model_type == "FairChem":
|
239 |
+
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
|
240 |
+
model_path = FAIRCHEM_MODELS[selected_model]
|
241 |
+
if selected_model == "UMA Small":
|
242 |
+
st.sidebar.warning("Meta FAIR Acceptable Use Policy. This model was developed by the Fundamental AI Research (FAIR) team at Meta. By using it, you agree to their acceptable use policy, which prohibits using their models to violate the law or others' rights, plan or develop activities that present a risk of death or harm, and deceive or mislead others.")
|
243 |
+
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
|
244 |
+
# Check atom count limit
|
245 |
+
if atoms is not None:
|
246 |
+
check_atom_limit(atoms, selected_model)
|
247 |
+
#st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
248 |
+
# Device selection
|
249 |
+
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"],
|
250 |
+
index=0 if not torch.cuda.is_available() else 1)
|
251 |
+
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
|
252 |
+
|
253 |
+
if device == "cpu" and torch.cuda.is_available():
|
254 |
+
st.sidebar.info("GPU is available but CPU was selected. Calculations will be slower.")
|
255 |
+
elif device == "cpu" and not torch.cuda.is_available():
|
256 |
+
st.sidebar.info("No GPU detected. Using CPU for calculations.")
|
257 |
+
|
258 |
+
# Task selection
|
259 |
+
st.sidebar.markdown("## Task Selection")
|
260 |
+
task = st.sidebar.selectbox("Select Calculation Task:",
|
261 |
+
["Energy Calculation",
|
262 |
+
"Energy + Forces Calculation",
|
263 |
+
"Geometry Optimization",
|
264 |
+
"Cell + Geometry Optimization"])
|
265 |
+
|
266 |
+
# Optimization parameters
|
267 |
+
if "Optimization" in task:
|
268 |
+
st.sidebar.markdown("### Optimization Parameters")
|
269 |
+
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=25, value=15, step=1)
|
270 |
+
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):",
|
271 |
+
min_value=0.001, max_value=0.1, value=0.05, step=0.001, format="%.3f")
|
272 |
+
optimizer = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
|
273 |
+
|
274 |
+
# Main content area
|
275 |
+
if atoms is not None:
|
276 |
+
col1, col2 = st.columns(2)
|
277 |
+
|
278 |
+
with col1:
|
279 |
+
st.markdown('### Structure Visualization', unsafe_allow_html=True)
|
280 |
+
|
281 |
+
# Generate visualization
|
282 |
+
def get_structure_viz(atoms_obj):
|
283 |
+
# Convert atoms to XYZ format
|
284 |
+
xyz_str = ""
|
285 |
+
xyz_str += f"{len(atoms_obj)}\n"
|
286 |
+
xyz_str += "Structure\n"
|
287 |
+
for atom in atoms_obj:
|
288 |
+
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
289 |
+
|
290 |
+
# Create a py3Dmol visualization
|
291 |
+
view = py3Dmol.view(width=400, height=400)
|
292 |
+
view.addModel(xyz_str, "xyz")
|
293 |
+
view.setStyle({'stick': {}})
|
294 |
+
view.zoomTo()
|
295 |
+
view.setBackgroundColor('white')
|
296 |
+
|
297 |
+
return view
|
298 |
+
|
299 |
+
# Display the 3D structure
|
300 |
+
view = get_structure_viz(atoms)
|
301 |
+
html_str = view._make_html()
|
302 |
+
st.components.v1.html(html_str, width=400, height=400)
|
303 |
+
|
304 |
+
# Display structure information
|
305 |
+
st.markdown("### Structure Information")
|
306 |
+
atoms_info = {
|
307 |
+
"Number of Atoms": len(atoms),
|
308 |
+
"Chemical Formula": atoms.get_chemical_formula(),
|
309 |
+
"Cell Dimensions": atoms.cell.cellpar() if atoms.cell else "No cell defined",
|
310 |
+
"Atom Types": ", ".join(set(atoms.get_chemical_symbols()))
|
311 |
+
}
|
312 |
+
|
313 |
+
for key, value in atoms_info.items():
|
314 |
+
st.write(f"**{key}:** {value}")
|
315 |
+
|
316 |
+
with col2:
|
317 |
+
st.markdown('## Calculation Setup', unsafe_allow_html=True)
|
318 |
+
|
319 |
+
# Display calculation details
|
320 |
+
st.markdown("### Selected Model")
|
321 |
+
st.write(f"**Model Type:** {model_type}")
|
322 |
+
st.write(f"**Model:** {selected_model}")
|
323 |
+
st.write(f"**Device:** {device}")
|
324 |
+
|
325 |
+
st.markdown("### Selected Task")
|
326 |
+
st.write(f"**Task:** {task}")
|
327 |
+
|
328 |
+
if "Optimization" in task:
|
329 |
+
st.write(f"**Max Steps:** {max_steps}")
|
330 |
+
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
|
331 |
+
st.write(f"**Optimizer:** {optimizer}")
|
332 |
+
|
333 |
+
# Run calculation button
|
334 |
+
run_calculation = st.button("Run Calculation", type="primary")
|
335 |
+
|
336 |
+
if run_calculation:
|
337 |
+
try:
|
338 |
+
with st.spinner("Running calculation..."):
|
339 |
+
# Copy atoms to avoid modifying the original
|
340 |
+
calc_atoms = atoms.copy()
|
341 |
+
|
342 |
+
# Set up calculator based on selected model
|
343 |
+
if model_type == "MACE":
|
344 |
+
st.write("Setting up MACE calculator...")
|
345 |
+
calc = get_mace_model(model_path, device, selected_default_dtype)
|
346 |
+
else: # FairChem
|
347 |
+
st.write("Setting up FairChem calculator...")
|
348 |
+
# Seems like the FairChem models use float32 and when switching from MACE 64 model to FairChem float32 model we get an error
|
349 |
+
# probably due to both sharing the underlying torch implementation
|
350 |
+
# So just a dummy statement to swithc torch to 32 bit
|
351 |
+
calc = get_mace_model('https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model', 'cpu', 'float32')
|
352 |
+
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
353 |
+
# Attach calculator to atoms
|
354 |
+
calc_atoms.calc = calc
|
355 |
+
|
356 |
+
# Perform the selected task
|
357 |
+
results = {}
|
358 |
+
|
359 |
+
if task == "Energy Calculation":
|
360 |
+
# Calculate energy
|
361 |
+
energy = calc_atoms.get_potential_energy()
|
362 |
+
results["Energy"] = f"{energy:.6f} eV"
|
363 |
+
|
364 |
+
elif task == "Energy + Forces Calculation":
|
365 |
+
# Calculate energy and forces
|
366 |
+
energy = calc_atoms.get_potential_energy()
|
367 |
+
forces = calc_atoms.get_forces()
|
368 |
+
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1)))
|
369 |
+
|
370 |
+
results["Energy"] = f"{energy:.6f} eV"
|
371 |
+
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
|
372 |
+
|
373 |
+
elif task == "Geometry Optimization":
|
374 |
+
# Set up optimizer
|
375 |
+
if optimizer == "BFGS":
|
376 |
+
opt = BFGS(calc_atoms)
|
377 |
+
elif optimizer == "LBFGS":
|
378 |
+
opt = LBFGS(calc_atoms)
|
379 |
+
else: # FIRE
|
380 |
+
opt = FIRE(calc_atoms)
|
381 |
+
|
382 |
+
# Streamlit placeholder for live-updating table
|
383 |
+
table_placeholder = st.empty()
|
384 |
+
|
385 |
+
# Container for log data
|
386 |
+
opt_log = []
|
387 |
+
# Attach the Streamlit logger to the optimizer
|
388 |
+
opt.attach(lambda: streamlit_log(opt), interval=1)
|
389 |
+
# Run optimization
|
390 |
+
st.write("Running geometry optimization...")
|
391 |
+
opt.run(fmax=fmax, steps=max_steps)
|
392 |
+
|
393 |
+
# Get results
|
394 |
+
energy = calc_atoms.get_potential_energy()
|
395 |
+
forces = calc_atoms.get_forces()
|
396 |
+
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1)))
|
397 |
+
|
398 |
+
results["Final Energy"] = f"{energy:.6f} eV"
|
399 |
+
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
|
400 |
+
results["Steps Taken"] = opt.get_number_of_steps()
|
401 |
+
results["Converged"] = "Yes" if opt.converged() else "No"
|
402 |
+
|
403 |
+
elif task == "Cell + Geometry Optimization":
|
404 |
+
# Set up optimizer with FrechetCellFilter
|
405 |
+
fcf = FrechetCellFilter(calc_atoms)
|
406 |
+
|
407 |
+
if optimizer == "BFGS":
|
408 |
+
opt = BFGS(fcf)
|
409 |
+
elif optimizer == "LBFGS":
|
410 |
+
opt = LBFGS(fcf)
|
411 |
+
else: # FIRE
|
412 |
+
opt = FIRE(fcf)
|
413 |
+
|
414 |
+
# Streamlit placeholder for live-updating table
|
415 |
+
table_placeholder = st.empty()
|
416 |
+
|
417 |
+
# Container for log data
|
418 |
+
opt_log = []
|
419 |
+
# Attach the Streamlit logger to the optimizer
|
420 |
+
opt.attach(lambda: streamlit_log(opt), interval=1)
|
421 |
+
# Run optimization
|
422 |
+
st.write("Running cell + geometry optimization...")
|
423 |
+
opt.run(fmax=fmax, steps=max_steps)
|
424 |
+
|
425 |
+
# Get results
|
426 |
+
energy = calc_atoms.get_potential_energy()
|
427 |
+
forces = calc_atoms.get_forces()
|
428 |
+
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1)))
|
429 |
+
|
430 |
+
results["Final Energy"] = f"{energy:.6f} eV"
|
431 |
+
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
|
432 |
+
results["Steps Taken"] = opt.get_number_of_steps()
|
433 |
+
results["Converged"] = "Yes" if opt.converged() else "No"
|
434 |
+
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4)
|
435 |
+
|
436 |
+
# Show results
|
437 |
+
st.success("Calculation completed successfully!")
|
438 |
+
st.markdown("### Results")
|
439 |
+
for key, value in results.items():
|
440 |
+
st.write(f"**{key}:** {value}")
|
441 |
+
|
442 |
+
# If we did an optimization, show the final structure
|
443 |
+
if "Optimization" in task:
|
444 |
+
st.markdown("### Optimized Structure")
|
445 |
+
view = get_structure_viz(calc_atoms)
|
446 |
+
html_str = view._make_html()
|
447 |
+
st.components.v1.html(html_str, width=400, height=400)
|
448 |
+
|
449 |
+
# Add download option for optimized structure
|
450 |
+
# First save the structure to a file
|
451 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz") as tmp_file:
|
452 |
+
write(tmp_file.name, calc_atoms)
|
453 |
+
tmp_filepath = tmp_file.name
|
454 |
+
|
455 |
+
# Read the content for downloading
|
456 |
+
with open(tmp_filepath, 'r') as file:
|
457 |
+
xyz_content = file.read()
|
458 |
+
|
459 |
+
st.download_button(
|
460 |
+
label="Download Optimized Structure (XYZ)",
|
461 |
+
data=xyz_content,
|
462 |
+
file_name="optimized_structure.xyz",
|
463 |
+
mime="chemical/x-xyz"
|
464 |
+
)
|
465 |
+
|
466 |
+
# Clean up the temp file
|
467 |
+
os.unlink(tmp_filepath)
|
468 |
+
|
469 |
+
except Exception as e:
|
470 |
+
st.error(f"Calculation error: {str(e)}")
|
471 |
+
st.error("Please make sure the structure is valid and compatible with the selected model.")
|
472 |
+
else:
|
473 |
+
# Display instructions if no structure is loaded
|
474 |
+
st.info("Please select a structure using the sidebar options to begin.")
|
475 |
+
|
476 |
+
|
477 |
+
# Footer
|
478 |
+
st.markdown("---")
|
479 |
+
with st.expander('## About This App'):
|
480 |
+
# Show some information about the app
|
481 |
+
st.write("""
|
482 |
+
This app allows you to perform atomistic simulations using pre-trained foundational machine learning interatomic potentials (MLIPs) such as those from the MACE and FairChem libraries.
|
483 |
+
|
484 |
+
### Features:
|
485 |
+
- Upload structure files (XYZ, CIF, POSCAR, etc.) or select from examples
|
486 |
+
- Choose between MACE and FairChem ML models (more models coming soon)
|
487 |
+
- Perform energy calculations, forces calculations, or geometry optimizations
|
488 |
+
- Visualize structures in 3D
|
489 |
+
- Download optimized structures
|
490 |
+
|
491 |
+
### Getting Started:
|
492 |
+
1. Select an input method in the sidebar
|
493 |
+
2. Choose a model and computational parameters
|
494 |
+
3. Select a calculation task
|
495 |
+
4. Run the calculation and analyze the results
|
496 |
+
""")
|
497 |
+
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️")
|
498 |
+
st.markdown("Made by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/)")
|