Spaces:
Runtime error
Runtime error
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| def get_sinusoidal_embeddings( | |
| timesteps: jax.Array, | |
| embedding_dim: int, | |
| freq_shift: float = 1, | |
| min_timescale: float = 1, | |
| max_timescale: float = 1.0e4, | |
| flip_sin_to_cos: bool = False, | |
| scale: float = 1.0, | |
| dtype: jnp.dtype = jnp.float32 | |
| ) -> jax.Array: | |
| assert timesteps.ndim == 1, "Timesteps should be a 1d-array" | |
| assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" | |
| num_timescales = float(embedding_dim // 2) | |
| log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift) | |
| inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment) | |
| emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) | |
| # scale embeddings | |
| scaled_time = scale * emb | |
| if flip_sin_to_cos: | |
| signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1) | |
| else: | |
| signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1) | |
| signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) | |
| return signal | |
| class TimestepEmbedding(nn.Module): | |
| time_embed_dim: int = 32 | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__(self, temb: jax.Array) -> jax.Array: | |
| temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb) | |
| temb = nn.silu(temb) | |
| temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb) | |
| return temb | |
| class Timesteps(nn.Module): | |
| dim: int = 32 | |
| flip_sin_to_cos: bool = False | |
| freq_shift: float = 1 | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__(self, timesteps: jax.Array) -> jax.Array: | |
| return get_sinusoidal_embeddings( | |
| timesteps = timesteps, | |
| embedding_dim = self.dim, | |
| flip_sin_to_cos = self.flip_sin_to_cos, | |
| freq_shift = self.freq_shift, | |
| dtype = self.dtype | |
| ) | |