Den4ikAI commited on
Commit
5f5dede
·
verified ·
1 Parent(s): 2223ef6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -11
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
5
  models = {
6
  "RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4",
@@ -14,24 +15,50 @@ for name, path in models.items():
14
  tokenizers[name] = AutoTokenizer.from_pretrained(path)
15
  model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path)
16
 
17
- def predict_spam(text, model_choice):
18
- tokenizer = tokenizers[model_choice]
19
- model = model_instances[model_choice]
 
 
 
 
 
 
 
 
 
20
 
21
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  logits = outputs.logits
25
- probabilities = torch.softmax(logits, dim=1)
26
  predicted_class = torch.argmax(logits, dim=1).item()
27
 
28
- spam_probability = probabilities[0][1].item()
29
- not_spam_probability = probabilities[0][0].item()
30
-
31
  result = "Спам" if predicted_class == 1 else "Не спам"
32
-
33
  return result
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Создание интерфейса Gradio
36
  iface = gr.Interface(
37
  fn=predict_spam,
@@ -39,9 +66,7 @@ iface = gr.Interface(
39
  gr.Textbox(lines=5, label="Введите текст"),
40
  gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4")
41
  ],
42
- outputs=[
43
- gr.Label(label="Результат")
44
- ],
45
  title="Определение спама в русскоязычных текстах"
46
  )
47
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import re
5
 
6
  models = {
7
  "RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4",
 
15
  tokenizers[name] = AutoTokenizer.from_pretrained(path)
16
  model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path)
17
 
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval()
20
+
21
+ def clean_text(text):
22
+ text = re.sub(r'http\S+', '', text)
23
+ text = re.sub(r'[^А-Яа-я0-9 ]+', ' ', text)
24
+ text = text.lower().strip()
25
+ return text
26
+
27
+ def predict_spam_deberta(text):
28
+ tokenizer = tokenizers["RUSpam/spam_deberta_v4"]
29
+ model = model_instances["RUSpam/spam_deberta_v4"]
30
 
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  logits = outputs.logits
 
35
  predicted_class = torch.argmax(logits, dim=1).item()
36
 
 
 
 
37
  result = "Спам" if predicted_class == 1 else "Не спам"
 
38
  return result
39
 
40
+ def predict_spam_spamns(text):
41
+ tokenizer = tokenizers["RUSpam/spamNS_v1"]
42
+ model = model_instances["RUSpam/spamNS_v1"]
43
+
44
+ text = clean_text(text)
45
+ encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
46
+ input_ids = encoding['input_ids'].to(device)
47
+ attention_mask = encoding['attention_mask'].to(device)
48
+
49
+ with torch.no_grad():
50
+ outputs = model(input_ids, attention_mask=attention_mask).logits
51
+ pred = torch.sigmoid(outputs).cpu().numpy()[0][0]
52
+
53
+ result = "Спам" if pred >= 0.5 else "Не спам"
54
+ return result
55
+
56
+ def predict_spam(text, model_choice):
57
+ if model_choice == "RUSpam/spam_deberta_v4":
58
+ return predict_spam_deberta(text)
59
+ elif model_choice == "RUSpam/spamNS_v1":
60
+ return predict_spam_spamns(text)
61
+
62
  # Создание интерфейса Gradio
63
  iface = gr.Interface(
64
  fn=predict_spam,
 
66
  gr.Textbox(lines=5, label="Введите текст"),
67
  gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4")
68
  ],
69
+ outputs=gr.Label(label="Результат"),
 
 
70
  title="Определение спама в русскоязычных текстах"
71
  )
72