Den4ikAI commited on
Commit
2223ef6
·
verified ·
1 Parent(s): ddb6473

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -2,11 +2,22 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- model_path = "RUSpam/spam_deberta_v4"
6
- tokenizer = AutoTokenizer.from_pretrained(model_path)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def predict_spam(text):
10
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
11
  with torch.no_grad():
12
  outputs = model(**inputs)
@@ -21,11 +32,13 @@ def predict_spam(text):
21
 
22
  return result
23
 
24
-
25
  # Создание интерфейса Gradio
26
  iface = gr.Interface(
27
  fn=predict_spam,
28
- inputs=gr.Textbox(lines=5, label="Введите текст"),
 
 
 
29
  outputs=[
30
  gr.Label(label="Результат")
31
  ],
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ models = {
6
+ "RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4",
7
+ "RUSpam/spamNS_v1": "RUSpam/spamNS_v1"
8
+ }
9
+
10
+ tokenizers = {}
11
+ model_instances = {}
12
+
13
+ for name, path in models.items():
14
+ tokenizers[name] = AutoTokenizer.from_pretrained(path)
15
+ model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path)
16
+
17
+ def predict_spam(text, model_choice):
18
+ tokenizer = tokenizers[model_choice]
19
+ model = model_instances[model_choice]
20
 
 
21
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
22
  with torch.no_grad():
23
  outputs = model(**inputs)
 
32
 
33
  return result
34
 
 
35
  # Создание интерфейса Gradio
36
  iface = gr.Interface(
37
  fn=predict_spam,
38
+ inputs=[
39
+ gr.Textbox(lines=5, label="Введите текст"),
40
+ gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4")
41
+ ],
42
  outputs=[
43
  gr.Label(label="Результат")
44
  ],