|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class SlotClassifier(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
num_slots: int, |
|
hidden_dim: int = 256, |
|
dropout: float = 0.1, |
|
num_layers: int = 2 |
|
): |
|
""" |
|
Initialize the slot classifier. |
|
input_dim: Dimension of the input features (usually dimension_of_model or d_model from transformer) |
|
num_slots: Number of different slot types to classify |
|
hidden_dim: Dimension of hidden layers in the MLP |
|
dropout: Dropout probability for regularization |
|
num_layers: Number of hidden layers in the MLP |
|
""" |
|
super().__init__() |
|
|
|
|
|
layers = [] |
|
prev_dim = input_dim |
|
|
|
|
|
for _ in range(num_layers - 1): |
|
layers.extend([ |
|
nn.Linear(prev_dim, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout) |
|
]) |
|
prev_dim = hidden_dim |
|
|
|
|
|
layers.append(nn.Linear(prev_dim, num_slots)) |
|
|
|
self.mlp = nn.Sequential(*layers) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass of the slot classifier. |
|
x: Input tensor of shape [batch_size, input_dim] |
|
Usually the [CLS] token representation from the transformer |
|
""" |
|
logits = self.mlp(x) |
|
return logits |
|
|
|
def predict(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Get predictions from the classifier. |
|
x: Input tensor of shape [batch_size, input_dim] |
|
""" |
|
logits = self.forward(x) |
|
return torch.argmax(logits, dim=-1) |
|
|
|
def get_probabilities(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Get probability distribution over slots. |
|
x: Input tensor of shape [batch_size, input_dim] |
|
""" |
|
logits = self.forward(x) |
|
return F.softmax(logits, dim=-1) |