import torch import torch.nn as nn from collections import Counter class BeastTokenizer: def __init__(self, texts=[], vocab_size=5000): self.word2idx = {'': 0, '': 1} if texts: counter = Counter(word for text in texts for word in text.split()) common = counter.most_common(vocab_size - 2) self.word2idx.update({word: idx + 2 for idx, (word, _) in enumerate(common)}) def encode(self, text, max_len=100): tokens = [self.word2idx.get(word, 1) for word in text.split()] return tokens[:max_len] + [0] * (max_len - len(tokens)) class BeastSpamModel(nn.Module): def __init__(self, vocab_size, embed_dim=128, hidden_dim=64): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.conv = nn.Conv1d(embed_dim, 128, kernel_size=5, padding=2) self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True) self.fc = nn.Linear(hidden_dim * 2, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.embedding(x) x = x.permute(0, 2, 1) x = self.conv(x).permute(0, 2, 1) lstm_out, _ = self.lstm(x) out = self.fc(lstm_out[:, -1, :]) return self.sigmoid(out).squeeze(1)