kadabengaran commited on
Commit
4e2ac67
·
1 Parent(s): 6657904
Files changed (1) hide show
  1. app/main.py +4 -0
app/main.py CHANGED
@@ -85,6 +85,10 @@ def predict_single(text, model, tokenizer, device):
85
  return predictions.item()
86
 
87
  def predict_multiple(data, model, tokenizer, device):
 
 
 
 
88
  input_ids = []
89
  attention_masks = []
90
  for row in data.tolist():
 
85
  return predictions.item()
86
 
87
  def predict_multiple(data, model, tokenizer, device):
88
+
89
+ if device.type == 'cuda':
90
+ model.cuda()
91
+
92
  input_ids = []
93
  attention_masks = []
94
  for row in data.tolist():