|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
direction_input_dim: int,
|
|
conditioning_input_dim: int,
|
|
latent_dim: int,
|
|
num_heads: int,
|
|
):
|
|
"""
|
|
Multi-Head Attention module.
|
|
|
|
Args:
|
|
direction_input_dim (int): The input dimension of the directional input.
|
|
conditioning_input_dim (int): The input dimension of the conditioning input.
|
|
latent_dim (int): The latent dimension of the module.
|
|
num_heads (int): The number of heads to use in the attention mechanism.
|
|
"""
|
|
super().__init__()
|
|
assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
|
|
self.num_heads = num_heads
|
|
self.head_dim = latent_dim // num_heads
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.query = nn.Linear(direction_input_dim, latent_dim)
|
|
self.key = nn.Linear(conditioning_input_dim, latent_dim)
|
|
self.value = nn.Linear(conditioning_input_dim, latent_dim)
|
|
self.fc_out = nn.Linear(latent_dim, latent_dim)
|
|
|
|
def forward(
|
|
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the Multi-Head Attention module.
|
|
|
|
Args:
|
|
query (torch.Tensor): The directional input tensor.
|
|
key (torch.Tensor): The conditioning input tensor for the keys.
|
|
value (torch.Tensor): The conditioning input tensor for the values.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of the Multi-Head Attention module.
|
|
"""
|
|
batch_size = query.size(0)
|
|
|
|
Q = (
|
|
self.query(query)
|
|
.view(batch_size, -1, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
K = (
|
|
self.key(key)
|
|
.view(batch_size, -1, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
V = (
|
|
self.value(value)
|
|
.view(batch_size, -1, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
attention = (
|
|
torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
|
|
)
|
|
attention = torch.softmax(attention, dim=-1)
|
|
|
|
out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
|
|
out = (
|
|
out.transpose(1, 2)
|
|
.contiguous()
|
|
.view(batch_size, -1, self.num_heads * self.head_dim)
|
|
)
|
|
|
|
out = self.fc_out(out).squeeze(1)
|
|
return out
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
direction_input_dim: int,
|
|
conditioning_input_dim: int,
|
|
latent_dim: int,
|
|
num_heads: int,
|
|
):
|
|
"""
|
|
Attention Layer module.
|
|
|
|
Args:
|
|
direction_input_dim (int): The input dimension of the directional input.
|
|
conditioning_input_dim (int): The input dimension of the conditioning input.
|
|
latent_dim (int): The latent dimension of the module.
|
|
num_heads (int): The number of heads to use in the attention mechanism.
|
|
"""
|
|
super().__init__()
|
|
self.mha = MultiHeadAttention(
|
|
direction_input_dim, conditioning_input_dim, latent_dim, num_heads
|
|
)
|
|
self.norm1 = nn.LayerNorm(latent_dim)
|
|
self.norm2 = nn.LayerNorm(latent_dim)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(latent_dim, latent_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(latent_dim, latent_dim),
|
|
)
|
|
|
|
def forward(
|
|
self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the Attention Layer module.
|
|
|
|
Args:
|
|
directional_input (torch.Tensor): The directional input tensor.
|
|
conditioning_input (torch.Tensor): The conditioning input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of the Attention Layer module.
|
|
"""
|
|
attn_output = self.mha(
|
|
directional_input, conditioning_input, conditioning_input
|
|
)
|
|
out1 = self.norm1(attn_output + directional_input)
|
|
fc_output = self.fc(out1)
|
|
out2 = self.norm2(fc_output + out1)
|
|
return out2
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_dim: int,
|
|
conditioning_input_dim: int,
|
|
hidden_features: int,
|
|
num_heads: int,
|
|
num_layers: int,
|
|
out_activation: Optional[nn.Module],
|
|
):
|
|
"""
|
|
Decoder module.
|
|
|
|
Args:
|
|
in_dim (int): The input dimension of the module.
|
|
conditioning_input_dim (int): The input dimension of the conditioning input.
|
|
hidden_features (int): The number of hidden features in the module.
|
|
num_heads (int): The number of heads to use in the attention mechanism.
|
|
num_layers (int): The number of layers in the module.
|
|
out_activation (nn.Module): The activation function to use on the output tensor.
|
|
"""
|
|
super().__init__()
|
|
self.residual_projection = nn.Linear(
|
|
in_dim, hidden_features
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
AttentionLayer(
|
|
hidden_features, conditioning_input_dim, hidden_features, num_heads
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
self.fc = nn.Linear(hidden_features, 3)
|
|
self.out_activation = out_activation
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, conditioning_input: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the Decoder module.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
conditioning_input (torch.Tensor): The conditioning input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of the Decoder module.
|
|
"""
|
|
x = self.residual_projection(x)
|
|
for layer in self.layers:
|
|
x = layer(x, conditioning_input)
|
|
x = self.fc(x)
|
|
if self.out_activation is not None:
|
|
x = self.out_activation(x)
|
|
return x
|
|
|