Spaces:
Runtime error
Runtime error
File size: 2,376 Bytes
c6b92c7 d8f9678 b2cb6f5 d8f9678 cf7ecf9 d8f9678 cf7ecf9 d8f9678 cf7ecf9 d8f9678 b2cb6f5 d8f9678 b2cb6f5 d8f9678 1112873 d8f9678 1112873 d8f9678 b2cb6f5 d8f9678 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import streamlit as st
from bertopic import BERTopic
import streamlit.components.v1 as components
from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
# Initialize BERTopic model
model = BERTopic()
st.subheader("Topic Modeling with Topic-Wizard")
uploaded_file = st.file_uploader("Choose a text file", type=["txt"])
if uploaded_file is not None:
st.session_state["text"] = uploaded_file.getvalue().decode("utf-8")
st.write("OR")
input_text = st.text_area(
label="Enter text separated by newlines",
value="",
key="text",
height=150,
)
button = st.button("Get Segments")
if button and (uploaded_file is not None or input_text != ""):
if uploaded_file is not None:
texts = st.session_state["text"].split("\n")
else:
texts = input_text.split("\n")
# Fit BERTopic model
topics, probabilities = model.fit_transform(texts)
# Create embeddings
embeddings_model = SentenceTransformer("distilbert-base-nli-mean-tokens")
embeddings = embeddings_model.encode(texts)
# Reduce dimensionality of embeddings using UMAP
umap_model = UMAP(n_neighbors=15, n_components=2, metric="cosine")
umap_embeddings = umap_model.fit_transform(embeddings)
# Cluster topics using HDBSCAN
cluster = HDBSCAN(
min_cluster_size=15, metric="euclidean", cluster_selection_method="eom"
).fit(umap_embeddings)
# Visualize BERTopic results with Streamlit
st.title("BERTopic Visualization")
# Display top N most representative topics and their documents
num_topics = st.sidebar.slider("Select number of topics to display", 1, 20, 5, 1)
topic_words = model.get_topics()
topic_freq = model.get_topic_freq().head(num_topics + 1) # Add 1 to exclude -1 (outliers topic)
for _, row in topic_freq.iterrows():
topic_id = row["Topic"]
if topic_id == -1:
continue # Skip the outliers topic
st.write(f"## Topic {topic_id}")
st.write("Keywords:", ", ".join(topic_words[topic_id]))
st.write("Documents:")
doc_ids = [idx for idx, topic in enumerate(topics) if topic == topic_id][:5]
for doc in doc_ids:
st.write("-", texts[doc])
# Display topic clusters
st.write("## Topic Clusters")
components.html(cluster.labels_.tolist(), height=500, width=800)
|