alexxxey123 commited on
Commit
36b3ca7
·
2 Parent(s): 093e3b6 77e2bb8

Merge branch 'main' of https://huggingface.co/spaces/alexxxey123/email_classifier

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import re
4
+ import nltk
5
+ from nltk.stem import WordNetLemmatizer
6
+ from nltk.corpus import stopwords
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+
10
+ # Modeļu inicializācija
11
+ model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]
12
+
13
+ models = {}
14
+ tokenizers = {}
15
+
16
+ # === Modelis → URL ===
17
+ model_urls = {
18
+ "best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
19
+ "best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
20
+ "best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
21
+ "best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
22
+ "best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/6rlgceyp3azbvd803efa7/best_model_roberta-base.pth?rlkey=xojr8akv2mmvjpkztrv7gg01a&st=h4g5jjf4&dl=1",
23
+ "best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
24
+ }
25
+
26
+
27
+ # === Lejupielādē modeļus, ja nav ===
28
+ for filename, url in model_urls.items():
29
+ if not os.path.exists(filename):
30
+ print(f"Lejupielādē: {filename}")
31
+ try:
32
+ urllib.request.urlretrieve(url, filename)
33
+ print(f" → Saglabāts: {filename}")
34
+ except Exception as e:
35
+ print(f" [!] Kļūda lejupielādējot {filename}: {e}")
36
+
37
+
38
+ for model_name in model_names:
39
+ # Tokenizators
40
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)
41
+
42
+ # Modelis ar 3 klasēm
43
+ models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
44
+
45
+ model_file_name = re.sub(r'/', '_', model_name)
46
+ models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
47
+
48
+ # Uz ierīces
49
+ models[model_name] = models[model_name].to('cpu')
50
+ models[model_name].eval()
51
+
52
+ # Label mapping
53
+ labels = {0: "Safe", 1: "Spam", 2: "Phishing"}
54
+
55
+ lemmatizer = WordNetLemmatizer()
56
+ stop_words = set(stopwords.words('english'))
57
+
58
+ def preprocess(text):
59
+ text = text.lower() # Teksta pārveide atmetot lielos burtus
60
+ text = re.sub(r'http\S+', '', text) # URL atmešana
61
+ text = re.sub(r"[^a-z']", ' ', text) # atmet simbolus, kas nav burti
62
+ text = re.sub(r'\s+', ' ', text).strip() # atmet liekās atstarpes
63
+ text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]) # lemmatizācija
64
+ return text
65
+
66
+ # Classification function (single model)
67
+ def classify_email_single_model(text, model_name):
68
+ text = preprocess(text)
69
+ inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
70
+ with torch.no_grad():
71
+ outputs = models[model_name](**inputs)
72
+ prediction = torch.argmax(outputs.logits, dim=1).item()
73
+ return labels[prediction]
74
+
75
+ # Classification function (all models together)
76
+ def classify_email(text):
77
+ votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
78
+
79
+ for model_name in model_names:
80
+ vote = classify_email_single_model(text, model_name)
81
+ votes[vote] += 1
82
+
83
+
84
+ response = ""
85
+ i = 1
86
+ for label, vote_count in votes.items():
87
+ vote_or_votes = "vote" if vote_count == 1 else "votes"
88
+ if i != 3:
89
+ response += f"{label}: {vote_count} {vote_or_votes}, "
90
+ else:
91
+ response += f"{label}: {vote_count} {vote_or_votes}"
92
+ i += 1
93
+
94
+ return response
95
+
96
+ # Gradio UI
97
+ demo = gr.Interface(
98
+ fn=classify_email,
99
+ inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
100
+ outputs="text",
101
+ title="E-pastu klasifikators (vairāku modeļu balsošana)",
102
+ description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
103
+ )
104
+
105
+ demo.launch(share=True)