Bor Hodošček commited on
Commit
d479194
·
unverified ·
1 Parent(s): e8b9f04

fix: actually use tfvectorizer; feat: improved plots

Browse files
Files changed (3) hide show
  1. app.py +103 -90
  2. pyproject.toml +1 -0
  3. uv.lock +15 -0
app.py CHANGED
@@ -8,6 +8,7 @@
8
  # "numpy==2.2.6",
9
  # "pandas==2.3.0",
10
  # "pca==2.10.0",
 
11
  # "pyarrow",
12
  # "scattertext==0.2.2",
13
  # "scikit-learn==1.7.0",
@@ -34,13 +35,12 @@ with app.setup:
34
  import numpy as np
35
  import random
36
  import re
37
- import altair as alt
38
  import scattertext as st
39
  from pca import pca
40
  import matplotlib.pyplot as plt
41
  from pathlib import Path
42
  from types import SimpleNamespace
43
- from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
44
 
45
  RANDOM_SEED = 42
46
  random.seed(RANDOM_SEED)
@@ -92,20 +92,28 @@ def function_export():
92
  chunk_size: int = 2000,
93
  ) -> tuple[list[str], list[str], list[str]]:
94
  """Chunk each text into segments of chunk_size tokens, preserving category and filename."""
95
- chunked_texts = []
96
- chunked_cats = []
97
- chunked_fnames = []
98
  for text, cat, fname in zip(texts, categories, filenames):
 
 
 
 
 
 
 
 
99
  tokens = text.split()
100
  for i in range(0, len(tokens), chunk_size):
101
  chunk = " ".join(tokens[i : i + chunk_size])
102
  chunked_texts.append(chunk)
103
  chunked_cats.append(cat)
104
- chunked_fnames.append(f"{fname}#{i // chunk_size + 1}")
105
- else: # chunk_size is larger then the text
106
  chunked_texts.append(chunk)
107
  chunked_cats.append(cat)
108
- chunked_fnames.append(f"{fname}#leftover")
109
  return chunked_texts, chunked_cats, chunked_fnames
110
 
111
  @mo.cache
@@ -113,7 +121,9 @@ def function_export():
113
  texts: list[str],
114
  categories: list[str],
115
  filenames: list[str],
