ShAnSantosh commited on
Commit
72a0576
·
1 Parent(s): eab1088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -49,8 +49,7 @@ class ToxicModel(torch.nn.Module):
49
  self.dropout3 = nn.Dropout(0.3)
50
  self.dropout4 = nn.Dropout(0.4)
51
  self.dropout5 = nn.Dropout(0.5)
52
- self.output = nn.Linear(config.hidden_size, NUM_CLASSES)
53
-
54
 
55
  def forward(self, input_ids, attention_mask, token_type_ids):
56
  transformer_out = self.transformer(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
@@ -64,7 +63,7 @@ class ToxicModel(torch.nn.Module):
64
  logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
65
  return logits
66
 
67
- def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
68
  model.eval()
69
  input_ids = input_ids.to(device)
70
  attention_mask = attention_mask.to(device)
@@ -75,8 +74,7 @@ class ToxicModel(torch.nn.Module):
75
  out = output.sigmoid().detach().cpu().numpy().flatten()
76
 
77
  return out
78
-
79
-
80
  def predict(comment=None) -> dict:
81
  text = str(comment)
82
  text = " ".join(text.split())
 
49
  self.dropout3 = nn.Dropout(0.3)
50
  self.dropout4 = nn.Dropout(0.4)
51
  self.dropout5 = nn.Dropout(0.5)
52
+ self.output = nn.Linear(config.hidden_size, NUM_CLASSES)
 
53
 
54
  def forward(self, input_ids, attention_mask, token_type_ids):
55
  transformer_out = self.transformer(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
 
63
  logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
64
  return logits
65
 
66
+ def inference_fn(model, input_ids=None, attention_mask=None, token_type_ids=None):
67
  model.eval()
68
  input_ids = input_ids.to(device)
69
  attention_mask = attention_mask.to(device)
 
74
  out = output.sigmoid().detach().cpu().numpy().flatten()
75
 
76
  return out
77
+
 
78
  def predict(comment=None) -> dict:
79
  text = str(comment)
80
  text = " ".join(text.split())