Bor Hodošček
fix: typo
7fd756b unverified
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "altair==5.5.0",
# "fugashi-plus",
# "marimo",
# "numpy==2.2.6",
# "pandas==2.3.0",
# "pyarrow",
# "scattertext==0.2.2",
# "scikit-learn==1.7.0",
# "scipy==1.13.1",
# ]
# ///
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="full", app_title="Scattertext on Japanese novels")
with app.setup:
import marimo as mo
import itertools
import fugashi
import pandas as pd
import scipy
import numpy as np
import random
import scattertext as st
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
@app.cell
def function_export():
@mo.cache
def parse_texts(texts: list[str]) -> list[str]:
"""Tokenize a list of raw strings via fugashi (MeCab)."""
tagger = fugashi.Tagger("-Owakati -d ./unidic-novel -r ./unidic-novel/dicrc")
return [tagger.parse(txt).strip() for txt in texts]
@mo.cache
def build_corpus_cached(
texts: list[str],
categories: list[str],
) -> st.Corpus:
"""Build or reuse cached Scattertext corpus."""
df = pd.DataFrame({"text": texts, "category": categories})
return (
st.CorpusFromPandas(
df,
category_col="category",
text_col="text",
nlp=st.whitespace_nlp_with_sentences,
)
.build()
.get_unigram_corpus()
.compact(st.AssociationCompactor(2000))
)
@mo.cache
def chunk_texts(
texts: list[str],
categories: list[str],
filenames: list[str],
chunk_size: int = 2000,
) -> tuple[list[str], list[str], list[str]]:
"""Chunk each text into segments of chunk_size tokens, preserving category and filename."""
chunked_texts = []
chunked_cats = []
chunked_fnames = []
for text, cat, fname in zip(texts, categories, filenames):
tokens = text.split()
for i in range(0, len(tokens), chunk_size):
chunk = " ".join(tokens[i : i + chunk_size])
chunked_texts.append(chunk)
chunked_cats.append(cat)
chunked_fnames.append(f"{fname}#{i // chunk_size + 1}")
return chunked_texts, chunked_cats, chunked_fnames
@mo.cache
def train_scikit_cached(
texts: list[str], categories: list[str], filenames: list[str]
) -> tuple[
st.Corpus,
scipy.sparse.spmatrix,
TfidfVectorizer,
list[str],
list[str],
]:
"""Fit TF-IDF + CountVectorizer & build a st.Corpus on chunked data."""
chunk_texts_out, chunk_cats, chunk_fnames = chunk_texts(
texts, categories, filenames
)
tfv = TfidfVectorizer()
X_tfidf = tfv.fit_transform(chunk_texts_out)
cv = CountVectorizer(vocabulary=tfv.vocabulary_, max_features=100)
y_codes = pd.Categorical(
chunk_cats, categories=pd.Categorical(chunk_cats).categories
).codes
scikit_corpus = st.CorpusFromScikit(
X=cv.fit_transform(chunk_texts_out),
y=y_codes,
feature_vocabulary=tfv.vocabulary_,
category_names=list(pd.Categorical(chunk_cats).categories),
raw_texts=chunk_texts_out,
).build()
return (
scikit_corpus,
X_tfidf,
tfv,
chunk_cats,
chunk_fnames,
)
return build_corpus_cached, chunk_texts, parse_texts, train_scikit_cached
@app.cell
def intro():
mo.md(
r"""
# Scattertext on Japanese novels / 近代文学作品のScattertext可視化
## 概要
2つの異なるカテゴリのテキストファイル群をアップロードし、その差異をScattertextで可視化します。
オプショナルで機械学習モデルで分類をし、モデルの分類制度とモデルが識別に用いるトークンも確認できます。
## ワークフロー
1. テキストファイルをアップロード(デフォルトを使う場合はそのままSubmitしてください)
2. データ内容を確認・修正
3. チャンク&サンプリング設定
4. Scattertextによる可視化
5. (任意)分類モデルによる性能検証
> 単語分割には、[近現代口語小説UniDic](https://clrd.ninjal.ac.jp/unidic/download_all.html#unidic_novel)を使用しています。異なる時代やジャンルのテキストには不向きです。
"""
)
return
@app.cell
def data_settings():
# 1) Create each widget
category_name = mo.ui.text(
label="カテゴリ名(例:著者名・時代区分など)",
placeholder="例:時代・性別・著者など",
value="著者",
full_width=True,
)
label_a = mo.ui.text(
label="Aのラベル", placeholder="例:夏目漱石", value="夏目漱石", full_width=True
)
files_a = mo.ui.file(
label="Aのファイルアップロード(UTF-8、.txt形式)", multiple=True, kind="area"
)
label_b = mo.ui.text(
label="Bのラベル", placeholder="例:海野十三", value="海野十三", full_width=True
)
files_b = mo.ui.file(
label="Bのファイルアップロード(UTF-8、.txt形式)", multiple=True, kind="area"
)
tpl = r"""
## データと分析の設定
※ 初期では夏目漱石と海野十三から各2作品をサンプルコーパスにしています。設定を変更せずSubmitすると、サンプルコーパスでの分析になります。ファイルをアップロードする場合は忘れずにカテゴリとラベルも変更してください。
※ ファイルはプレインテキスト形式必須(.txt, UTF-8エンコーディング)
{category_name}
### グループA
{label_a}
{files_a}
### グループB
{label_b}
{files_b}
"""
data_form = (
mo.md(tpl)
.batch(
# info_box=info_box,
category_name=category_name,
label_a=label_a,
files_a=files_a,
label_b=label_b,
files_b=files_b,
)
.form(show_clear_button=True, bordered=True)
)
data_form
return data_form, label_a, label_b
@app.cell
def data_check(data_form, parse_texts):
mo.stop(data_form.value is None)
from pathlib import Path
validation_messages: list[str] = []
if data_form.value["label_a"] == data_form.value["label_b"]:
print("a")
validation_messages.append(
"⚠️ **警告**: グループAとBのラベルが同じです。AとBは異なるラベルを設定してください。\n"
)
if not data_form.value["files_a"] and not data_form.value["files_b"]:
validation_messages.append(
"ℹ️ ファイルが未指定のため、デフォルトサンプルを使用しています。\n"
)
try:
# Group A: either uploaded files or default (坊っちゃん + こころ)
if data_form.value["files_a"]:
category_a_texts = (
f.contents.decode("utf-8") for f in data_form.value["files_a"]
)
category_a_names = (f.name for f in data_form.value["files_a"])
else:
natsume_1 = Path("Natsume_S_Bocchan.txt").read_text(encoding="utf-8")
natsume_2 = Path("Natsume_S_Kokoro.txt").read_text(encoding="utf-8")
category_a_texts = [natsume_1, natsume_2]
category_a_names = ["Natsume_S_Bocchan.txt", "Natsume_S_Kokoro.txt"]
# Group B: either uploaded files or default (地球要塞 + 火星兵団)
if data_form.value["files_b"]:
category_b_texts = (
f.contents.decode("utf-8") for f in data_form.value["files_b"]
)
category_b_names = (f.name for f in data_form.value["files_b"])
else:
unno_1 = Path("Unno_J_Chikyuuyousa.txt").read_text(encoding="utf-8")
unno_2 = Path("Unno_J_Kaseiheidan.txt").read_text(encoding="utf-8")
category_b_texts = [unno_1, unno_2]
category_b_names = ["Unno_J_Chikyuuyousa.txt", "Unno_J_Kaseiheidan.txt"]
data = pd.DataFrame(
{
"category": (
[data_form.value["label_a"]]
* (
len(data_form.value["files_a"])
if data_form.value["files_a"]
else 2
)
)
+ (
[data_form.value["label_b"]]
* (
len(data_form.value["files_b"])
if data_form.value["files_b"]
else 2
)
),
"filename": itertools.chain(category_a_names, category_b_names),
"text": itertools.chain(category_a_texts, category_b_texts),
}
)
with mo.status.spinner("コーパスを形態素解析中..."):
data["text"] = parse_texts(list(data["text"]))
except Exception as e:
data = None
validation_messages.append(
f"❌ **エラー**: ファイルの読み込みに失敗しました: {str(e)}\n"
)
# We need the maximum number of tokens for the slider
max_tokens = data["text"].map(lambda s: len(s.split())).max()
mo.md(f"""
## データ確認
{"**警告**:\n" if validation_messages else ""}
{"\n".join(map(lambda x: f"- {x}", validation_messages))}
解析済テキスト一覧:
{mo.ui.table(data, selection="multi", format_mapping={"text": lambda s: s[:20] + "..."})}
""")
return (data,)
@app.cell
def sampling_controls_setup():
chunk_size = mo.ui.slider(
start=500,
stop=50_000,
value=2000,
step=500,
label="1チャンクあたり最大トークン数",
full_width=True,
)
sample_frac = mo.ui.slider(
start=0.1,
stop=1.0,
value=0.2,
step=0.05,
label="使用割合(1.0で全データ)",
full_width=True,
)
sampling_form = (
mo.md("{chunk_size}\n{sample_frac}")
.batch(chunk_size=chunk_size, sample_frac=sample_frac)
.form(show_clear_button=True, bordered=False)
)
sampling_form
return chunk_size, sample_frac, sampling_form
@app.cell
def _(build_corpus_cached, chunk_texts, data, sample_frac, sampling_form):
mo.stop(sampling_form.value is None)
with mo.status.spinner("コーパスをサンプリング中…"):
texts, cats, fnames = chunk_texts(
list(data.text),
list(data.category),
list(data.filename),
sampling_form.value["chunk_size"],
)
if sample_frac.value < 1.0:
N = len(texts)
k = int(N * sampling_form.value["sample_frac"])
idx = random.sample(range(N), k)
texts = [texts[i] for i in idx]
cats = [cats[i] for i in idx]
fnames = [fnames[i] for i in idx]
corpus = build_corpus_cached(
texts,
cats,
)
return cats, corpus, fnames, texts
@app.cell
def sampling_controls(chunk_size):
mo.md("トークン数を増やすと処理時間が長くなります").callout(
kind="info"
) if chunk_size.value > 30_000 else None
return
@app.cell
def plot_main_scatterplot(corpus, data_form, fnames):
cat_name = data_form.value["category_name"]
with mo.status.spinner("Scatterplot作成中…"):
html = st.produce_scattertext_explorer(
corpus,
category=data_form.value["label_a"],
category_name=f"{cat_name}: {data_form.value['label_a']}",
not_category_name=f"{cat_name}: {data_form.value['label_b']}",
width_in_pixels=1000,
metadata=fnames,
)
mo.vstack(
[
mo.md(f"""
# Scattertextの結果
### Scattertext可視化の見方
- (縦)上に行くほど{data_form.value["label_a"]}で相対的に多く使われるトークン
- (横)右に行くほど{data_form.value["label_b"]}で相対的に多く使われるトークン
HTMLをダウンロードしてブラウザで開くと見やすい
"""),
mo.iframe(html),
]
)
return (html,)
@app.cell
def _(html):
download_button = mo.download(
data=html.encode(),
filename="scattertext_analysis.html",
label="可視化結果をダウンロード",
)
mo.md(f"{download_button}")
return
@app.cell
def classification_toggle():
run_model = mo.ui.switch(label="分類モデルを適用する")
run_model
return (run_model,)
@app.cell
def _(run_model):
mo.stop(not run_model.value)
mo.md(
r"""
# 分類モデルによる検証
2つのカテゴリを分類するモデルを学習し、それぞれのカテゴリを分ける有効な素性(単語)がどれなのかもScattertextで観察できます。
ここはロジスティック回帰という機械学習モデルを使用しています。
"""
)
return
@app.cell
def _(cats, fnames, run_model, texts, train_scikit_cached):
mo.stop(not run_model.value)
scikit_corpus, tfidf_X, vectorizer, chunk_cats, chunk_fnames = train_scikit_cached(
texts, cats, fnames
)
return chunk_cats, chunk_fnames, scikit_corpus, tfidf_X, vectorizer
@app.cell
def model_selection(run_model):
mo.stop(not run_model.value)
model_dropdown = mo.ui.dropdown(
options=[
"LogisticRegression",
"RandomForestClassifier",
"GradientBoostingClassifier",
],
value="LogisticRegression",
label="モデル選択",
)
model_dropdown
return (model_dropdown,)
@app.cell
def hyperparameters(model_dropdown):
lr_C = mo.ui.slider(0.01, 10.0, value=1.0, step=0.01, label="LR C")
lr_max_iter = mo.ui.slider(100, 2000, value=1000, step=100, label="LR max_iter")
rf_n = mo.ui.slider(10, 500, value=100, step=10, label="RF n_estimators")
rf_max_depth = mo.ui.slider(1, 50, value=10, step=1, label="RF max_depth")
gb_n = mo.ui.slider(10, 500, value=100, step=10, label="GB n_estimators")
gb_lr = mo.ui.slider(0.01, 1.0, value=0.1, step=0.01, label="GB learning_rate")
gb_md = mo.ui.slider(1, 10, value=3, step=1, label="GB max_depth")
widgets = []
if model_dropdown.value == "LogisticRegression":
widgets = {"lr_C": lr_C, "lr_max_iter": lr_max_iter}
elif model_dropdown.value == "RandomForestClassifier":
widgets = {"rf_n": rf_n, "rf_max_depth": rf_max_depth}
else: # GradientBoostingClassifier
widgets = {"gb_n": gb_n, "gb_lr": gb_lr, "gb_md": gb_md}
test_size = mo.ui.slider(0.1, 0.5, value=0.3, step=0.05, label="テストデータ比率")
model_form = (
mo.md("### モデルのパラメータ設定\n{widgets}\n{test_size}")
.batch(
widgets=mo.ui.dictionary(widgets),
test_size=test_size,
)
.form(show_clear_button=True, bordered=False)
)
model_form
return (model_form,)
@app.cell
def _(
chunk_cats,
label_a,
label_b,
model_dropdown,
model_form,
roc_auc,
roc_df,
run_model,
tfidf_X,
vectorizer,
):
mo.stop(not run_model.value or not model_form.value)
import altair as alt
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
auc,
classification_report,
confusion_matrix,
roc_curve,
)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
tfidf_X,
chunk_cats,
test_size=model_form.value["test_size"],
random_state=RANDOM_SEED,
)
name = model_dropdown.value
if name == "LogisticRegression":
clf = LogisticRegression(
C=model_form.value["widgets"]["lr_C"],
max_iter=int(model_form.value["widgets"]["lr_max_iter"]),
)
elif name == "RandomForestClassifier":
clf = RandomForestClassifier(
n_estimators=int(model_form.value["widgets"]["rf_n"]),
max_depth=int(model_form.value["widgets"]["rf_max_depth"]),
random_state=RANDOM_SEED,
)
else: # GradientBoostingClassifier
clf = GradientBoostingClassifier(
n_estimators=int(model_form.value["widgets"]["gb_n"]),
learning_rate=float(model_form.value["widgets"]["gb_lr"]),
max_depth=int(model_form.value["widgets"]["gb_md"]),
random_state=RANDOM_SEED,
)
clf.fit(X_train, y_train)
if hasattr(clf, "feature_importances_"):
term_scores = clf.feature_importances_
else:
term_scores = abs(clf.coef_[0])
y_pred = clf.predict(X_test)
report = classification_report(y_test, y_pred, output_dict=True)
cm = confusion_matrix(y_test, y_pred, labels=clf.classes_)
cm_df = (
pd.DataFrame(cm, index=clf.classes_, columns=clf.classes_)
.reset_index()
.melt(
id_vars="index",
var_name="Predicted",
value_name="count",
)
.rename(columns={"index": "Actual"})
)
# pos_idx = list(clf.classes_).index(label_a.value)
# _proba, roc_auc = None, None
# roc_df = None
# if hasattr(clf, "predict_proba"):
# probs = clf.predict_proba(X_test)[:, pos_idx]
# y_test_arr = np.array(y_test)
# fpr, tpr, _ = roc_curve((y_test_arr == label_a.value).astype(int), probs)
# roc_auc = auc(fpr, tpr)
# roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr})
feature_names = vectorizer.get_feature_names_out()
importances = (
pd.DataFrame({"単語": feature_names, "重要度": term_scores})
.sort_values("重要度", ascending=False)
.head(20)
)
imp_chart = (
alt.Chart(importances)
.mark_bar()
.encode(
x=alt.X("重要度:Q", title="重要度"),
y=alt.Y("単語:N", sort="-x"),
)
.properties(title="Top‐20 重要特徴語", width=600, height=400)
)
cm_chart = (
alt.Chart(cm_df)
.mark_rect()
.encode(
x="Predicted:N",
y="Actual:N",
color=alt.Color("count:Q", title="Count"),
tooltip=["Actual", "Predicted", "count"],
)
.properties(title="Confusion Matrix", width=250, height=250)
)
# roc_chart = (
# alt.Chart(roc_df)
# .mark_line(point=True)
# .encode(
# x=alt.X("fpr:Q", title="False Positive Rate"),
# y=alt.Y("tpr:Q", title="True Positive Rate"),
# )
# .properties(
# title=f"ROC Curve (AUC={roc_auc:.2f})",
# width=400,
# height=300,
# )
# )
mo.vstack(
[
mo.ui.altair_chart(imp_chart),
mo.ui.altair_chart(cm_chart),
# mo.ui.altair_chart(roc_chart), # Turned out to not be too informative as task is too easy?
mo.md(f"""
## テストセット上の分類性能
- {label_a.value}: 精度 {report[label_a.value]["precision"]:.2%}, 再現率 {report[label_a.value]["recall"]:.2%}
- {label_b.value}: 精度 {report[label_b.value]["precision"]:.2%}, 再現率 {report[label_b.value]["recall"]:.2%}
"""),
]
)
return (term_scores,)
@app.cell
def _(
chunk_fnames,
data_form,
model_form,
run_model,
scikit_corpus,
term_scores,
):
mo.stop(not run_model.value or not model_form.value)
with mo.status.spinner("分類モデルのScatterplotを作成中…"):
scikit_html = st.produce_scattertext_explorer(
corpus=scikit_corpus,
category=data_form.value["label_a"],
category_name=data_form.value["label_a"],
not_category_name=data_form.value["label_b"],
scores=term_scores,
terms_to_include=st.AutoTermSelector.get_selected_terms(
scikit_corpus, term_scores, 4000
),
metadata=chunk_fnames,
transform=lambda freqs, _index, total: freqs / total.sum(),
rescale_x=lambda arr: arr, # identity
rescale_y=lambda arr: arr, # identity
)
mo.iframe(scikit_html)
return
if __name__ == "__main__":
app.run()