Sephfox commited on
Commit
ddc2af9
·
verified ·
1 Parent(s): e153d74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -156,7 +156,8 @@ def predict_emotion(context):
156
  emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
157
  predictions = emotion_prediction_pipeline(context)
158
  emotion_scores = predictions[0]
159
- emotion_pred = emotion_classes[np.argmax(emotion_scores)]
 
160
  return emotion_pred
161
 
162
  def generate_text(prompt, max_length=100, emotion=None):
@@ -202,4 +203,4 @@ with gr.Blocks() as demo:
202
 
203
  predict_btn.click(fn=lambda context: (predict_emotion(context), generate_response(context, emotion=predict_emotion(context))), inputs=context_input, outputs=[emotion_output, generated_text_output])
204
 
205
- demo.launch(share=True)
 
156
  emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
157
  predictions = emotion_prediction_pipeline(context)
158
  emotion_scores = predictions[0]
159
+ sorted_scores = sorted(emotion_scores.items(), key=lambda x: x[1], reverse=True)
160
+ emotion_pred = sorted_scores[0][0]
161
  return emotion_pred
162
 
163
  def generate_text(prompt, max_length=100, emotion=None):
 
203
 
204
  predict_btn.click(fn=lambda context: (predict_emotion(context), generate_response(context, emotion=predict_emotion(context))), inputs=context_input, outputs=[emotion_output, generated_text_output])
205
 
206
+ demo.launch()