NeuroSpaceX commited on
Commit
0950099
·
verified ·
1 Parent(s): 8a098b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -11,12 +11,11 @@ models = {
11
  tokenizers = {}
12
  model_instances = {}
13
 
 
 
14
  for name, path in models.items():
15
  tokenizers[name] = AutoTokenizer.from_pretrained(path)
16
- model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path)
17
-
18
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
- model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval()
20
 
21
  def clean_text(text):
22
  text = re.sub(r'http\S+', '', text)
@@ -29,6 +28,8 @@ def predict_spam_deberta(text):
29
  model = model_instances["RUSpam/spam_deberta_v4"]
30
 
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
 
 
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  logits = outputs.logits
 
11
  tokenizers = {}
12
  model_instances = {}
13
 
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
  for name, path in models.items():
17
  tokenizers[name] = AutoTokenizer.from_pretrained(path)
18
+ model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path).to(device).eval()
 
 
 
19
 
20
  def clean_text(text):
21
  text = re.sub(r'http\S+', '', text)
 
28
  model = model_instances["RUSpam/spam_deberta_v4"]
29
 
30
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
31
+ inputs = {key: val.to(device) for key, val in inputs.items()} # Move inputs to the device
32
+
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
  logits = outputs.logits