File size: 1,303 Bytes
f02a16d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
from collections import Counter

class BeastTokenizer:
    def __init__(self, texts=[], vocab_size=5000):
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        if texts:
            counter = Counter(word for text in texts for word in text.split())
            common = counter.most_common(vocab_size - 2)
            self.word2idx.update({word: idx + 2 for idx, (word, _) in enumerate(common)})

    def encode(self, text, max_len=100):
        tokens = [self.word2idx.get(word, 1) for word in text.split()]
        return tokens[:max_len] + [0] * (max_len - len(tokens))

class BeastSpamModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.conv = nn.Conv1d(embed_dim, 128, kernel_size=5, padding=2)
        self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        x = self.conv(x).permute(0, 2, 1)
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])
        return self.sigmoid(out).squeeze(1)