ShAnSantosh commited on
Commit
fe0532d
·
1 Parent(s): eb844ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -65,6 +65,7 @@ class ToxicModel(torch.nn.Module):
65
 
66
  def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
67
  model.eval()
 
68
  input_ids = input_ids.to(device)
69
  attention_mask = attention_mask.to(device)
70
  token_type_ids = token_type_ids.to(device)
@@ -91,6 +92,10 @@ def predict(comment=None) -> dict:
91
  mask = inputs['attention_mask']
92
  token_type_ids = inputs["token_type_ids"]
93
 
 
 
 
 
94
  model = ToxicModel()
95
 
96
  model.load_state_dict(torch.load("toxicx_model_0.pth", map_location=torch.device(device)))
 
65
 
66
  def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
67
  model.eval()
68
+
69
  input_ids = input_ids.to(device)
70
  attention_mask = attention_mask.to(device)
71
  token_type_ids = token_type_ids.to(device)
 
92
  mask = inputs['attention_mask']
93
  token_type_ids = inputs["token_type_ids"]
94
 
95
+ ids = torch.tensor(ids, dtype=torch.long),
96
+ mask = torch.tensor(mask, dtype=torch.long),
97
+ token_type_ids = torch.tensor(token_type_ids, dtype=torch.long),
98
+
99
  model = ToxicModel()
100
 
101
  model.load_state_dict(torch.load("toxicx_model_0.pth", map_location=torch.device(device)))