import streamlit as st from bertopic import BERTopic import streamlit.components.v1 as components from sentence_transformers import SentenceTransformer from umap import UMAP from hdbscan import HDBSCAN # Initialize BERTopic model model = BERTopic() st.subheader("Topic Modeling with Topic-Wizard") uploaded_file = st.file_uploader("Choose a text file", type=["txt"]) if uploaded_file is not None: st.session_state["text"] = uploaded_file.getvalue().decode("utf-8") st.write("OR") input_text = st.text_area( label="Enter text separated by newlines", value="", key="text", height=150, ) button = st.button("Get Segments") if button and (uploaded_file is not None or input_text != ""): if uploaded_file is not None: texts = st.session_state["text"].split("\n") else: texts = input_text.split("\n") # Fit BERTopic model topics, probabilities = model.fit_transform(texts) # Create embeddings embeddings_model = SentenceTransformer("distilbert-base-nli-mean-tokens") embeddings = embeddings_model.encode(texts) # Reduce dimensionality of embeddings using UMAP umap_model = UMAP(n_neighbors=15, n_components=2, metric="cosine") umap_embeddings = umap_model.fit_transform(embeddings) # Cluster topics using HDBSCAN cluster = HDBSCAN( min_cluster_size=15, metric="euclidean", cluster_selection_method="eom" ).fit(umap_embeddings) # Visualize BERTopic results with Streamlit st.title("BERTopic Visualization") # Display top N most representative topics and their documents num_topics = st.sidebar.slider("Select number of topics to display", 1, 20, 5, 1) topic_words = model.get_topics() topic_freq = model.get_topic_freq().head(num_topics + 1) # Add 1 to exclude -1 (outliers topic) for _, row in topic_freq.iterrows(): topic_id = row["Topic"] if topic_id == -1: continue # Skip the outliers topic st.write(f"## Topic {topic_id}") st.write("Keywords:", ", ".join(topic_words[topic_id])) st.write("Documents:") doc_ids = [idx for idx, topic in enumerate(topics) if topic == topic_id][:5] for doc in doc_ids: st.write("-", texts[doc]) # Display topic clusters st.write("## Topic Clusters") components.html(cluster.labels_.tolist(), height=500, width=800)