116
- max_features: int = 100,
 
 
117
  ) -> tuple[
118
  st.Corpus,
119
  scipy.sparse.spmatrix,
@@ -124,15 +134,14 @@ def function_export():
124
  """Fit TF-IDF + CountVectorizer & build a st.Corpus on already‐chunked data."""
125
 
126
  # texts, categories, filenames are assumed already chunked upstream
127
- tfv = TfidfVectorizer()
128
  X_tfidf = tfv.fit_transform(texts)
129
- cv = CountVectorizer(vocabulary=tfv.vocabulary_, max_features=max_features)
130
  y_codes = pd.Categorical(
131
  categories, categories=pd.Categorical(categories).categories
132
  ).codes
133
 
134
  scikit_corpus = st.CorpusFromScikit(
135
- X=cv.fit_transform(texts),
136
  y=y_codes,
137
  feature_vocabulary=tfv.vocabulary_,
138
  category_names=list(pd.Categorical(categories).categories),
@@ -570,6 +579,9 @@ def _():
570
  # 探索的検証
571
 
572
  クラスター分析のデンドログラムと主成分分析(biplot)による探索的検証を行います。
 
 
 
573
  """
574
  )
575
  return
@@ -608,108 +620,109 @@ def pca_biplot(chunk_cats, tfidf_X, vectorizer):
608
  row_labels=chunk_cats,
609
  )
610
 
611
- model.biplot(legend=True, figsize=(12, 8), PC=[0, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  # labels=np.array(chunk_fnames)
613
  topfeat = results["topfeat"]
614
 
615
  mo.vstack(
616
  [
617
  mo.md(
618
- "## [PCA](https://erdogant.github.io/pca/pages/html/index.html)によるbiplot"
 
619
  ),
620
  mo.mpl.interactive(plt.gcf()),
621
  topfeat,
622
  ]
623
  )
624
- return (X,)
625
 
626
 
627
  @app.cell
628
- def _(X, chunk_fnames):
629
- import scipy.cluster.hierarchy as sch
630
- import scipy.spatial.distance as ssd
631
-
632
- # 2. compute linkage on cosine distance
633
- dists = ssd.pdist(X, metric="cosine")
634
- Z = sch.linkage(dists, method="average")
635
-
636
- # 3. get a truncated dendrogram (no_plot=True just to get data)
637
- # Use our filenames for leaf labels
638
- den = sch.dendrogram(
639
- Z,
640
- no_plot=True,
641
- truncate_mode="level",
642
- p=3,
643
- labels=chunk_fnames,
644
- )
645
-
646
- # 4. helpers to reshape the SciPy output
647
- def get_leaf_loc(den):
648
- # leaves are spaced every 10 units in icoord
649
- mn = int(np.min(den["icoord"]))
650
- mx = int(np.max(den["icoord"]) + 1)
651
- return list(range(mn, mx, 10))
652
-
653
- def get_df_coord(den):
654
- cols_x = ["xk1", "xk2", "xk3", "xk4"]
655
- cols_y = ["yk1", "yk2", "yk3", "yk4"]
656
- dfx = pd.DataFrame(den["icoord"], columns=cols_x)
657
- dfy = pd.DataFrame(den["dcoord"], columns=cols_y)
658
- return dfx.merge(dfy, left_index=True, right_index=True)
659
-
660
- source = get_df_coord(den)
661
-
662
- # 5. build the U‐shapes with three mark_rule layers
663
- base = alt.Chart(source)
664
- shoulder = base.mark_rule().encode(
665
- alt.X("xk2:Q", title=""),
666
- alt.X2("xk3:Q"),
667
- alt.Y("yk2:Q", title=""),
668
- )
669
- arm1 = base.mark_rule().encode(
670
- alt.X("xk1:Q"),
671
- alt.Y("yk1:Q"),
672
- alt.Y2("yk2:Q"),
673
  )
674
- arm2 = base.mark_rule().encode(
675
- alt.X("xk3:Q"),
676
- alt.Y("yk3:Q"),
677
- alt.Y2("yk4:Q"),
678
  )
679
- chart_den = shoulder + arm1 + arm2
680
-
681
- # 6. leaf labels
682
- # den["ivl"] now contains the correct filenames for each displayed leaf
683
- df_text = pd.DataFrame({
684
- "labels": den["ivl"],
685
- "x": get_leaf_loc(den),
686
- })
687
- chart_text = (
688
- alt.Chart(df_text)
689
- .mark_text(dy=0, angle=0, align="center")
690
- .encode(
691
- x=alt.X("x:Q", axis={"grid": False, "title": "Leaf nodes"}),
692
- text=alt.Text("labels:N"),
693
- )
694
  )
695
 
696
- # 7. combine and configure
697
- final = (
698
- (chart_den & chart_text)
699
- .resolve_scale(x="shared")
700
- .configure(padding={"top": 10, "left": 10})
701
- .configure_concat(spacing=0)
702
- .configure_axis(labels=False, ticks=False, grid=False)
703
- .properties(title="Hierarchical Clustering Dendrogram")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  )
 
705
 
706
- # 8. hand off to Marimo
707
- mo.ui.altair_chart(final)
708
  return
709
 
710
 
711
  @app.cell
712
- def _():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  return
714
 
715
 
 
8
  # "numpy==2.2.6",
9
  # "pandas==2.3.0",
10
  # "pca==2.10.0",
11
+ # "plotly==6.2.0",
12
  # "pyarrow",
13
  # "scattertext==0.2.2",
14
  # "scikit-learn==1.7.0",
 
35
  import numpy as np
36
  import random
37
  import re
 
38
  import scattertext as st
39
  from pca import pca
40
  import matplotlib.pyplot as plt
41
  from pathlib import Path
42
  from types import SimpleNamespace
43
+ from sklearn.feature_extraction.text import TfidfVectorizer
44
 
45
  RANDOM_SEED = 42
46
  random.seed(RANDOM_SEED)
 
92
  chunk_size: int = 2000,
93
  ) -> tuple[list[str], list[str], list[str]]:
94
  """Chunk each text into segments of chunk_size tokens, preserving category and filename."""
95
+ chunked_texts: list[str] = []
96
+ chunked_cats: list[str] = []
97
+ chunked_fnames: list[str] = []
98
  for text, cat, fname in zip(texts, categories, filenames):
99
+ # compute a short “Initials‐Initials” label for author‐title
100
+ stem = Path(fname).stem.replace("_advanced", "")
101
+ author, title = stem.split("_", 1)
102
+
103
+ def _initials(s: str) -> str:
104
+ return "".join(tok[0].upper() for tok in s.split("-"))
105
+
106
+ short_label = f"{_initials(author)}-{_initials(title)}"
107
  tokens = text.split()
108
  for i in range(0, len(tokens), chunk_size):
109
  chunk = " ".join(tokens[i : i + chunk_size])
110
  chunked_texts.append(chunk)
111
  chunked_cats.append(cat)
112
+ chunked_fnames.append(f"{short_label}({cat})#{i // chunk_size + 1}")
113
+ else:
114
  chunked_texts.append(chunk)
115
  chunked_cats.append(cat)
116
+ chunked_fnames.append(f"{short_label}({cat})#last")
117
  return chunked_texts, chunked_cats, chunked_fnames
118
 
119
  @mo.cache
 
121
  texts: list[str],
122
  categories: list[str],
123
  filenames: list[str],
124
+ min_df: float = 0.25,
125
+ max_df: float = 0.8,
126
+ max_features: int = 200,
127
  ) -> tuple[
128
  st.Corpus,
129
  scipy.sparse.spmatrix,
 
134
  """Fit TF-IDF + CountVectorizer & build a st.Corpus on already‐chunked data."""
135
 
136
  # texts, categories, filenames are assumed already chunked upstream
137
+ tfv = TfidfVectorizer(min_df=min_df, max_df=max_df, max_features=max_features)
138
  X_tfidf = tfv.fit_transform(texts)
 
139
  y_codes = pd.Categorical(
140
  categories, categories=pd.Categorical(categories).categories
141
  ).codes
142
 
143
  scikit_corpus = st.CorpusFromScikit(
144
+ X=tfv.fit_transform(texts),
145
  y=y_codes,
146
  feature_vocabulary=tfv.vocabulary_,
147
  category_names=list(pd.Categorical(categories).categories),
 
579
  # 探索的検証
580
 
581
  クラスター分析のデンドログラムと主成分分析(biplot)による探索的検証を行います。
582
+
583
+ Biplotでは各テキストが丸点で、各素性が矢印で同じプロットで示されています。
584
+ 矢印の色が赤の場合、その素性の負荷量絶対値が高く、色が青いの場合は、どの主成分で高くないという意味になります。
585
  """
586
  )
587
  return
 
620
  row_labels=chunk_cats,
621
  )
