awacke1's picture
Update app.py
367bf39 verified
raw
history blame
2.56 kB
import streamlit as st
import torch
import torch.nn as nn
import nltk
from nltk.corpus import stopwords
import pandas as pd
import base64
# Ensure NLTK resources are downloaded
nltk.download('punkt')
nltk.download('stopwords')
# Function to perform convolution on text data
def text_convolution(input_text, kernel_size=3):
words = nltk.word_tokenize(input_text)
words = [word for word in words if word not in stopwords.words('english')]
tensor_input = torch.tensor([hash(word) for word in words], dtype=torch.float)
conv_layer = nn.Conv1d(1, 1, kernel_size, stride=1)
tensor_input = tensor_input.view(1, 1, -1)
output = conv_layer(tensor_input)
return output, words
# Function to color the bars based on whether they appear together or not
def color_bars(words):
color_map = {}
color_index = 0
for word in words:
if word not in color_map:
color_map[word] = color_index
color_index += 1
colors = [f"#{random.randint(0, 0xFFFFFF):06x}" for _ in range(len(color_map))]
return [colors[color_map[word]] for word in words]
# Streamlit UI
def main():
st.title("Text Convolution Demonstration")
st.write("This app demonstrates how text convolution works. Upload a text file and see the convolution result along with a distribution plot of word tokens.")
uploaded_file = st.file_uploader("Choose a text file (TXT only)", type=["txt"])
user_email = st.text_input("Enter your email to save your prompts:")
if uploaded_file is not None and user_email:
text_data = uploaded_file.read().decode("utf-8")
conv_result, words = text_convolution(text_data)
st.write("Convolution result:", conv_result)
# Visualization
word_counts = pd.Series(words).value_counts()
word_counts = word_counts.sort_values(ascending=False)
colors = color_bars(word_counts.index)
st.bar_chart(word_counts.head(20), color=colors)
# Saving user prompts
user_file_name = f"{user_email}_prompts.txt"
with open(user_file_name, "a") as file:
file.write(text_data + "\n")
st.success(f"Your prompts have been added to {user_file_name}")
# Download link for the file
with open(user_file_name, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
href = f'<a href="data:file/txt;base64,{b64}" download="{user_file_name}">Download {user_file_name}</a>'
st.markdown(href, unsafe_allow_html=True)
if __name__ == "__main__":
main()