Spaces:
Runtime error
Runtime error
Commit
·
fe0532d
1
Parent(s):
eb844ee
Update app.py
Browse files
app.py
CHANGED
@@ -65,6 +65,7 @@ class ToxicModel(torch.nn.Module):
|
|
65 |
|
66 |
def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
|
67 |
model.eval()
|
|
|
68 |
input_ids = input_ids.to(device)
|
69 |
attention_mask = attention_mask.to(device)
|
70 |
token_type_ids = token_type_ids.to(device)
|
@@ -91,6 +92,10 @@ def predict(comment=None) -> dict:
|
|
91 |
mask = inputs['attention_mask']
|
92 |
token_type_ids = inputs["token_type_ids"]
|
93 |
|
|
|
|
|
|
|
|
|
94 |
model = ToxicModel()
|
95 |
|
96 |
model.load_state_dict(torch.load("toxicx_model_0.pth", map_location=torch.device(device)))
|
|
|
65 |
|
66 |
def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
|
67 |
model.eval()
|
68 |
+
|
69 |
input_ids = input_ids.to(device)
|
70 |
attention_mask = attention_mask.to(device)
|
71 |
token_type_ids = token_type_ids.to(device)
|
|
|
92 |
mask = inputs['attention_mask']
|
93 |
token_type_ids = inputs["token_type_ids"]
|
94 |
|
95 |
+
ids = torch.tensor(ids, dtype=torch.long),
|
96 |
+
mask = torch.tensor(mask, dtype=torch.long),
|
97 |
+
token_type_ids = torch.tensor(token_type_ids, dtype=torch.long),
|
98 |
+
|
99 |
model = ToxicModel()
|
100 |
|
101 |
model.load_state_dict(torch.load("toxicx_model_0.pth", map_location=torch.device(device)))
|