Aegis-ATIS-Demo / built_transformer /slot_classifier.py
literallybannedfromcallingbob's picture
updated
9622166
import torch
import torch.nn as nn
import torch.nn.functional as F
class SlotClassifier(nn.Module):
def __init__(
self,
input_dim: int,
num_slots: int,
hidden_dim: int = 256,
dropout: float = 0.1,
num_layers: int = 2
):
"""
Initialize the slot classifier.
input_dim: Dimension of the input features (usually dimension_of_model or d_model from transformer)
num_slots: Number of different slot types to classify
hidden_dim: Dimension of hidden layers in the MLP
dropout: Dropout probability for regularization
num_layers: Number of hidden layers in the MLP
"""
super().__init__()
# Build MLP layers
layers = []
prev_dim = input_dim
# Add hidden layers
for _ in range(num_layers - 1):
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
])
prev_dim = hidden_dim
# Add final classification layer
layers.append(nn.Linear(prev_dim, num_slots))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the slot classifier.
x: Input tensor of shape [batch_size, input_dim]
Usually the [CLS] token representation from the transformer
"""
logits = self.mlp(x)
return logits
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Get predictions from the classifier.
x: Input tensor of shape [batch_size, input_dim]
"""
logits = self.forward(x)
return torch.argmax(logits, dim=-1)
def get_probabilities(self, x: torch.Tensor) -> torch.Tensor:
"""
Get probability distribution over slots.
x: Input tensor of shape [batch_size, input_dim]
"""
logits = self.forward(x)
return F.softmax(logits, dim=-1)