Spaces:
Runtime error
Runtime error
import streamlit as st | |
from bertopic import BERTopic | |
import streamlit.components.v1 as components | |
from sentence_transformers import SentenceTransformer | |
from umap import UMAP | |
from hdbscan import HDBSCAN | |
# Initialize BERTopic model | |
model = BERTopic() | |
st.subheader("Topic Modeling with Topic-Wizard") | |
uploaded_file = st.file_uploader("Choose a text file", type=["txt"]) | |
if uploaded_file is not None: | |
st.session_state["text"] = uploaded_file.getvalue().decode("utf-8") | |
st.write("OR") | |
input_text = st.text_area( | |
label="Enter text separated by newlines", | |
value="", | |
key="text", | |
height=150, | |
) | |
button = st.button("Get Segments") | |
if button and (uploaded_file is not None or input_text != ""): | |
if uploaded_file is not None: | |
texts = st.session_state["text"].split("\n") | |
else: | |
texts = input_text.split("\n") | |
# Fit BERTopic model | |
topics, probabilities = model.fit_transform(texts) | |
# Create embeddings | |
embeddings_model = SentenceTransformer("distilbert-base-nli-mean-tokens") | |
embeddings = embeddings_model.encode(texts) | |
# Reduce dimensionality of embeddings using UMAP | |
umap_model = UMAP(n_neighbors=15, n_components=2, metric="cosine") | |
umap_embeddings = umap_model.fit_transform(embeddings) | |
# Cluster topics using HDBSCAN | |
cluster = HDBSCAN( | |
min_cluster_size=15, metric="euclidean", cluster_selection_method="eom" | |
).fit(umap_embeddings) | |
# Visualize BERTopic results with Streamlit | |
st.title("BERTopic Visualization") | |
# Display top N most representative topics and their documents | |
num_topics = st.sidebar.slider("Select number of topics to display", 1, 20, 5, 1) | |
topic_words, topic_docs = model.get_topics(with_documents=True) | |
for i, topic in enumerate(topic_words.items()): | |
if i >= num_topics: | |
break | |
st.write(f"## Topic {topic[0]}") | |
st.write("Keywords:", ", ".join(topic[1])) | |
st.write("Documents:") | |
for doc in topic_docs[topic[0]][:5]: | |
st.write("-", texts[doc]) | |
# Display topic clusters | |
st.write("## Topic Clusters") | |
components.html(cluster.labels_.tolist(), height=500, width=800) | |