Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +803 -317
src/streamlit_app.py
CHANGED
@@ -1,3 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import os
|
3 |
import tempfile
|
@@ -13,6 +42,7 @@ import py3Dmol
|
|
13 |
from mace.calculators import mace_mp
|
14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
15 |
import pandas as pd
|
|
|
16 |
|
17 |
from huggingface_hub import login
|
18 |
|
@@ -23,14 +53,532 @@ from huggingface_hub import login
|
|
23 |
# except Exception as e:
|
24 |
# print("streamlit hf secret not defined/assigned")
|
25 |
try:
|
26 |
-
hf_token = os.getenv("YOUR SECRET KEY")
|
27 |
-
|
|
|
|
|
|
|
28 |
except Exception as e:
|
29 |
-
print("hf
|
|
|
30 |
|
31 |
-
import os
|
32 |
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Check if running on Streamlit Cloud vs locally
|
35 |
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
|
36 |
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud
|
@@ -43,31 +591,6 @@ st.set_page_config(
|
|
43 |
layout="wide"
|
44 |
)
|
45 |
|
46 |
-
# Add CSS for better formatting
|
47 |
-
# st.markdown("""
|
48 |
-
# <style>
|
49 |
-
# .stApp {
|
50 |
-
# max-width: 1200px;
|
51 |
-
# margin: 0 auto;
|
52 |
-
# }
|
53 |
-
# .main-header {
|
54 |
-
# font-size: 2.5rem;
|
55 |
-
# font-weight: bold;
|
56 |
-
# margin-bottom: 1rem;
|
57 |
-
# }
|
58 |
-
# .section-header {
|
59 |
-
# font-size: 1.5rem;
|
60 |
-
# font-weight: bold;
|
61 |
-
# margin-top: 1.5rem;
|
62 |
-
# margin-bottom: 1rem;
|
63 |
-
# }
|
64 |
-
# .info-text {
|
65 |
-
# font-size: 1rem;
|
66 |
-
# color: #555;
|
67 |
-
# }
|
68 |
-
# </style>
|
69 |
-
# """, unsafe_allow_html=True)
|
70 |
-
|
71 |
# Title and description
|
72 |
st.markdown('## MLIP Playground', unsafe_allow_html=True)
|
73 |
st.write('#### Run, test and compare >17 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials')
|
@@ -88,37 +611,15 @@ SAMPLE_STRUCTURES = {
|
|
88 |
}
|
89 |
|
90 |
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
|
91 |
-
"""
|
92 |
-
Generate visualization of atomic structure with optional unit cell display
|
93 |
-
|
94 |
-
Parameters:
|
95 |
-
-----------
|
96 |
-
atoms_obj : ase.Atoms
|
97 |
-
ASE Atoms object containing the structure
|
98 |
-
style : str
|
99 |
-
Visualization style: 'ball_stick', 'stick', or 'ball'
|
100 |
-
show_unit_cell : bool
|
101 |
-
Whether to display unit cell for periodic systems
|
102 |
-
width, height : int
|
103 |
-
Dimensions of the visualization window
|
104 |
-
|
105 |
-
Returns:
|
106 |
-
--------
|
107 |
-
py3Dmol.view object
|
108 |
-
"""
|
109 |
-
|
110 |
-
# Convert atoms to XYZ format
|
111 |
xyz_str = ""
|
112 |
xyz_str += f"{len(atoms_obj)}\n"
|
113 |
xyz_str += "Structure\n"
|
114 |
for atom in atoms_obj:
|
115 |
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
116 |
|
117 |
-
# Create a py3Dmol visualization
|
118 |
view = py3Dmol.view(width=width, height=height)
|
119 |
view.addModel(xyz_str, "xyz")
|
120 |
|
121 |
-
# Set molecular style based on input
|
122 |
if style.lower() == 'ball_stick':
|
123 |
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
|
124 |
elif style.lower() == 'stick':
|
@@ -126,80 +627,59 @@ def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400,
|
|
126 |
elif style.lower() == 'ball':
|
127 |
view.setStyle({'sphere': {'scale': 0.4}})
|
128 |
else:
|
129 |
-
# Default to stick if unknown style
|
130 |
view.setStyle({'stick': {'radius': 0.15}})
|
131 |
|
132 |
-
|
133 |
-
if show_unit_cell and any(atoms_obj.pbc):
|
134 |
cell = atoms_obj.get_cell()
|
135 |
-
|
136 |
-
# Define unit cell edges
|
137 |
origin = np.array([0.0, 0.0, 0.0])
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
(cell[1], cell[1] + cell[2]), # c from b
|
153 |
-
(cell[0] + cell[1], cell[0] + cell[1] + cell[2]) # c from a+b
|
154 |
-
]
|
155 |
-
|
156 |
-
# Add unit cell lines
|
157 |
-
for start, end in edges:
|
158 |
-
view.addCylinder({
|
159 |
-
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
|
160 |
-
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
|
161 |
-
'radius': 0.05,
|
162 |
-
'color': 'black',
|
163 |
-
'alpha': 0.7
|
164 |
-
})
|
165 |
-
|
166 |
view.zoomTo()
|
167 |
view.setBackgroundColor('white')
|
168 |
-
|
169 |
return view
|
170 |
|
|
|
|
|
171 |
|
172 |
-
# Custom logger that updates the table
|
173 |
def streamlit_log(opt):
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
def check_atom_limit(atoms_obj, selected_model):
|
187 |
if atoms_obj is None:
|
188 |
return True
|
189 |
-
|
190 |
num_atoms = len(atoms_obj)
|
191 |
-
if ('UMA' in selected_model or 'ESEN MD' in selected_model)
|
192 |
-
|
193 |
-
st.
|
194 |
-
return False
|
195 |
-
if num_atoms > MAX_ATOMS_CLOUD:
|
196 |
-
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, which exceeds the {MAX_ATOMS_CLOUD} atom limit for Streamlit Cloud deployments. For larger systems, please download the repository from GitHub and run it locally on your machine where no atom limit applies.")
|
197 |
-
st.info("💡 Running locally allows you to process much larger structures and use your own computational resources more efficiently.")
|
198 |
return False
|
199 |
return True
|
200 |
|
201 |
-
|
202 |
-
# Define the available MACE models
|
203 |
MACE_MODELS = {
|
204 |
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
|
205 |
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
|
@@ -210,13 +690,12 @@ MACE_MODELS = {
|
|
210 |
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
|
211 |
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
|
212 |
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
|
213 |
-
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-
|
214 |
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
|
215 |
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
|
216 |
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
|
217 |
}
|
218 |
|
219 |
-
# Define the available FairChem models
|
220 |
FAIRCHEM_MODELS = {
|
221 |
"UMA Small": "uma-sm",
|
222 |
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
|
@@ -224,54 +703,40 @@ FAIRCHEM_MODELS = {
|
|
224 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
225 |
}
|
226 |
|
227 |
-
|
228 |
def get_mace_model(model_path, device, selected_default_dtype):
|
229 |
-
# Create a model of the specified type.
|
230 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
231 |
|
232 |
-
|
233 |
-
def get_fairchem_model(
|
234 |
-
predictor = pretrained_mlip.get_predict_unit(
|
235 |
-
if
|
236 |
-
calc = FAIRChemCalculator(predictor, task_name=
|
237 |
else:
|
238 |
calc = FAIRChemCalculator(predictor)
|
239 |
return calc
|
240 |
|
241 |
-
# Sidebar for file input and parameters
|
242 |
st.sidebar.markdown("## Input Options")
|
243 |
-
|
244 |
-
# Input method selection
|
245 |
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
|
246 |
-
|
247 |
-
# Initialize atoms variable
|
248 |
atoms = None
|
249 |
|
250 |
-
# File upload option
|
251 |
if input_method == "Upload File":
|
252 |
-
uploaded_file = st.sidebar.file_uploader("Upload structure file",
|
253 |
-
|
254 |
-
|
255 |
-
if uploaded_file is not None:
|
256 |
-
# Create a temporary file to save the uploaded content
|
257 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
258 |
tmp_file.write(uploaded_file.getvalue())
|
259 |
tmp_filepath = tmp_file.name
|
260 |
-
|
261 |
try:
|
262 |
-
# Read the structure using ASE
|
263 |
atoms = read(tmp_filepath)
|
264 |
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
|
265 |
except Exception as e:
|
266 |
st.sidebar.error(f"Error loading file: {str(e)}")
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
|
271 |
-
# Example structure selection
|
272 |
elif input_method == "Select Example":
|
273 |
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
274 |
-
|
275 |
if example_name:
|
276 |
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
277 |
try:
|
@@ -280,139 +745,106 @@ elif input_method == "Select Example":
|
|
280 |
except Exception as e:
|
281 |
st.sidebar.error(f"Error loading example: {str(e)}")
|
282 |
|
283 |
-
# Paste content option
|
284 |
elif input_method == "Paste Content":
|
285 |
-
file_format = st.sidebar.selectbox("File Format:",
|
286 |
-
["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
287 |
-
|
288 |
content = st.sidebar.text_area("Paste file content here:", height=200)
|
289 |
-
|
290 |
-
if content: #and st.sidebar.button("Parse Content"):
|
291 |
try:
|
292 |
-
|
293 |
-
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz",
|
294 |
-
"POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
|
295 |
-
|
296 |
suffix = suffix_map.get(file_format, ".xyz")
|
297 |
-
|
298 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
299 |
tmp_file.write(content.encode())
|
300 |
tmp_filepath = tmp_file.name
|
301 |
-
|
302 |
-
# Read the structure using ASE
|
303 |
atoms = read(tmp_filepath)
|
304 |
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
305 |
-
|
306 |
-
# Clean up the temporary file
|
307 |
-
os.unlink(tmp_filepath)
|
308 |
except Exception as e:
|
309 |
st.sidebar.error(f"Error parsing content: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
-
atoms.info["charge"] = 0
|
312 |
-
atoms.info["spin"] = 0
|
313 |
-
# Model selection
|
314 |
st.sidebar.markdown("## Model Selection")
|
315 |
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
|
316 |
|
317 |
-
selected_task_type = None
|
318 |
if model_type == "MACE":
|
319 |
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
|
320 |
model_path = MACE_MODELS[selected_model]
|
321 |
if selected_model == "MACE OMAT Medium":
|
322 |
-
st.sidebar.warning("Using model under Academic Software License (ASL)
|
323 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
|
324 |
if model_type == "FairChem":
|
325 |
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
|
326 |
model_path = FAIRCHEM_MODELS[selected_model]
|
327 |
if selected_model == "UMA Small":
|
328 |
-
st.sidebar.warning("Meta FAIR Acceptable Use Policy
|
329 |
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
|
330 |
-
if selected_task_type == "omol":
|
331 |
-
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=0)
|
332 |
-
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11)
|
333 |
-
# Set the total charge and spin multiplicity if using the OMol task
|
334 |
atoms.info["charge"] = charge
|
335 |
-
atoms.info["spin"] = spin_multiplicity
|
336 |
-
|
337 |
if atoms is not None:
|
338 |
-
check_atom_limit(atoms, selected_model)
|
339 |
-
|
340 |
-
|
341 |
-
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"],
|
342 |
-
index=0 if not torch.cuda.is_available() else 1)
|
343 |
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
|
344 |
|
345 |
if device == "cpu" and torch.cuda.is_available():
|
346 |
-
st.sidebar.info("GPU is available but CPU was selected.
|
347 |
elif device == "cpu" and not torch.cuda.is_available():
|
348 |
-
st.sidebar.info("No GPU detected. Using CPU
|
349 |
|
350 |
-
# Task selection
|
351 |
st.sidebar.markdown("## Task Selection")
|
352 |
task = st.sidebar.selectbox("Select Calculation Task:",
|
353 |
["Energy Calculation",
|
354 |
"Energy + Forces Calculation",
|
|
|
355 |
"Geometry Optimization",
|
356 |
"Cell + Geometry Optimization"])
|
357 |
|
358 |
-
# Optimization parameters
|
359 |
if "Optimization" in task:
|
360 |
st.sidebar.markdown("### Optimization Parameters")
|
361 |
-
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=
|
362 |
-
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):",
|
363 |
-
|
364 |
-
|
365 |
|
366 |
-
# Main content area
|
367 |
if atoms is not None:
|
368 |
col1, col2 = st.columns(2)
|
369 |
|
370 |
with col1:
|
371 |
st.markdown('### Structure Visualization', unsafe_allow_html=True)
|
|
|
|
|
372 |
|
373 |
-
# Generate visualization
|
374 |
-
def get_structure_viz(atoms_obj):
|
375 |
-
# Convert atoms to XYZ format
|
376 |
-
xyz_str = ""
|
377 |
-
xyz_str += f"{len(atoms_obj)}\n"
|
378 |
-
xyz_str += "Structure\n"
|
379 |
-
for atom in atoms_obj:
|
380 |
-
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
381 |
-
|
382 |
-
# Create a py3Dmol visualization
|
383 |
-
view = py3Dmol.view(width=400, height=400)
|
384 |
-
view.addModel(xyz_str, "xyz")
|
385 |
-
view.setStyle({'stick': {}})
|
386 |
-
view.zoomTo()
|
387 |
-
view.setBackgroundColor('white')
|
388 |
-
|
389 |
-
return view
|
390 |
-
|
391 |
-
# Display the 3D structure
|
392 |
-
view = get_structure_viz2(atoms, style='stick', show_unit_cell=True, width=400, height=400)
|
393 |
-
# view = get_structure_viz(atoms)
|
394 |
-
html_str = view._make_html()
|
395 |
-
st.components.v1.html(html_str, width=400, height=400)
|
396 |
-
|
397 |
-
# Display structure information
|
398 |
st.markdown("### Structure Information")
|
399 |
atoms_info = {
|
400 |
"Number of Atoms": len(atoms),
|
401 |
"Chemical Formula": atoms.get_chemical_formula(),
|
402 |
-
"
|
403 |
-
"
|
|
|
404 |
}
|
405 |
-
|
406 |
for key, value in atoms_info.items():
|
407 |
st.write(f"**{key}:** {value}")
|
408 |
|
409 |
with col2:
|
410 |
st.markdown('## Calculation Setup', unsafe_allow_html=True)
|
411 |
-
|
412 |
-
# Display calculation details
|
413 |
st.markdown("### Selected Model")
|
414 |
st.write(f"**Model Type:** {model_type}")
|
415 |
st.write(f"**Model:** {selected_model}")
|
|
|
|
|
416 |
st.write(f"**Device:** {device}")
|
417 |
|
418 |
st.markdown("### Selected Task")
|
@@ -421,172 +853,226 @@ if atoms is not None:
|
|
421 |
if "Optimization" in task:
|
422 |
st.write(f"**Max Steps:** {max_steps}")
|
423 |
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
|
424 |
-
st.write(f"**Optimizer:** {
|
425 |
|
426 |
-
# Run calculation button
|
427 |
run_calculation = st.button("Run Calculation", type="primary")
|
428 |
|
429 |
if run_calculation:
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
try:
|
431 |
-
with st.spinner("Running calculation..."):
|
432 |
-
# Copy atoms to avoid modifying the original
|
433 |
calc_atoms = atoms.copy()
|
434 |
|
435 |
-
# Set up calculator based on selected model
|
436 |
if model_type == "MACE":
|
437 |
-
st.write("Setting up MACE calculator...")
|
438 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
439 |
else: # FairChem
|
440 |
-
st.write("Setting up FairChem calculator...")
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
446 |
-
# Attach calculator to atoms
|
447 |
-
calc_atoms.calc = calc
|
448 |
|
449 |
-
|
450 |
-
results = {}
|
451 |
|
452 |
if task == "Energy Calculation":
|
453 |
-
# Calculate energy
|
454 |
energy = calc_atoms.get_potential_energy()
|
455 |
results["Energy"] = f"{energy:.6f} eV"
|
456 |
|
457 |
elif task == "Energy + Forces Calculation":
|
458 |
-
# Calculate energy and forces
|
459 |
energy = calc_atoms.get_potential_energy()
|
460 |
forces = calc_atoms.get_forces()
|
461 |
-
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1)))
|
462 |
-
|
463 |
results["Energy"] = f"{energy:.6f} eV"
|
464 |
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
|
465 |
-
|
466 |
-
elif task == "
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
490 |
|
491 |
-
|
492 |
-
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
|
493 |
-
results["Steps Taken"] = opt.get_number_of_steps()
|
494 |
-
results["Converged"] = "Yes" if opt.converged() else "No"
|
495 |
-
|
496 |
-
elif task == "Cell + Geometry Optimization":
|
497 |
-
# Set up optimizer with FrechetCellFilter
|
498 |
-
fcf = FrechetCellFilter(calc_atoms)
|
499 |
|
500 |
-
|
501 |
-
opt = BFGS(fcf)
|
502 |
-
elif optimizer == "LBFGS":
|
503 |
-
opt = LBFGS(fcf)
|
504 |
-
else: # FIRE
|
505 |
-
opt = FIRE(fcf)
|
506 |
-
|
507 |
-
# Streamlit placeholder for live-updating table
|
508 |
-
table_placeholder = st.empty()
|
509 |
-
|
510 |
-
# Container for log data
|
511 |
-
opt_log = []
|
512 |
-
# Attach the Streamlit logger to the optimizer
|
513 |
-
opt.attach(lambda: streamlit_log(opt), interval=1)
|
514 |
-
# Run optimization
|
515 |
-
st.write("Running cell + geometry optimization...")
|
516 |
opt.run(fmax=fmax, steps=max_steps)
|
517 |
|
518 |
-
# Get results
|
519 |
energy = calc_atoms.get_potential_energy()
|
520 |
forces = calc_atoms.get_forces()
|
521 |
-
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1)))
|
522 |
|
523 |
results["Final Energy"] = f"{energy:.6f} eV"
|
524 |
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
|
525 |
results["Steps Taken"] = opt.get_number_of_steps()
|
526 |
results["Converged"] = "Yes" if opt.converged() else "No"
|
527 |
-
|
|
|
528 |
|
529 |
-
# Show results
|
530 |
st.success("Calculation completed successfully!")
|
531 |
st.markdown("### Results")
|
532 |
for key, value in results.items():
|
533 |
st.write(f"**{key}:** {value}")
|
534 |
|
535 |
-
|
536 |
-
if "Optimization" in task:
|
537 |
st.markdown("### Optimized Structure")
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
write(tmp_file.name, calc_atoms)
|
546 |
-
tmp_filepath = tmp_file.name
|
547 |
|
548 |
-
|
549 |
-
|
550 |
-
xyz_content = file.read()
|
551 |
|
552 |
st.download_button(
|
553 |
label="Download Optimized Structure (XYZ)",
|
554 |
-
data=
|
555 |
file_name="optimized_structure.xyz",
|
556 |
mime="chemical/x-xyz"
|
557 |
)
|
558 |
-
|
559 |
-
# Clean up the temp file
|
560 |
-
os.unlink(tmp_filepath)
|
561 |
|
562 |
except Exception as e:
|
563 |
-
st.error(f"Calculation error: {str(e)}")
|
564 |
-
st.error("Please
|
|
|
|
|
|
|
565 |
else:
|
566 |
-
|
567 |
-
st.info("Please select a structure using the sidebar options to begin.")
|
568 |
-
|
569 |
|
570 |
-
# Footer
|
571 |
st.markdown("---")
|
572 |
-
with st.expander('
|
573 |
-
# Show some information about the app
|
574 |
st.write("""
|
575 |
-
Test, compare and benchmark universal machine learning interatomic potentials (MLIPs)
|
576 |
-
This
|
|
|
577 |
|
578 |
-
|
579 |
-
- Upload structure files (XYZ, CIF, POSCAR, etc.) or
|
580 |
-
-
|
581 |
-
-
|
582 |
-
-
|
583 |
-
-
|
584 |
|
585 |
-
|
586 |
-
1.
|
587 |
-
2.
|
588 |
-
3. Select a calculation task
|
589 |
-
4. Run
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
""")
|
591 |
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️")
|
592 |
-
st.markdown("
|
|
|
|
|
|
1 |
+
Okay, I can help you modify the Streamlit application to include a task for calculating atomization or cohesive energy.
|
2 |
+
|
3 |
+
Here's a breakdown of the changes and the modified code:
|
4 |
+
|
5 |
+
## Summary of Modifications:
|
6 |
+
|
7 |
+
1. **Import `yaml`**: To parse the reference energy data for FairChem models.
|
8 |
+
2. **Define `ELEMENT_REF_ENERGIES_YAML`**: Added the YAML string containing reference energies for isolated atoms for FairChem models. This is parsed into the `ELEMENT_REF_ENERGIES` dictionary.
|
9 |
+
* **Note**: The YAML lists for `oc20_elem_refs`, `odac_elem_refs`, `omat_elem_refs`, and `omc_elem_refs` in your provided snippet are incomplete. I've added a few `0.0` placeholders to make them valid lists, but for accurate calculations, these should contain the correct reference energies for all relevant elements. The `omol_elem_refs` list from your example is more populated and used as such. The code will warn if a reference energy for a specific element is not found in the loaded lists.
|
10 |
+
3. **Update Task Selection**: Added "Atomization/Cohesive Energy" to the `st.sidebar.selectbox` for calculation tasks.
|
11 |
+
4. **Implement Atomization/Cohesive Energy Calculation Logic**:
|
12 |
+
* When the "Atomization/Cohesive Energy" task is selected:
|
13 |
+
* The total energy of the system (`E_system`) is calculated.
|
14 |
+
* The code determines if the system is periodic (for cohesive energy) or not (for atomization energy).
|
15 |
+
* **For FairChem Models**:
|
16 |
+
* It selects the appropriate list of reference energies from `ELEMENT_REF_ENERGIES` based on whether the model is "UMA Small" (using `selected_task_type` like `omol`, `omat`, etc.) or an "ESEN" model (which defaults to `omol_elem_refs`).
|
17 |
+
* It sums the reference energies for each atom in the system. A warning is issued if any element's reference energy is not found in the specified list (treating missing energies as 0).
|
18 |
+
* **For MACE Models**:
|
19 |
+
* It calculates the energy of each unique type of isolated atom by creating a single-atom `ase.Atoms` object (non-periodic) and using the already initialized MACE calculator.
|
20 |
+
* These individual atomic energies are summed, weighted by their counts in the system. A progress bar is shown for this step.
|
21 |
+
* **Final Calculation**:
|
22 |
+
* Atomization Energy: $E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{system}}$
|
23 |
+
* Cohesive Energy: $E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{system}}) / N_{\text{atoms}}$
|
24 |
+
* The results, including system energy and total isolated atom energy, are displayed.
|
25 |
+
* Error handling is included for cases like systems with zero atoms or missing reference energy lists for FairChem.
|
26 |
+
|
27 |
+
## Modified Code:
|
28 |
+
|
29 |
+
```python
|
30 |
import streamlit as st
|
31 |
import os
|
32 |
import tempfile
|
|
|
42 |
from mace.calculators import mace_mp
|
43 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
44 |
import pandas as pd
|
45 |
+
import yaml # Added for FairChem reference energies
|
46 |
|
47 |
from huggingface_hub import login
|
48 |
|
|
|
53 |
# except Exception as e:
|
54 |
# print("streamlit hf secret not defined/assigned")
|
55 |
try:
|
56 |
+
hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately
|
57 |
+
if hf_token:
|
58 |
+
login(token = hf_token)
|
59 |
+
else:
|
60 |
+
print("Hugging Face token not found. Some models might not be accessible.")
|
61 |
except Exception as e:
|
62 |
+
print(f"hf login error: {e}")
|
63 |
+
|
64 |
|
|
|
65 |
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
|
66 |
|
67 |
+
# YAML data for FairChem reference energies
|
68 |
+
ELEMENT_REF_ENERGIES_YAML = """
|
69 |
+
oc20_elem_refs:
|
70 |
+
- 0.0
|
71 |
+
- -0.16141512
|
72 |
+
- 0.03262098
|
73 |
+
- -0.04787699
|
74 |
+
- -0.06299825
|
75 |
+
- -0.14979306
|
76 |
+
- -0.11657468
|
77 |
+
- -0.10862579
|
78 |
+
- -0.10298174
|
79 |
+
- -0.03420248
|
80 |
+
- 0.02673997
|
81 |
+
- -0.03729558
|
82 |
+
- 0.00515243
|
83 |
+
- -0.07535697
|
84 |
+
- -0.13663351
|
85 |
+
- -0.12922852
|
86 |
+
- -0.11796547
|
87 |
+
- -0.07802946
|
88 |
+
- -0.00672682
|
89 |
+
- -0.04089589
|
90 |
+
- -0.00024177
|
91 |
+
- -1.74545186
|
92 |
+
- -1.54220241
|
93 |
+
- -1.0934019
|
94 |
+
- -1.16168372
|
95 |
+
- -1.23073475
|
96 |
+
- -0.78852824
|
97 |
+
- -0.71851599
|
98 |
+
- -0.52465053
|
99 |
+
- -0.02692092
|
100 |
+
- -0.00317922
|
101 |
+
- -0.06266862
|
102 |
+
- -0.10835274
|
103 |
+
- -0.12394474
|
104 |
+
- -0.11351727
|
105 |
+
- -0.07455817
|
106 |
+
- -0.00258354
|
107 |
+
- -0.04111325
|
108 |
+
- -0.02090265
|
109 |
+
- -1.89306078
|
110 |
+
- -1.30591887
|
111 |
+
- -0.63320009
|
112 |
+
- -0.26230344
|
113 |
+
- -0.2633669
|
114 |
+
- -0.5160055
|
115 |
+
- -0.95950798
|
116 |
+
- -1.45589361
|
117 |
+
- -0.0429969
|
118 |
+
- -0.00026949
|
119 |
+
- -0.05925609
|
120 |
+
- -0.09734631
|
121 |
+
- -0.12406852
|
122 |
+
- -0.11427538
|
123 |
+
- -0.07021442
|
124 |
+
- 0.01091345
|
125 |
+
- -0.05305289
|
126 |
+
- -0.02427209
|
127 |
+
- -0.19975668
|
128 |
+
- -1.71692859
|
129 |
+
- -1.53677781
|
130 |
+
- -3.89987009
|
131 |
+
- -10.70940462
|
132 |
+
- -6.71693816
|
133 |
+
- -0.28102249
|
134 |
+
- -8.86944824
|
135 |
+
- -7.95762687
|
136 |
+
- -7.13041437
|
137 |
+
- -6.64620014
|
138 |
+
- -5.11482482
|
139 |
+
- -4.42548227
|
140 |
+
- 0.00848295
|
141 |
+
- -0.06956227
|
142 |
+
- -2.6748853
|
143 |
+
- -2.21153293
|
144 |
+
- -1.67367741
|
145 |
+
- -1.07636151
|
146 |
+
- -0.79009981
|
147 |
+
- -0.16387243
|
148 |
+
- -0.18164401
|
149 |
+
- -0.04122529
|
150 |
+
- -0.00041833
|
151 |
+
- -0.05259382
|
152 |
+
- -0.0934314
|
153 |
+
- -0.11023834
|
154 |
+
- -0.10039175
|
155 |
+
- -0.06069209
|
156 |
+
- 0.01790437
|
157 |
+
- -0.04694024
|
158 |
+
- 0.00334084
|
159 |
+
- -0.06030621
|
160 |
+
- -0.58793619
|
161 |
+
- -1.27821808
|
162 |
+
- -4.97483577
|
163 |
+
- -5.66985655
|
164 |
+
- -8.43154622
|
165 |
+
- -11.15001317
|
166 |
+
- -12.95770812
|
167 |
+
- 0.0
|
168 |
+
- -14.47602729
|
169 |
+
- 0.0
|
170 |
+
odac_elem_refs:
|
171 |
+
- 0.0
|
172 |
+
- -1.11737936
|
173 |
+
- -0.00011835
|
174 |
+
- -0.2941727
|
175 |
+
- -0.03868426
|
176 |
+
- -0.34862832
|
177 |
+
- -1.31552566
|
178 |
+
- -3.12457285
|
179 |
+
- -1.6052078
|
180 |
+
- -0.49653389
|
181 |
+
- -0.01137327
|
182 |
+
- -0.21957281
|
183 |
+
- -0.0008343
|
184 |
+
- -0.2750172
|
185 |
+
- -0.88417265
|
186 |
+
- -1.887378
|
187 |
+
- -0.94903558
|
188 |
+
- -0.31628167
|
189 |
+
- -0.02014536
|
190 |
+
- -0.15901053
|
191 |
+
- -0.00731884
|
192 |
+
- -1.96521355
|
193 |
+
- -1.89045209
|
194 |
+
- -2.53057428
|
195 |
+
- -5.43600675
|
196 |
+
- -5.09739336
|
197 |
+
- -3.03088746
|
198 |
+
- -1.23786562
|
199 |
+
- -0.40650749
|
200 |
+
- -0.2416017
|
201 |
+
- -0.01139188
|
202 |
+
- -0.26282496
|
203 |
+
- -0.82446455
|
204 |
+
- -1.70237206
|
205 |
+
- -0.84245376
|
206 |
+
- -0.28544892
|
207 |
+
- -0.02239991
|
208 |
+
- -0.14115912
|
209 |
+
- -0.02840799
|
210 |
+
- -2.09540994
|
211 |
+
- -1.85863996
|
212 |
+
- -1.12257399
|
213 |
+
- -4.32965355
|
214 |
+
- -3.30670045
|
215 |
+
- -1.19460755
|
216 |
+
- -1.26257601
|
217 |
+
- -1.46832888
|
218 |
+
- -0.19779414
|
219 |
+
- -0.0144274
|
220 |
+
- -0.23668767
|
221 |
+
- -0.70836953
|
222 |
+
- -1.43186113
|
223 |
+
- -0.71701186
|
224 |
+
- -0.24883129
|
225 |
+
- -0.01118184
|
226 |
+
- -0.13173447
|
227 |
+
- -0.0318395
|
228 |
+
- -0.41195547
|
229 |
+
- -1.23134873
|
230 |
+
- -2.03082996
|
231 |
+
- 0.1375954
|
232 |
+
- -5.45866275
|
233 |
+
- -7.59139905
|
234 |
+
- -5.99965965
|
235 |
+
- -8.43495767
|
236 |
+
- -2.6578407
|
237 |
+
- -7.77349787
|
238 |
+
- -5.30762201
|
239 |
+
- -5.15109657
|
240 |
+
- -4.41466995
|
241 |
+
- -0.02995219
|
242 |
+
- -0.2544495
|
243 |
+
- -3.23821202
|
244 |
+
- -3.45887214
|
245 |
+
- -4.53635003
|
246 |
+
- -4.60979468
|
247 |
+
- -2.90707964
|
248 |
+
- -1.28286153
|
249 |
+
- -0.57716664
|
250 |
+
- -0.18337108
|
251 |
+
- -0.01135944
|
252 |
+
- -0.22045398
|
253 |
+
- -0.66150479
|
254 |
+
- -1.32506342
|
255 |
+
- -0.66500178
|
256 |
+
- -0.22643927
|
257 |
+
- -0.00728197
|
258 |
+
- -0.11208472
|
259 |
+
- -0.00757856
|
260 |
+
- -0.21798637
|
261 |
+
- -0.91078787
|
262 |
+
- -1.78187161
|
263 |
+
- -3.89912261
|
264 |
+
- -3.94192659
|
265 |
+
- -7.59026042
|
266 |
+
- 0.0
|
267 |
+
- 0.0
|
268 |
+
- 0.0
|
269 |
+
- 0.0
|
270 |
+
- 0.0
|
271 |
+
omat_elem_refs:
|
272 |
+
- 0.0
|
273 |
+
- -1.11700253
|
274 |
+
- 0.00079886
|
275 |
+
- -0.29731164
|
276 |
+
- -0.04129868
|
277 |
+
- -0.29106192
|
278 |
+
- -1.27751531
|
279 |
+
- -3.12342715
|
280 |
+
- -1.54797136
|
281 |
+
- -0.43969356
|
282 |
+
- -0.01250908
|
283 |
+
- -0.22855413
|
284 |
+
- -0.00943179
|
285 |
+
- -0.21707638
|
286 |
+
- -0.82619133
|
287 |
+
- -1.88667434
|
288 |
+
- -0.89093583
|
289 |
+
- -0.25816211
|
290 |
+
- -0.02414768
|
291 |
+
- -0.17662425
|
292 |
+
- -0.02568319
|
293 |
+
- -2.13001165
|
294 |
+
- -2.38688845
|
295 |
+
- -3.55934233
|
296 |
+
- -5.44700879
|
297 |
+
- -5.14749562
|
298 |
+
- -3.30662847
|
299 |
+
- -1.42167737
|
300 |
+
- -0.63181379
|
301 |
+
- -0.23449167
|
302 |
+
- -0.01146636
|
303 |
+
- -0.21291259
|
304 |
+
- -0.77939897
|
305 |
+
- -1.70148487
|
306 |
+
- -0.78386705
|
307 |
+
- -0.22690657
|
308 |
+
- -0.02245409
|
309 |
+
- -0.16092396
|
310 |
+
- -0.02798717
|
311 |
+
- -2.25685695
|
312 |
+
- -2.23690495
|
313 |
+
- -2.15347771
|
314 |
+
- -4.60251809
|
315 |
+
- -3.36416792
|
316 |
+
- -2.23062607
|
317 |
+
- -1.15550917
|
318 |
+
- -1.47553527
|
319 |
+
- -0.19918102
|
320 |
+
- -0.01475888
|
321 |
+
- -0.19767692
|
322 |
+
- -0.68005773
|
323 |
+
- -1.43073368
|
324 |
+
- -0.65790462
|
325 |
+
- -0.18915279
|
326 |
+
- -0.01179476
|
327 |
+
- -0.13507902
|
328 |
+
- -0.03056979
|
329 |
+
- -0.36017439
|
330 |
+
- -0.86279246
|
331 |
+
- -0.20573327
|
332 |
+
- -0.2734463
|
333 |
+
- -0.20046965
|
334 |
+
- -0.25444338
|
335 |
+
- -8.37972664
|
336 |
+
- -9.58424928
|
337 |
+
- -0.19466184
|
338 |
+
- -0.24860115
|
339 |
+
- -0.19531288
|
340 |
+
- -0.15401392
|
341 |
+
- -0.14577898
|
342 |
+
- -0.19655747
|
343 |
+
- -0.15645898
|
344 |
+
- -3.49380556
|
345 |
+
- -3.5317097
|
346 |
+
- -4.57108006
|
347 |
+
- -4.63425205
|
348 |
+
- -2.88247063
|
349 |
+
- -1.45679675
|
350 |
+
- -0.50290184
|
351 |
+
- -0.18521704
|
352 |
+
- -0.01123956
|
353 |
+
- -0.17483649
|
354 |
+
- -0.63132037
|
355 |
+
- -1.3248562
|
356 |
+
- 0.0
|
357 |
+
- 0.0
|
358 |
+
- 0.0
|
359 |
+
- 0.0
|
360 |
+
- 0.0
|
361 |
+
- -0.24135757
|
362 |
+
- -1.04601971
|
363 |
+
- -2.04574044
|
364 |
+
- -3.84544799
|
365 |
+
- -7.28626119
|
366 |
+
- -7.3136314
|
367 |
+
- 0.0
|
368 |
+
- 0.0
|
369 |
+
- 0.0
|
370 |
+
- 0.0
|
371 |
+
- 0.0
|
372 |
+
omol_elem_refs:
|
373 |
+
- 0.0
|
374 |
+
- -13.44558
|
375 |
+
- -78.82027
|
376 |
+
- -203.32564
|
377 |
+
- -398.94742
|
378 |
+
- -670.75275
|
379 |
+
- -1029.85403
|
380 |
+
- -1485.54188
|
381 |
+
- -2042.97832
|
382 |
+
- -2714.24015
|
383 |
+
- -3508.74317
|
384 |
+
- -4415.24203
|
385 |
+
- -5443.89712
|
386 |
+
- -6594.61834
|
387 |
+
- -7873.6878
|
388 |
+
- -9285.6593
|
389 |
+
- -10832.62132
|
390 |
+
- -12520.66852
|
391 |
+
- -14354.278
|
392 |
+
- -16323.54671
|
393 |
+
- -18436.47845
|
394 |
+
- -20696.18244
|
395 |
+
- -23110.5386
|
396 |
+
- -25682.99429
|
397 |
+
- -28418.37804
|
398 |
+
- -31317.92317
|
399 |
+
- -34383.42519
|
400 |
+
- -37623.46835
|
401 |
+
- -41039.92413
|
402 |
+
- -44637.38634
|
403 |
+
- -48417.14864
|
404 |
+
- -52373.87849
|
405 |
+
- -56512.76952
|
406 |
+
- -60836.14871
|
407 |
+
- -65344.28833
|
408 |
+
- -70041.24251
|
409 |
+
- -74929.56277
|
410 |
+
- -653.64777
|
411 |
+
- -833.31922
|
412 |
+
- -1038.0281
|
413 |
+
- -1273.96788
|
414 |
+
- -1542.45481
|
415 |
+
- -1850.74158
|
416 |
+
- -2193.91654
|
417 |
+
- -2577.18734
|
418 |
+
- -3004.13604
|
419 |
+
- -3477.52796
|
420 |
+
- -3997.31825
|
421 |
+
- -4563.75804
|
422 |
+
- -5171.82293
|
423 |
+
- -5828.85334
|
424 |
+
- -6535.61529
|
425 |
+
- -7291.54792
|
426 |
+
- -8099.87914
|
427 |
+
- -8962.17916
|
428 |
+
- -546.03214
|
429 |
+
- -690.6089
|
430 |
+
- -854.11237
|
431 |
+
- -12923.04096
|
432 |
+
- -14064.26124
|
433 |
+
- -15272.68689
|
434 |
+
- -16550.20551
|
435 |
+
- -17900.36515
|
436 |
+
- -19323.23406
|
437 |
+
- -20829.08848
|
438 |
+
- -22428.73258
|
439 |
+
- -24078.68008
|
440 |
+
- -25794.42097
|
441 |
+
- -27616.6819
|
442 |
+
- -29523.5526
|
443 |
+
- -31526.68012
|
444 |
+
- -33615.37779
|
445 |
+
- -1300.17791
|
446 |
+
- -1544.40924
|
447 |
+
- -1818.62298
|
448 |
+
- -2123.14417
|
449 |
+
- -2461.76028
|
450 |
+
- -2833.76287
|
451 |
+
- -3242.79895
|
452 |
+
- -3690.363
|
453 |
+
- -4174.99772
|
454 |
+
- -4691.75674
|
455 |
+
- -5245.36013
|
456 |
+
- -5838.12005
|
457 |
+
- -6469.07296
|
458 |
+
- -7140.86455
|
459 |
+
- -7854.60638
|
460 |
+
- 0.0
|
461 |
+
- 0.0
|
462 |
+
- 0.0
|
463 |
+
- 0.0
|
464 |
+
- 0.0
|
465 |
+
- 0.0
|
466 |
+
- 0.0
|
467 |
+
- 0.0
|
468 |
+
- 0.0
|
469 |
+
- 0.0
|
470 |
+
- 0.0
|
471 |
+
- 0.0
|
472 |
+
- 0.0
|
473 |
+
omc_elem_refs:
|
474 |
+
- 0.0
|
475 |
+
- -0.02831808
|
476 |
+
- 4.512e-05
|
477 |
+
- -0.03227157
|
478 |
+
- -0.03842519
|
479 |
+
- -0.05829283
|
480 |
+
- -0.0845041
|
481 |
+
- -0.08806738
|
482 |
+
- -0.09021346
|
483 |
+
- -0.06669846
|
484 |
+
- -0.01218631
|
485 |
+
- -0.03650269
|
486 |
+
- -0.00059093
|
487 |
+
- -0.05787736
|
488 |
+
- -0.08730952
|
489 |
+
- -0.0975534
|
490 |
+
- -0.09264199
|
491 |
+
- -0.07124762
|
492 |
+
- -0.02374602
|
493 |
+
- -0.05299112
|
494 |
+
- -0.02631476
|
495 |
+
- -1.7772147
|
496 |
+
- -1.25083444
|
497 |
+
- -0.79579447
|
498 |
+
- -0.49099317
|
499 |
+
- -0.31414986
|
500 |
+
- -0.20292182
|
501 |
+
- -0.14011632
|
502 |
+
- -0.09929659
|
503 |
+
- -0.03771207
|
504 |
+
- -0.01117902
|
505 |
+
- -0.06168715
|
506 |
+
- -0.08873364
|
507 |
+
- -0.09512942
|
508 |
+
- -0.09035978
|
509 |
+
- -0.06910849
|
510 |
+
- -0.02244872
|
511 |
+
- -0.05303651
|
512 |
+
- -0.02871903
|
513 |
+
- -1.94805417
|
514 |
+
- -1.33379896
|
515 |
+
- -0.69169331
|
516 |
+
- -0.26184306
|
517 |
+
- -0.20631599
|
518 |
+
- -0.48251608
|
519 |
+
- -0.96911893
|
520 |
+
- -1.47569462
|
521 |
+
- -0.03845194
|
522 |
+
- -0.0142445
|
523 |
+
- -0.07118991
|
524 |
+
- -0.09940292
|
525 |
+
- -0.09235056
|
526 |
+
- -0.08755943
|
527 |
+
- -0.06544925
|
528 |
+
- -0.01246646
|
529 |
+
- -0.04692937
|
530 |
+
- -0.03225123
|
531 |
+
- -0.26086039
|
532 |
+
- -27.20024339
|
533 |
+
- -0.08412926
|
534 |
+
- -0.08225924
|
535 |
+
- -0.07799715
|
536 |
+
- -0.07806185
|
537 |
+
- 0.00043759
|
538 |
+
- -0.07459766
|
539 |
+
- 0.0
|
540 |
+
- -0.06842841
|
541 |
+
- -0.07758266
|
542 |
+
- -0.07025152
|
543 |
+
- -0.08055003
|
544 |
+
- -0.07118177
|
545 |
+
- -0.07159568
|
546 |
+
- -2.69202862
|
547 |
+
- -2.21926765
|
548 |
+
- -1.679756
|
549 |
+
- -1.06135075
|
550 |
+
- -0.4554231
|
551 |
+
- -0.14488432
|
552 |
+
- -0.18377098
|
553 |
+
- -0.03603118
|
554 |
+
- -0.01076585
|
555 |
+
- -0.06381411
|
556 |
+
- -0.0905623
|
557 |
+
- -0.10095787
|
558 |
+
- -0.09501217
|
559 |
+
- -0.0574478
|
560 |
+
- -0.00599173
|
561 |
+
- -0.04134751
|
562 |
+
- -0.0082683
|
563 |
+
- -0.08704692
|
564 |
+
- -0.49656425
|
565 |
+
- -5.24233138
|
566 |
+
- -2.32542606
|
567 |
+
- -4.3376616
|
568 |
+
- -5.96430676
|
569 |
+
- 0.0
|
570 |
+
- 0.0
|
571 |
+
- -0.03842519
|
572 |
+
- 0.0
|
573 |
+
- 0.0
|
574 |
+
"""
|
575 |
+
try:
|
576 |
+
ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML)
|
577 |
+
except yaml.YAMLError as e:
|
578 |
+
# st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow
|
579 |
+
print(f"Error parsing YAML reference energies: {e}")
|
580 |
+
ELEMENT_REF_ENERGIES = {} # Fallback
|
581 |
+
|
582 |
# Check if running on Streamlit Cloud vs locally
|
583 |
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
|
584 |
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud
|
|
|
591 |
layout="wide"
|
592 |
)
|
593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
# Title and description
|
595 |
st.markdown('## MLIP Playground', unsafe_allow_html=True)
|
596 |
st.write('#### Run, test and compare >17 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials')
|
|
|
611 |
}
|
612 |
|
613 |
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
xyz_str = ""
|
615 |
xyz_str += f"{len(atoms_obj)}\n"
|
616 |
xyz_str += "Structure\n"
|
617 |
for atom in atoms_obj:
|
618 |
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
619 |
|
|
|
620 |
view = py3Dmol.view(width=width, height=height)
|
621 |
view.addModel(xyz_str, "xyz")
|
622 |
|
|
|
623 |
if style.lower() == 'ball_stick':
|
624 |
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
|
625 |
elif style.lower() == 'stick':
|
|
|
627 |
elif style.lower() == 'ball':
|
628 |
view.setStyle({'sphere': {'scale': 0.4}})
|
629 |
else:
|
|
|
630 |
view.setStyle({'stick': {'radius': 0.15}})
|
631 |
|
632 |
+
if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any()
|
|
|
633 |
cell = atoms_obj.get_cell()
|
|
|
|
|
634 |
origin = np.array([0.0, 0.0, 0.0])
|
635 |
+
if cell is not None and cell.any(): # Ensure cell is not None and not all zeros
|
636 |
+
edges = [
|
637 |
+
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
|
638 |
+
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
|
639 |
+
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
|
640 |
+
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
|
641 |
+
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
|
642 |
+
]
|
643 |
+
for start, end in edges:
|
644 |
+
view.addCylinder({
|
645 |
+
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
|
646 |
+
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
|
647 |
+
'radius': 0.05, 'color': 'black', 'alpha': 0.7
|
648 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
view.zoomTo()
|
650 |
view.setBackgroundColor('white')
|
|
|
651 |
return view
|
652 |
|
653 |
+
opt_log = [] # Define globally or pass around if necessary
|
654 |
+
table_placeholder = st.empty() # Define globally if updated from callback
|
655 |
|
|
|
656 |
def streamlit_log(opt):
|
657 |
+
global opt_log, table_placeholder
|
658 |
+
try:
|
659 |
+
energy = opt.atoms.get_potential_energy()
|
660 |
+
forces = opt.atoms.get_forces()
|
661 |
+
fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0
|
662 |
+
opt_log.append({
|
663 |
+
"Step": opt.nsteps,
|
664 |
+
"Energy (eV)": round(energy, 6),
|
665 |
+
"Fmax (eV/Å)": round(fmax_step, 6)
|
666 |
+
})
|
667 |
+
df = pd.DataFrame(opt_log)
|
668 |
+
table_placeholder.dataframe(df)
|
669 |
+
except Exception as e:
|
670 |
+
st.warning(f"Error in optimization logger: {e}")
|
671 |
+
|
672 |
+
|
673 |
def check_atom_limit(atoms_obj, selected_model):
|
674 |
if atoms_obj is None:
|
675 |
return True
|
|
|
676 |
num_atoms = len(atoms_obj)
|
677 |
+
limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD
|
678 |
+
if num_atoms > limit:
|
679 |
+
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.")
|
|
|
|
|
|
|
|
|
680 |
return False
|
681 |
return True
|
682 |
|
|
|
|
|
683 |
MACE_MODELS = {
|
684 |
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
|
685 |
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
|
|
|
690 |
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
|
691 |
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
|
692 |
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
|
693 |
+
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code
|
694 |
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
|
695 |
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
|
696 |
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
|
697 |
}
|
698 |
|
|
|
699 |
FAIRCHEM_MODELS = {
|
700 |
"UMA Small": "uma-sm",
|
701 |
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
|
|
|
703 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
704 |
}
|
705 |
|
706 |
+
@st.cache_resource
|
707 |
def get_mace_model(model_path, device, selected_default_dtype):
|
|
|
708 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
709 |
|
710 |
+
@st.cache_resource
|
711 |
+
def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict
|
712 |
+
predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device)
|
713 |
+
if selected_model_name == "UMA Small":
|
714 |
+
calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc)
|
715 |
else:
|
716 |
calc = FAIRChemCalculator(predictor)
|
717 |
return calc
|
718 |
|
|
|
719 |
st.sidebar.markdown("## Input Options")
|
|
|
|
|
720 |
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
|
|
|
|
|
721 |
atoms = None
|
722 |
|
|
|
723 |
if input_method == "Upload File":
|
724 |
+
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
|
725 |
+
if uploaded_file:
|
|
|
|
|
|
|
726 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
727 |
tmp_file.write(uploaded_file.getvalue())
|
728 |
tmp_filepath = tmp_file.name
|
|
|
729 |
try:
|
|
|
730 |
atoms = read(tmp_filepath)
|
731 |
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
|
732 |
except Exception as e:
|
733 |
st.sidebar.error(f"Error loading file: {str(e)}")
|
734 |
+
finally:
|
735 |
+
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
736 |
+
os.unlink(tmp_filepath)
|
737 |
|
|
|
738 |
elif input_method == "Select Example":
|
739 |
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
|
|
740 |
if example_name:
|
741 |
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
742 |
try:
|
|
|
745 |
except Exception as e:
|
746 |
st.sidebar.error(f"Error loading example: {str(e)}")
|
747 |
|
|
|
748 |
elif input_method == "Paste Content":
|
749 |
+
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
|
|
|
|
750 |
content = st.sidebar.text_area("Paste file content here:", height=200)
|
751 |
+
if content:
|
|
|
752 |
try:
|
753 |
+
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
|
|
|
|
|
|
|
754 |
suffix = suffix_map.get(file_format, ".xyz")
|
|
|
755 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
756 |
tmp_file.write(content.encode())
|
757 |
tmp_filepath = tmp_file.name
|
|
|
|
|
758 |
atoms = read(tmp_filepath)
|
759 |
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
|
|
|
|
|
|
760 |
except Exception as e:
|
761 |
st.sidebar.error(f"Error parsing content: {str(e)}")
|
762 |
+
finally:
|
763 |
+
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
764 |
+
os.unlink(tmp_filepath)
|
765 |
+
|
766 |
+
if atoms is not None:
|
767 |
+
if not hasattr(atoms, 'info'):
|
768 |
+
atoms.info = {}
|
769 |
+
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
|
770 |
+
atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1)
|
771 |
+
|
772 |
|
|
|
|
|
|
|
773 |
st.sidebar.markdown("## Model Selection")
|
774 |
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
|
775 |
|
776 |
+
selected_task_type = None # For FairChem UMA
|
777 |
if model_type == "MACE":
|
778 |
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
|
779 |
model_path = MACE_MODELS[selected_model]
|
780 |
if selected_model == "MACE OMAT Medium":
|
781 |
+
st.sidebar.warning("Using model under Academic Software License (ASL).")
|
782 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
|
783 |
if model_type == "FairChem":
|
784 |
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
|
785 |
model_path = FAIRCHEM_MODELS[selected_model]
|
786 |
if selected_model == "UMA Small":
|
787 |
+
st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.")
|
788 |
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
|
789 |
+
if selected_task_type == "omol" and atoms is not None:
|
790 |
+
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0))
|
791 |
+
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
|
|
|
792 |
atoms.info["charge"] = charge
|
793 |
+
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
|
794 |
+
|
795 |
if atoms is not None:
|
796 |
+
if not check_atom_limit(atoms, selected_model):
|
797 |
+
st.stop() # Stop execution if limit exceeded
|
798 |
+
|
799 |
+
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1)
|
|
|
800 |
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
|
801 |
|
802 |
if device == "cpu" and torch.cuda.is_available():
|
803 |
+
st.sidebar.info("GPU is available but CPU was selected.")
|
804 |
elif device == "cpu" and not torch.cuda.is_available():
|
805 |
+
st.sidebar.info("No GPU detected. Using CPU.")
|
806 |
|
|
|
807 |
st.sidebar.markdown("## Task Selection")
|
808 |
task = st.sidebar.selectbox("Select Calculation Task:",
|
809 |
["Energy Calculation",
|
810 |
"Energy + Forces Calculation",
|
811 |
+
"Atomization/Cohesive Energy", # New Task Added
|
812 |
"Geometry Optimization",
|
813 |
"Cell + Geometry Optimization"])
|
814 |
|
|
|
815 |
if "Optimization" in task:
|
816 |
st.sidebar.markdown("### Optimization Parameters")
|
817 |
+
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
|
818 |
+
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
|
819 |
+
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
|
820 |
+
|
821 |
|
|
|
822 |
if atoms is not None:
|
823 |
col1, col2 = st.columns(2)
|
824 |
|
825 |
with col1:
|
826 |
st.markdown('### Structure Visualization', unsafe_allow_html=True)
|
827 |
+
view_3d = get_structure_viz2(atoms, style='stick', show_unit_cell=True, width=400, height=400)
|
828 |
+
st.components.v1.html(view_3d._make_html(), width=400, height=400)
|
829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
830 |
st.markdown("### Structure Information")
|
831 |
atoms_info = {
|
832 |
"Number of Atoms": len(atoms),
|
833 |
"Chemical Formula": atoms.get_chemical_formula(),
|
834 |
+
"Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(),
|
835 |
+
"Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic",
|
836 |
+
"Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols()))))
|
837 |
}
|
|
|
838 |
for key, value in atoms_info.items():
|
839 |
st.write(f"**{key}:** {value}")
|
840 |
|
841 |
with col2:
|
842 |
st.markdown('## Calculation Setup', unsafe_allow_html=True)
|
|
|
|
|
843 |
st.markdown("### Selected Model")
|
844 |
st.write(f"**Model Type:** {model_type}")
|
845 |
st.write(f"**Model:** {selected_model}")
|
846 |
+
if model_type == "FairChem" and selected_model == "UMA Small":
|
847 |
+
st.write(f"**UMA Task Type:** {selected_task_type}")
|
848 |
st.write(f"**Device:** {device}")
|
849 |
|
850 |
st.markdown("### Selected Task")
|
|
|
853 |
if "Optimization" in task:
|
854 |
st.write(f"**Max Steps:** {max_steps}")
|
855 |
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
|
856 |
+
st.write(f"**Optimizer:** {optimizer_type}")
|
857 |
|
|
|
858 |
run_calculation = st.button("Run Calculation", type="primary")
|
859 |
|
860 |
if run_calculation:
|
861 |
+
results = {}
|
862 |
+
global opt_log, table_placeholder # Ensure they are accessible
|
863 |
+
opt_log = [] # Reset log for each run
|
864 |
+
if "Optimization" in task:
|
865 |
+
table_placeholder = st.empty() # Recreate placeholder for table
|
866 |
+
|
867 |
try:
|
868 |
+
with st.spinner("Running calculation... Please wait."):
|
|
|
869 |
calc_atoms = atoms.copy()
|
870 |
|
|
|
871 |
if model_type == "MACE":
|
872 |
+
# st.write("Setting up MACE calculator...")
|
873 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
874 |
else: # FairChem
|
875 |
+
# st.write("Setting up FairChem calculator...")
|
876 |
+
# Workaround for potential dtype issues when switching models
|
877 |
+
if device == "cpu": # Ensure torch default dtype matches if needed
|
878 |
+
torch.set_default_dtype(torch.float32)
|
879 |
+
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
|
880 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
|
|
|
|
881 |
|
882 |
+
calc_atoms.calc = calc
|
|
|
883 |
|
884 |
if task == "Energy Calculation":
|
|
|
885 |
energy = calc_atoms.get_potential_energy()
|
886 |
results["Energy"] = f"{energy:.6f} eV"
|
887 |
|
888 |
elif task == "Energy + Forces Calculation":
|
|
|
889 |
energy = calc_atoms.get_potential_energy()
|
890 |
forces = calc_atoms.get_forces()
|
891 |
+
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
|
|
|
892 |
results["Energy"] = f"{energy:.6f} eV"
|
893 |
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
|
894 |
+
|
895 |
+
elif task == "Atomization/Cohesive Energy":
|
896 |
+
st.write("Calculating system energy...")
|
897 |
+
E_system = calc_atoms.get_potential_energy()
|
898 |
+
num_atoms = len(calc_atoms)
|
899 |
+
|
900 |
+
if num_atoms == 0:
|
901 |
+
st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.")
|
902 |
+
results["Error"] = "System has no atoms."
|
903 |
+
else:
|
904 |
+
atomic_numbers = calc_atoms.get_atomic_numbers()
|
905 |
+
E_isolated_atoms_total = 0.0
|
906 |
+
calculation_possible = True
|
907 |
+
|
908 |
+
if model_type == "FairChem":
|
909 |
+
st.write("Fetching FairChem reference energies for isolated atoms...")
|
910 |
+
ref_key_suffix = "_elem_refs"
|
911 |
+
chosen_ref_list_name = None
|
912 |
+
if selected_model == "UMA Small":
|
913 |
+
if selected_task_type:
|
914 |
+
chosen_ref_list_name = selected_task_type + ref_key_suffix
|
915 |
+
elif "ESEN" in selected_model:
|
916 |
+
chosen_ref_list_name = "omol" + ref_key_suffix
|
917 |
+
|
918 |
+
if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES:
|
919 |
+
ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name]
|
920 |
+
missing_Z_refs = []
|
921 |
+
for Z_val in atomic_numbers:
|
922 |
+
if Z_val > 0 and Z_val < len(ref_energies):
|
923 |
+
E_isolated_atoms_total += ref_energies[Z_val]
|
924 |
+
else:
|
925 |
+
if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val)
|
926 |
+
if missing_Z_refs:
|
927 |
+
st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} "
|
928 |
+
f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). "
|
929 |
+
"These atoms are treated as having 0 reference energy.")
|
930 |
+
else:
|
931 |
+
st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' "
|
932 |
+
f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.")
|
933 |
+
results["Error"] = "Missing FairChem reference energies."
|
934 |
+
calculation_possible = False
|
935 |
+
|
936 |
+
elif model_type == "MACE":
|
937 |
+
st.write("Calculating isolated atom energies with MACE...")
|
938 |
+
unique_atomic_numbers = sorted(list(set(atomic_numbers)))
|
939 |
+
atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers}
|
940 |
+
|
941 |
+
progress_text = "Calculating isolated atom energies: 0% complete"
|
942 |
+
mace_progress_bar = st.progress(0, text=progress_text)
|
943 |
+
|
944 |
+
for i, Z_unique in enumerate(unique_atomic_numbers):
|
945 |
+
isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False)
|
946 |
+
if not hasattr(isolated_atom, 'info'): isolated_atom.info = {}
|
947 |
+
isolated_atom.info["charge"] = 0
|
948 |
+
isolated_atom.info["spin"] = 0
|
949 |
+
isolated_atom.calc = calc # Use the same MACE calculator
|
950 |
+
|
951 |
+
E_isolated_atom_type = isolated_atom.get_potential_energy()
|
952 |
+
E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique]
|
953 |
+
|
954 |
+
progress_val = (i + 1) / len(unique_atomic_numbers)
|
955 |
+
mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete")
|
956 |
+
mace_progress_bar.empty()
|
957 |
+
|
958 |
+
if calculation_possible:
|
959 |
+
is_periodic = any(calc_atoms.pbc)
|
960 |
+
if is_periodic:
|
961 |
+
cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms
|
962 |
+
results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom"
|
963 |
+
else:
|
964 |
+
atomization_E = E_isolated_atoms_total - E_system
|
965 |
+
results["Atomization Energy"] = f"{atomization_E:.6f} eV"
|
966 |
+
|
967 |
+
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
|
968 |
+
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
|
969 |
+
|
970 |
+
elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
|
971 |
+
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
|
972 |
|
973 |
+
if optimizer_type == "BFGS":
|
974 |
+
opt = BFGS(opt_atoms_obj)
|
975 |
+
elif optimizer_type == "LBFGS":
|
976 |
+
opt = LBFGS(opt_atoms_obj)
|
977 |
+
else: # FIRE
|
978 |
+
opt = FIRE(opt_atoms_obj)
|
979 |
|
980 |
+
opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
981 |
|
982 |
+
st.write(f"Running {task.lower()}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
opt.run(fmax=fmax, steps=max_steps)
|
984 |
|
|
|
985 |
energy = calc_atoms.get_potential_energy()
|
986 |
forces = calc_atoms.get_forces()
|
987 |
+
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
|
988 |
|
989 |
results["Final Energy"] = f"{energy:.6f} eV"
|
990 |
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
|
991 |
results["Steps Taken"] = opt.get_number_of_steps()
|
992 |
results["Converged"] = "Yes" if opt.converged() else "No"
|
993 |
+
if task == "Cell + Geometry Optimization":
|
994 |
+
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist()
|
995 |
|
|
|
996 |
st.success("Calculation completed successfully!")
|
997 |
st.markdown("### Results")
|
998 |
for key, value in results.items():
|
999 |
st.write(f"**{key}:** {value}")
|
1000 |
|
1001 |
+
if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
|
|
|
1002 |
st.markdown("### Optimized Structure")
|
1003 |
+
# Need get_structure_viz function that takes atoms obj
|
1004 |
+
def get_structure_viz_simple(atoms_obj_viz):
|
1005 |
+
xyz_str_viz = f"{len(atoms_obj_viz)}\nStructure\n"
|
1006 |
+
for atom_viz in atoms_obj_viz:
|
1007 |
+
xyz_str_viz += f"{atom_viz.symbol} {atom_viz.position[0]:.6f} {atom_viz.position[1]:.6f} {atom_viz.position[2]:.6f}\n"
|
1008 |
+
view_viz = py3Dmol.view(width=400, height=400)
|
1009 |
+
view_viz.addModel(xyz_str_viz, "xyz")
|
1010 |
+
view_viz.setStyle({'stick': {}})
|
1011 |
+
if any(atoms_obj_viz.pbc): # Show cell for optimized periodic structures
|
1012 |
+
cell_viz = atoms_obj_viz.get_cell()
|
1013 |
+
if cell_viz is not None and cell_viz.any():
|
1014 |
+
# Simplified cell drawing for brevity, use get_structure_viz2 if full cell needed
|
1015 |
+
view_viz.addUnitCell({'box': {'lx':cell_viz.lengths()[0],'ly':cell_viz.lengths()[1],'lz':cell_viz.lengths()[2],
|
1016 |
+
'hx':cell_viz.cellpar()[3],'hy':cell_viz.cellpar()[4],'hz':cell_viz.cellpar()[5]}})
|
1017 |
+
|
1018 |
+
view_viz.zoomTo()
|
1019 |
+
view_viz.setBackgroundColor('white')
|
1020 |
+
return view_viz
|
1021 |
+
|
1022 |
+
opt_view = get_structure_viz2(calc_atoms, style='stick', show_unit_cell=True, width=400, height=400)
|
1023 |
+
st.components.v1.html(opt_view._make_html(), width=400, height=400)
|
1024 |
|
1025 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
|
1026 |
+
write(tmp_file_opt.name, calc_atoms, format="xyz")
|
1027 |
+
tmp_filepath_opt = tmp_file_opt.name
|
|
|
|
|
1028 |
|
1029 |
+
with open(tmp_filepath_opt, 'r') as file_opt:
|
1030 |
+
xyz_content_opt = file_opt.read()
|
|
|
1031 |
|
1032 |
st.download_button(
|
1033 |
label="Download Optimized Structure (XYZ)",
|
1034 |
+
data=xyz_content_opt,
|
1035 |
file_name="optimized_structure.xyz",
|
1036 |
mime="chemical/x-xyz"
|
1037 |
)
|
1038 |
+
os.unlink(tmp_filepath_opt)
|
|
|
|
|
1039 |
|
1040 |
except Exception as e:
|
1041 |
+
st.error(f"🔴 Calculation error: {str(e)}")
|
1042 |
+
st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
|
1043 |
+
import traceback
|
1044 |
+
st.error(f"Traceback: {traceback.format_exc()}")
|
1045 |
+
|
1046 |
else:
|
1047 |
+
st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.")
|
|
|
|
|
1048 |
|
|
|
1049 |
st.markdown("---")
|
1050 |
+
with st.expander('ℹ️ About This App & Foundational MLIPs'):
|
|
|
1051 |
st.write("""
|
1052 |
+
**Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).**
|
1053 |
+
This application allows you to perform atomistic simulations using pre-trained foundational MLIPs
|
1054 |
+
from the MACE and FairChem (by Meta AI) libraries.
|
1055 |
|
1056 |
+
**Features:**
|
1057 |
+
- Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples.
|
1058 |
+
- Select from various MACE and FairChem models.
|
1059 |
+
- Calculate energies, forces, and perform geometry/cell optimizations.
|
1060 |
+
- **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems).
|
1061 |
+
- Visualize atomic structures in 3D and download results.
|
1062 |
|
1063 |
+
**Quick Start:**
|
1064 |
+
1. **Input**: Choose an input method in the sidebar (e.g., "Select Example").
|
1065 |
+
2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials).
|
1066 |
+
3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization").
|
1067 |
+
4. **Run**: Click "Run Calculation" and view the results.
|
1068 |
+
|
1069 |
+
**Atomization/Cohesive Energy Notes:**
|
1070 |
+
- **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules).
|
1071 |
+
- **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems.
|
1072 |
+
- For **MACE models**, isolated atom energies are computed on-the-fly.
|
1073 |
+
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
|
1074 |
""")
|
1075 |
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️")
|
1076 |
+
st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) ([Fundamental AI Research (FAIR) team, Meta AI](https://ai.meta.com/research/fair/) and [Ananth Govind Rajan Group, IISc Bangalore](https://www.agrgroup.org/))")
|
1077 |
+
|
1078 |
+
```
|