Spaces:
Runtime error
Runtime error
Commit
·
4d84659
1
Parent(s):
59c16a6
Update app.py
Browse files
app.py
CHANGED
@@ -67,9 +67,9 @@ class ToxicModel(torch.nn.Module):
|
|
67 |
def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
|
68 |
model.eval()
|
69 |
print(input_ids, attention_mask, token_type_ids)
|
70 |
-
input_ids = input_ids.to(device)
|
71 |
-
attention_mask = attention_mask.to(device)
|
72 |
-
token_type_ids = token_type_ids.to(device)
|
73 |
|
74 |
with torch.no_grad():
|
75 |
output = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), token_type_ids.unsqueeze(0))
|
|
|
67 |
def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
|
68 |
model.eval()
|
69 |
print(input_ids, attention_mask, token_type_ids)
|
70 |
+
input_ids = input_ids[0].to(device)
|
71 |
+
attention_mask = attention_mask[0].to(device)
|
72 |
+
token_type_ids = token_type_ids[0].to(device)
|
73 |
|
74 |
with torch.no_grad():
|
75 |
output = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), token_type_ids.unsqueeze(0))
|