import numpy as np import tensorflow as tf import jax import jax.numpy as jnp from scipy.interpolate import LinearNDInterpolator class Configs: nn = 101 def load_model(rve): mdir = f'./trained_models/EquiNO_{rve}' model = tf.keras.models.load_model(mdir) nodes = np.array(model.kinema.get_config()['nodeCoord']) scaling_bu = model.get_config()['scaling_bu'] scaling_bs = model.get_config()['scaling_bs'] scale_input = 0.04 shp = np.array(model.kinema.get_config()['shp']) elemInc = np.array(model.kinema.get_config()['elemInc']) # Extract model outputs and reshape arrays t, _, t_s = model.trunk(nodes) t = t.numpy() t_s = t_s.numpy() t_s = t_s.reshape((-1, 9, 16, 3)) t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0])) t_sg = np.zeros((nodes.shape[0], 16, 3)) t_sg[elemInc[:, 2:] - 1] = np.array(t_s) t_s = t_sg x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True) y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True) xx, yy = np.meshgrid(x, y) def interp(d, s): d = d.reshape((-1, 16*2)) s = s.reshape((-1, 16*3)) d = LinearNDInterpolator(nodes, d)(xx, yy) s = LinearNDInterpolator(nodes, s)(xx, yy) d = d.reshape((-1, 16, 2)) s = s.reshape((-1, 16, 3)) return d, s t, t_s = interp(t, t_s) nodes = np.stack((xx.flatten(), yy.flatten()), 1) weights = model.branch.get_weights() weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)] params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u] weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)] params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s] cd = jnp.array([[2.0, 0.0], [0.0, 2.0]]) del model @jax.jit def periodic_disp(x): matrix = nodes[..., None] * cd[None, ...] matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1) return jnp.einsum('ij,ljm->ilm', x, matrix) @jax.jit def jax_branch_s(x): x_n = x for (w, b) in params_s[:-1]: x_n = jax.nn.swish(jnp.dot(x_n, w) + b) final_w, final_b = params_s[-1] return jnp.dot(x_n, final_w) + final_b @jax.jit def jax_branch_u(x): x_n = x for (w, b) in params_u[:-1]: x_n = jax.nn.swish(jnp.dot(x_n, w) + b) final_w, final_b = params_u[-1] return jnp.dot(x_n, final_w) + final_b @jax.jit def forward(x): x = x.reshape((1, 3)) b = jax_branch_u(x / scale_input) b_s = jax_branch_s(x / scale_input) b = b * scaling_bu[1] + scaling_bu[0] b_s = b_s * scaling_bs[1] + scaling_bs[0] u = jnp.einsum('im,lmn->iln', b, t) s = jnp.einsum('im,lmn->iln', b_s, t_s) u = u - u[:, :1] + periodic_disp(x) u = u.reshape(Configs.nn, Configs.nn, -1) s = s.reshape(Configs.nn, Configs.nn, -1) return u, s return forward