File size: 4,355 Bytes
3793db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e2bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3793db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import re
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForSequenceClassification


# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]

models = {}
tokenizers = {}

# === Modelis → URL ===
model_urls = {
    "best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
    "best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
    "best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
    "best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
    "best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/6rlgceyp3azbvd803efa7/best_model_roberta-base.pth?rlkey=xojr8akv2mmvjpkztrv7gg01a&st=h4g5jjf4&dl=1",
    "best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}


# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
    if not os.path.exists(filename):
        print(f"Lejupielādē: {filename}")
        try:
            urllib.request.urlretrieve(url, filename)
            print(f"  → Saglabāts: {filename}")
        except Exception as e:
            print(f"  [!] Kļūda lejupielādējot {filename}: {e}")


for model_name in model_names:
    # Tokenizators
    tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)

    # Modelis ar 3 klasēm
    models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

    model_file_name = re.sub(r'/', '_', model_name)
    models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
    
    # Uz ierīces
    models[model_name] = models[model_name].to('cpu')
    models[model_name].eval()

# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def preprocess(text):
    text = text.lower()  # Teksta pārveide atmetot lielos burtus
    text = re.sub(r'http\S+', '', text)  # URL atmešana
    text = re.sub(r"[^a-z']", ' ', text)  # atmet simbolus, kas nav burti
    text = re.sub(r'\s+', ' ', text).strip()  # atmet liekās atstarpes
    text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words])  # lemmatizācija
    return text

# Classification function (single model)
def classify_email_single_model(text, model_name):
    text = preprocess(text)
    inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = models[model_name](**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).item()
    return labels[prediction]

# Classification function (all models together)
def classify_email(text):
    votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
    
    for model_name in model_names:
        vote = classify_email_single_model(text, model_name)
        votes[vote] += 1
        
        
    response = ""
    i = 1
    for label, vote_count in votes.items():
        vote_or_votes = "vote" if vote_count == 1 else "votes"
        if i != 3:
            response += f"{label}: {vote_count} {vote_or_votes}, "
        else:
            response += f"{label}: {vote_count} {vote_or_votes}"
        i += 1
        
    return response
    
# Gradio UI
demo = gr.Interface(
    fn=classify_email,
    inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
    outputs="text",
    title="E-pastu klasifikators (vairāku modeļu balsošana)",
    description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)

demo.launch(share=True)