|
import numpy as np |
|
import tensorflow as tf |
|
import jax |
|
import jax.numpy as jnp |
|
from scipy.interpolate import LinearNDInterpolator |
|
|
|
class Configs: |
|
nn = 101 |
|
|
|
def load_model(rve): |
|
mdir = f'./trained_models/EquiNO_{rve}' |
|
model = tf.keras.models.load_model(mdir) |
|
nodes = np.array(model.kinema.get_config()['nodeCoord']) |
|
scaling_bu = model.get_config()['scaling_bu'] |
|
scaling_bs = model.get_config()['scaling_bs'] |
|
scale_input = 0.04 |
|
|
|
shp = np.array(model.kinema.get_config()['shp']) |
|
elemInc = np.array(model.kinema.get_config()['elemInc']) |
|
|
|
|
|
t, _, t_s = model.trunk(nodes) |
|
t = t.numpy() |
|
t_s = t_s.numpy() |
|
t_s = t_s.reshape((-1, 9, 16, 3)) |
|
t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0])) |
|
|
|
t_sg = np.zeros((nodes.shape[0], 16, 3)) |
|
t_sg[elemInc[:, 2:] - 1] = np.array(t_s) |
|
t_s = t_sg |
|
|
|
x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True) |
|
y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True) |
|
xx, yy = np.meshgrid(x, y) |
|
|
|
def interp(d, s): |
|
d = d.reshape((-1, 16*2)) |
|
s = s.reshape((-1, 16*3)) |
|
d = LinearNDInterpolator(nodes, d)(xx, yy) |
|
s = LinearNDInterpolator(nodes, s)(xx, yy) |
|
d = d.reshape((-1, 16, 2)) |
|
s = s.reshape((-1, 16, 3)) |
|
return d, s |
|
|
|
t, t_s = interp(t, t_s) |
|
|
|
nodes = np.stack((xx.flatten(), yy.flatten()), 1) |
|
|
|
weights = model.branch.get_weights() |
|
|
|
weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)] |
|
params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u] |
|
|
|
weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)] |
|
params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s] |
|
|
|
cd = jnp.array([[2.0, 0.0], [0.0, 2.0]]) |
|
|
|
del model |
|
|
|
@jax.jit |
|
def periodic_disp(x): |
|
matrix = nodes[..., None] * cd[None, ...] |
|
matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1) |
|
return jnp.einsum('ij,ljm->ilm', x, matrix) |
|
|
|
@jax.jit |
|
def jax_branch_s(x): |
|
x_n = x |
|
for (w, b) in params_s[:-1]: |
|
x_n = jax.nn.swish(jnp.dot(x_n, w) + b) |
|
final_w, final_b = params_s[-1] |
|
return jnp.dot(x_n, final_w) + final_b |
|
|
|
@jax.jit |
|
def jax_branch_u(x): |
|
x_n = x |
|
for (w, b) in params_u[:-1]: |
|
x_n = jax.nn.swish(jnp.dot(x_n, w) + b) |
|
final_w, final_b = params_u[-1] |
|
return jnp.dot(x_n, final_w) + final_b |
|
|
|
@jax.jit |
|
def forward(x): |
|
x = x.reshape((1, 3)) |
|
b = jax_branch_u(x / scale_input) |
|
b_s = jax_branch_s(x / scale_input) |
|
b = b * scaling_bu[1] + scaling_bu[0] |
|
b_s = b_s * scaling_bs[1] + scaling_bs[0] |
|
u = jnp.einsum('im,lmn->iln', b, t) |
|
s = jnp.einsum('im,lmn->iln', b_s, t_s) |
|
u = u - u[:, :1] + periodic_disp(x) |
|
u = u.reshape(Configs.nn, Configs.nn, -1) |
|
s = s.reshape(Configs.nn, Configs.nn, -1) |
|
return u, s |
|
|
|
return forward |