ShAnSantosh commited on
Commit
4d84659
·
1 Parent(s): 59c16a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -67,9 +67,9 @@ class ToxicModel(torch.nn.Module):
67
  def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
68
  model.eval()
69
  print(input_ids, attention_mask, token_type_ids)
70
- input_ids = input_ids.to(device)
71
- attention_mask = attention_mask.to(device)
72
- token_type_ids = token_type_ids.to(device)
73
 
74
  with torch.no_grad():
75
  output = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), token_type_ids.unsqueeze(0))
 
67
  def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
68
  model.eval()
69
  print(input_ids, attention_mask, token_type_ids)
70
+ input_ids = input_ids[0].to(device)
71
+ attention_mask = attention_mask[0].to(device)
72
+ token_type_ids = token_type_ids[0].to(device)
73
 
74
  with torch.no_grad():
75
  output = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), token_type_ids.unsqueeze(0))