Spaces:
Sleeping
Sleeping
Commit
路
89ffe36
1
Parent(s):
d966a8e
Cleaner Layout and Tabs for Different Models
Browse files
app.py
CHANGED
|
@@ -28,13 +28,21 @@ def config_style():
|
|
| 28 |
</style>
|
| 29 |
""", unsafe_allow_html=True)
|
| 30 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
| 31 |
-
st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
|
| 37 |
|
|
|
|
| 38 |
def reducer_selector(df_combined, embedding_cols):
|
| 39 |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
|
| 40 |
all_embeddings = df_combined[embedding_cols].values
|
|
@@ -88,11 +96,6 @@ def split_versions(df_combined, reduced):
|
|
| 88 |
unique_es = sorted(df_es['label'].unique().tolist())
|
| 89 |
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
|
| 90 |
|
| 91 |
-
def subset_selectors(unique_subsets: dict):
|
| 92 |
-
selected_real = st.multiselect("Select Real Subsets:", options=unique_subsets["real"], default=unique_subsets["real"])
|
| 93 |
-
selected_es = st.multiselect("Select Synthetic Subsets:", options=unique_subsets["es-digital-seq"], default=unique_subsets["es-digital-seq"])
|
| 94 |
-
return {"real": selected_real, "es-digital-seq": selected_es}
|
| 95 |
-
|
| 96 |
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
|
| 97 |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
|
| 98 |
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
|
|
@@ -119,52 +122,61 @@ def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
|
|
| 119 |
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
|
| 120 |
return pd.DataFrame(distances).T
|
| 121 |
|
| 122 |
-
def
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
embeddings["real"]["version"] = "real"
|
| 126 |
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
|
| 127 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
| 128 |
-
|
| 129 |
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 133 |
-
selected_subsets =
|
| 134 |
color_maps = get_color_maps(selected_subsets)
|
| 135 |
-
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
|
| 136 |
|
|
|
|
| 137 |
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
|
| 138 |
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
|
| 139 |
df_distances = compute_distances(centers_es, centers_real)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
df_table.reset_index(inplace=True)
|
| 144 |
-
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
| 145 |
-
source_table = ColumnDataSource(df_table)
|
| 146 |
-
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 147 |
-
for col in df_table.columns:
|
| 148 |
-
if col != 'Synthetic':
|
| 149 |
-
columns.append(TableColumn(field=col, title=col))
|
| 150 |
-
data_table = DataTable(source=source_table, columns=columns, width=400, height=300)
|
| 151 |
-
|
| 152 |
-
# Widget Select para elegir el subset real (columnas de la tabla)
|
| 153 |
-
real_subset_names = list(df_table.columns[1:]) # todas las columnas excepto 'Synthetic'
|
| 154 |
-
real_select = Select(title="Select Real Subset:", value=real_subset_names[0], options=real_subset_names)
|
| 155 |
-
|
| 156 |
-
# Bot贸n para resetear la visualizaci贸n a colores originales
|
| 157 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 158 |
-
|
| 159 |
-
# Fuente para la l铆nea que conecta los centros
|
| 160 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 161 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 162 |
|
| 163 |
-
# Preparar centros para el callback
|
| 164 |
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
|
| 165 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 166 |
|
| 167 |
-
# Callback para actualizar
|
| 168 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
| 169 |
synthetic_centers=synthetic_centers_js,
|
| 170 |
real_centers=real_centers_js,
|
|
@@ -228,11 +240,9 @@ def main():
|
|
| 228 |
}
|
| 229 |
}
|
| 230 |
""")
|
| 231 |
-
|
| 232 |
source_table.selected.js_on_change('indices', callback)
|
| 233 |
real_select.js_on_change('value', callback)
|
| 234 |
|
| 235 |
-
# Callback para el bot贸n de resetear: se reinician la l铆nea y los colores a su estado original.
|
| 236 |
reset_callback = CustomJS(args=dict(line_source=line_source,
|
| 237 |
synthetic_renderers=synthetic_renderers,
|
| 238 |
real_renderers=real_renderers,
|
|
@@ -258,9 +268,21 @@ def main():
|
|
| 258 |
""")
|
| 259 |
reset_button.js_on_event("button_click", reset_callback)
|
| 260 |
|
| 261 |
-
# Organizar el layout: gr谩fico, dropdown, bot贸n de reset y tabla
|
| 262 |
layout = column(fig, column(real_select, reset_button, data_table))
|
| 263 |
-
st.bokeh_chart(layout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
if __name__ == "__main__":
|
| 266 |
main()
|
|
|
|
| 28 |
</style>
|
| 29 |
""", unsafe_allow_html=True)
|
| 30 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
|
|
|
| 31 |
|
| 32 |
+
# Modificamos load_embeddings para aceptar el modelo a cargar
|
| 33 |
+
def load_embeddings(model):
|
| 34 |
+
if model == "Donut":
|
| 35 |
+
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
| 36 |
+
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
| 37 |
+
elif model == "Idefics2":
|
| 38 |
+
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
| 39 |
+
df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
| 40 |
+
else:
|
| 41 |
+
st.error("Modelo no reconocido")
|
| 42 |
+
return None
|
| 43 |
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
|
| 44 |
|
| 45 |
+
# Funciones auxiliares (id茅nticas a las de tu c贸digo)
|
| 46 |
def reducer_selector(df_combined, embedding_cols):
|
| 47 |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
|
| 48 |
all_embeddings = df_combined[embedding_cols].values
|
|
|
|
| 96 |
unique_es = sorted(df_es['label'].unique().tolist())
|
| 97 |
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
|
| 100 |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
|
| 101 |
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
|
|
|
|
| 122 |
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
|
| 123 |
return pd.DataFrame(distances).T
|
| 124 |
|
| 125 |
+
def create_table(df_distances):
|
| 126 |
+
df_table = df_distances.copy()
|
| 127 |
+
df_table.reset_index(inplace=True)
|
| 128 |
+
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
| 129 |
+
source_table = ColumnDataSource(df_table)
|
| 130 |
+
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 131 |
+
for col in df_table.columns:
|
| 132 |
+
if col != 'Synthetic':
|
| 133 |
+
columns.append(TableColumn(field=col, title=col))
|
| 134 |
+
row_height = 28
|
| 135 |
+
header_height = 30
|
| 136 |
+
total_height = header_height + len(df_table) * row_height
|
| 137 |
+
|
| 138 |
+
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
| 139 |
+
return data_table, df_table, source_table
|
| 140 |
+
|
| 141 |
+
# Funci贸n que ejecuta todo el proceso para un modelo determinado
|
| 142 |
+
def run_model(model_name):
|
| 143 |
+
embeddings = load_embeddings(model_name)
|
| 144 |
+
if embeddings is None:
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
# Asignamos la versi贸n para distinguir en el split
|
| 148 |
embeddings["real"]["version"] = "real"
|
| 149 |
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
|
| 150 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
|
|
|
| 151 |
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
|
| 152 |
+
|
| 153 |
+
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
| 154 |
+
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
|
| 155 |
+
if reduction_method == "PCA":
|
| 156 |
+
reducer = PCA(n_components=2)
|
| 157 |
+
else:
|
| 158 |
+
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
| 159 |
+
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
| 160 |
|
| 161 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 162 |
+
selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
|
| 163 |
color_maps = get_color_maps(selected_subsets)
|
|
|
|
| 164 |
|
| 165 |
+
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
|
| 166 |
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
|
| 167 |
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
|
| 168 |
df_distances = compute_distances(centers_es, centers_real)
|
| 169 |
+
data_table, df_table, source_table = create_table(df_distances)
|
| 170 |
+
real_subset_names = list(df_table.columns[1:])
|
| 171 |
+
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
|
|
|
|
|
|
| 173 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 174 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 175 |
|
|
|
|
| 176 |
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
|
| 177 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 178 |
|
| 179 |
+
# Callback para actualizar el gr谩fico
|
| 180 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
| 181 |
synthetic_centers=synthetic_centers_js,
|
| 182 |
real_centers=real_centers_js,
|
|
|
|
| 240 |
}
|
| 241 |
}
|
| 242 |
""")
|
|
|
|
| 243 |
source_table.selected.js_on_change('indices', callback)
|
| 244 |
real_select.js_on_change('value', callback)
|
| 245 |
|
|
|
|
| 246 |
reset_callback = CustomJS(args=dict(line_source=line_source,
|
| 247 |
synthetic_renderers=synthetic_renderers,
|
| 248 |
real_renderers=real_renderers,
|
|
|
|
| 268 |
""")
|
| 269 |
reset_button.js_on_event("button_click", reset_callback)
|
| 270 |
|
|
|
|
| 271 |
layout = column(fig, column(real_select, reset_button, data_table))
|
| 272 |
+
st.bokeh_chart(layout, use_container_width=True)
|
| 273 |
+
|
| 274 |
+
# Funci贸n principal con tabs para cambiar de modelo
|
| 275 |
+
def main():
|
| 276 |
+
config_style()
|
| 277 |
+
tabs = st.tabs(["Donut", "Idefics2"])
|
| 278 |
+
|
| 279 |
+
with tabs[0]:
|
| 280 |
+
st.markdown('<h2 class="sub-title">Modelo Donut 馃</h2>', unsafe_allow_html=True)
|
| 281 |
+
run_model("Donut")
|
| 282 |
+
|
| 283 |
+
with tabs[1]:
|
| 284 |
+
st.markdown('<h2 class="sub-title">Modelo Idefics2 馃</h2>', unsafe_allow_html=True)
|
| 285 |
+
run_model("Idefics2")
|
| 286 |
|
| 287 |
if __name__ == "__main__":
|
| 288 |
main()
|