johannesschmude's picture
Initial commit
b73936d
"""
Perceiver code is based on Aurora: https://github.com/microsoft/aurora/blob/main/aurora/model/perceiver.py
Some conventions for notation:
B - Batch
T - Time
H - Height (pixel space)
W - Width (pixel space)
HT - Height (token space)
WT - Width (token space)
ST - Sequence (token space)
C - Input channels
D - Model (embedding) dimension
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
class PatchEmbed3D(nn.Module):
"""Timeseries Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, time_dim=2
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.time_dim = time_dim
self.proj = nn.Conv2d(
in_chans * time_dim,
embed_dim,
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
)
def forward(self, x):
"""
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
Tensor of shape (B, ST, D)
"""
B, C, T, H, W = x.shape
x = self.proj(x.flatten(1, 2)) # (B, C, T, H, W) -> (B, D, HT, WT)
x = rearrange(x, "B D HT WT -> B (HT WT) D") # (B, N, D)
return x
class LinearEmbedding(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
time_dim=2,
embed_dim=768,
drop_rate=0.0,
):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbed3D(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
time_dim=time_dim,
)
self._generate_position_encoding(img_size, patch_size, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
def _generate_position_encoding(self, img_size, patch_size, embed_dim):
"""
Generates a positional encoding signal for the model. The generated
positional encoding signal is stored as a buffer (`self.fourier_signal`).
Args:
img_size (int): The size of the input image.
patch_size (int): The size of each patch in the image.
embed_dim (int): The embedding dimension of the model.
Returns:
None.
"""
# Generate signal of shape (C, H, W)
x = torch.linspace(0.0, 1.0, img_size // patch_size)
y = torch.linspace(0.0, 1.0, img_size // patch_size)
x, y = torch.meshgrid(x, y, indexing="xy")
fourier_signal = []
frequencies = torch.linspace(1, (img_size // patch_size) / 2.0, embed_dim // 4)
for f in frequencies:
fourier_signal.extend(
[
torch.cos(2.0 * torch.pi * f * x),
torch.sin(2.0 * torch.pi * f * x),
torch.cos(2.0 * torch.pi * f * y),
torch.sin(2.0 * torch.pi * f * y),
]
)
fourier_signal = torch.stack(fourier_signal, dim=2)
fourier_signal = rearrange(fourier_signal, "h w c -> 1 (h w) c")
self.register_buffer("pos_embed", fourier_signal)
def forward(self, x, dt):
"""
Args:
x: Tensor of shape (B, C, T, H, W).
dt: Tensor of shape (B, T). However it is not used.
Returns:
Tensor of shape (B, ST, D)
"""
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
return x
class LinearDecoder(nn.Module):
def __init__(
self,
patch_size: int,
out_chans: int,
embed_dim: int,
):
"""
Args:
patch_size: patch size
in_chans: number of iput channels
embed_dim: embedding dimension
"""
super().__init__()
self.unembed = nn.Sequential(
nn.Conv2d(
in_channels=embed_dim,
out_channels=(patch_size**2) * out_chans,
kernel_size=1,
),
nn.PixelShuffle(patch_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (B, L, D). For ensembles, we have implicitly B = (B E).
Returns:
Tensor of shape (B C H W).
Here
- C equals num_queries
- H == W == sqrt(L) x patch_size
"""
# Reshape the tokens to 2d token space: (B, C, H_token, W_token)
_, L, _ = x.shape
H_token = W_token = int(L**0.5)
x = rearrange(x, "B (H W) D -> B D H W", H=H_token, W=W_token)
# Unembed the tokens. Convolution + pixel shuffle.
x = self.unembed(x)
return x
class MLP(nn.Module):
"""A simple one-hidden-layer MLP."""
def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None:
"""Initialise.
Args:
dim (int): Input dimensionality.
hidden_features (int): Width of the hidden layer.
dropout (float, optional): Drop-out rate. Defaults to no drop-out.
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_features),
nn.GELU(),
nn.Linear(hidden_features, dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the MLP."""
return self.net(x)
class PerceiverAttention(nn.Module):
"""Cross attention module from the Perceiver architecture."""
def __init__(
self,
latent_dim: int,
context_dim: int,
head_dim: int = 64,
num_heads: int = 8,
) -> None:
"""Initialise.
Args:
latent_dim (int): Dimensionality of the latent features given as input.
context_dim (int): Dimensionality of the context features also given as input.
head_dim (int): Attention head dimensionality.
num_heads (int): Number of heads.
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = head_dim * num_heads
self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Run the cross-attention module.
Args:
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, Latent_D)`
where typically `L1 < L2` and `Latent_D <= Context_D`. `Latent_D` is equal to
`self.latent_dim`.
x (:class:`torch.Tensor`): Context features of shape `(B, L2, Context_D)`.
Returns:
:class:`torch.Tensor`: Latent values of shape `(B, L1, Latent_D)`.
"""
h = self.num_heads
q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v)
out = rearrange(out, "B H L1 D -> B L1 (H D)") # (B, L1, D)
return self.to_out(out) # (B, L1, Latent_D)
class PerceiverResampler(nn.Module):
"""Perceiver Resampler module from the Flamingo paper."""
def __init__(
self,
latent_dim: int,
context_dim: int,
depth: int = 1,
head_dim: int = 64,
num_heads: int = 16,
mlp_ratio: float = 4.0,
drop: float = 0.0,
residual_latent: bool = True,
ln_eps: float = 1e-5,
) -> None:
"""Initialise.
Args:
latent_dim (int): Dimensionality of the latent features given as input.
context_dim (int): Dimensionality of the context features also given as input.
depth (int, optional): Number of attention layers.
head_dim (int, optional): Attention head dimensionality. Defaults to `64`.
num_heads (int, optional): Number of heads. Defaults to `16`
mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the
input for all MLPs. Defaults to `4.0`.
drop (float, optional): Drop-out rate. Defaults to no drop-out.
residual_latent (bool, optional): Use residual attention w.r.t. the latent features.
Defaults to `True`.
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
`1e-5`.
"""
super().__init__()
self.residual_latent = residual_latent
self.layers = nn.ModuleList([])
mlp_hidden_dim = int(latent_dim * mlp_ratio)
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(
latent_dim=latent_dim,
context_dim=context_dim,
head_dim=head_dim,
num_heads=num_heads,
),
MLP(
dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop
),
nn.LayerNorm(latent_dim, eps=ln_eps),
nn.LayerNorm(latent_dim, eps=ln_eps),
]
)
)
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Run the module.
Args:
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, D1)`.
x (:class:`torch.Tensor`): Context features of shape `(B, L2, D1)`.
Returns:
torch.Tensor: Latent features of shape `(B, L1, D1)`.
"""
for attn, ff, ln1, ln2 in self.layers:
# We use post-res-norm like in Swin v2 and most Transformer architectures these days.
# This empirically works better than the pre-norm used in the original Perceiver.
attn_out = ln1(attn(latents, x))
# HuggingFace suggests using non-residual attention in Perceiver might work better when
# the semantics of the query and the output are different:
#
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398
#
latents = attn_out + latents if self.residual_latent else attn_out
latents = ln2(ff(latents)) + latents
return latents
class PerceiverChannelEmbedding(nn.Module):
def __init__(
self,
in_chans: int,
img_size: int,
patch_size: int,
time_dim: int,
num_queries: int,
embed_dim: int,
drop_rate: float,
):
super().__init__()
if embed_dim % 2 != 0:
raise ValueError(
f"Temporal embeddings require `embed_dim` to be even. Currently we have {embed_dim}."
)
self.num_patches = (img_size // patch_size) ** 2
self.num_queries = num_queries
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_channels=in_chans * time_dim,
out_channels=in_chans * embed_dim,
kernel_size=patch_size,
stride=patch_size,
groups=in_chans,
)
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_patches))
trunc_normal_(self.pos_embed, std=0.02)
self.latent_queries = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
trunc_normal_(self.latent_queries, std=0.02)
self.perceiver = PerceiverResampler(
latent_dim=embed_dim,
context_dim=embed_dim,
depth=1,
head_dim=embed_dim // 16,
num_heads=16,
mlp_ratio=4.0,
drop=0.0,
residual_latent=False,
ln_eps=1e-5,
)
self.latent_aggregation = nn.Linear(num_queries * embed_dim, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x, dt):
"""
Args:
x: Tensor of shape (B, C, T, H, W)
dt: Tensor of shape (B, T) identifying time deltas.
Returns:
Tensor of shape (B, ST, D)
"""
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> B (C T) H W")
x = self.proj(x) # B (C T) H W -> B (C D) HT WT
x = x.flatten(2, 3) # B (C D) ST
ST = x.shape[2]
assert ST == self.num_patches
x = rearrange(x, "B (C D) ST -> (B C) D ST", B=B, ST=ST, C=C, D=self.embed_dim)
x = x + self.pos_embed
x = rearrange(x, "(B C) D ST -> (B ST) C D", B=B, ST=ST, C=C, D=self.embed_dim)
# ((B ST) NQ D), ((B ST) C D) -> ((B ST) NQ D)
x = self.perceiver(self.latent_queries.expand(B * ST, -1, -1), x)
x = rearrange(
x,
"(B ST) NQ D -> B ST (NQ D)",
B=B,
ST=self.num_patches,
NQ=self.num_queries,
D=self.embed_dim,
)
x = self.latent_aggregation(x) # B ST (NQ D) -> B ST D'
assert x.shape[1] == self.num_patches
assert x.shape[2] == self.embed_dim
x = self.pos_drop(x)
return x
class PerceiverDecoder(nn.Module):
def __init__(
self,
embed_dim: int,
patch_size: int,
out_chans: int,
):
"""
Args:
embed_dim: embedding dimension
patch_size: patch size
out_chans: number of output channels. This determines the number of latent queries.
drop_rate: dropout rate
"""
super().__init__()
self.embed_dim = embed_dim
self.patch_size = patch_size
self.out_chans = out_chans
self.latent_queries = nn.Parameter(torch.zeros(1, out_chans, embed_dim))
trunc_normal_(self.latent_queries, std=0.02)
self.perceiver = PerceiverResampler(
latent_dim=embed_dim,
context_dim=embed_dim,
depth=1,
head_dim=embed_dim // 16,
num_heads=16,
mlp_ratio=4.0,
drop=0.0,
residual_latent=False,
ln_eps=1e-5,
)
self.proj = nn.Conv2d(
in_channels=out_chans * embed_dim,
out_channels=out_chans * patch_size**2,
kernel_size=1,
padding=0,
groups=out_chans,
)
self.pixel_shuffle = nn.PixelShuffle(patch_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (B, L, D) For ensembles, we have implicitly B = (B E).
Returns:
Tensor of shape (B C H W).
Here
- C equals out_chans
- H == W == sqrt(L) x patch_size
"""
B, L, D = x.shape
H_token = W_token = int(L**0.5)
x = rearrange(x, "B L D -> (B L) 1 D")
# (B L) 1 D -> (B L) C D
x = self.perceiver(self.latent_queries.expand(B * L, -1, -1), x)
x = rearrange(x, "(B H W) C D -> B (C D) H W", H=H_token, W=W_token)
# B (C D) H_token W_token -> B (C patch_size patch_size) H_token W_token
x = self.proj(x)
# B (C patch_size patch_size) H_token W_token -> B C H W
x = self.pixel_shuffle(x)
return x