Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| import einops | |
| #from flax_memory_efficient_attention import jax_memory_efficient_attention | |
| #from flax_attention import FlaxAttention | |
| from diffusers.models.attention_flax import FlaxAttention | |
| class TransformerPseudo3DModel(nn.Module): | |
| in_channels: int | |
| num_attention_heads: int | |
| attention_head_dim: int | |
| num_layers: int = 1 | |
| use_memory_efficient_attention: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| inner_dim = self.num_attention_heads * self.attention_head_dim | |
| self.norm = nn.GroupNorm( | |
| num_groups = 32, | |
| epsilon = 1e-5 | |
| ) | |
| self.proj_in = nn.Conv( | |
| inner_dim, | |
| kernel_size = (1, 1), | |
| strides = (1, 1), | |
| padding = 'VALID', | |
| dtype = self.dtype | |
| ) | |
| transformer_blocks = [] | |
| #CheckpointTransformerBlock = nn.checkpoint( | |
| # BasicTransformerBlockPseudo3D, | |
| # static_argnums = (2,3,4) | |
| # #prevent_cse = False | |
| #) | |
| CheckpointTransformerBlock = BasicTransformerBlockPseudo3D | |
| for _ in range(self.num_layers): | |
| transformer_blocks.append(CheckpointTransformerBlock( | |
| dim = inner_dim, | |
| num_attention_heads = self.num_attention_heads, | |
| attention_head_dim = self.attention_head_dim, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| )) | |
| self.transformer_blocks = transformer_blocks | |
| self.proj_out = nn.Conv( | |
| inner_dim, | |
| kernel_size = (1, 1), | |
| strides = (1, 1), | |
| padding = 'VALID', | |
| dtype = self.dtype | |
| ) | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| encoder_hidden_states: Optional[jax.Array] = None | |
| ) -> jax.Array: | |
| is_video = hidden_states.ndim == 5 | |
| f: Optional[int] = None | |
| if is_video: | |
| # jax is channels last | |
| # b,c,f,h,w WRONG | |
| # b,f,h,w,c CORRECT | |
| # b, c, f, h, w = hidden_states.shape | |
| #hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w') | |
| b, f, h, w, c = hidden_states.shape | |
| hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') | |
| batch, height, width, channels = hidden_states.shape | |
| residual = hidden_states | |
| hidden_states = self.norm(hidden_states) | |
| hidden_states = self.proj_in(hidden_states) | |
| hidden_states = hidden_states.reshape(batch, height * width, channels) | |
| for block in self.transformer_blocks: | |
| hidden_states = block( | |
| hidden_states, | |
| encoder_hidden_states, | |
| f, | |
| height, | |
| width | |
| ) | |
| hidden_states = hidden_states.reshape(batch, height, width, channels) | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states + residual | |
| if is_video: | |
| hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) | |
| return hidden_states | |
| class BasicTransformerBlockPseudo3D(nn.Module): | |
| dim: int | |
| num_attention_heads: int | |
| attention_head_dim: int | |
| use_memory_efficient_attention: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| self.attn1 = FlaxAttention( | |
| query_dim = self.dim, | |
| heads = self.num_attention_heads, | |
| dim_head = self.attention_head_dim, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| self.ff = FeedForward(dim = self.dim, dtype = self.dtype) | |
| self.attn2 = FlaxAttention( | |
| query_dim = self.dim, | |
| heads = self.num_attention_heads, | |
| dim_head = self.attention_head_dim, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| self.attn_temporal = FlaxAttention( | |
| query_dim = self.dim, | |
| heads = self.num_attention_heads, | |
| dim_head = self.attention_head_dim, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) | |
| self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) | |
| self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) | |
| self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| context: Optional[jax.Array] = None, | |
| frames_length: Optional[int] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None | |
| ) -> jax.Array: | |
| if context is not None and frames_length is not None: | |
| context = context.repeat(frames_length, axis = 0) | |
| # self attention | |
| norm_hidden_states = self.norm1(hidden_states) | |
| hidden_states = self.attn1(norm_hidden_states) + hidden_states | |
| # cross attention | |
| norm_hidden_states = self.norm2(hidden_states) | |
| hidden_states = self.attn2( | |
| norm_hidden_states, | |
| context = context | |
| ) + hidden_states | |
| # temporal attention | |
| if frames_length is not None: | |
| #bf, hw, c = hidden_states.shape | |
| # (b f) (h w) c -> b f (h w) c | |
| #hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c) | |
| #b, f, hw, c = hidden_states.shape | |
| # b f (h w) c -> b (h w) f c | |
| #hidden_states = hidden_states.transpose(0, 2, 1, 3) | |
| # b (h w) f c -> (b h w) f c | |
| #hidden_states = hidden_states.reshape(b * hw, frames_length, c) | |
| hidden_states = einops.rearrange( | |
| hidden_states, | |
| '(b f) (h w) c -> (b h w) f c', | |
| f = frames_length, | |
| h = height, | |
| w = width | |
| ) | |
| norm_hidden_states = self.norm_temporal(hidden_states) | |
| hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states | |
| # (b h w) f c -> b (h w) f c | |
| #hidden_states = hidden_states.reshape(b, hw, f, c) | |
| # b (h w) f c -> b f (h w) c | |
| #hidden_states = hidden_states.transpose(0, 2, 1, 3) | |
| # b f h w c -> (b f) (h w) c | |
| #hidden_states = hidden_states.reshape(bf, hw, c) | |
| hidden_states = einops.rearrange( | |
| hidden_states, | |
| '(b h w) f c -> (b f) (h w) c', | |
| f = frames_length, | |
| h = height, | |
| w = width | |
| ) | |
| norm_hidden_states = self.norm3(hidden_states) | |
| hidden_states = self.ff(norm_hidden_states) + hidden_states | |
| return hidden_states | |
| class FeedForward(nn.Module): | |
| dim: int | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| self.net_0 = GEGLU(self.dim, self.dtype) | |
| self.net_2 = nn.Dense(self.dim, dtype = self.dtype) | |
| def __call__(self, hidden_states: jax.Array) -> jax.Array: | |
| hidden_states = self.net_0(hidden_states) | |
| hidden_states = self.net_2(hidden_states) | |
| return hidden_states | |
| class GEGLU(nn.Module): | |
| dim: int | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| inner_dim = self.dim * 4 | |
| self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype) | |
| def __call__(self, hidden_states: jax.Array) -> jax.Array: | |
| hidden_states = self.proj(hidden_states) | |
| hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2) | |
| return hidden_linear * nn.gelu(hidden_gelu) | |