import os import numpy as np import plotly.graph_objects as go from dash import Dash, dcc, html, Input, Output import dash_bootstrap_components as dbc from predict import load_model, Configs import jax.numpy as jnp # Initialize Dash with a dark theme app = Dash(__name__, external_stylesheets=[dbc.themes.DARKLY]) nn = Configs.nn x = np.linspace(-0.5, 0.5, nn, endpoint=True) y = np.linspace(-0.5, 0.5, nn, endpoint=True) forward1 = load_model('RVE1') forward2 = load_model('RVE2') forward3 = load_model('RVE3') # Nord color palette nord_bg = "#2E3440" # Background nord_fg = "#D8DEE9" # Text color nord_grid = "#4C566A" # Grid color colormap = 'RdBu' # Better for dark themes # PDE Solver def solve_pde(boundary_conditions, rve): x = jnp.array(boundary_conditions) return forward1(x) if rve == 'RVE1' else forward2(x) if rve == 'RVE2' else forward3(x) initial_bc = [0.02, 0.02, 0.02] d, s = solve_pde(initial_bc, 'RVE1') # Function to create Heatmap figures with a dark theme def create_figure(z_data, title): fig = go.Figure(go.Contour(x=x, y=y, z=z_data, colorscale=colormap, showscale=False, line_smoothing=1.2, line_width=0)) fig.update_layout( title=title, title_font=dict(color=nord_fg, size=18), paper_bgcolor=nord_bg, plot_bgcolor=nord_bg, font=dict(color=nord_fg), xaxis=dict(gridcolor=nord_grid, tickvals=[], scaleanchor="y", showgrid=False, zeroline=False, showline=False), yaxis=dict(gridcolor=nord_grid, tickvals=[], showgrid=False, zeroline=False, showline=False), autosize=False, width=250, height=250, margin=dict(l=0, r=0, t=30, b=0) ) return fig titles = ["Disp. x", "Disp. y", "Stress xx", "Stress yy", "Stress xy"] pde_figs = [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]), create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]), create_figure(s[:,:,2], titles[4])] # Layout app.layout = dbc.Container([ html.H1("RVE simulation using Equilibrium Neural Operator (EquiNO)", style={'textAlign': 'center', 'color': nord_fg, 'padding': '100px'}), dbc.Row([ dbc.Col([ html.H3("Prescribed Global Strains", style={'textAlign': 'center', 'color': nord_fg, 'padding': '10px'}), html.Label("Strain xx", style={'font-size': '18px', 'color': nord_fg}), dcc.Slider(id='bc1-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[0], marks=None, tooltip={"placement": "top", "always_visible": True}, updatemode='drag', className='slider-no-border'), html.Label("Strain yy", style={'font-size': '18px', 'color': nord_fg}), dcc.Slider(id='bc2-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[1], marks=None, tooltip={"placement": "top", "always_visible": True}, updatemode='drag', className='slider-no-border'), html.Label("Strain xy", style={'font-size': '18px', 'color': nord_fg}), dcc.Slider(id='bc3-slider', min=-0.04, max=0.04, step=0.005, value=initial_bc[2], marks=None, tooltip={"placement": "top", "always_visible": True}, updatemode='drag', className='slider-no-border'), html.Label("RVE Selector", style={'font-size': '18px', 'color': nord_fg}), dcc.Dropdown(id='rve-selector', options=[ {'label': 'RVE1', 'value': 'RVE1'}, {'label': 'RVE2', 'value': 'RVE2'}, {'label': 'RVE3', 'value': 'RVE3'} ], value='RVE1', clearable=False, style={'backgroundColor': 'white', 'color': nord_bg}), ], width=3, style={'padding': '100px', 'backgroundColor': nord_bg}), dbc.Col([ dbc.Row([dcc.Graph(id=f'pde-figure{i+1}', figure=pde_figs[i], style={'width': '19%', 'display': 'inline-block'}) for i in range(5)]), ], width=9, style={'padding': '100px', 'backgroundColor': nord_bg}) ]) ], fluid=True, style={'backgroundColor': nord_bg, 'minHeight': '200vh'}) # Callback @app.callback( [Output(f'pde-figure{i+1}', 'figure') for i in range(5)], [Input('bc1-slider', 'value'), Input('bc2-slider', 'value'), Input('bc3-slider', 'value'), Input('rve-selector', 'value')] ) def update_pde(bc1, bc2, bc3, selected_rve): new_bc = [bc1, bc2, bc3] d, s = solve_pde(new_bc, selected_rve) return [create_figure(d[:,:,0], titles[0]), create_figure(d[:,:,1], titles[1]), create_figure(s[:,:,0], titles[2]), create_figure(s[:,:,1], titles[3]), create_figure(s[:,:,2], titles[4])] # Run the app server = app.server if __name__ == "__main__": app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))