import jax import jax.numpy as jnp from flax import linen as nn import jax.numpy as jnp from einops import rearrange from .attentions import FMHA def log_cosh(x): sgn_x = -2 * jnp.signbit(x.real) + 1 x = x * sgn_x return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0) def extract_patches1d(x, b): return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b) def extract_patches2d(x, b): batch = x.shape[0] L_eff = int((x.shape[1] // b**2)**0.5) x = x.reshape(batch, L_eff, b, L_eff, b) # [L_eff, b, L_eff, b] x = x.transpose(0, 1, 3, 2, 4) # [L_eff, L_eff, b, b] # flatten the patches x = x.reshape(batch, L_eff, L_eff, -1) # [L_eff, L_eff, b*b] x = x.reshape(batch, L_eff*L_eff, -1) # [L_eff*L_eff, b*b] return x class Embed(nn.Module): d_model : int b: int two_dimensional: bool = False def setup(self): if self.two_dimensional: self.extract_patches = extract_patches2d else: self.extract_patches = extract_patches1d self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64) def __call__(self, x): x = self.extract_patches(x, self.b) x = self.embed(x) return x class EncoderBlock(nn.Module): d_model : int h: int L_eff: int transl_invariant: bool = True two_dimensional: bool = False def setup(self): self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) self.ff = nn.Sequential([ nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64), nn.gelu, nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64), ]) def __call__(self, x): x = x + self.attn(self.layer_norm_1(x)) x = x + self.ff( self.layer_norm_2(x) ) return x class Encoder(nn.Module): num_layers: int d_model : int h: int L_eff: int transl_invariant: bool = True two_dimensional: bool = False def setup(self): self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)] def __call__(self, x): for l in self.layers: x = l(x) return x class OuputHead(nn.Module): d_model : int def setup(self): self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) self.norm2 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64) self.norm3 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64) self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros) self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros) def __call__(self, x, return_z=False): z = self.out_layer_norm(x.sum(axis=1)) if return_z: return z amp = self.norm2(self.output_layer0(z)) sign = self.norm3(self.output_layer1(z)) out = amp + 1j*sign return jnp.sum(log_cosh(out), axis=-1) class ViT(nn.Module): num_layers: int d_model : int heads: int L_eff: int b: int transl_invariant: bool = True two_dimensional: bool = False def setup(self): self.patches_and_embed = Embed(self.d_model, self.b, two_dimensional=self.two_dimensional) self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) self.output = OuputHead(self.d_model) def __call__(self, spins, return_z=False): x = jnp.atleast_2d(spins) x = self.patches_and_embed(x) x = self.encoder(x) output = self.output(x, return_z=return_z) return output