Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn as nn
|
@@ -5,7 +6,6 @@ import nltk
|
|
5 |
from nltk.corpus import stopwords
|
6 |
import pandas as pd
|
7 |
import base64
|
8 |
-
import random
|
9 |
|
10 |
# Ensure NLTK resources are downloaded
|
11 |
nltk.download('punkt')
|
@@ -21,19 +21,6 @@ def text_convolution(input_text, kernel_size=3):
|
|
21 |
output = conv_layer(tensor_input)
|
22 |
return output, words
|
23 |
|
24 |
-
|
25 |
-
# Function to color the bars based on whether they appear together or not
|
26 |
-
def color_bars(words):
|
27 |
-
color_map = {}
|
28 |
-
color_index = 0
|
29 |
-
for word in words:
|
30 |
-
if word not in color_map:
|
31 |
-
color_map[word] = color_index
|
32 |
-
color_index += 1
|
33 |
-
colors = [f"#{random.randint(0, 0xFFFFFF):06x}" for _ in range(len(color_map))]
|
34 |
-
return [colors[color_map[word]] for word in words]
|
35 |
-
|
36 |
-
|
37 |
# Streamlit UI
|
38 |
def main():
|
39 |
st.title("Text Convolution Demonstration")
|
@@ -49,9 +36,7 @@ def main():
|
|
49 |
|
50 |
# Visualization
|
51 |
word_counts = pd.Series(words).value_counts()
|
52 |
-
word_counts
|
53 |
-
colors = color_bars(word_counts.index)
|
54 |
-
st.bar_chart(word_counts.head(20), color=colors)
|
55 |
|
56 |
# Saving user prompts
|
57 |
user_file_name = f"{user_email}_prompts.txt"
|
@@ -67,4 +52,3 @@ def main():
|
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
main()
|
70 |
-
|
|
|
1 |
+
|
2 |
import streamlit as st
|
3 |
import torch
|
4 |
import torch.nn as nn
|
|
|
6 |
from nltk.corpus import stopwords
|
7 |
import pandas as pd
|
8 |
import base64
|
|
|
9 |
|
10 |
# Ensure NLTK resources are downloaded
|
11 |
nltk.download('punkt')
|
|
|
21 |
output = conv_layer(tensor_input)
|
22 |
return output, words
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Streamlit UI
|
25 |
def main():
|
26 |
st.title("Text Convolution Demonstration")
|
|
|
36 |
|
37 |
# Visualization
|
38 |
word_counts = pd.Series(words).value_counts()
|
39 |
+
st.bar_chart(word_counts.head(20))
|
|
|
|
|
40 |
|
41 |
# Saving user prompts
|
42 |
user_file_name = f"{user_email}_prompts.txt"
|
|
|
52 |
|
53 |
if __name__ == "__main__":
|
54 |
main()
|
|