j1j2_square_10x10_05 / attentions.py
rrende's picture
Upload model
d30db4c verified
raw
history blame
2.08 kB
import jax
import jax.numpy as jnp
from flax import linen as nn
import jax.numpy as jnp
from einops import rearrange
def roll(J, shift, axis=-1):
return jnp.roll(J, shift, axis=axis)
from functools import partial
@partial(jax.vmap, in_axes=(None, 0, None), out_axes=1)
@partial(jax.vmap, in_axes=(None, None, 0), out_axes=1)
def roll2d(spins, i, j):
side = int(spins.shape[-1]**0.5)
spins = spins.reshape(spins.shape[0], side, side)
spins = jnp.roll(jnp.roll(spins, i, axis=-2), j, axis=-1)
return spins.reshape(spins.shape[0], -1)
class FMHA(nn.Module):
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.v = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
if self.transl_invariant:
self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff), jnp.float64)
if self.two_dimensional:
sq_L_eff = int(self.L_eff**0.5)
assert sq_L_eff * sq_L_eff == self.L_eff
self.J = roll2d(self.J, jnp.arange(sq_L_eff), jnp.arange(sq_L_eff))
self.J = self.J.reshape(self.h, -1, self.L_eff)
else:
self.J = jax.vmap(roll, (None, 0), out_axes=1)(self.J, jnp.arange(self.L_eff))
else:
self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff, self.L_eff), jnp.float64)
self.W = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
def __call__(self, x):
v = self.v(x)
v = rearrange(v, 'batch L_eff (h d_eff) -> batch L_eff h d_eff', h=self.h)
v = rearrange(v, 'batch L_eff h d_eff -> batch h L_eff d_eff')
x = jnp.matmul(self.J, v)
x = rearrange(x, 'batch h L_eff d_eff -> batch L_eff h d_eff')
x = rearrange(x, 'batch L_eff h d_eff -> batch L_eff (h d_eff)')
x = self.W(x)
return x