Hamidreza Eivazi commited on
Commit
7badbdd
·
1 Parent(s): b3e496c
Files changed (4) hide show
  1. Dockerfile +18 -0
  2. app.py +101 -0
  3. predict.py +97 -0
  4. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime
2
+ FROM python:3.9
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy all files to the container
8
+ COPY . /app
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Expose the port Dash runs on
14
+ EXPOSE 7860
15
+
16
+ # Command to run the app
17
+ CMD ["python", "app.py"]
18
+
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from dash import Dash, dcc, html, Input, Output
5
+ import dash_bootstrap_components as dbc
6
+ from predict import load_model, Configs
7
+ import jax.numpy as jnp
8
+
9
+ # Initialize Dash with a dark theme
10
+ app = Dash(__name__, external_stylesheets=[dbc.themes.DARKLY])
11
+
12
+ nn = Configs.nn
13
+ x = np.linspace(-0.5, 0.5, nn, endpoint=True)
14
+ y = np.linspace(-0.5, 0.5, nn, endpoint=True)
15
+
16
+ forward1 = load_model('RVE1')
17
+ forward2 = load_model('RVE2')
18
+ forward3 = load_model('RVE3')
19
+
20
+ # Nord color palette
21
+ nord_bg = "#2E3440" # Background
22
+ nord_fg = "#D8DEE9" # Text color
23
+ nord_grid = "#4C566A" # Grid color
24
+ colormap = 'RdBu' # Better for dark themes
25
+
26
+ # PDE Solver
27
+ def solve_pde(boundary_conditions, rve):
28
+ x = jnp.array(boundary_conditions)
29
+ return forward1(x) if rve == 'RVE1' else forward2(x) if rve == 'RVE2' else forward3(x)
30
+
31
+ initial_bc = [0.02, 0.02, 0.02]
32
+ d, s = solve_pde(initial_bc, 'RVE1')
33
+
34
+ # Function to create Heatmap figures with a dark theme
35
+ def create_figure(z_data, title):
36
+ fig = go.Figure(go.Contour(x=x, y=y, z=z_data, colorscale=colormap, showscale=False, line_smoothing=1.2, line_width=0))
37
+ fig.update_layout(
38
+ title=title, title_font=dict(color=nord_fg, size=18),
39
+ paper_bgcolor=nord_bg, plot_bgcolor=nord_bg, font=dict(color=nord_fg),
40
+ xaxis=dict(gridcolor=nord_grid, tickvals=[], scaleanchor="y", showgrid=False, zeroline=False, showline=False),
41
+ yaxis=dict(gridcolor=nord_grid, tickvals=[], showgrid=False, zeroline=False, showline=False),
42
+ autosize=False, width=300, height=300, margin=dict(l=0, r=0, t=30, b=0)
43
+ )
44
+ return fig
45
+
46
+ titles = ["Disp. x", "Disp. y", "Stress xx", "Stress yy", "Stress xy"]
47
+ pde_figs = [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
48
+ create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
49
+ create_figure(s[:,:,2], titles[4])]
50
+
51
+ # Layout
52
+ app.layout = dbc.Container([
53
+ html.H1("RVE simulation using Equilibrium Neural Operator (EquiNO)", style={'textAlign': 'center', 'color': nord_fg, 'padding': '100px'}),
54
+
55
+ dbc.Row([
56
+ dbc.Col([
57
+ html.H3("Prescribed Global Strains", style={'textAlign': 'center', 'color': nord_fg, 'padding': '10px'}),
58
+ html.Label("Strain xx", style={'font-size': '18px', 'color': nord_fg}),
59
+ dcc.Slider(id='bc1-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[0], marks=None,
60
+ tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
61
+ className='slider-no-border'),
62
+
63
+ html.Label("Strain yy", style={'font-size': '18px', 'color': nord_fg}),
64
+ dcc.Slider(id='bc2-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[1], marks=None,
65
+ tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
66
+ className='slider-no-border'),
67
+
68
+ html.Label("Strain xy", style={'font-size': '18px', 'color': nord_fg}),
69
+ dcc.Slider(id='bc3-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[2], marks=None,
70
+ tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
71
+ className='slider-no-border'),
72
+
73
+ html.Label("RVE Selector", style={'font-size': '18px', 'color': nord_fg}),
74
+ dcc.Dropdown(id='rve-selector', options=[
75
+ {'label': 'RVE1', 'value': 'RVE1'}, {'label': 'RVE2', 'value': 'RVE2'}, {'label': 'RVE3', 'value': 'RVE3'}
76
+ ], value='RVE1', clearable=False, style={'backgroundColor': 'white', 'color': nord_bg}),
77
+ ], width=3, style={'padding': '100px', 'backgroundColor': nord_bg}),
78
+
79
+ dbc.Col([
80
+ dbc.Row([dcc.Graph(id=f'pde-figure{i+1}', figure=pde_figs[i], style={'width': '19%', 'display': 'inline-block'}) for i in range(5)]),
81
+ ], width=9, style={'padding': '100px', 'backgroundColor': nord_bg})
82
+ ])
83
+ ], fluid=True, style={'backgroundColor': nord_bg, 'minHeight': '200vh'})
84
+
85
+ # Callback
86
+ @app.callback(
87
+ [Output(f'pde-figure{i+1}', 'figure') for i in range(5)],
88
+ [Input('bc1-slider', 'value'), Input('bc2-slider', 'value'), Input('bc3-slider', 'value'), Input('rve-selector', 'value')]
89
+ )
90
+ def update_pde(bc1, bc2, bc3, selected_rve):
91
+ new_bc = [bc1, bc2, bc3]
92
+ d, s = solve_pde(new_bc, selected_rve)
93
+ return [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
94
+ create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
95
+ create_figure(s[:,:,2], titles[4])]
96
+
97
+ # Run the app
98
+ server = app.server
99
+ if __name__ == "__main__":
100
+ app.run_server(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
101
+
predict.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from scipy.interpolate import LinearNDInterpolator
6
+
7
+ class Configs:
8
+ nn = 101
9
+
10
+ def load_model(rve):
11
+ mdir = f'./trained_models/EquiNO_{rve}'
12
+ model = tf.keras.models.load_model(mdir)
13
+ nodes = np.array(model.kinema.get_config()['nodeCoord'])
14
+ scaling_bu = model.get_config()['scaling_bu']
15
+ scaling_bs = model.get_config()['scaling_bs']
16
+ scale_input = 0.04
17
+
18
+ shp = np.array(model.kinema.get_config()['shp'])
19
+ elemInc = np.array(model.kinema.get_config()['elemInc'])
20
+
21
+ # Extract model outputs and reshape arrays
22
+ t, _, t_s = model.trunk(nodes)
23
+ t = t.numpy()
24
+ t_s = t_s.numpy()
25
+ t_s = t_s.reshape((-1, 9, 16, 3))
26
+ t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0]))
27
+
28
+ t_sg = np.zeros((nodes.shape[0], 16, 3))
29
+ t_sg[elemInc[:, 2:] - 1] = np.array(t_s)
30
+ t_s = t_sg
31
+
32
+ x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
33
+ y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
34
+ xx, yy = np.meshgrid(x, y)
35
+
36
+ def interp(d, s):
37
+ d = d.reshape((-1, 16*2))
38
+ s = s.reshape((-1, 16*3))
39
+ d = LinearNDInterpolator(nodes, d)(xx, yy)
40
+ s = LinearNDInterpolator(nodes, s)(xx, yy)
41
+ d = d.reshape((-1, 16, 2))
42
+ s = s.reshape((-1, 16, 3))
43
+ return d, s
44
+
45
+ t, t_s = interp(t, t_s)
46
+
47
+ nodes = np.stack((xx.flatten(), yy.flatten()), 1)
48
+
49
+ weights = model.branch.get_weights()
50
+
51
+ weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)]
52
+ params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u]
53
+
54
+ weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)]
55
+ params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s]
56
+
57
+ cd = jnp.array([[2.0, 0.0], [0.0, 2.0]])
58
+
59
+ del model
60
+
61
+ @jax.jit
62
+ def periodic_disp(x):
63
+ matrix = nodes[..., None] * cd[None, ...]
64
+ matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1)
65
+ return jnp.einsum('ij,ljm->ilm', x, matrix)
66
+
67
+ @jax.jit
68
+ def jax_branch_s(x):
69
+ x_n = x
70
+ for (w, b) in params_s[:-1]:
71
+ x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
72
+ final_w, final_b = params_s[-1]
73
+ return jnp.dot(x_n, final_w) + final_b
74
+
75
+ @jax.jit
76
+ def jax_branch_u(x):
77
+ x_n = x
78
+ for (w, b) in params_u[:-1]:
79
+ x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
80
+ final_w, final_b = params_u[-1]
81
+ return jnp.dot(x_n, final_w) + final_b
82
+
83
+ @jax.jit
84
+ def forward(x):
85
+ x = x.reshape((1, 3))
86
+ b = jax_branch_u(x / scale_input)
87
+ b_s = jax_branch_s(x / scale_input)
88
+ b = b * scaling_bu[1] + scaling_bu[0]
89
+ b_s = b_s * scaling_bs[1] + scaling_bs[0]
90
+ u = jnp.einsum('im,lmn->iln', b, t)
91
+ s = jnp.einsum('im,lmn->iln', b_s, t_s)
92
+ u = u - u[:, :1] + periodic_disp(x)
93
+ u = u.reshape(Configs.nn, Configs.nn, -1)
94
+ s = s.reshape(Configs.nn, Configs.nn, -1)
95
+ return u, s
96
+
97
+ return forward
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ dash
2
+ dash_bootstrap_components
3
+ jax==0.4.20
4
+ jaxlib==0.4.20
5
+ numpy==1.24.3
6
+ plotly==5.9.0
7
+ scipy==1.10.1
8
+ tensorflow==2.14.0
9
+ Flask
10
+ gunicorn