alexxxey123 commited on
Commit
3793db4
·
1 Parent(s): c511209

Add application file

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import re
4
+ import nltk
5
+ from nltk.stem import WordNetLemmatizer
6
+ from nltk.corpus import stopwords
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+
10
+ # Modeļu inicializācija
11
+ model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]
12
+
13
+ models = {}
14
+ tokenizers = {}
15
+
16
+ for model_name in model_names:
17
+ # Tokenizators
18
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)
19
+
20
+ # Modelis ar 3 klasēm
21
+ models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
22
+
23
+ model_file_name = re.sub(r'/', '_', model_name)
24
+ models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
25
+
26
+ # Uz ierīces
27
+ models[model_name] = models[model_name].to('cpu')
28
+ models[model_name].eval()
29
+
30
+ # Label mapping
31
+ labels = {0: "Safe", 1: "Spam", 2: "Phishing"}
32
+
33
+ lemmatizer = WordNetLemmatizer()
34
+ stop_words = set(stopwords.words('english'))
35
+
36
+ def preprocess(text):
37
+ text = text.lower() # Teksta pārveide atmetot lielos burtus
38
+ text = re.sub(r'http\S+', '', text) # URL atmešana
39
+ text = re.sub(r"[^a-z']", ' ', text) # atmet simbolus, kas nav burti
40
+ text = re.sub(r'\s+', ' ', text).strip() # atmet liekās atstarpes
41
+ text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]) # lemmatizācija
42
+ return text
43
+
44
+ # Classification function (single model)
45
+ def classify_email_single_model(text, model_name):
46
+ text = preprocess(text)
47
+ inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
48
+ with torch.no_grad():
49
+ outputs = models[model_name](**inputs)
50
+ prediction = torch.argmax(outputs.logits, dim=1).item()
51
+ return labels[prediction]
52
+
53
+ # Classification function (all models together)
54
+ def classify_email(text):
55
+ votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
56
+
57
+ for model_name in model_names:
58
+ vote = classify_email_single_model(text, model_name)
59
+ votes[vote] += 1
60
+
61
+
62
+ response = ""
63
+ i = 1
64
+ for label, vote_count in votes.items():
65
+ vote_or_votes = "vote" if vote_count == 1 else "votes"
66
+ if i != 3:
67
+ response += f"{label}: {vote_count} {vote_or_votes}, "
68
+ else:
69
+ response += f"{label}: {vote_count} {vote_or_votes}"
70
+ i += 1
71
+
72
+ return response
73
+
74
+ # Gradio UI
75
+ demo = gr.Interface(
76
+ fn=classify_email,
77
+ inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
78
+ outputs="text",
79
+ title="E-pastu klasifikators (vairāku modeļu balsošana)",
80
+ description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
81
+ )
82
+
83
+ demo.launch(share=True)