RVEsim / predict.py
Hamidreza Eivazi
add app
7badbdd
import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
from scipy.interpolate import LinearNDInterpolator
class Configs:
nn = 101
def load_model(rve):
mdir = f'./trained_models/EquiNO_{rve}'
model = tf.keras.models.load_model(mdir)
nodes = np.array(model.kinema.get_config()['nodeCoord'])
scaling_bu = model.get_config()['scaling_bu']
scaling_bs = model.get_config()['scaling_bs']
scale_input = 0.04
shp = np.array(model.kinema.get_config()['shp'])
elemInc = np.array(model.kinema.get_config()['elemInc'])
# Extract model outputs and reshape arrays
t, _, t_s = model.trunk(nodes)
t = t.numpy()
t_s = t_s.numpy()
t_s = t_s.reshape((-1, 9, 16, 3))
t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0]))
t_sg = np.zeros((nodes.shape[0], 16, 3))
t_sg[elemInc[:, 2:] - 1] = np.array(t_s)
t_s = t_sg
x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
xx, yy = np.meshgrid(x, y)
def interp(d, s):
d = d.reshape((-1, 16*2))
s = s.reshape((-1, 16*3))
d = LinearNDInterpolator(nodes, d)(xx, yy)
s = LinearNDInterpolator(nodes, s)(xx, yy)
d = d.reshape((-1, 16, 2))
s = s.reshape((-1, 16, 3))
return d, s
t, t_s = interp(t, t_s)
nodes = np.stack((xx.flatten(), yy.flatten()), 1)
weights = model.branch.get_weights()
weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)]
params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u]
weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)]
params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s]
cd = jnp.array([[2.0, 0.0], [0.0, 2.0]])
del model
@jax.jit
def periodic_disp(x):
matrix = nodes[..., None] * cd[None, ...]
matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1)
return jnp.einsum('ij,ljm->ilm', x, matrix)
@jax.jit
def jax_branch_s(x):
x_n = x
for (w, b) in params_s[:-1]:
x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
final_w, final_b = params_s[-1]
return jnp.dot(x_n, final_w) + final_b
@jax.jit
def jax_branch_u(x):
x_n = x
for (w, b) in params_u[:-1]:
x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
final_w, final_b = params_u[-1]
return jnp.dot(x_n, final_w) + final_b
@jax.jit
def forward(x):
x = x.reshape((1, 3))
b = jax_branch_u(x / scale_input)
b_s = jax_branch_s(x / scale_input)
b = b * scaling_bu[1] + scaling_bu[0]
b_s = b_s * scaling_bs[1] + scaling_bs[0]
u = jnp.einsum('im,lmn->iln', b, t)
s = jnp.einsum('im,lmn->iln', b_s, t_s)
u = u - u[:, :1] + periodic_disp(x)
u = u.reshape(Configs.nn, Configs.nn, -1)
s = s.reshape(Configs.nn, Configs.nn, -1)
return u, s
return forward