622
 
623
+ three_switch = mo.ui.switch(label="3D")
624
+ three_switch
625
+ return X, model, results, three_switch
626
+
627
+
628
+ @app.cell
629
+ def _(model, results, three_switch):
630
+ model.biplot(
631
+ legend=True,
632
+ figsize=(12, 8),
633
+ fontsize=12,
634
+ s=20,
635
+ PC=[0, 1, 2] if three_switch.value else [0, 1],
636
+ )
637
  # labels=np.array(chunk_fnames)
638
  topfeat = results["topfeat"]
639
 
640
  mo.vstack(
641
  [
642
  mo.md(
643
+ """## [PCA](https://erdogant.github.io/pca/pages/html/index.html)biplot
644
+ """
645
  ),
646
  mo.mpl.interactive(plt.gcf()),
647
  topfeat,
648
  ]
649
  )
650
+ return
651
 
652
 
653
  @app.cell
654
+ def _():
655
+ linkage_methods = mo.ui.dropdown(
656
+ options=[
657
+ "ward",
658
+ "single",
659
+ "complete",
660
+ "average",
661
+ ],
662
+ value="ward",
663
+ label="Linkage Method",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  )
665
+ distance_metrics = mo.ui.dropdown(
666
+ options=["cosine", "euclidean", "cityblock", "hamming"],
667
+ value="cosine",
668
+ label="Distance Metric",
669
  )
670
+ dendrogram_height = mo.ui.number(
671
+ label="Dendrogram plot height (increase if hard to see labels)",
672
+ start=800,
673
+ value=1600,
 
 
 
 
 
 
 
 
 
 
 
674
  )
675
 
676
+ d_stack = mo.hstack([linkage_methods, distance_metrics], justify="start")
677
+
678
+ mo.md(f"""
679
+ ## 階層的クラスタリング
680
+
681
+ {d_stack}
682
+ {dendrogram_height}
683
+ """)
684
+ return dendrogram_height, distance_metrics, linkage_methods
685
+
686
+
687
+ @app.cell
688
+ def _(X, chunk_fnames, dendrogram_height, distance_metrics, linkage_methods):
689
+ import plotly.figure_factory as ff
690
+ import scipy.spatial.distance as ssd
691
+ import scipy.cluster.hierarchy as sch
692
+
693
+ distfun = lambda M: ssd.pdist(M, metric=distance_metrics.value)
694
+ linkagefun = lambda D: sch.linkage(D, method=linkage_methods.value)
695
+
696
+ fig = ff.create_dendrogram(
697
+ X,
698
+ orientation="left",
699
+ labels=list(chunk_fnames),
700
+ distfun=distfun,
701
+ linkagefun=linkagefun,
702
  )
703
+ fig.update_layout(width=800, height=dendrogram_height.value)
704
 
705
+ mo.ui.plotly(fig)
 
706
  return
707
 
708
 
709
  @app.cell
