ethanp55 commited on
Commit
e9b0055
·
verified ·
1 Parent(s): da3874c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from classes import classes
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer, util
4
+ import streamlit as st
5
+
6
+
7
+ # Simple sentence transformer
8
+ model_checkpoint = 'sentence-transformers/paraphrase-distilroberta-base-v1'
9
+ model = SentenceTransformer(model_checkpoint)
10
+
11
+ # Predefined messages and their embeddings
12
+ classes_text = np.array(classes)
13
+ classes_embeddings = model.encode(classes_text, convert_to_numpy=True)
14
+ assert classes_embeddings.shape[0] == len(classes)
15
+
16
+ # Function to compare the embedding of the human chat/text message with the embeddings of the
17
+ # predefined messages
18
+ def convert(sentence_embedding: np.array, class_embeddings: np.array, top_n=5) -> np.array:
19
+ similarities = np.array(util.cos_sim(sentence_embedding, class_embeddings)).reshape(-1,)
20
+ top_n_indices = np.argsort(similarities)[::-1][0:top_n]
21
+
22
+ return top_n_indices
23
+
24
+ # Simple title and description for the app
25
+ st.title('JHG Chat Message Converter')
26
+ st.write('Converts human chat/text messages into predefined chat messages via a sentence transformer')
27
+
28
+ # Text box to enter a chat/text message
29
+ text = st.text_area('Enter chat message')
30
+
31
+ if text:
32
+ # Use the sentence transformer and "convert" function to display predicted, predefined messages
33
+ text_embedding = model.encode(text, convert_to_numpy=True)
34
+ indices = convert(text_embedding, classes_embeddings)
35
+ predicted_classes = classes_text[indices]
36
+
37
+ for converted_message in predicted_classes:
38
+ st.write(converted_message)