Spaces:
Runtime error
Runtime error
import streamlit as st | |
from bertopic import BERTopic | |
import streamlit.components.v1 as components | |
from sentence_transformers import SentenceTransformer | |
from umap import UMAP | |
from hdbscan import HDBSCAN | |
# Initialize BERTopic model | |
model = BERTopic() | |
st.subheader("Topic Modeling with Topic-Wizard") | |
uploaded_file = st.file_uploader("Choose a text file", type=["txt"]) | |
if uploaded_file is not None: | |
st.session_state["text"] = uploaded_file.getvalue().decode("utf-8") | |
st.write("OR") | |
input_text = st.text_area( | |
label="Enter text separated by newlines", | |
value="", | |
key="text", | |
height=150, | |
) | |
button = st.button("Get Segments") | |
if button and (uploaded_file is not None or input_text != ""): | |
if uploaded_file is not None: | |
texts = st.session_state["text"].split("\n") | |
else: | |
texts = input_text.split("\n") | |
# Fit BERTopic model | |
topics, probabilities = model.fit_transform(texts) | |
# Create embeddings | |
embeddings_model = SentenceTransformer("distilbert-base-nli-mean-tokens") | |
embeddings = embeddings_model.encode(texts) | |
# Reduce dimensionality of embeddings using UMAP | |
umap_model = UMAP(n_neighbors=15, n_components=2, metric="cosine") | |
umap_embeddings = umap_model.fit_transform(embeddings) | |
# Cluster topics using HDBSCAN | |
cluster = HDBSCAN( | |
min_cluster_size=15, metric="euclidean", cluster_selection_method="eom" | |
).fit(umap_embeddings) | |
# Visualize BERTopic results with Streamlit | |
st.title("BERTopic Visualization") | |
# Display top N most representative topics and their documents | |
num_topics = st.sidebar.slider("Select number of topics to display", 1, 20, 5, 1) | |
topic_words = model.get_topics() | |
topic_freq = model.get_topic_freq().head(num_topics + 1) # Add 1 to exclude -1 (outliers topic) | |
for _, row in topic_freq.iterrows(): | |
topic_id = row["Topic"] | |
if topic_id == -1: | |
continue # Skip the outliers topic | |
st.write(f"## Topic {topic_id}") | |
st.write("Keywords:", ", ".join(topic_words[topic_id])) | |
st.write("Documents:") | |
doc_ids = [idx for idx, topic in enumerate(topics) if topic == topic_id][:5] | |
for doc in doc_ids: | |
st.write("-", texts[doc]) | |
# Display topic clusters | |
st.write("## Topic Clusters") | |
components.html(cluster.labels_.tolist(), height=500, width=800) | |