MIND-states-LDA / streamlit_app_LDA.py
Dana Atzil
fix display
c35872d
raw
history blame
6.01 kB
import streamlit as st
import json
import random
import numpy as np
from gensim import corpora, models
import pyLDAvis.gensim_models as gensimvis
import pyLDAvis
import pandas as pd
import streamlit.components.v1 as components
from MIND_utils import df_to_self_states_json, element_short_desc_map
# ---------------------------
# Streamlit App Layout
# ---------------------------
st.title("Prototypical Self-States via Topic Modeling")
st.sidebar.header("Model Parameters")
num_topics = st.sidebar.slider("Number of Topics", min_value=2, max_value=20, value=5)
num_passes = st.sidebar.slider("Number of Passes", min_value=5, max_value=50, value=10)
lda_document_is = st.radio("A 'Document' in the topic model will correspond to a:", ("self-state", "segment"))
seed_value = st.sidebar.number_input("Random Seed", value=42)
st.sidebar.header("Display")
num_top_elements_to_show = st.sidebar.slider("# top element to show in a topic", min_value=2, max_value=15, value=5)
show_long_elements = st.checkbox("Show full element name")
# ---------------------------
# Load Data
# ---------------------------
# You can also allow users to upload their file via st.file_uploader.
# @st.cache(allow_output_mutation=True)
def load_data():
return pd.read_csv("clean_annotations_safe.csv")
df = load_data()
# ---------------------------
# Preprocess Data: Build Documents
# ---------------------------
# Set random seeds for reproducibility
random.seed(seed_value)
np.random.seed(seed_value)
# Functions to extract "words" (elements -- <dim>:<category>) from a segment / self-state
def extract_elements_from_selfstate(selfstate):
words = []
for dim, dim_obj in selfstate.items():
if dim == "is_adaptive":
continue
if "Category" in dim_obj and not pd.isna(dim_obj["Category"]):
word = f"{dim}:{dim_obj['Category']}"
words.append(word)
return words
def extract_elements_from_segment(segment):
words = []
for selfstate in segment["self-states"]:
words += extract_elements_from_selfstate(selfstate)
return words
# Build a list of "documents" (one per segment)
lda_documents = []
lda_document_ids = []
for (doc_id, annotator), df_ in df.groupby(["document", "annotator"]):
doc_json = df_to_self_states_json(df_, doc_id, annotator)
### * for Segment-level LDA-documents:
if lda_document_is == "segment":
for segment in doc_json["segments"]:
lda_doc = extract_elements_from_segment(segment)
if lda_doc: # only add if non-empty
lda_documents.append(lda_doc)
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}")
### * for SelfState-level LDA-documents:
elif lda_document_is == "self-state":
for segment in doc_json["segments"]:
for i, selfstate in enumerate(segment["self-states"]):
lda_doc = extract_elements_from_selfstate(selfstate)
if lda_doc:
lda_documents.append(lda_doc)
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}_state{i+1}")
# Create a dictionary and corpus for LDA
dictionary = corpora.Dictionary(lda_documents)
corpus = [dictionary.doc2bow(doc) for doc in lda_documents]
# ---------------------------
# Run LDA Model
# ---------------------------
lda_model = models.LdaModel(corpus,
num_topics=num_topics,
id2word=dictionary,
passes=num_passes,
random_state=seed_value)
# ---------------------------
# Display Pretty Printed Topics
# ---------------------------
st.header("Pretty Printed Topics")
# Build a mapping for each topic to the list of (document index, topic probability)
topic_docs = {topic_id: [] for topic_id in range(lda_model.num_topics)}
# Iterate over the corpus to get topic distributions for each document
for i, doc_bow in enumerate(corpus):
# Get the full topic distribution (with minimum_probability=0 so every topic is included)
doc_topics = lda_model.get_document_topics(doc_bow, minimum_probability=0)
for topic_id, prob in doc_topics:
topic_docs[topic_id].append((i, prob))
# For each topic, sort the documents by probability in descending order and keep the top 3
top_docs = {}
for topic_id, doc_list in topic_docs.items():
sorted_docs = sorted(doc_list, key=lambda x: x[1], reverse=True)
top_docs[topic_id] = sorted_docs[:3]
# Aggregate output into a single string
output_str = "Identified Prototypical Self-States (Topics):\n\n"
for topic_id, topic_str in lda_model.print_topics(num_words=num_top_elements_to_show):
output_str += f"Topic {topic_id}:\n"
terms = topic_str.split(" + ")
for term in terms:
weight, token = term.split("*")
token = token.strip().replace('"', '')
output_str += f" {float(weight):.3f} -> {token}\n"
output_str += " Top 3 Documents (Segment Indices) for this topic:\n"
for doc_index, prob in top_docs[topic_id]:
# Assuming lda_document_ids is a list or dict mapping document indices to identifiers
output_str += f" Doc {doc_index} ({lda_document_ids[doc_index]}) with probability {prob:.3f}\n"
output_str += "-" * 60 + "\n"
# Now you can display the aggregated string in Streamlit:
import streamlit as st
st.text(output_str)
# ---------------------------
# Prepare and Display pyLDAvis Visualization
# ---------------------------
st.header("Interactive Topic Visualization")
if not show_long_elements:
vis_dict = {i: element_short_desc_map[v] for i, v in dictionary.items()}
vis_dictionary = corpora.dictionary.Dictionary([[new_token] for new_token in vis_dict.values()])
vis_data = gensimvis.prepare(lda_model, corpus, vis_dictionary)
else:
vis_data = gensimvis.prepare(lda_model, corpus, dictionary)
html_string = pyLDAvis.prepared_data_to_html(vis_data)
components.html(html_string, width=2300, height=800, scrolling=True)