Spaces:
Running
Running
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "altair==5.5.0", | |
# "fugashi-plus", | |
# "marimo", | |
# "numpy==2.2.6", | |
# "pandas==2.3.0", | |
# "pyarrow", | |
# "scattertext==0.2.2", | |
# "scikit-learn==1.7.0", | |
# "scipy==1.13.1", | |
# ] | |
# /// | |
import marimo | |
__generated_with = "0.13.15" | |
app = marimo.App(width="full", app_title="Scattertext on Japanese novels") | |
with app.setup: | |
import marimo as mo | |
import itertools | |
import fugashi | |
import pandas as pd | |
import scipy | |
import numpy as np | |
import random | |
import scattertext as st | |
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |
RANDOM_SEED = 42 | |
random.seed(RANDOM_SEED) | |
np.random.seed(RANDOM_SEED) | |
def function_export(): | |
def parse_texts(texts: list[str]) -> list[str]: | |
"""Tokenize a list of raw strings via fugashi (MeCab).""" | |
tagger = fugashi.Tagger("-Owakati -d ./unidic-novel -r ./unidic-novel/dicrc") | |
return [tagger.parse(txt).strip() for txt in texts] | |
def build_corpus_cached( | |
texts: list[str], | |
categories: list[str], | |
) -> st.Corpus: | |
"""Build or reuse cached Scattertext corpus.""" | |
df = pd.DataFrame({"text": texts, "category": categories}) | |
return ( | |
st.CorpusFromPandas( | |
df, | |
category_col="category", | |
text_col="text", | |
nlp=st.whitespace_nlp_with_sentences, | |
) | |
.build() | |
.get_unigram_corpus() | |
.compact(st.AssociationCompactor(2000)) | |
) | |
def chunk_texts( | |
texts: list[str], | |
categories: list[str], | |
filenames: list[str], | |
chunk_size: int = 2000, | |
) -> tuple[list[str], list[str], list[str]]: | |
"""Chunk each text into segments of chunk_size tokens, preserving category and filename.""" | |
chunked_texts = [] | |
chunked_cats = [] | |
chunked_fnames = [] | |
for text, cat, fname in zip(texts, categories, filenames): | |
tokens = text.split() | |
for i in range(0, len(tokens), chunk_size): | |
chunk = " ".join(tokens[i : i + chunk_size]) | |
chunked_texts.append(chunk) | |
chunked_cats.append(cat) | |
chunked_fnames.append(f"{fname}#{i // chunk_size + 1}") | |
return chunked_texts, chunked_cats, chunked_fnames | |
def train_scikit_cached( | |
texts: list[str], categories: list[str], filenames: list[str] | |
) -> tuple[ | |
st.Corpus, | |
scipy.sparse.spmatrix, | |
TfidfVectorizer, | |
list[str], | |
list[str], | |
]: | |
"""Fit TF-IDF + CountVectorizer & build a st.Corpus on chunked data.""" | |
chunk_texts_out, chunk_cats, chunk_fnames = chunk_texts( | |
texts, categories, filenames | |
) | |
tfv = TfidfVectorizer() | |
X_tfidf = tfv.fit_transform(chunk_texts_out) | |
cv = CountVectorizer(vocabulary=tfv.vocabulary_, max_features=100) | |
y_codes = pd.Categorical( | |
chunk_cats, categories=pd.Categorical(chunk_cats).categories | |
).codes | |
scikit_corpus = st.CorpusFromScikit( | |
X=cv.fit_transform(chunk_texts_out), | |
y=y_codes, | |
feature_vocabulary=tfv.vocabulary_, | |
category_names=list(pd.Categorical(chunk_cats).categories), | |
raw_texts=chunk_texts_out, | |
).build() | |
return ( | |
scikit_corpus, | |
X_tfidf, | |
tfv, | |
chunk_cats, | |
chunk_fnames, | |
) | |
return build_corpus_cached, chunk_texts, parse_texts, train_scikit_cached | |
def intro(): | |
mo.md( | |
r""" | |
# Scattertext on Japanese novels / 近代文学作品のScattertext可視化 | |
## 概要 | |
2つの異なるカテゴリのテキストファイル群をアップロードし、その差異をScattertextで可視化します。 | |
オプショナルで機械学習モデルで分類をし、モデルの分類制度とモデルが識別に用いるトークンも確認できます。 | |
## ワークフロー | |
1. テキストファイルをアップロード(デフォルトを使う場合はそのままSubmitしてください) | |
2. データ内容を確認・修正 | |
3. チャンク&サンプリング設定 | |
4. Scattertextによる可視化 | |
5. (任意)分類モデルによる性能検証 | |
> 単語分割には、[近現代口語小説UniDic](https://clrd.ninjal.ac.jp/unidic/download_all.html#unidic_novel)を使用しています。異なる時代やジャンルのテキストには不向きです。 | |
""" | |
) | |
return | |
def data_settings(): | |
# 1) Create each widget | |
category_name = mo.ui.text( | |
label="カテゴリ名(例:著者名・時代区分など)", | |
placeholder="例:時代・性別・著者など", | |
value="著者", | |
full_width=True, | |
) | |
label_a = mo.ui.text( | |
label="Aのラベル", placeholder="例:夏目漱石", value="夏目漱石", full_width=True | |
) | |
files_a = mo.ui.file( | |
label="Aのファイルアップロード(UTF-8、.txt形式)", multiple=True, kind="area" | |
) | |
label_b = mo.ui.text( | |
label="Bのラベル", placeholder="例:海野十三", value="海野十三", full_width=True | |
) | |
files_b = mo.ui.file( | |
label="Bのファイルアップロード(UTF-8、.txt形式)", multiple=True, kind="area" | |
) | |
tpl = r""" | |
## データと分析の設定 | |
※ 初期では夏目漱石と海野十三から各2作品をサンプルコーパスにしています。設定を変更せずSubmitすると、サンプルコーパスでの分析になります。ファイルをアップロードする場合は忘れずにカテゴリとラベルも変更してください。 | |
※ ファイルはプレインテキスト形式必須(.txt, UTF-8エンコーディング) | |
{category_name} | |
### グループA | |
{label_a} | |
{files_a} | |
### グループB | |
{label_b} | |
{files_b} | |
""" | |
data_form = ( | |
mo.md(tpl) | |
.batch( | |
# info_box=info_box, | |
category_name=category_name, | |
label_a=label_a, | |
files_a=files_a, | |
label_b=label_b, | |
files_b=files_b, | |
) | |
.form(show_clear_button=True, bordered=True) | |
) | |
data_form | |
return data_form, label_a, label_b | |
def data_check(data_form, parse_texts): | |
mo.stop(data_form.value is None) | |
from pathlib import Path | |
validation_messages: list[str] = [] | |
if data_form.value["label_a"] == data_form.value["label_b"]: | |
print("a") | |
validation_messages.append( | |
"⚠️ **警告**: グループAとBのラベルが同じです。AとBは異なるラベルを設定してください。\n" | |
) | |
if not data_form.value["files_a"] and not data_form.value["files_b"]: | |
validation_messages.append( | |
"ℹ️ ファイルが未指定のため、デフォルトサンプルを使用しています。\n" | |
) | |
try: | |
# Group A: either uploaded files or default (坊っちゃん + こころ) | |
if data_form.value["files_a"]: | |
category_a_texts = ( | |
f.contents.decode("utf-8") for f in data_form.value["files_a"] | |
) | |
category_a_names = (f.name for f in data_form.value["files_a"]) | |
else: | |
natsume_1 = Path("Natsume_S_Bocchan.txt").read_text(encoding="utf-8") | |
natsume_2 = Path("Natsume_S_Kokoro.txt").read_text(encoding="utf-8") | |
category_a_texts = [natsume_1, natsume_2] | |
category_a_names = ["Natsume_S_Bocchan.txt", "Natsume_S_Kokoro.txt"] | |
# Group B: either uploaded files or default (地球要塞 + 火星兵団) | |
if data_form.value["files_b"]: | |
category_b_texts = ( | |
f.contents.decode("utf-8") for f in data_form.value["files_b"] | |
) | |
category_b_names = (f.name for f in data_form.value["files_b"]) | |
else: | |
unno_1 = Path("Unno_J_Chikyuuyousa.txt").read_text(encoding="utf-8") | |
unno_2 = Path("Unno_J_Kaseiheidan.txt").read_text(encoding="utf-8") | |
category_b_texts = [unno_1, unno_2] | |
category_b_names = ["Unno_J_Chikyuuyousa.txt", "Unno_J_Kaseiheidan.txt"] | |
data = pd.DataFrame( | |
{ | |
"category": ( | |
[data_form.value["label_a"]] | |
* ( | |
len(data_form.value["files_a"]) | |
if data_form.value["files_a"] | |
else 2 | |
) | |
) | |
+ ( | |
[data_form.value["label_b"]] | |
* ( | |
len(data_form.value["files_b"]) | |
if data_form.value["files_b"] | |
else 2 | |
) | |
), | |
"filename": itertools.chain(category_a_names, category_b_names), | |
"text": itertools.chain(category_a_texts, category_b_texts), | |
} | |
) | |
with mo.status.spinner("コーパスを形態素解析中..."): | |
data["text"] = parse_texts(list(data["text"])) | |
except Exception as e: | |
data = None | |
validation_messages.append( | |
f"❌ **エラー**: ファイルの読み込みに失敗しました: {str(e)}\n" | |
) | |
# We need the maximum number of tokens for the slider | |
max_tokens = data["text"].map(lambda s: len(s.split())).max() | |
mo.md(f""" | |
## データ確認 | |
{"**警告**:\n" if validation_messages else ""} | |
{"\n".join(map(lambda x: f"- {x}", validation_messages))} | |
解析済テキスト一覧: | |
{mo.ui.table(data, selection="multi", format_mapping={"text": lambda s: s[:20] + "..."})} | |
""") | |
return (data,) | |
def sampling_controls_setup(): | |
chunk_size = mo.ui.slider( | |
start=500, | |
stop=50_000, | |
value=2000, | |
step=500, | |
label="1チャンクあたり最大トークン数", | |
full_width=True, | |
) | |
sample_frac = mo.ui.slider( | |
start=0.1, | |
stop=1.0, | |
value=0.2, | |
step=0.05, | |
label="使用割合(1.0で全データ)", | |
full_width=True, | |
) | |
sampling_form = ( | |
mo.md("{chunk_size}\n{sample_frac}") | |
.batch(chunk_size=chunk_size, sample_frac=sample_frac) | |
.form(show_clear_button=True, bordered=False) | |
) | |
sampling_form | |
return chunk_size, sample_frac, sampling_form | |
def _(build_corpus_cached, chunk_texts, data, sample_frac, sampling_form): | |
mo.stop(sampling_form.value is None) | |
with mo.status.spinner("コーパスをサンプリング中…"): | |
texts, cats, fnames = chunk_texts( | |
list(data.text), | |
list(data.category), | |
list(data.filename), | |
sampling_form.value["chunk_size"], | |
) | |
if sample_frac.value < 1.0: | |
N = len(texts) | |
k = int(N * sampling_form.value["sample_frac"]) | |
idx = random.sample(range(N), k) | |
texts = [texts[i] for i in idx] | |
cats = [cats[i] for i in idx] | |
fnames = [fnames[i] for i in idx] | |
corpus = build_corpus_cached( | |
texts, | |
cats, | |
) | |
return cats, corpus, fnames, texts | |
def sampling_controls(chunk_size): | |
mo.md("トークン数を増やすと処理時間が長くなります").callout( | |
kind="info" | |
) if chunk_size.value > 30_000 else None | |
return | |
def plot_main_scatterplot(corpus, data_form, fnames): | |
cat_name = data_form.value["category_name"] | |
with mo.status.spinner("Scatterplot作成中…"): | |
html = st.produce_scattertext_explorer( | |
corpus, | |
category=data_form.value["label_a"], | |
category_name=f"{cat_name}: {data_form.value['label_a']}", | |
not_category_name=f"{cat_name}: {data_form.value['label_b']}", | |
width_in_pixels=1000, | |
metadata=fnames, | |
) | |
mo.vstack( | |
[ | |
mo.md(f""" | |
# Scattertextの結果 | |
### Scattertext可視化の見方 | |
- (縦)上に行くほど{data_form.value["label_a"]}で相対的に多く使われるトークン | |
- (横)右に行くほど{data_form.value["label_b"]}で相対的に多く使われるトークン | |
HTMLをダウンロードしてブラウザで開くと見やすい | |
"""), | |
mo.iframe(html), | |
] | |
) | |
return (html,) | |
def _(html): | |
download_button = mo.download( | |
data=html.encode(), | |
filename="scattertext_analysis.html", | |
label="可視化結果をダウンロード", | |
) | |
mo.md(f"{download_button}") | |
return | |
def classification_toggle(): | |
run_model = mo.ui.switch(label="分類モデルを適用する") | |
run_model | |
return (run_model,) | |
def _(run_model): | |
mo.stop(not run_model.value) | |
mo.md( | |
r""" | |
# 分類モデルによる検証 | |
2つのカテゴリを分類するモデルを学習し、それぞれのカテゴリを分ける有効な素性(単語)がどれなのかもScattertextで観察できます。 | |
ここはロジスティック回帰という機械学習モデルを使用しています。 | |
""" | |
) | |
return | |
def _(cats, fnames, run_model, texts, train_scikit_cached): | |
mo.stop(not run_model.value) | |
scikit_corpus, tfidf_X, vectorizer, chunk_cats, chunk_fnames = train_scikit_cached( | |
texts, cats, fnames | |
) | |
return chunk_cats, chunk_fnames, scikit_corpus, tfidf_X, vectorizer | |
def model_selection(run_model): | |
mo.stop(not run_model.value) | |
model_dropdown = mo.ui.dropdown( | |
options=[ | |
"LogisticRegression", | |
"RandomForestClassifier", | |
"GradientBoostingClassifier", | |
], | |
value="LogisticRegression", | |
label="モデル選択", | |
) | |
model_dropdown | |
return (model_dropdown,) | |
def hyperparameters(model_dropdown): | |
lr_C = mo.ui.slider(0.01, 10.0, value=1.0, step=0.01, label="LR C") | |
lr_max_iter = mo.ui.slider(100, 2000, value=1000, step=100, label="LR max_iter") | |
rf_n = mo.ui.slider(10, 500, value=100, step=10, label="RF n_estimators") | |
rf_max_depth = mo.ui.slider(1, 50, value=10, step=1, label="RF max_depth") | |
gb_n = mo.ui.slider(10, 500, value=100, step=10, label="GB n_estimators") | |
gb_lr = mo.ui.slider(0.01, 1.0, value=0.1, step=0.01, label="GB learning_rate") | |
gb_md = mo.ui.slider(1, 10, value=3, step=1, label="GB max_depth") | |
widgets = [] | |
if model_dropdown.value == "LogisticRegression": | |
widgets = {"lr_C": lr_C, "lr_max_iter": lr_max_iter} | |
elif model_dropdown.value == "RandomForestClassifier": | |
widgets = {"rf_n": rf_n, "rf_max_depth": rf_max_depth} | |
else: # GradientBoostingClassifier | |
widgets = {"gb_n": gb_n, "gb_lr": gb_lr, "gb_md": gb_md} | |
test_size = mo.ui.slider(0.1, 0.5, value=0.3, step=0.05, label="テストデータ比率") | |
model_form = ( | |
mo.md("### モデルのパラメータ設定\n{widgets}\n{test_size}") | |
.batch( | |
widgets=mo.ui.dictionary(widgets), | |
test_size=test_size, | |
) | |
.form(show_clear_button=True, bordered=False) | |
) | |
model_form | |
return (model_form,) | |
def _( | |
chunk_cats, | |
label_a, | |
label_b, | |
model_dropdown, | |
model_form, | |
roc_auc, | |
roc_df, | |
run_model, | |
tfidf_X, | |
vectorizer, | |
): | |
mo.stop(not run_model.value or not model_form.value) | |
import altair as alt | |
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.metrics import ( | |
auc, | |
classification_report, | |
confusion_matrix, | |
roc_curve, | |
) | |
from sklearn.model_selection import train_test_split | |
X_train, X_test, y_train, y_test = train_test_split( | |
tfidf_X, | |
chunk_cats, | |
test_size=model_form.value["test_size"], | |
random_state=RANDOM_SEED, | |
) | |
name = model_dropdown.value | |
if name == "LogisticRegression": | |
clf = LogisticRegression( | |
C=model_form.value["widgets"]["lr_C"], | |
max_iter=int(model_form.value["widgets"]["lr_max_iter"]), | |
) | |
elif name == "RandomForestClassifier": | |
clf = RandomForestClassifier( | |
n_estimators=int(model_form.value["widgets"]["rf_n"]), | |
max_depth=int(model_form.value["widgets"]["rf_max_depth"]), | |
random_state=RANDOM_SEED, | |
) | |
else: # GradientBoostingClassifier | |
clf = GradientBoostingClassifier( | |
n_estimators=int(model_form.value["widgets"]["gb_n"]), | |
learning_rate=float(model_form.value["widgets"]["gb_lr"]), | |
max_depth=int(model_form.value["widgets"]["gb_md"]), | |
random_state=RANDOM_SEED, | |
) | |
clf.fit(X_train, y_train) | |
if hasattr(clf, "feature_importances_"): | |
term_scores = clf.feature_importances_ | |
else: | |
term_scores = abs(clf.coef_[0]) | |
y_pred = clf.predict(X_test) | |
report = classification_report(y_test, y_pred, output_dict=True) | |
cm = confusion_matrix(y_test, y_pred, labels=clf.classes_) | |
cm_df = ( | |
pd.DataFrame(cm, index=clf.classes_, columns=clf.classes_) | |
.reset_index() | |
.melt( | |
id_vars="index", | |
var_name="Predicted", | |
value_name="count", | |
) | |
.rename(columns={"index": "Actual"}) | |
) | |
# pos_idx = list(clf.classes_).index(label_a.value) | |
# _proba, roc_auc = None, None | |
# roc_df = None | |
# if hasattr(clf, "predict_proba"): | |
# probs = clf.predict_proba(X_test)[:, pos_idx] | |
# y_test_arr = np.array(y_test) | |
# fpr, tpr, _ = roc_curve((y_test_arr == label_a.value).astype(int), probs) | |
# roc_auc = auc(fpr, tpr) | |
# roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr}) | |
feature_names = vectorizer.get_feature_names_out() | |
importances = ( | |
pd.DataFrame({"単語": feature_names, "重要度": term_scores}) | |
.sort_values("重要度", ascending=False) | |
.head(20) | |
) | |
imp_chart = ( | |
alt.Chart(importances) | |
.mark_bar() | |
.encode( | |
x=alt.X("重要度:Q", title="重要度"), | |
y=alt.Y("単語:N", sort="-x"), | |
) | |
.properties(title="Top‐20 重要特徴語", width=600, height=400) | |
) | |
cm_chart = ( | |
alt.Chart(cm_df) | |
.mark_rect() | |
.encode( | |
x="Predicted:N", | |
y="Actual:N", | |
color=alt.Color("count:Q", title="Count"), | |
tooltip=["Actual", "Predicted", "count"], | |
) | |
.properties(title="Confusion Matrix", width=250, height=250) | |
) | |
# roc_chart = ( | |
# alt.Chart(roc_df) | |
# .mark_line(point=True) | |
# .encode( | |
# x=alt.X("fpr:Q", title="False Positive Rate"), | |
# y=alt.Y("tpr:Q", title="True Positive Rate"), | |
# ) | |
# .properties( | |
# title=f"ROC Curve (AUC={roc_auc:.2f})", | |
# width=400, | |
# height=300, | |
# ) | |
# ) | |
mo.vstack( | |
[ | |
mo.ui.altair_chart(imp_chart), | |
mo.ui.altair_chart(cm_chart), | |
# mo.ui.altair_chart(roc_chart), # Turned out to not be too informative as task is too easy? | |
mo.md(f""" | |
## テストセット上の分類性能 | |
- {label_a.value}: 精度 {report[label_a.value]["precision"]:.2%}, 再現率 {report[label_a.value]["recall"]:.2%} | |
- {label_b.value}: 精度 {report[label_b.value]["precision"]:.2%}, 再現率 {report[label_b.value]["recall"]:.2%} | |
"""), | |
] | |
) | |
return (term_scores,) | |
def _( | |
chunk_fnames, | |
data_form, | |
model_form, | |
run_model, | |
scikit_corpus, | |
term_scores, | |
): | |
mo.stop(not run_model.value or not model_form.value) | |
with mo.status.spinner("分類モデルのScatterplotを作成中…"): | |
scikit_html = st.produce_scattertext_explorer( | |
corpus=scikit_corpus, | |
category=data_form.value["label_a"], | |
category_name=data_form.value["label_a"], | |
not_category_name=data_form.value["label_b"], | |
scores=term_scores, | |
terms_to_include=st.AutoTermSelector.get_selected_terms( | |
scikit_corpus, term_scores, 4000 | |
), | |
metadata=chunk_fnames, | |
transform=lambda freqs, _index, total: freqs / total.sum(), | |
rescale_x=lambda arr: arr, # identity | |
rescale_y=lambda arr: arr, # identity | |
) | |
mo.iframe(scikit_html) | |
return | |
if __name__ == "__main__": | |
app.run() | |