Hamidreza Eivazi
commited on
Commit
·
7badbdd
1
Parent(s):
b3e496c
add app
Browse files- Dockerfile +18 -0
- app.py +101 -0
- predict.py +97 -0
- requirements.txt +10 -0
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime
|
2 |
+
FROM python:3.9
|
3 |
+
|
4 |
+
# Set the working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy all files to the container
|
8 |
+
COPY . /app
|
9 |
+
|
10 |
+
# Install dependencies
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
# Expose the port Dash runs on
|
14 |
+
EXPOSE 7860
|
15 |
+
|
16 |
+
# Command to run the app
|
17 |
+
CMD ["python", "app.py"]
|
18 |
+
|
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
from dash import Dash, dcc, html, Input, Output
|
5 |
+
import dash_bootstrap_components as dbc
|
6 |
+
from predict import load_model, Configs
|
7 |
+
import jax.numpy as jnp
|
8 |
+
|
9 |
+
# Initialize Dash with a dark theme
|
10 |
+
app = Dash(__name__, external_stylesheets=[dbc.themes.DARKLY])
|
11 |
+
|
12 |
+
nn = Configs.nn
|
13 |
+
x = np.linspace(-0.5, 0.5, nn, endpoint=True)
|
14 |
+
y = np.linspace(-0.5, 0.5, nn, endpoint=True)
|
15 |
+
|
16 |
+
forward1 = load_model('RVE1')
|
17 |
+
forward2 = load_model('RVE2')
|
18 |
+
forward3 = load_model('RVE3')
|
19 |
+
|
20 |
+
# Nord color palette
|
21 |
+
nord_bg = "#2E3440" # Background
|
22 |
+
nord_fg = "#D8DEE9" # Text color
|
23 |
+
nord_grid = "#4C566A" # Grid color
|
24 |
+
colormap = 'RdBu' # Better for dark themes
|
25 |
+
|
26 |
+
# PDE Solver
|
27 |
+
def solve_pde(boundary_conditions, rve):
|
28 |
+
x = jnp.array(boundary_conditions)
|
29 |
+
return forward1(x) if rve == 'RVE1' else forward2(x) if rve == 'RVE2' else forward3(x)
|
30 |
+
|
31 |
+
initial_bc = [0.02, 0.02, 0.02]
|
32 |
+
d, s = solve_pde(initial_bc, 'RVE1')
|
33 |
+
|
34 |
+
# Function to create Heatmap figures with a dark theme
|
35 |
+
def create_figure(z_data, title):
|
36 |
+
fig = go.Figure(go.Contour(x=x, y=y, z=z_data, colorscale=colormap, showscale=False, line_smoothing=1.2, line_width=0))
|
37 |
+
fig.update_layout(
|
38 |
+
title=title, title_font=dict(color=nord_fg, size=18),
|
39 |
+
paper_bgcolor=nord_bg, plot_bgcolor=nord_bg, font=dict(color=nord_fg),
|
40 |
+
xaxis=dict(gridcolor=nord_grid, tickvals=[], scaleanchor="y", showgrid=False, zeroline=False, showline=False),
|
41 |
+
yaxis=dict(gridcolor=nord_grid, tickvals=[], showgrid=False, zeroline=False, showline=False),
|
42 |
+
autosize=False, width=300, height=300, margin=dict(l=0, r=0, t=30, b=0)
|
43 |
+
)
|
44 |
+
return fig
|
45 |
+
|
46 |
+
titles = ["Disp. x", "Disp. y", "Stress xx", "Stress yy", "Stress xy"]
|
47 |
+
pde_figs = [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
|
48 |
+
create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
|
49 |
+
create_figure(s[:,:,2], titles[4])]
|
50 |
+
|
51 |
+
# Layout
|
52 |
+
app.layout = dbc.Container([
|
53 |
+
html.H1("RVE simulation using Equilibrium Neural Operator (EquiNO)", style={'textAlign': 'center', 'color': nord_fg, 'padding': '100px'}),
|
54 |
+
|
55 |
+
dbc.Row([
|
56 |
+
dbc.Col([
|
57 |
+
html.H3("Prescribed Global Strains", style={'textAlign': 'center', 'color': nord_fg, 'padding': '10px'}),
|
58 |
+
html.Label("Strain xx", style={'font-size': '18px', 'color': nord_fg}),
|
59 |
+
dcc.Slider(id='bc1-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[0], marks=None,
|
60 |
+
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
|
61 |
+
className='slider-no-border'),
|
62 |
+
|
63 |
+
html.Label("Strain yy", style={'font-size': '18px', 'color': nord_fg}),
|
64 |
+
dcc.Slider(id='bc2-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[1], marks=None,
|
65 |
+
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
|
66 |
+
className='slider-no-border'),
|
67 |
+
|
68 |
+
html.Label("Strain xy", style={'font-size': '18px', 'color': nord_fg}),
|
69 |
+
dcc.Slider(id='bc3-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[2], marks=None,
|
70 |
+
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
|
71 |
+
className='slider-no-border'),
|
72 |
+
|
73 |
+
html.Label("RVE Selector", style={'font-size': '18px', 'color': nord_fg}),
|
74 |
+
dcc.Dropdown(id='rve-selector', options=[
|
75 |
+
{'label': 'RVE1', 'value': 'RVE1'}, {'label': 'RVE2', 'value': 'RVE2'}, {'label': 'RVE3', 'value': 'RVE3'}
|
76 |
+
], value='RVE1', clearable=False, style={'backgroundColor': 'white', 'color': nord_bg}),
|
77 |
+
], width=3, style={'padding': '100px', 'backgroundColor': nord_bg}),
|
78 |
+
|
79 |
+
dbc.Col([
|
80 |
+
dbc.Row([dcc.Graph(id=f'pde-figure{i+1}', figure=pde_figs[i], style={'width': '19%', 'display': 'inline-block'}) for i in range(5)]),
|
81 |
+
], width=9, style={'padding': '100px', 'backgroundColor': nord_bg})
|
82 |
+
])
|
83 |
+
], fluid=True, style={'backgroundColor': nord_bg, 'minHeight': '200vh'})
|
84 |
+
|
85 |
+
# Callback
|
86 |
+
@app.callback(
|
87 |
+
[Output(f'pde-figure{i+1}', 'figure') for i in range(5)],
|
88 |
+
[Input('bc1-slider', 'value'), Input('bc2-slider', 'value'), Input('bc3-slider', 'value'), Input('rve-selector', 'value')]
|
89 |
+
)
|
90 |
+
def update_pde(bc1, bc2, bc3, selected_rve):
|
91 |
+
new_bc = [bc1, bc2, bc3]
|
92 |
+
d, s = solve_pde(new_bc, selected_rve)
|
93 |
+
return [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
|
94 |
+
create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
|
95 |
+
create_figure(s[:,:,2], titles[4])]
|
96 |
+
|
97 |
+
# Run the app
|
98 |
+
server = app.server
|
99 |
+
if __name__ == "__main__":
|
100 |
+
app.run_server(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|
101 |
+
|
predict.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from scipy.interpolate import LinearNDInterpolator
|
6 |
+
|
7 |
+
class Configs:
|
8 |
+
nn = 101
|
9 |
+
|
10 |
+
def load_model(rve):
|
11 |
+
mdir = f'./trained_models/EquiNO_{rve}'
|
12 |
+
model = tf.keras.models.load_model(mdir)
|
13 |
+
nodes = np.array(model.kinema.get_config()['nodeCoord'])
|
14 |
+
scaling_bu = model.get_config()['scaling_bu']
|
15 |
+
scaling_bs = model.get_config()['scaling_bs']
|
16 |
+
scale_input = 0.04
|
17 |
+
|
18 |
+
shp = np.array(model.kinema.get_config()['shp'])
|
19 |
+
elemInc = np.array(model.kinema.get_config()['elemInc'])
|
20 |
+
|
21 |
+
# Extract model outputs and reshape arrays
|
22 |
+
t, _, t_s = model.trunk(nodes)
|
23 |
+
t = t.numpy()
|
24 |
+
t_s = t_s.numpy()
|
25 |
+
t_s = t_s.reshape((-1, 9, 16, 3))
|
26 |
+
t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0]))
|
27 |
+
|
28 |
+
t_sg = np.zeros((nodes.shape[0], 16, 3))
|
29 |
+
t_sg[elemInc[:, 2:] - 1] = np.array(t_s)
|
30 |
+
t_s = t_sg
|
31 |
+
|
32 |
+
x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
|
33 |
+
y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
|
34 |
+
xx, yy = np.meshgrid(x, y)
|
35 |
+
|
36 |
+
def interp(d, s):
|
37 |
+
d = d.reshape((-1, 16*2))
|
38 |
+
s = s.reshape((-1, 16*3))
|
39 |
+
d = LinearNDInterpolator(nodes, d)(xx, yy)
|
40 |
+
s = LinearNDInterpolator(nodes, s)(xx, yy)
|
41 |
+
d = d.reshape((-1, 16, 2))
|
42 |
+
s = s.reshape((-1, 16, 3))
|
43 |
+
return d, s
|
44 |
+
|
45 |
+
t, t_s = interp(t, t_s)
|
46 |
+
|
47 |
+
nodes = np.stack((xx.flatten(), yy.flatten()), 1)
|
48 |
+
|
49 |
+
weights = model.branch.get_weights()
|
50 |
+
|
51 |
+
weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)]
|
52 |
+
params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u]
|
53 |
+
|
54 |
+
weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)]
|
55 |
+
params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s]
|
56 |
+
|
57 |
+
cd = jnp.array([[2.0, 0.0], [0.0, 2.0]])
|
58 |
+
|
59 |
+
del model
|
60 |
+
|
61 |
+
@jax.jit
|
62 |
+
def periodic_disp(x):
|
63 |
+
matrix = nodes[..., None] * cd[None, ...]
|
64 |
+
matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1)
|
65 |
+
return jnp.einsum('ij,ljm->ilm', x, matrix)
|
66 |
+
|
67 |
+
@jax.jit
|
68 |
+
def jax_branch_s(x):
|
69 |
+
x_n = x
|
70 |
+
for (w, b) in params_s[:-1]:
|
71 |
+
x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
|
72 |
+
final_w, final_b = params_s[-1]
|
73 |
+
return jnp.dot(x_n, final_w) + final_b
|
74 |
+
|
75 |
+
@jax.jit
|
76 |
+
def jax_branch_u(x):
|
77 |
+
x_n = x
|
78 |
+
for (w, b) in params_u[:-1]:
|
79 |
+
x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
|
80 |
+
final_w, final_b = params_u[-1]
|
81 |
+
return jnp.dot(x_n, final_w) + final_b
|
82 |
+
|
83 |
+
@jax.jit
|
84 |
+
def forward(x):
|
85 |
+
x = x.reshape((1, 3))
|
86 |
+
b = jax_branch_u(x / scale_input)
|
87 |
+
b_s = jax_branch_s(x / scale_input)
|
88 |
+
b = b * scaling_bu[1] + scaling_bu[0]
|
89 |
+
b_s = b_s * scaling_bs[1] + scaling_bs[0]
|
90 |
+
u = jnp.einsum('im,lmn->iln', b, t)
|
91 |
+
s = jnp.einsum('im,lmn->iln', b_s, t_s)
|
92 |
+
u = u - u[:, :1] + periodic_disp(x)
|
93 |
+
u = u.reshape(Configs.nn, Configs.nn, -1)
|
94 |
+
s = s.reshape(Configs.nn, Configs.nn, -1)
|
95 |
+
return u, s
|
96 |
+
|
97 |
+
return forward
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dash
|
2 |
+
dash_bootstrap_components
|
3 |
+
jax==0.4.20
|
4 |
+
jaxlib==0.4.20
|
5 |
+
numpy==1.24.3
|
6 |
+
plotly==5.9.0
|
7 |
+
scipy==1.10.1
|
8 |
+
tensorflow==2.14.0
|
9 |
+
Flask
|
10 |
+
gunicorn
|