Spaces:
Running
Running
File size: 4,997 Bytes
3793db4 c583103 3793db4 dd706fe 39ccd13 3793db4 3adc964 3793db4 77e2bb8 3793db4 b15f954 3793db4 665b792 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 c8f1793 b15f954 d0b7fe0 c8f1793 b15f954 3793db4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import torch
import torch.nn.functional as F
import re
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
import os
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import urllib.request
# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]
models = {}
tokenizers = {}
# === Modelis → URL ===
model_urls = {
"best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
"best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
"best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
"best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
"best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/6rlgceyp3azbvd803efa7/best_model_roberta-base.pth?rlkey=xojr8akv2mmvjpkztrv7gg01a&st=h4g5jjf4&dl=1",
"best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}
# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
if not os.path.exists(filename):
print(f"Lejupielādē: {filename}")
try:
urllib.request.urlretrieve(url, filename)
print(f" → Saglabāts: {filename}")
except Exception as e:
print(f" [!] Kļūda lejupielādējot {filename}: {e}")
for model_name in model_names:
# Tokenizators
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)
# Modelis ar 3 klasēm
models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
model_file_name = re.sub(r'/', '_', model_name)
models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
# Uz ierīces
models[model_name] = models[model_name].to('cpu')
models[model_name].eval()
# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))
def preprocess(text):
text = text.lower() # Teksta pārveide atmetot lielos burtus
text = re.sub(r'http\S+', '', text) # URL atmešana
text = re.sub(r"[^a-z']", ' ', text) # atmet simbolus, kas nav burti
text = re.sub(r'\s+', ' ', text).strip() # atmet liekās atstarpes
text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]) # lemmatizācija
return text
# Classification function (single model)
def classify_email_single_model(text, model_name):
text = preprocess(text)
inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = models[model_name](**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
probs = F.softmax(outputs.logits, dim=1)
probs_percent = probs.cpu().numpy() * 100
response = {"prediction": labels[prediction], "probabilities": probs_percent}
return response
# Classification function (all models together)
def classify_email(text):
votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
probabilities = {}
for model_name in model_names:
response = classify_email_single_model(text, model_name)
vote = response['prediction']
votes[vote] += 1
probabilities[model_name] = response['probabilities']
response = ""
i = 1
for label, vote_count in votes.items():
vote_or_votes = "vote" if vote_count == 1 else "votes"
if i != 3:
response += f"{label}: {vote_count} {vote_or_votes}, "
else:
response += f"{label}: {vote_count} {vote_or_votes}\n"
i += 1
response += "\n"
for model_name in model_names:
response += f"{model_name}: "
for j, prob in enumerate(probabilities[model_name][0]):
response += f"{labels[j]}: {prob:.2f}% "
response += "\n"
return response
# Gradio UI
demo = gr.Interface(
fn=classify_email,
inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
outputs="text",
title="E-pastu klasifikators (vairāku modeļu balsošana)",
description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)
demo.launch(share=True) |