NeuroSpaceX commited on
Commit
81ceb92
·
verified ·
1 Parent(s): 5f9bfe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -29,14 +29,18 @@ def predict_spam_deberta(text):
29
  model = model_instances["RUSpam/spam_deberta_v4"]
30
 
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
 
 
 
32
  with torch.no_grad():
33
- outputs = model(**inputs)
34
- logits = outputs.logits
35
- predicted_class = torch.argmax(logits, dim=1).item()
36
-
37
- result = "Спам" if predicted_class == 1 else "Не спам"
38
  return result
39
 
 
40
  def predict_spam_spamns(text):
41
  tokenizer = tokenizers["RUSpam/spamNS_v1"]
42
  model = model_instances["RUSpam/spamNS_v1"]
 
29
  model = model_instances["RUSpam/spam_deberta_v4"]
30
 
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
32
+ input_ids = inputs['input_ids'].to(device)
33
+ attention_mask = inputs['attention_mask'].to(device)
34
+
35
  with torch.no_grad():
36
+ outputs = model(input_ids, attention_mask=attention_mask).logits
37
+ pred = torch.sigmoid(outputs).cpu().numpy()[0][0]
38
+
39
+ is_spam = int(pred >= 0.5)
40
+ result = "Спам" if is_spam == 1 else "Не спам"
41
  return result
42
 
43
+
44
  def predict_spam_spamns(text):
45
  tokenizer = tokenizers["RUSpam/spamNS_v1"]
46
  model = model_instances["RUSpam/spamNS_v1"]