Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 8,337 Bytes
			
			| 149cc2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | 
from typing import Optional
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
#from flax_memory_efficient_attention import jax_memory_efficient_attention
#from flax_attention import FlaxAttention
from diffusers.models.attention_flax import FlaxAttention
class TransformerPseudo3DModel(nn.Module):
    in_channels: int
    num_attention_heads: int
    attention_head_dim: int
    num_layers: int = 1
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32
    def setup(self) -> None:
        inner_dim = self.num_attention_heads * self.attention_head_dim
        self.norm = nn.GroupNorm(
                num_groups = 32,
                epsilon = 1e-5
        )
        self.proj_in = nn.Conv(
                inner_dim,
                kernel_size = (1, 1),
                strides = (1, 1),
                padding = 'VALID',
                dtype = self.dtype
        )
        transformer_blocks = []
        #CheckpointTransformerBlock = nn.checkpoint(
        #        BasicTransformerBlockPseudo3D,
        #        static_argnums = (2,3,4)
        #        #prevent_cse = False
        #)
        CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
        for _ in range(self.num_layers):
            transformer_blocks.append(CheckpointTransformerBlock(
                        dim = inner_dim,
                        num_attention_heads = self.num_attention_heads,
                        attention_head_dim = self.attention_head_dim,
                        use_memory_efficient_attention = self.use_memory_efficient_attention,
                        dtype = self.dtype
                ))
        self.transformer_blocks = transformer_blocks
        self.proj_out = nn.Conv(
                inner_dim,
                kernel_size = (1, 1),
                strides = (1, 1),
                padding = 'VALID',
                dtype = self.dtype
        )
    def __call__(self,
            hidden_states: jax.Array,
            encoder_hidden_states: Optional[jax.Array] = None
    ) -> jax.Array:
        is_video = hidden_states.ndim == 5
        f: Optional[int] = None
        if is_video:
            # jax is channels last
            # b,c,f,h,w WRONG
            # b,f,h,w,c CORRECT
            # b, c, f, h, w = hidden_states.shape
            #hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
            b, f, h, w, c = hidden_states.shape
            hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
        batch, height, width, channels = hidden_states.shape
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        hidden_states = self.proj_in(hidden_states)
        hidden_states = hidden_states.reshape(batch, height * width, channels)
        for block in self.transformer_blocks:
            hidden_states = block(
                    hidden_states,
                    encoder_hidden_states,
                    f,
                    height,
                    width
            )
        hidden_states = hidden_states.reshape(batch, height, width, channels)
        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states + residual
        if is_video:
            hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
        return hidden_states
class BasicTransformerBlockPseudo3D(nn.Module):
    dim: int
    num_attention_heads: int
    attention_head_dim: int
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32
    def setup(self) -> None:
        self.attn1 = FlaxAttention(
                query_dim = self.dim,
                heads = self.num_attention_heads,
                dim_head = self.attention_head_dim,
                use_memory_efficient_attention = self.use_memory_efficient_attention,
                dtype = self.dtype
        )
        self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
        self.attn2 = FlaxAttention(
                query_dim = self.dim,
                heads = self.num_attention_heads,
                dim_head = self.attention_head_dim,
                use_memory_efficient_attention = self.use_memory_efficient_attention,
                dtype = self.dtype
        )
        self.attn_temporal = FlaxAttention(
                query_dim = self.dim,
                heads = self.num_attention_heads,
                dim_head = self.attention_head_dim,
                use_memory_efficient_attention = self.use_memory_efficient_attention,
                dtype = self.dtype
        )
        self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
        self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
        self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
        self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
    def __call__(self,
            hidden_states: jax.Array,
            context: Optional[jax.Array] = None,
            frames_length: Optional[int] = None,
            height: Optional[int] = None,
            width: Optional[int] = None
        ) -> jax.Array:
            if context is not None and frames_length is not None:
                context = context.repeat(frames_length, axis = 0)
            # self attention
            norm_hidden_states = self.norm1(hidden_states)
            hidden_states = self.attn1(norm_hidden_states) + hidden_states
            # cross attention
            norm_hidden_states = self.norm2(hidden_states)
            hidden_states = self.attn2(
                    norm_hidden_states,
                    context = context
            ) + hidden_states
            # temporal attention
            if frames_length is not None:
                #bf, hw, c = hidden_states.shape
                # (b f) (h w) c -> b f (h w) c
                #hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
                #b, f, hw, c = hidden_states.shape
                # b f (h w) c -> b (h w) f c
                #hidden_states = hidden_states.transpose(0, 2, 1, 3)
                # b (h w) f c -> (b h w) f c
                #hidden_states = hidden_states.reshape(b * hw, frames_length, c)
                hidden_states = einops.rearrange(
                        hidden_states,
                        '(b f) (h w) c -> (b h w) f c',
                        f = frames_length,
                        h = height,
                        w = width
                )
                norm_hidden_states = self.norm_temporal(hidden_states)
                hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
                # (b h w) f c -> b (h w) f c
                #hidden_states = hidden_states.reshape(b, hw, f, c)
                # b (h w) f c -> b f (h w) c
                #hidden_states = hidden_states.transpose(0, 2, 1, 3)
                # b f h w c -> (b f) (h w) c
                #hidden_states = hidden_states.reshape(bf, hw, c)
                hidden_states = einops.rearrange(
                        hidden_states,
                        '(b h w) f c -> (b f) (h w) c',
                        f = frames_length,
                        h = height,
                        w = width
                )
            norm_hidden_states = self.norm3(hidden_states)
            hidden_states = self.ff(norm_hidden_states) + hidden_states
            return hidden_states
class FeedForward(nn.Module):
    dim: int
    dtype: jnp.dtype = jnp.float32
    def setup(self) -> None:
        self.net_0 = GEGLU(self.dim, self.dtype)
        self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
    def __call__(self, hidden_states: jax.Array) -> jax.Array:
        hidden_states = self.net_0(hidden_states)
        hidden_states = self.net_2(hidden_states)
        return hidden_states
class GEGLU(nn.Module):
    dim: int
    dtype: jnp.dtype = jnp.float32
    def setup(self) -> None:
        inner_dim = self.dim * 4
        self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
    def __call__(self, hidden_states: jax.Array) -> jax.Array:
        hidden_states = self.proj(hidden_states)
        hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
        return hidden_linear * nn.gelu(hidden_gelu)
 | 
