File size: 3,069 Bytes
7badbdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
from scipy.interpolate import LinearNDInterpolator

class Configs:
    nn = 101

def load_model(rve):
    mdir = f'./trained_models/EquiNO_{rve}'
    model = tf.keras.models.load_model(mdir)
    nodes = np.array(model.kinema.get_config()['nodeCoord'])
    scaling_bu = model.get_config()['scaling_bu']
    scaling_bs = model.get_config()['scaling_bs']
    scale_input = 0.04
    
    shp = np.array(model.kinema.get_config()['shp'])
    elemInc = np.array(model.kinema.get_config()['elemInc'])
    
    # Extract model outputs and reshape arrays
    t, _, t_s = model.trunk(nodes)
    t = t.numpy()
    t_s = t_s.numpy()
    t_s = t_s.reshape((-1, 9, 16, 3))
    t_s = jnp.einsum('nlrk,lj->njrk', t_s, jnp.linalg.pinv(shp[:, 0]))
    
    t_sg = np.zeros((nodes.shape[0], 16, 3))
    t_sg[elemInc[:, 2:] - 1] = np.array(t_s)
    t_s = t_sg
    
    x = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
    y = np.linspace(-0.5, 0.5, Configs.nn, endpoint=True)
    xx, yy = np.meshgrid(x, y)
    
    def interp(d, s):
        d = d.reshape((-1, 16*2))
        s = s.reshape((-1, 16*3))
        d = LinearNDInterpolator(nodes, d)(xx, yy)
        s = LinearNDInterpolator(nodes, s)(xx, yy)
        d = d.reshape((-1, 16, 2))
        s = s.reshape((-1, 16, 3))
        return d, s
    
    t, t_s = interp(t, t_s)
    
    nodes = np.stack((xx.flatten(), yy.flatten()), 1)
    
    weights = model.branch.get_weights()

    weights_u = [weights[i:i+2] for i in range(0, len(weights), 4)]
    params_u = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_u]
    
    weights_s = [weights[i:i+2] for i in range(2, len(weights), 4)]
    params_s = [(jnp.array(w[0]), jnp.array(w[1])) for w in weights_s]
    
    cd = jnp.array([[2.0, 0.0], [0.0, 2.0]])

    del model
    
    @jax.jit
    def periodic_disp(x):
        matrix = nodes[..., None] * cd[None, ...]
        matrix = 0.5 * jnp.concatenate([matrix, jnp.flip(nodes, 1)[:, None, :]], 1)
        return jnp.einsum('ij,ljm->ilm', x, matrix)

    @jax.jit
    def jax_branch_s(x):
        x_n = x
        for (w, b) in params_s[:-1]:
            x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
        final_w, final_b = params_s[-1]
        return jnp.dot(x_n, final_w) + final_b
    
    @jax.jit
    def jax_branch_u(x):
        x_n = x
        for (w, b) in params_u[:-1]:
            x_n = jax.nn.swish(jnp.dot(x_n, w) + b)
        final_w, final_b = params_u[-1]
        return jnp.dot(x_n, final_w) + final_b
    
    @jax.jit
    def forward(x):
        x = x.reshape((1, 3))
        b = jax_branch_u(x / scale_input)
        b_s = jax_branch_s(x / scale_input)
        b = b * scaling_bu[1] + scaling_bu[0]
        b_s = b_s * scaling_bs[1] + scaling_bs[0]
        u = jnp.einsum('im,lmn->iln', b, t)
        s = jnp.einsum('im,lmn->iln', b_s, t_s)
        u = u - u[:, :1] + periodic_disp(x)
        u = u.reshape(Configs.nn, Configs.nn, -1)
        s = s.reshape(Configs.nn, Configs.nn, -1)
        return u, s
    
    return forward