Spaces:
Sleeping
Sleeping
File size: 787 Bytes
03ce4c5 c0953f1 03ce4c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch.nn as nn
# Updated BiLSTM to handle variable layers
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.1):
super(BiLSTM, self).__init__()
self.bilstm = nn.LSTM(
input_size,
hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0 # Dropout only applies for num_layers > 1
)
self.fc = nn.Linear(hidden_size * 2, output_size) # Multiply hidden_size by 2 for bidirectional
def forward(self, x):
bilstm_output, _ = self.bilstm(x)
output = self.fc(bilstm_output[:, -1, :]) # Use the last time step
return output |