awacke1 commited on
Commit
e88272a
·
verified ·
1 Parent(s): c76e6a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ awacke1
16
+ /
17
+ NLPSentenceSimilarityHeatmap
18
+
19
+ like
20
+ 3
21
+
22
+ App
23
+ Files
24
+ Community
25
+ Settings
26
+ NLPSentenceSimilarityHeatmap
27
+ /
28
+ app.py
29
+ awacke1's picture
30
+ awacke1
31
+ Update app.py
32
+ c4d6857
33
+ 12 months ago
34
+ raw
35
+ history
36
+ blame
37
+ edit
38
+ delete
39
+ No virus
40
+ 3.06 kB
41
+ import streamlit as st
42
+ import nltk
43
+ from transformers import pipeline
44
+ from sentence_transformers import SentenceTransformer
45
+ from scipy.spatial.distance import cosine
46
+ import numpy as np
47
+ import seaborn as sns
48
+ import matplotlib.pyplot as plt
49
+ from sklearn.cluster import KMeans
50
+ import tensorflow as tf
51
+ import tensorflow_hub as hub
52
+
53
+
54
+ def cluster_examples(messages, embed, nc=3):
55
+ km = KMeans(
56
+ n_clusters=nc, init='random',
57
+ n_init=10, max_iter=300,
58
+ tol=1e-04, random_state=0
59
+ )
60
+ km = km.fit_predict(embed)
61
+ cluster_list = []
62
+ for n in range(nc):
63
+ idxs = [i for i in range(len(km)) if km[i] == n]
64
+ ms = [messages[i] for i in idxs]
65
+ cluster_list.append(ms)
66
+ return cluster_list
67
+
68
+
69
+ def plot_heatmap(labels, heatmap, rotation=90):
70
+ sns.set(font_scale=1.2)
71
+ fig, ax = plt.subplots()
72
+ g = sns.heatmap(
73
+ heatmap,
74
+ xticklabels=labels,
75
+ yticklabels=labels,
76
+ vmin=-1,
77
+ vmax=1,
78
+ cmap="coolwarm")
79
+ g.set_xticklabels(labels, rotation=rotation)
80
+ g.set_title("Textual Similarity")
81
+
82
+ st.pyplot(fig)
83
+
84
+ # Streamlit app setup
85
+ st.set_page_config(page_title="Sentence Similarity Demo")
86
+
87
+ st.sidebar.title("Sentence Similarity Demo")
88
+
89
+ text = st.sidebar.text_area('Enter sentences:', value="Self confidence in outcomes helps us win and to make us successful.\nShe has a seriously impressive intellect and mind.\nStimulating and deep conversation helps us develop and grow.\nFrom basic quantum particles we get aerodynamics, friction, surface tension, weather, electromagnetism.\nIf she actively engages and comments positively, her anger disappears adapting into win-win's favor.\nI love interesting topics of conversation and the understanding and exploration of thoughts.\nThere is the ability to manipulate things the way you want in your mind to go how you want when you are self confident, that we don’t understand yet.")
90
+
91
+ nc = st.sidebar.slider('Select a number of clusters:', min_value=1, max_value=15, value=3)
92
+
93
+ model_type = st.sidebar.radio("Choose model:", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0)
94
+
95
+ # Model setup
96
+ if model_type == "Sentence Transformer":
97
+ model = SentenceTransformer('paraphrase-distilroberta-base-v1')
98
+ elif model_type == "Universal Sentence Encoder":
99
+ model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
100
+ model = hub.load(model_url)
101
+
102
+ nltk.download('punkt')
103
+
104
+ # Run model
105
+ if text:
106
+ sentences = nltk.tokenize.sent_tokenize(text)
107
+ if model_type == "Sentence Transformer":
108
+ embed = model.encode(sentences)
109
+ elif model_type == "Universal Sentence Encoder":
110
+ embed = model(sentences).numpy()
111
+ sim = np.zeros([len(embed), len(embed)])
112
+ for i,em in enumerate(embed):
113
+ for j,ea in enumerate(embed):
114
+ sim[i][j] = 1.0-cosine(em,ea)
115
+ st.subheader("Similarity Heatmap")
116
+ plot_heatmap(sentences, sim)
117
+ cluster_list = cluster_examples(sentences, embed, nc)
118
+ st.subheader("Results from K-Means Clustering")
119
+ cluster_table = st.table(cluster_list)
120
+