Spaces:
Sleeping
Sleeping
Commit
·
6b1f66d
1
Parent(s):
b961047
Multiple Dataset Versions
Browse files
app.py
CHANGED
|
@@ -3,11 +3,12 @@ import pandas as pd
|
|
| 3 |
import numpy as np
|
| 4 |
from bokeh.plotting import figure
|
| 5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
|
| 6 |
-
from bokeh.layouts import
|
| 7 |
-
from bokeh.palettes import Reds9, Blues9
|
| 8 |
from sklearn.decomposition import PCA
|
| 9 |
from sklearn.manifold import TSNE
|
| 10 |
import io
|
|
|
|
| 11 |
|
| 12 |
TOOLTIPS = """
|
| 13 |
<div>
|
|
@@ -30,20 +31,31 @@ def config_style():
|
|
| 30 |
""", unsafe_allow_html=True)
|
| 31 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
def load_embeddings(model):
|
| 35 |
if model == "Donut":
|
| 36 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
elif model == "Idefics2":
|
| 39 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
else:
|
| 42 |
st.error("Modelo no reconocido")
|
| 43 |
return None
|
| 44 |
-
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
|
| 45 |
|
| 46 |
-
#
|
| 47 |
def reducer_selector(df_combined, embedding_cols):
|
| 48 |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
|
| 49 |
all_embeddings = df_combined[embedding_cols].values
|
|
@@ -53,7 +65,8 @@ def reducer_selector(df_combined, embedding_cols):
|
|
| 53 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
| 54 |
return reducer.fit_transform(all_embeddings)
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
renderers = {}
|
| 58 |
for label in selected_labels:
|
| 59 |
subset = df[df['label'] == label]
|
|
@@ -63,112 +76,153 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
|
|
| 63 |
x=subset['x'],
|
| 64 |
y=subset['y'],
|
| 65 |
label=subset['label'],
|
| 66 |
-
img=subset
|
| 67 |
))
|
| 68 |
color = color_mapping[label]
|
|
|
|
|
|
|
| 69 |
if marker == "circle":
|
| 70 |
r = fig.circle('x', 'y', size=10, source=source,
|
| 71 |
fill_color=color, line_color=color,
|
| 72 |
-
legend_label=
|
| 73 |
elif marker == "square":
|
| 74 |
-
r = fig.square('x', 'y', size=
|
| 75 |
fill_color=color, line_color=color,
|
| 76 |
-
legend_label=
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return renderers
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
|
|
|
|
| 91 |
def split_versions(df_combined, reduced):
|
| 92 |
df_combined['x'] = reduced[:, 0]
|
| 93 |
df_combined['y'] = reduced[:, 1]
|
| 94 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 95 |
-
|
|
|
|
| 96 |
unique_real = sorted(df_real['label'].unique().tolist())
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
|
|
|
| 101 |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
|
| 102 |
-
real_renderers = add_dataset_to_fig(fig,
|
| 103 |
-
marker="circle", color_mapping=color_maps["real"]
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
fig.legend.location = "top_right"
|
| 107 |
fig.legend.click_policy = "hide"
|
| 108 |
return fig, real_renderers, synthetic_renderers
|
| 109 |
|
| 110 |
-
|
|
|
|
| 111 |
centers = {}
|
| 112 |
-
for label in
|
| 113 |
subset = df[df['label'] == label]
|
| 114 |
if not subset.empty:
|
| 115 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 116 |
return centers
|
| 117 |
|
| 118 |
-
|
|
|
|
| 119 |
distances = {}
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
return pd.DataFrame(distances).T
|
| 125 |
|
| 126 |
def create_table(df_distances):
|
| 127 |
df_table = df_distances.copy()
|
| 128 |
df_table.reset_index(inplace=True)
|
| 129 |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
| 130 |
-
|
| 131 |
-
# Calcular las filas de medias, máximos y mínimos para cada columna numérica
|
| 132 |
min_row = {"Synthetic": "Min."}
|
| 133 |
mean_row = {"Synthetic": "Mean"}
|
| 134 |
max_row = {"Synthetic": "Max."}
|
| 135 |
-
|
| 136 |
for col in df_table.columns:
|
| 137 |
if col != "Synthetic":
|
| 138 |
min_row[col] = df_table[col].min()
|
| 139 |
mean_row[col] = df_table[col].mean()
|
| 140 |
max_row[col] = df_table[col].max()
|
| 141 |
-
|
| 142 |
-
# Agregar las filas de medias, máximos y mínimos al final del DataFrame
|
| 143 |
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
| 144 |
-
|
| 145 |
source_table = ColumnDataSource(df_table)
|
| 146 |
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 147 |
for col in df_table.columns:
|
| 148 |
if col != 'Synthetic':
|
| 149 |
columns.append(TableColumn(field=col, title=col))
|
| 150 |
-
|
| 151 |
-
row_height = 28
|
| 152 |
-
header_height = 30
|
| 153 |
-
total_height = header_height + len(df_table) * row_height
|
| 154 |
-
|
| 155 |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
| 156 |
return data_table, df_table, source_table
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
# Función que ejecuta todo el proceso para un modelo determinado
|
| 161 |
def run_model(model_name):
|
| 162 |
embeddings = load_embeddings(model_name)
|
| 163 |
if embeddings is None:
|
| 164 |
return
|
| 165 |
-
|
| 166 |
-
# Asignamos la versión para distinguir en el split
|
| 167 |
-
embeddings["real"]["version"] = "real"
|
| 168 |
-
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
|
| 169 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
| 170 |
-
|
| 171 |
-
|
| 172 |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
| 173 |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
|
| 174 |
if reduction_method == "PCA":
|
|
@@ -176,125 +230,72 @@ def run_model(model_name):
|
|
| 176 |
else:
|
| 177 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
| 178 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
| 179 |
-
|
| 180 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 181 |
-
selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
|
| 182 |
-
color_maps = get_color_maps(selected_subsets)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
data_table, df_table, source_table = create_table(df_distances)
|
|
|
|
| 189 |
real_subset_names = list(df_table.columns[1:])
|
| 190 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
| 191 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 192 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 193 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 194 |
|
| 195 |
-
|
| 196 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 197 |
|
| 198 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
| 200 |
-
synthetic_centers=
|
| 201 |
real_centers=real_centers_js,
|
| 202 |
-
synthetic_renderers=synthetic_renderers,
|
| 203 |
-
real_renderers=real_renderers,
|
| 204 |
-
synthetic_colors=color_maps["es-digital-seq"],
|
| 205 |
-
real_colors=color_maps["real"],
|
| 206 |
real_select=real_select),
|
| 207 |
code="""
|
| 208 |
var selected = source.selected.indices;
|
| 209 |
if (selected.length > 0) {
|
| 210 |
-
var
|
| 211 |
var data = source.data;
|
| 212 |
-
var
|
| 213 |
var real_label = real_select.value;
|
| 214 |
-
var syn_coords = synthetic_centers[
|
| 215 |
var real_coords = real_centers[real_label];
|
| 216 |
-
line_source.data = {
|
| 217 |
line_source.change.emit();
|
| 218 |
-
|
| 219 |
-
for (var key in synthetic_renderers) {
|
| 220 |
-
if (synthetic_renderers.hasOwnProperty(key)) {
|
| 221 |
-
var renderer = synthetic_renderers[key];
|
| 222 |
-
if (key === synthetic_label) {
|
| 223 |
-
renderer.glyph.fill_color = synthetic_colors[key];
|
| 224 |
-
renderer.glyph.line_color = synthetic_colors[key];
|
| 225 |
-
} else {
|
| 226 |
-
renderer.glyph.fill_color = "lightgray";
|
| 227 |
-
renderer.glyph.line_color = "lightgray";
|
| 228 |
-
}
|
| 229 |
-
}
|
| 230 |
-
}
|
| 231 |
-
for (var key in real_renderers) {
|
| 232 |
-
if (real_renderers.hasOwnProperty(key)) {
|
| 233 |
-
var renderer = real_renderers[key];
|
| 234 |
-
if (key === real_label) {
|
| 235 |
-
renderer.glyph.fill_color = real_colors[key];
|
| 236 |
-
renderer.glyph.line_color = real_colors[key];
|
| 237 |
-
} else {
|
| 238 |
-
renderer.glyph.fill_color = "lightgray";
|
| 239 |
-
renderer.glyph.line_color = "lightgray";
|
| 240 |
-
}
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
} else {
|
| 244 |
-
line_source.data = {
|
| 245 |
line_source.change.emit();
|
| 246 |
-
for (var key in synthetic_renderers) {
|
| 247 |
-
if (synthetic_renderers.hasOwnProperty(key)) {
|
| 248 |
-
var renderer = synthetic_renderers[key];
|
| 249 |
-
renderer.glyph.fill_color = synthetic_colors[key];
|
| 250 |
-
renderer.glyph.line_color = synthetic_colors[key];
|
| 251 |
-
}
|
| 252 |
-
}
|
| 253 |
-
for (var key in real_renderers) {
|
| 254 |
-
if (real_renderers.hasOwnProperty(key)) {
|
| 255 |
-
var renderer = real_renderers[key];
|
| 256 |
-
renderer.glyph.fill_color = real_colors[key];
|
| 257 |
-
renderer.glyph.line_color = real_colors[key];
|
| 258 |
-
}
|
| 259 |
-
}
|
| 260 |
}
|
| 261 |
""")
|
| 262 |
source_table.selected.js_on_change('indices', callback)
|
| 263 |
real_select.js_on_change('value', callback)
|
| 264 |
|
| 265 |
-
reset_callback = CustomJS(args=dict(line_source=line_source,
|
| 266 |
-
synthetic_renderers=synthetic_renderers,
|
| 267 |
-
real_renderers=real_renderers,
|
| 268 |
-
synthetic_colors=color_maps["es-digital-seq"],
|
| 269 |
-
real_colors=color_maps["real"]),
|
| 270 |
code="""
|
| 271 |
-
line_source.data = {
|
| 272 |
line_source.change.emit();
|
| 273 |
-
for (var key in synthetic_renderers) {
|
| 274 |
-
if (synthetic_renderers.hasOwnProperty(key)) {
|
| 275 |
-
var renderer = synthetic_renderers[key];
|
| 276 |
-
renderer.glyph.fill_color = synthetic_colors[key];
|
| 277 |
-
renderer.glyph.line_color = synthetic_colors[key];
|
| 278 |
-
}
|
| 279 |
-
}
|
| 280 |
-
for (var key in real_renderers) {
|
| 281 |
-
if (real_renderers.hasOwnProperty(key)) {
|
| 282 |
-
var renderer = real_renderers[key];
|
| 283 |
-
renderer.glyph.fill_color = real_colors[key];
|
| 284 |
-
renderer.glyph.line_color = real_colors[key];
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
""")
|
| 288 |
reset_button.js_on_event("button_click", reset_callback)
|
| 289 |
-
|
| 290 |
buffer = io.BytesIO()
|
| 291 |
df_table.to_excel(buffer, index=False)
|
| 292 |
buffer.seek(0)
|
| 293 |
-
|
| 294 |
layout = column(fig, column(real_select, reset_button, data_table))
|
| 295 |
st.bokeh_chart(layout, use_container_width=True)
|
| 296 |
-
|
| 297 |
-
# Agregar un botón de descarga en Streamlit
|
| 298 |
st.download_button(
|
| 299 |
label="Export Table",
|
| 300 |
data=buffer,
|
|
@@ -302,18 +303,13 @@ def run_model(model_name):
|
|
| 302 |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
| 303 |
key=f"download_button_excel_{model_name}"
|
| 304 |
)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
|
| 308 |
-
# Función principal con tabs para cambiar de modelo
|
| 309 |
def main():
|
| 310 |
config_style()
|
| 311 |
tabs = st.tabs(["Donut", "Idefics2"])
|
| 312 |
-
|
| 313 |
with tabs[0]:
|
| 314 |
st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True)
|
| 315 |
run_model("Donut")
|
| 316 |
-
|
| 317 |
with tabs[1]:
|
| 318 |
st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True)
|
| 319 |
run_model("Idefics2")
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from bokeh.plotting import figure
|
| 5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
|
| 6 |
+
from bokeh.layouts import column
|
| 7 |
+
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9
|
| 8 |
from sklearn.decomposition import PCA
|
| 9 |
from sklearn.manifold import TSNE
|
| 10 |
import io
|
| 11 |
+
import ot
|
| 12 |
|
| 13 |
TOOLTIPS = """
|
| 14 |
<div>
|
|
|
|
| 31 |
""", unsafe_allow_html=True)
|
| 32 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
| 33 |
|
| 34 |
+
# Carga los datos y asigna versiones de forma uniforme
|
| 35 |
def load_embeddings(model):
|
| 36 |
if model == "Donut":
|
| 37 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
| 38 |
+
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
| 39 |
+
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
| 40 |
+
df_real["version"] = "real"
|
| 41 |
+
df_seq["version"] = "synthetic"
|
| 42 |
+
df_line["version"] = "synthetic"
|
| 43 |
+
# Usamos un identificador en la columna 'source' para diferenciarlos
|
| 44 |
+
df_seq["source"] = "es-digital-seq"
|
| 45 |
+
df_line["source"] = "es-digital-line-degradation-seq"
|
| 46 |
+
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)}
|
| 47 |
elif model == "Idefics2":
|
| 48 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
| 49 |
+
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
| 50 |
+
df_real["version"] = "real"
|
| 51 |
+
df_seq["version"] = "synthetic"
|
| 52 |
+
df_seq["source"] = "es-digital-seq"
|
| 53 |
+
return {"real": df_real, "synthetic": df_seq}
|
| 54 |
else:
|
| 55 |
st.error("Modelo no reconocido")
|
| 56 |
return None
|
|
|
|
| 57 |
|
| 58 |
+
# Selección de reducción dimensional
|
| 59 |
def reducer_selector(df_combined, embedding_cols):
|
| 60 |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
|
| 61 |
all_embeddings = df_combined[embedding_cols].values
|
|
|
|
| 65 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
| 66 |
return reducer.fit_transform(all_embeddings)
|
| 67 |
|
| 68 |
+
# Función genérica para agregar datos al gráfico
|
| 69 |
+
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
| 70 |
renderers = {}
|
| 71 |
for label in selected_labels:
|
| 72 |
subset = df[df['label'] == label]
|
|
|
|
| 76 |
x=subset['x'],
|
| 77 |
y=subset['y'],
|
| 78 |
label=subset['label'],
|
| 79 |
+
img=subset.get('img', "")
|
| 80 |
))
|
| 81 |
color = color_mapping[label]
|
| 82 |
+
# Se añade el identificador de la fuente en la leyenda
|
| 83 |
+
legend_label = f"{label} ({group_label})"
|
| 84 |
if marker == "circle":
|
| 85 |
r = fig.circle('x', 'y', size=10, source=source,
|
| 86 |
fill_color=color, line_color=color,
|
| 87 |
+
legend_label=legend_label)
|
| 88 |
elif marker == "square":
|
| 89 |
+
r = fig.square('x', 'y', size=10, source=source,
|
| 90 |
fill_color=color, line_color=color,
|
| 91 |
+
legend_label=legend_label)
|
| 92 |
+
elif marker == "triangle":
|
| 93 |
+
r = fig.triangle('x', 'y', size=12, source=source,
|
| 94 |
+
fill_color=color, line_color=color,
|
| 95 |
+
legend_label=legend_label)
|
| 96 |
+
renderers[label + f" ({group_label})"] = r
|
| 97 |
return renderers
|
| 98 |
|
| 99 |
+
# Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética)
|
| 100 |
+
def get_color_maps(unique_subsets):
|
| 101 |
+
color_map = {}
|
| 102 |
+
# Real
|
| 103 |
+
num_real = len(unique_subsets["real"])
|
| 104 |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
|
| 105 |
+
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
|
| 106 |
|
| 107 |
+
# Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas
|
| 108 |
+
# Suponemos que en la columna "source" se encuentran los identificadores
|
| 109 |
+
synthetic_labels = sorted(unique_subsets["synthetic"])
|
| 110 |
+
# Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere
|
| 111 |
+
blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)]
|
| 112 |
+
color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)}
|
| 113 |
+
return color_map
|
| 114 |
|
| 115 |
+
# Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters)
|
| 116 |
def split_versions(df_combined, reduced):
|
| 117 |
df_combined['x'] = reduced[:, 0]
|
| 118 |
df_combined['y'] = reduced[:, 1]
|
| 119 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 120 |
+
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
| 121 |
+
# Extraemos los clusters (subset) usando la columna 'label'
|
| 122 |
unique_real = sorted(df_real['label'].unique().tolist())
|
| 123 |
+
unique_synth = sorted(df_synth['label'].unique().tolist())
|
| 124 |
+
df_dict = {"real": df_real, "synthetic": df_synth}
|
| 125 |
+
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
| 126 |
+
return df_dict, unique_subsets
|
| 127 |
|
| 128 |
+
# Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos
|
| 129 |
+
def create_figure(dfs, unique_subsets, color_maps):
|
| 130 |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
|
| 131 |
+
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
| 132 |
+
marker="circle", color_mapping=color_maps["real"],
|
| 133 |
+
group_label="Real")
|
| 134 |
+
# Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores
|
| 135 |
+
synth_df = dfs["synthetic"]
|
| 136 |
+
# Dividimos por 'source'
|
| 137 |
+
df_seq = synth_df[synth_df["source"] == "es-digital-seq"]
|
| 138 |
+
df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"]
|
| 139 |
+
|
| 140 |
+
# Extraemos los clusters para cada fuente (si existen)
|
| 141 |
+
unique_seq = sorted(df_seq['label'].unique().tolist())
|
| 142 |
+
unique_line = sorted(df_line['label'].unique().tolist())
|
| 143 |
+
|
| 144 |
+
seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq,
|
| 145 |
+
marker="square", color_mapping=color_maps["synthetic"],
|
| 146 |
+
group_label="es-digital-seq")
|
| 147 |
+
line_renderers = add_dataset_to_fig(fig, df_line, unique_line,
|
| 148 |
+
marker="triangle", color_mapping=color_maps["synthetic"],
|
| 149 |
+
group_label="es-digital-line-degradation-seq")
|
| 150 |
+
# Combina ambos renderers sintéticos
|
| 151 |
+
synthetic_renderers = {**seq_renderers, **line_renderers}
|
| 152 |
+
|
| 153 |
fig.legend.location = "top_right"
|
| 154 |
fig.legend.click_policy = "hide"
|
| 155 |
return fig, real_renderers, synthetic_renderers
|
| 156 |
|
| 157 |
+
# Calcula los centros de cada cluster (por grupo)
|
| 158 |
+
def calculate_cluster_centers(df, labels):
|
| 159 |
centers = {}
|
| 160 |
+
for label in labels:
|
| 161 |
subset = df[df['label'] == label]
|
| 162 |
if not subset.empty:
|
| 163 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 164 |
return centers
|
| 165 |
|
| 166 |
+
# Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
|
| 167 |
+
def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real):
|
| 168 |
distances = {}
|
| 169 |
+
# Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas)
|
| 170 |
+
synth_labels = sorted(df_synth['label'].unique().tolist())
|
| 171 |
+
for label in synth_labels:
|
| 172 |
+
key = f"{label}"
|
| 173 |
+
distances[key] = {}
|
| 174 |
+
cluster = df_synth[df_synth['label'] == label][['x','y']].values
|
| 175 |
+
n = cluster.shape[0]
|
| 176 |
+
weights = np.ones(n) / n
|
| 177 |
+
for real_label in labels_real:
|
| 178 |
+
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
|
| 179 |
+
m = cluster_real.shape[0]
|
| 180 |
+
weights_real = np.ones(m) / m
|
| 181 |
+
M = ot.dist(cluster, cluster_real, metric='euclidean')
|
| 182 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
| 183 |
+
# Distancia global del conjunto sintético a cada cluster real
|
| 184 |
+
key = "Global synthetic"
|
| 185 |
+
distances[key] = {}
|
| 186 |
+
global_synth = df_synth[['x','y']].values
|
| 187 |
+
n_global = global_synth.shape[0]
|
| 188 |
+
weights_global = np.ones(n_global) / n_global
|
| 189 |
+
for real_label in labels_real:
|
| 190 |
+
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
|
| 191 |
+
m = cluster_real.shape[0]
|
| 192 |
+
weights_real = np.ones(m) / m
|
| 193 |
+
M = ot.dist(global_synth, cluster_real, metric='euclidean')
|
| 194 |
+
distances[key][real_label] = ot.emd2(weights_global, weights_real, M)
|
| 195 |
return pd.DataFrame(distances).T
|
| 196 |
|
| 197 |
def create_table(df_distances):
|
| 198 |
df_table = df_distances.copy()
|
| 199 |
df_table.reset_index(inplace=True)
|
| 200 |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
|
|
|
|
|
|
| 201 |
min_row = {"Synthetic": "Min."}
|
| 202 |
mean_row = {"Synthetic": "Mean"}
|
| 203 |
max_row = {"Synthetic": "Max."}
|
|
|
|
| 204 |
for col in df_table.columns:
|
| 205 |
if col != "Synthetic":
|
| 206 |
min_row[col] = df_table[col].min()
|
| 207 |
mean_row[col] = df_table[col].mean()
|
| 208 |
max_row[col] = df_table[col].max()
|
|
|
|
|
|
|
| 209 |
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
|
|
|
| 210 |
source_table = ColumnDataSource(df_table)
|
| 211 |
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 212 |
for col in df_table.columns:
|
| 213 |
if col != 'Synthetic':
|
| 214 |
columns.append(TableColumn(field=col, title=col))
|
| 215 |
+
total_height = 30 + len(df_table)*28
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
| 217 |
return data_table, df_table, source_table
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
def run_model(model_name):
|
| 220 |
embeddings = load_embeddings(model_name)
|
| 221 |
if embeddings is None:
|
| 222 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
| 224 |
+
# Combina todos los DataFrames
|
| 225 |
+
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
| 226 |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
| 227 |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
|
| 228 |
if reduction_method == "PCA":
|
|
|
|
| 230 |
else:
|
| 231 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
| 232 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
|
|
|
| 233 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
# Se espera que unique_subsets tenga claves "real" y "synthetic"
|
| 236 |
+
color_maps = get_color_maps(unique_subsets)
|
| 237 |
+
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps)
|
| 238 |
+
|
| 239 |
+
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
|
| 240 |
+
|
| 241 |
+
df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"],
|
| 242 |
+
dfs_reduced["real"],
|
| 243 |
+
unique_subsets["real"])
|
| 244 |
data_table, df_table, source_table = create_table(df_distances)
|
| 245 |
+
|
| 246 |
real_subset_names = list(df_table.columns[1:])
|
| 247 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
| 248 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 249 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 250 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 251 |
|
| 252 |
+
# Preparar centros para callback (para trazar líneas entre centros)
|
| 253 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 254 |
|
| 255 |
+
# Se podría preparar también los centros sintéticos si se requiere
|
| 256 |
+
synthetic_centers = {}
|
| 257 |
+
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
|
| 258 |
+
for label in synth_labels:
|
| 259 |
+
subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
|
| 260 |
+
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
| 261 |
+
|
| 262 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
| 263 |
+
synthetic_centers=synthetic_centers,
|
| 264 |
real_centers=real_centers_js,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
real_select=real_select),
|
| 266 |
code="""
|
| 267 |
var selected = source.selected.indices;
|
| 268 |
if (selected.length > 0) {
|
| 269 |
+
var idx = selected[0];
|
| 270 |
var data = source.data;
|
| 271 |
+
var synth_label = data['Synthetic'][idx];
|
| 272 |
var real_label = real_select.value;
|
| 273 |
+
var syn_coords = synthetic_centers[synth_label];
|
| 274 |
var real_coords = real_centers[real_label];
|
| 275 |
+
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
|
| 276 |
line_source.change.emit();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
} else {
|
| 278 |
+
line_source.data = {'x': [], 'y': []};
|
| 279 |
line_source.change.emit();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
}
|
| 281 |
""")
|
| 282 |
source_table.selected.js_on_change('indices', callback)
|
| 283 |
real_select.js_on_change('value', callback)
|
| 284 |
|
| 285 |
+
reset_callback = CustomJS(args=dict(line_source=line_source),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
code="""
|
| 287 |
+
line_source.data = {'x': [], 'y': []};
|
| 288 |
line_source.change.emit();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
""")
|
| 290 |
reset_button.js_on_event("button_click", reset_callback)
|
| 291 |
+
|
| 292 |
buffer = io.BytesIO()
|
| 293 |
df_table.to_excel(buffer, index=False)
|
| 294 |
buffer.seek(0)
|
| 295 |
+
|
| 296 |
layout = column(fig, column(real_select, reset_button, data_table))
|
| 297 |
st.bokeh_chart(layout, use_container_width=True)
|
| 298 |
+
|
|
|
|
| 299 |
st.download_button(
|
| 300 |
label="Export Table",
|
| 301 |
data=buffer,
|
|
|
|
| 303 |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
| 304 |
key=f"download_button_excel_{model_name}"
|
| 305 |
)
|
|
|
|
|
|
|
| 306 |
|
|
|
|
| 307 |
def main():
|
| 308 |
config_style()
|
| 309 |
tabs = st.tabs(["Donut", "Idefics2"])
|
|
|
|
| 310 |
with tabs[0]:
|
| 311 |
st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True)
|
| 312 |
run_model("Donut")
|
|
|
|
| 313 |
with tabs[1]:
|
| 314 |
st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True)
|
| 315 |
run_model("Idefics2")
|