File size: 2,105 Bytes
9622166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class SlotClassifier(nn.Module):
def __init__(
self,
input_dim: int,
num_slots: int,
hidden_dim: int = 256,
dropout: float = 0.1,
num_layers: int = 2
):
"""
Initialize the slot classifier.
input_dim: Dimension of the input features (usually dimension_of_model or d_model from transformer)
num_slots: Number of different slot types to classify
hidden_dim: Dimension of hidden layers in the MLP
dropout: Dropout probability for regularization
num_layers: Number of hidden layers in the MLP
"""
super().__init__()
# Build MLP layers
layers = []
prev_dim = input_dim
# Add hidden layers
for _ in range(num_layers - 1):
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
])
prev_dim = hidden_dim
# Add final classification layer
layers.append(nn.Linear(prev_dim, num_slots))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the slot classifier.
x: Input tensor of shape [batch_size, input_dim]
Usually the [CLS] token representation from the transformer
"""
logits = self.mlp(x)
return logits
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Get predictions from the classifier.
x: Input tensor of shape [batch_size, input_dim]
"""
logits = self.forward(x)
return torch.argmax(logits, dim=-1)
def get_probabilities(self, x: torch.Tensor) -> torch.Tensor:
"""
Get probability distribution over slots.
x: Input tensor of shape [batch_size, input_dim]
"""
logits = self.forward(x)
return F.softmax(logits, dim=-1) |