File size: 4,414 Bytes
3793db4
 
 
 
39ccd13
3793db4
 
6c73b87
3793db4
3adc964
3793db4
 
 
 
 
 
 
 
77e2bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3793db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
import re
import nltk
import os
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
nltk.download('stopwords')
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import urllib.request


# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]

models = {}
tokenizers = {}

# === Modelis → URL ===
model_urls = {
    "best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
    "best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
    "best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
    "best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
    "best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/6rlgceyp3azbvd803efa7/best_model_roberta-base.pth?rlkey=xojr8akv2mmvjpkztrv7gg01a&st=h4g5jjf4&dl=1",
    "best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}


# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
    if not os.path.exists(filename):
        print(f"Lejupielādē: {filename}")
        try:
            urllib.request.urlretrieve(url, filename)
            print(f"  → Saglabāts: {filename}")
        except Exception as e:
            print(f"  [!] Kļūda lejupielādējot {filename}: {e}")


for model_name in model_names:
    # Tokenizators
    tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)

    # Modelis ar 3 klasēm
    models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

    model_file_name = re.sub(r'/', '_', model_name)
    models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
    
    # Uz ierīces
    models[model_name] = models[model_name].to('cpu')
    models[model_name].eval()

# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def preprocess(text):
    text = text.lower()  # Teksta pārveide atmetot lielos burtus
    text = re.sub(r'http\S+', '', text)  # URL atmešana
    text = re.sub(r"[^a-z']", ' ', text)  # atmet simbolus, kas nav burti
    text = re.sub(r'\s+', ' ', text).strip()  # atmet liekās atstarpes
    text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words])  # lemmatizācija
    return text

# Classification function (single model)
def classify_email_single_model(text, model_name):
    text = preprocess(text)
    inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = models[model_name](**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).item()
    return labels[prediction]

# Classification function (all models together)
def classify_email(text):
    votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
    
    for model_name in model_names:
        vote = classify_email_single_model(text, model_name)
        votes[vote] += 1
        
        
    response = ""
    i = 1
    for label, vote_count in votes.items():
        vote_or_votes = "vote" if vote_count == 1 else "votes"
        if i != 3:
            response += f"{label}: {vote_count} {vote_or_votes}, "
        else:
            response += f"{label}: {vote_count} {vote_or_votes}"
        i += 1
        
    return response
    
# Gradio UI
demo = gr.Interface(
    fn=classify_email,
    inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
    outputs="text",
    title="E-pastu klasifikators (vairāku modeļu balsošana)",
    description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)

demo.launch(share=True)