Spaces:
Running
Running
Commit
路
ed8f744
1
Parent(s):
14bdc44
Draw Lines between Cluster Centers
Browse files
app.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
|
|
3 |
from bokeh.plotting import figure
|
4 |
-
from bokeh.models import ColumnDataSource
|
|
|
5 |
from bokeh.palettes import Reds9, Blues9
|
6 |
from sklearn.decomposition import PCA
|
7 |
from sklearn.manifold import TSNE
|
@@ -17,7 +19,6 @@ TOOLTIPS = """
|
|
17 |
</div>
|
18 |
"""
|
19 |
|
20 |
-
|
21 |
def config_style():
|
22 |
st.markdown("""
|
23 |
<style>
|
@@ -28,171 +29,212 @@ def config_style():
|
|
28 |
""", unsafe_allow_html=True)
|
29 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
30 |
st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
|
31 |
-
st.markdown(
|
32 |
-
"""
|
33 |
-
<p class="custom-text">
|
34 |
-
Se cargan ambas versiones de los embeddings y se aplica una reducci贸n dimensional sobre el conjunto combinado.
|
35 |
-
Los puntos de la versi贸n real se muestran como <strong>c铆rculos</strong> (tonos de rojo)
|
36 |
-
y los de la es_digital_seq como <strong>cuadrados</strong> (tonos de azul).
|
37 |
-
</p>
|
38 |
-
""", unsafe_allow_html=True)
|
39 |
-
|
40 |
|
41 |
def load_embeddings():
|
42 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
43 |
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
44 |
-
|
45 |
-
embeddings = {
|
46 |
-
"real": df_real,
|
47 |
-
"es-digital-seq": df_es_digital_seq
|
48 |
-
}
|
49 |
-
|
50 |
-
return embeddings
|
51 |
-
|
52 |
|
53 |
def reducer_selector(df_combined, embedding_cols):
|
54 |
-
|
55 |
-
reduction_method = st.selectbox("Seleccione m茅todo de reducci贸n:", options=["PCA", "t-SNE"])
|
56 |
all_embeddings = df_combined[embedding_cols].values
|
57 |
if reduction_method == "PCA":
|
58 |
reducer = PCA(n_components=2)
|
59 |
else:
|
60 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
61 |
-
|
62 |
-
|
63 |
-
return reduced
|
64 |
-
|
65 |
|
66 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
|
|
|
67 |
for label in selected_labels:
|
68 |
subset = df[df['label'] == label]
|
69 |
if subset.empty:
|
70 |
continue
|
71 |
source = ColumnDataSource(data=dict(
|
72 |
-
x
|
73 |
-
y
|
74 |
-
label
|
75 |
-
img
|
76 |
))
|
77 |
color = color_mapping[label]
|
78 |
if marker == "circle":
|
79 |
-
fig.circle('x', 'y', size=10, source=source,
|
80 |
-
|
81 |
-
|
82 |
elif marker == "square":
|
83 |
-
fig.square('x', 'y', size=
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
def get_color_maps(selected_subsets: dict):
|
88 |
-
|
89 |
-
# real
|
90 |
num_real = len(selected_subsets["real"])
|
91 |
-
if num_real <= 9:
|
92 |
-
red_palette = Reds9[:num_real]
|
93 |
-
else:
|
94 |
-
red_palette = (Reds9 * ((num_real // 9) + 1))[:num_real]
|
95 |
color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
else:
|
102 |
-
blue_palette = (Blues9 * ((num_es_digital_seq // 9) + 1))[:num_es_digital_seq]
|
103 |
-
color_mapping_es_digital_seq = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
|
104 |
-
|
105 |
-
# Gather color maps
|
106 |
-
color_maps = {
|
107 |
-
"real": color_mapping_real,
|
108 |
-
"es-digital-seq": color_mapping_es_digital_seq
|
109 |
-
}
|
110 |
-
|
111 |
-
return color_maps
|
112 |
-
|
113 |
|
114 |
def split_versions(df_combined, reduced):
|
115 |
-
|
116 |
df_combined['x'] = reduced[:, 0]
|
117 |
df_combined['y'] = reduced[:, 1]
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
unique_subsets_real = sorted(df_real_reduced['label'].unique().tolist())
|
124 |
-
unique_subsets_es_digital_seq = sorted(df_es_digital_seq_reduced['label'].unique().tolist())
|
125 |
-
|
126 |
-
unique_subsets = {
|
127 |
-
"real": unique_subsets_real,
|
128 |
-
"es-digital-seq": unique_subsets_es_digital_seq,
|
129 |
-
}
|
130 |
-
|
131 |
-
dfs_reduced = {
|
132 |
-
"real": df_real_reduced,
|
133 |
-
"es-digital-seq": df_es_digital_seq_reduced,
|
134 |
-
}
|
135 |
-
|
136 |
-
return dfs_reduced, unique_subsets
|
137 |
-
|
138 |
|
139 |
def subset_selectors(unique_subsets: dict):
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
default=unique_subsets["real"])
|
144 |
-
selected_subsets_es_digital_seq = st.multiselect("Seleccione subsets para visualizar (Sint茅tico):",
|
145 |
-
options=unique_subsets["es-digital-seq"],
|
146 |
-
default=unique_subsets["es-digital-seq"])
|
147 |
-
|
148 |
-
selected_subsets = {
|
149 |
-
"real": selected_subsets_real,
|
150 |
-
"es-digital-seq": selected_subsets_es_digital_seq
|
151 |
-
}
|
152 |
-
|
153 |
-
return selected_subsets
|
154 |
-
|
155 |
|
156 |
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
marker="circle", color_mapping=color_maps["real"])
|
163 |
-
add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
|
164 |
-
marker="square", color_mapping=color_maps["es-digital-seq"])
|
165 |
-
|
166 |
fig.legend.location = "top_right"
|
167 |
fig.legend.click_policy = "hide"
|
|
|
168 |
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
def main():
|
173 |
-
|
174 |
config_style()
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
embeddings_dfs["real"]["version"] = "real"
|
179 |
-
embeddings_dfs["es-digital-seq"]["version"] = "es_digital_seq"
|
180 |
-
|
181 |
-
embedding_cols = [col for col in embeddings_dfs["real"].columns if col.startswith("dim_")]
|
182 |
-
|
183 |
-
# Combine dataframes to apply method reduction
|
184 |
-
df_combined = pd.concat([embeddings_dfs["real"], embeddings_dfs["es-digital-seq"]], ignore_index=True)
|
185 |
-
|
186 |
reduced = reducer_selector(df_combined, embedding_cols)
|
187 |
|
188 |
-
# Split back the different versions
|
189 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
190 |
-
|
191 |
selected_subsets = subset_selectors(unique_subsets)
|
192 |
color_maps = get_color_maps(selected_subsets)
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
|
|
|
|
|
196 |
|
197 |
if __name__ == "__main__":
|
198 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
from bokeh.plotting import figure
|
5 |
+
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select
|
6 |
+
from bokeh.layouts import row, column
|
7 |
from bokeh.palettes import Reds9, Blues9
|
8 |
from sklearn.decomposition import PCA
|
9 |
from sklearn.manifold import TSNE
|
|
|
19 |
</div>
|
20 |
"""
|
21 |
|
|
|
22 |
def config_style():
|
23 |
st.markdown("""
|
24 |
<style>
|
|
|
29 |
""", unsafe_allow_html=True)
|
30 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
31 |
st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def load_embeddings():
|
34 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
35 |
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
36 |
+
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def reducer_selector(df_combined, embedding_cols):
|
39 |
+
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
|
|
|
40 |
all_embeddings = df_combined[embedding_cols].values
|
41 |
if reduction_method == "PCA":
|
42 |
reducer = PCA(n_components=2)
|
43 |
else:
|
44 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
45 |
+
return reducer.fit_transform(all_embeddings)
|
|
|
|
|
|
|
46 |
|
47 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
|
48 |
+
renderers = {}
|
49 |
for label in selected_labels:
|
50 |
subset = df[df['label'] == label]
|
51 |
if subset.empty:
|
52 |
continue
|
53 |
source = ColumnDataSource(data=dict(
|
54 |
+
x=subset['x'],
|
55 |
+
y=subset['y'],
|
56 |
+
label=subset['label'],
|
57 |
+
img=subset['img']
|
58 |
))
|
59 |
color = color_mapping[label]
|
60 |
if marker == "circle":
|
61 |
+
r = fig.circle('x', 'y', size=10, source=source,
|
62 |
+
fill_color=color, line_color=color,
|
63 |
+
legend_label=f"{label} (Real)")
|
64 |
elif marker == "square":
|
65 |
+
r = fig.square('x', 'y', size=10, source=source,
|
66 |
+
fill_color=color, line_color=color,
|
67 |
+
legend_label=f"{label} (Synthetic)")
|
68 |
+
renderers[label] = r
|
69 |
+
return renderers
|
70 |
|
71 |
def get_color_maps(selected_subsets: dict):
|
72 |
+
# Para real
|
|
|
73 |
num_real = len(selected_subsets["real"])
|
74 |
+
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
|
|
|
|
|
|
|
75 |
color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
|
76 |
+
# Para es-digital-seq (sint茅ticos)
|
77 |
+
num_es = len(selected_subsets["es-digital-seq"])
|
78 |
+
blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es]
|
79 |
+
color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
|
80 |
+
return {"real": color_mapping_real, "es-digital-seq": color_mapping_es}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def split_versions(df_combined, reduced):
|
|
|
83 |
df_combined['x'] = reduced[:, 0]
|
84 |
df_combined['y'] = reduced[:, 1]
|
85 |
+
df_real = df_combined[df_combined["version"] == "real"].copy()
|
86 |
+
df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy()
|
87 |
+
unique_real = sorted(df_real['label'].unique().tolist())
|
88 |
+
unique_es = sorted(df_es['label'].unique().tolist())
|
89 |
+
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
def subset_selectors(unique_subsets: dict):
|
92 |
+
selected_real = st.multiselect("Select Real Subsets:", options=unique_subsets["real"], default=unique_subsets["real"])
|
93 |
+
selected_es = st.multiselect("Select Synthetic Subsets:", options=unique_subsets["es-digital-seq"], default=unique_subsets["es-digital-seq"])
|
94 |
+
return {"real": selected_real, "es-digital-seq": selected_es}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
|
97 |
+
fig = figure(width=600, height=600, tooltips=TOOLTIPS, title="")
|
98 |
+
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
|
99 |
+
marker="circle", color_mapping=color_maps["real"])
|
100 |
+
synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
|
101 |
+
marker="square", color_mapping=color_maps["es-digital-seq"])
|
|
|
|
|
|
|
|
|
102 |
fig.legend.location = "top_right"
|
103 |
fig.legend.click_policy = "hide"
|
104 |
+
return fig, real_renderers, synthetic_renderers
|
105 |
|
106 |
+
def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict:
|
107 |
+
centers = {}
|
108 |
+
for label in selected_labels:
|
109 |
+
subset = df[df['label'] == label]
|
110 |
+
if not subset.empty:
|
111 |
+
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
112 |
+
return centers
|
113 |
+
|
114 |
+
def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
|
115 |
+
distances = {}
|
116 |
+
for es_label, (x_es, y_es) in centers_es.items():
|
117 |
+
distances[es_label] = {}
|
118 |
+
for real_label, (x_real, y_real) in centers_real.items():
|
119 |
+
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
|
120 |
+
return pd.DataFrame(distances).T
|
121 |
|
122 |
def main():
|
|
|
123 |
config_style()
|
124 |
+
embeddings = load_embeddings()
|
125 |
+
embeddings["real"]["version"] = "real"
|
126 |
+
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
|
127 |
+
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
128 |
|
129 |
+
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
reduced = reducer_selector(df_combined, embedding_cols)
|
131 |
|
|
|
132 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
|
|
133 |
selected_subsets = subset_selectors(unique_subsets)
|
134 |
color_maps = get_color_maps(selected_subsets)
|
135 |
+
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
|
136 |
+
|
137 |
+
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
|
138 |
+
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
|
139 |
+
df_distances = compute_distances(centers_es, centers_real)
|
140 |
+
|
141 |
+
# Creamos la tabla de distancias (se muestran todas las combinaciones)
|
142 |
+
df_table = df_distances.copy()
|
143 |
+
df_table.reset_index(inplace=True)
|
144 |
+
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
145 |
+
source_table = ColumnDataSource(df_table)
|
146 |
+
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
147 |
+
for col in df_table.columns:
|
148 |
+
if col != 'Synthetic':
|
149 |
+
columns.append(TableColumn(field=col, title=col))
|
150 |
+
data_table = DataTable(source=source_table, columns=columns, width=400, height=300) # Selecci贸n por fila
|
151 |
+
|
152 |
+
# Creamos un widget Select para elegir el subset real (columnas de la tabla)
|
153 |
+
real_subset_names = list(df_table.columns[1:]) # todas las columnas excepto 'Synthetic'
|
154 |
+
real_select = Select(title="Select Real Subset:", value=real_subset_names[0], options=real_subset_names)
|
155 |
+
|
156 |
+
# Fuente para la l铆nea que conecta los centros
|
157 |
+
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
158 |
+
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
159 |
+
|
160 |
+
# Preparar centros para el callback
|
161 |
+
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
|
162 |
+
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
163 |
+
|
164 |
+
# Callback para actualizar la l铆nea y colores en funci贸n de la fila seleccionada y el valor del dropdown
|
165 |
+
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
166 |
+
synthetic_centers=synthetic_centers_js,
|
167 |
+
real_centers=real_centers_js,
|
168 |
+
synthetic_renderers=synthetic_renderers,
|
169 |
+
real_renderers=real_renderers,
|
170 |
+
synthetic_colors=color_maps["es-digital-seq"],
|
171 |
+
real_colors=color_maps["real"],
|
172 |
+
real_select=real_select),
|
173 |
+
code="""
|
174 |
+
var selected = source.selected.indices;
|
175 |
+
if (selected.length > 0) {
|
176 |
+
var row = selected[0];
|
177 |
+
var data = source.data;
|
178 |
+
var synthetic_label = data['Synthetic'][row];
|
179 |
+
var real_label = real_select.value;
|
180 |
+
var syn_coords = synthetic_centers[synthetic_label];
|
181 |
+
var real_coords = real_centers[real_label];
|
182 |
+
line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] };
|
183 |
+
line_source.change.emit();
|
184 |
+
|
185 |
+
// Actualizar colores: resaltar 煤nicamente los puntos implicados
|
186 |
+
for (var key in synthetic_renderers) {
|
187 |
+
if (synthetic_renderers.hasOwnProperty(key)) {
|
188 |
+
var renderer = synthetic_renderers[key];
|
189 |
+
if (key === synthetic_label) {
|
190 |
+
renderer.glyph.fill_color = synthetic_colors[key];
|
191 |
+
renderer.glyph.line_color = synthetic_colors[key];
|
192 |
+
} else {
|
193 |
+
renderer.glyph.fill_color = "lightgray";
|
194 |
+
renderer.glyph.line_color = "lightgray";
|
195 |
+
}
|
196 |
+
}
|
197 |
+
}
|
198 |
+
for (var key in real_renderers) {
|
199 |
+
if (real_renderers.hasOwnProperty(key)) {
|
200 |
+
var renderer = real_renderers[key];
|
201 |
+
if (key === real_label) {
|
202 |
+
renderer.glyph.fill_color = real_colors[key];
|
203 |
+
renderer.glyph.line_color = real_colors[key];
|
204 |
+
} else {
|
205 |
+
renderer.glyph.fill_color = "lightgray";
|
206 |
+
renderer.glyph.line_color = "lightgray";
|
207 |
+
}
|
208 |
+
}
|
209 |
+
}
|
210 |
+
} else {
|
211 |
+
// Sin selecci贸n: reiniciar l铆nea y colores
|
212 |
+
line_source.data = { 'x': [], 'y': [] };
|
213 |
+
line_source.change.emit();
|
214 |
+
for (var key in synthetic_renderers) {
|
215 |
+
if (synthetic_renderers.hasOwnProperty(key)) {
|
216 |
+
var renderer = synthetic_renderers[key];
|
217 |
+
renderer.glyph.fill_color = synthetic_colors[key];
|
218 |
+
renderer.glyph.line_color = synthetic_colors[key];
|
219 |
+
}
|
220 |
+
}
|
221 |
+
for (var key in real_renderers) {
|
222 |
+
if (real_renderers.hasOwnProperty(key)) {
|
223 |
+
var renderer = real_renderers[key];
|
224 |
+
renderer.glyph.fill_color = real_colors[key];
|
225 |
+
renderer.glyph.line_color = real_colors[key];
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
""")
|
230 |
+
|
231 |
+
# Asociamos el callback a los cambios en la selecci贸n de filas y en el dropdown
|
232 |
+
source_table.selected.js_on_change('indices', callback)
|
233 |
+
real_select.js_on_change('value', callback)
|
234 |
|
235 |
+
# Organizar layout: colocamos el gr谩fico, la tabla y el dropdown
|
236 |
+
layout = row(fig, column(real_select, data_table))
|
237 |
+
st.bokeh_chart(layout)
|
238 |
|
239 |
if __name__ == "__main__":
|
240 |
main()
|