RVEsim / app.py
Hamidreza Eivazi
update fig size
169245e
import os
import numpy as np
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output
import dash_bootstrap_components as dbc
from predict import load_model, Configs
import jax.numpy as jnp
# Initialize Dash with a dark theme
app = Dash(__name__, external_stylesheets=[dbc.themes.DARKLY])
nn = Configs.nn
x = np.linspace(-0.5, 0.5, nn, endpoint=True)
y = np.linspace(-0.5, 0.5, nn, endpoint=True)
forward1 = load_model('RVE1')
forward2 = load_model('RVE2')
forward3 = load_model('RVE3')
# Nord color palette
nord_bg = "#2E3440" # Background
nord_fg = "#D8DEE9" # Text color
nord_grid = "#4C566A" # Grid color
colormap = 'RdBu' # Better for dark themes
# PDE Solver
def solve_pde(boundary_conditions, rve):
x = jnp.array(boundary_conditions)
return forward1(x) if rve == 'RVE1' else forward2(x) if rve == 'RVE2' else forward3(x)
initial_bc = [0.02, 0.02, 0.02]
d, s = solve_pde(initial_bc, 'RVE1')
# Function to create Heatmap figures with a dark theme
def create_figure(z_data, title):
fig = go.Figure(go.Contour(x=x, y=y, z=z_data, colorscale=colormap, showscale=False, line_smoothing=1.2, line_width=0))
fig.update_layout(
title=title, title_font=dict(color=nord_fg, size=18),
paper_bgcolor=nord_bg, plot_bgcolor=nord_bg, font=dict(color=nord_fg),
xaxis=dict(gridcolor=nord_grid, tickvals=[], scaleanchor="y", showgrid=False, zeroline=False, showline=False),
yaxis=dict(gridcolor=nord_grid, tickvals=[], showgrid=False, zeroline=False, showline=False),
autosize=False, width=250, height=250, margin=dict(l=0, r=0, t=30, b=0)
)
return fig
titles = ["Disp. x", "Disp. y", "Stress xx", "Stress yy", "Stress xy"]
pde_figs = [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
create_figure(s[:,:,2], titles[4])]
# Layout
app.layout = dbc.Container([
html.H1("RVE simulation using Equilibrium Neural Operator (EquiNO)", style={'textAlign': 'center', 'color': nord_fg, 'padding': '100px'}),
dbc.Row([
dbc.Col([
html.H3("Prescribed Global Strains", style={'textAlign': 'center', 'color': nord_fg, 'padding': '10px'}),
html.Label("Strain xx", style={'font-size': '18px', 'color': nord_fg}),
dcc.Slider(id='bc1-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[0], marks=None,
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
className='slider-no-border'),
html.Label("Strain yy", style={'font-size': '18px', 'color': nord_fg}),
dcc.Slider(id='bc2-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[1], marks=None,
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
className='slider-no-border'),
html.Label("Strain xy", style={'font-size': '18px', 'color': nord_fg}),
dcc.Slider(id='bc3-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[2], marks=None,
tooltip={"placement": "top", "always_visible": True}, updatemode='drag',
className='slider-no-border'),
html.Label("RVE Selector", style={'font-size': '18px', 'color': nord_fg}),
dcc.Dropdown(id='rve-selector', options=[
{'label': 'RVE1', 'value': 'RVE1'}, {'label': 'RVE2', 'value': 'RVE2'}, {'label': 'RVE3', 'value': 'RVE3'}
], value='RVE1', clearable=False, style={'backgroundColor': 'white', 'color': nord_bg}),
], width=3, style={'padding': '100px', 'backgroundColor': nord_bg}),
dbc.Col([
dbc.Row([dcc.Graph(id=f'pde-figure{i+1}', figure=pde_figs[i], style={'width': '19%', 'display': 'inline-block'}) for i in range(5)]),
], width=9, style={'padding': '100px', 'backgroundColor': nord_bg})
])
], fluid=True, style={'backgroundColor': nord_bg, 'minHeight': '200vh'})
# Callback
@app.callback(
[Output(f'pde-figure{i+1}', 'figure') for i in range(5)],
[Input('bc1-slider', 'value'), Input('bc2-slider', 'value'), Input('bc3-slider', 'value'), Input('rve-selector', 'value')]
)
def update_pde(bc1, bc2, bc3, selected_rve):
new_bc = [bc1, bc2, bc3]
d, s = solve_pde(new_bc, selected_rve)
return [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]),
create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]),
create_figure(s[:,:,2], titles[4])]
# Run the app
server = app.server
if __name__ == "__main__":
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))