710
+ def sample_selector(fnames):
711
+ text_selector = mo.ui.dropdown(
712
+ options=list(sorted(fnames)),
713
+ value=fnames[0] if fnames else None,
714
+ label="Select a sample to view",
715
+ )
716
+ text_selector
717
+ return (text_selector,)
718
+
719
+
720
+ @app.cell
721
+ def sample_viewer(fnames, text_selector, texts):
722
+ mo.stop(not text_selector.value, "No sample selected.")
723
+
724
+ selected_idx = fnames.index(text_selector.value)
725
+ mo.md(f"### {text_selector.value}\n\n{texts[selected_idx]}")
726
  return
727
 
728
 
pyproject.toml CHANGED
@@ -11,6 +11,7 @@ dependencies = [
11
  "numpy>=2.2.6",
12
  "pandas>=2.3.0",
13
  "pca>=2.10.0",
 
14
  "pyarrow>=20.0.0",
15
  "scattertext==0.2.2",
16
  "scikit-learn==1.7.0",
 
11
  "numpy>=2.2.6",
12
  "pandas>=2.3.0",
13
  "pca>=2.10.0",
14
+ "plotly>=6.2.0",
15
  "pyarrow>=20.0.0",
16
  "scattertext==0.2.2",
17
  "scikit-learn==1.7.0",
uv.lock CHANGED
@@ -914,6 +914,19 @@ wheels = [
914
  { url = "https://files.pythonhosted.org/packages/67/32/32dc030cfa91ca0fc52baebbba2e009bb001122a1daa8b6a79ad830b38d3/pillow-11.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:225c832a13326e34f212d2072982bb1adb210e0cc0b153e688743018c94a2681", size = 2417234, upload-time = "2025-04-12T17:49:08.399Z" },
915
  ]
916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
  [[package]]
918
  name = "preshed"
919
  version = "3.0.10"
@@ -1262,6 +1275,7 @@ dependencies = [
1262
  { name = "numpy" },
1263
  { name = "pandas" },
1264
  { name = "pca" },
 
1265
  { name = "pyarrow" },
1266
  { name = "scattertext" },
1267
  { name = "scikit-learn" },
@@ -1278,6 +1292,7 @@ requires-dist = [
1278
  { name = "numpy", specifier = ">=2.2.6" },
1279
  { name = "pandas", specifier = ">=2.3.0" },
1280
  { name = "pca", specifier = ">=2.10.0" },
 
1281
  { name = "pyarrow", specifier = ">=20.0.0" },
1282
  { name = "scattertext", specifier = "==0.2.2" },
1283
  { name = "scikit-learn", specifier = "==1.7.0" },
 
914
  { url = "https://files.pythonhosted.org/packages/67/32/32dc030cfa91ca0fc52baebbba2e009bb001122a1daa8b6a79ad830b38d3/pillow-11.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:225c832a13326e34f212d2072982bb1adb210e0cc0b153e688743018c94a2681", size = 2417234, upload-time = "2025-04-12T17:49:08.399Z" },
915
  ]
916
 
917
+ [[package]]
918
+ name = "plotly"
919
+ version = "6.2.0"
920
+ source = { registry = "https://pypi.org/simple" }
921
+ dependencies = [
922
+ { name = "narwhals" },
923
+ { name = "packaging" },
924
+ ]
925
+ sdist = { url = "https://files.pythonhosted.org/packages/6e/5c/0efc297df362b88b74957a230af61cd6929f531f72f48063e8408702ffba/plotly-6.2.0.tar.gz", hash = "sha256:9dfa23c328000f16c928beb68927444c1ab9eae837d1fe648dbcda5360c7953d", size = 6801941, upload-time = "2025-06-26T16:20:45.765Z" }
926
+ wheels = [
927
+ { url = "https://files.pythonhosted.org/packages/ed/20/f2b7ac96a91cc5f70d81320adad24cc41bf52013508d649b1481db225780/plotly-6.2.0-py3-none-any.whl", hash = "sha256:32c444d4c940887219cb80738317040363deefdfee4f354498cc0b6dab8978bd", size = 9635469, upload-time = "2025-06-26T16:20:40.76Z" },
928
+ ]
929
+
930
  [[package]]
931
  name = "preshed"
932
  version = "3.0.10"
 
1275
  { name = "numpy" },
1276
  { name = "pandas" },
1277
  { name = "pca" },
1278
+ { name = "plotly" },
1279
  { name = "pyarrow" },
1280
  { name = "scattertext" },
1281
  { name = "scikit-learn" },
 
1292
  { name = "numpy", specifier = ">=2.2.6" },
1293
  { name = "pandas", specifier = ">=2.3.0" },
1294
  { name = "pca", specifier = ">=2.10.0" },
1295
+ { name = "plotly", specifier = ">=6.2.0" },
1296
  { name = "pyarrow", specifier = ">=20.0.0" },
1297
  { name = "scattertext", specifier = "==0.2.2" },
1298
  { name = "scikit-learn", specifier = "==1.7.0" },