APJ23 commited on
Commit
4f1a3ad
·
1 Parent(s): 70cd4ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -21,7 +21,7 @@ classes = {
21
  6: 'Identity Hate'
22
  }
23
  @st.cache(allow_output_mutation=True)
24
- def prediction(tweet,model,tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
@@ -40,12 +40,12 @@ def create_table(predictions):
40
  st.title('Toxicity Prediction App')
41
  tweet=st.text_input('Enter a tweet to check for toxicity')
42
  async def run_async_function():
43
- result = await prediction(tweet, model, tokenizer)
44
  return result
45
  if st.button('Predict'):
46
- predicted_class_label, predicted_prob = predict_toxicity(tweet, model, tokenizer)
47
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
48
  st.write(prediction_text)
49
- predictions = {tweet_input: (predicted_class_label, predicted_prob)}
50
  table = create_table(predictions)
51
  st.table(table)
 
21
  6: 'Identity Hate'
22
  }
23
  @st.cache(allow_output_mutation=True)
24
+ async def async_prediction(tweet,model,tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
 
40
  st.title('Toxicity Prediction App')
41
  tweet=st.text_input('Enter a tweet to check for toxicity')
42
  async def run_async_function():
43
+ result = await async_prediction(tweet, model, tokenizer)
44
  return result
45
  if st.button('Predict'):
46
+ predicted_class_label, predicted_prob = asyncio.run(run_async_function())
47
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
48
  st.write(prediction_text)
49
+ predictions = {tweet: (predicted_class_label, predicted_prob)}
50
  table = create_table(predictions)
51
  st.table(table)