alexxxey123's picture
Update app.py
6c73b87 verified
raw
history blame
4.41 kB
import gradio as gr
import torch
import re
import nltk
import os
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
nltk.download('stopwords')
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import urllib.request
# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]
models = {}
tokenizers = {}
# === Modelis → URL ===
model_urls = {
"best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
"best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
"best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
"best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
"best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/6rlgceyp3azbvd803efa7/best_model_roberta-base.pth?rlkey=xojr8akv2mmvjpkztrv7gg01a&st=h4g5jjf4&dl=1",
"best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}
# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
if not os.path.exists(filename):
print(f"Lejupielādē: {filename}")
try:
urllib.request.urlretrieve(url, filename)
print(f" → Saglabāts: {filename}")
except Exception as e:
print(f" [!] Kļūda lejupielādējot {filename}: {e}")
for model_name in model_names:
# Tokenizators
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)
# Modelis ar 3 klasēm
models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
model_file_name = re.sub(r'/', '_', model_name)
models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
# Uz ierīces
models[model_name] = models[model_name].to('cpu')
models[model_name].eval()
# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))
def preprocess(text):
text = text.lower() # Teksta pārveide atmetot lielos burtus
text = re.sub(r'http\S+', '', text) # URL atmešana
text = re.sub(r"[^a-z']", ' ', text) # atmet simbolus, kas nav burti
text = re.sub(r'\s+', ' ', text).strip() # atmet liekās atstarpes
text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]) # lemmatizācija
return text
# Classification function (single model)
def classify_email_single_model(text, model_name):
text = preprocess(text)
inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = models[model_name](**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
return labels[prediction]
# Classification function (all models together)
def classify_email(text):
votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
for model_name in model_names:
vote = classify_email_single_model(text, model_name)
votes[vote] += 1
response = ""
i = 1
for label, vote_count in votes.items():
vote_or_votes = "vote" if vote_count == 1 else "votes"
if i != 3:
response += f"{label}: {vote_count} {vote_or_votes}, "
else:
response += f"{label}: {vote_count} {vote_or_votes}"
i += 1
return response
# Gradio UI
demo = gr.Interface(
fn=classify_email,
inputs=gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
outputs="text",
title="E-pastu klasifikators (vairāku modeļu balsošana)",
description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)
demo.launch(share=True)