de-Rodrigo commited on
Commit
ce05869
1 Parent(s): 789e1f0

Scatter Plot with Regression

Browse files
Files changed (1) hide show
  1. app.py +87 -4
app.py CHANGED
@@ -2,13 +2,14 @@ import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  from bokeh.plotting import figure
5
- from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
6
  from bokeh.layouts import column
7
  from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
11
  import ot
 
12
 
13
  TOOLTIPS = """
14
  <div>
@@ -81,7 +82,9 @@ def reducer_selector(df_combined, embedding_cols):
81
  if reduction_method == "PCA":
82
  reducer = PCA(n_components=2)
83
  else:
84
- reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
 
 
85
  return reducer.fit_transform(all_embeddings)
86
 
87
  # Funci贸n para agregar datos reales (por cada etiqueta)
@@ -330,7 +333,86 @@ def run_model(model_name):
330
 
331
  centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
332
 
333
- df_distances = compute_wasserstein_distances_synthetic_individual(dfs_reduced["synthetic"], dfs_reduced["real"], unique_subsets["real"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  data_table, df_table, source_table = create_table(df_distances)
335
 
336
  real_subset_names = list(df_table.columns[1:])
@@ -380,7 +462,7 @@ def run_model(model_name):
380
  df_table.to_excel(buffer, index=False)
381
  buffer.seek(0)
382
 
383
- layout = column(fig, column(real_select, reset_button, data_table))
384
  st.bokeh_chart(layout, use_container_width=True)
385
 
386
  st.download_button(
@@ -391,6 +473,7 @@ def run_model(model_name):
391
  key=f"download_button_excel_{model_name}"
392
  )
393
 
 
394
  def main():
395
  config_style()
396
  tabs = st.tabs(["Donut", "Idefics2"])
 
2
  import pandas as pd
3
  import numpy as np
4
  from bokeh.plotting import figure
5
+ from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool
6
  from bokeh.layouts import column
7
  from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
11
  import ot
12
+ from sklearn.linear_model import LinearRegression
13
 
14
  TOOLTIPS = """
15
  <div>
 
82
  if reduction_method == "PCA":
83
  reducer = PCA(n_components=2)
84
  else:
85
+ perplexity_val = st.number_input("Perplexity", min_value=5, max_value=50, value=30, step=1)
86
+ learning_rate_val = st.number_input("Learning Rate", min_value=10, max_value=1000, value=200, step=10)
87
+ reducer = TSNE(n_components=2, random_state=42, perplexity=perplexity_val, learning_rate=learning_rate_val)
88
  return reducer.fit_transform(all_embeddings)
89
 
90
  # Funci贸n para agregar datos reales (por cada etiqueta)
 
333
 
334
  centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
335
 
336
+ df_distances = compute_wasserstein_distances_synthetic_individual(
337
+ dfs_reduced["synthetic"],
338
+ dfs_reduced["real"],
339
+ unique_subsets["real"]
340
+ )
341
+
342
+ # --- Scatter plot usando f1-donut.csv ---
343
+ try:
344
+ df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
345
+ except Exception as e:
346
+ st.error(f"Error loading f1-donut.csv: {e}")
347
+ return
348
+
349
+ # Extraer los valores globales para cada fuente (sin promediar: 10 valores por fuente)
350
+ global_distances = {}
351
+ for idx in df_distances.index:
352
+ if idx.startswith("Global"):
353
+ # Ejemplo: "Global (es-digital-seq)"
354
+ source = idx.split("(")[1].rstrip(")")
355
+ global_distances[source] = df_distances.loc[idx].values
356
+
357
+ # Reutilizaci贸n de los c贸digos de colores
358
+ source_colors = {
359
+ "es-digital-paragraph-degradation-seq": "blue",
360
+ "es-digital-line-degradation-seq": "green",
361
+ "es-digital-seq": "red",
362
+ "es-digital-zoom-degradation-seq": "orange",
363
+ "es-digital-rotation-degradation-seq": "purple",
364
+ "es-digital-rotation-zoom-degradation-seq": "brown",
365
+ "es-render-seq": "cyan"
366
+ }
367
+
368
+ scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", title="Scatter Plot: Wasserstein vs F1")
369
+ # Variables para la regresi贸n global
370
+ all_x = []
371
+ all_y = []
372
+
373
+ # Se plotea cada fuente y se acumulan los datos para la regresi贸n global
374
+ for source in df_f1.columns:
375
+ if source in global_distances:
376
+ x_vals = global_distances[source] # 10 valores (uno por colegio)
377
+ y_vals = df_f1[source].values # 10 valores de f1, en el mismo orden
378
+ data = {"x": x_vals, "y": y_vals, "Fuente": [source] * len(x_vals)}
379
+ cds = ColumnDataSource(data=data)
380
+ scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
381
+ fill_color=source_colors.get(source, "gray"),
382
+ line_color=source_colors.get(source, "gray"),
383
+ legend_label=source)
384
+ all_x.extend(x_vals)
385
+ all_y.extend(y_vals)
386
+
387
+ scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
388
+ scatter_fig.yaxis.axis_label = "F1 Score"
389
+ scatter_fig.legend.location = "top_right"
390
+
391
+ # Agregar HoverTool para mostrar x, y y la fuente al hacer hover
392
+ hover_tool = HoverTool(tooltips=[("x", "@x"), ("y", "@y"), ("Fuente", "@Fuente")])
393
+ scatter_fig.add_tools(hover_tool)
394
+ # --- Fin scatter plot ---
395
+
396
+ # --- Regresi贸n global ---
397
+ all_x_arr = np.array(all_x).reshape(-1, 1)
398
+ all_y_arr = np.array(all_y)
399
+ model_global = LinearRegression().fit(all_x_arr, all_y_arr)
400
+ slope = model_global.coef_[0]
401
+ intercept = model_global.intercept_
402
+ r2 = model_global.score(all_x_arr, all_y_arr)
403
+
404
+ # Agregar l铆nea de regresi贸n global al scatter plot
405
+ x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
406
+ y_line = model_global.predict(x_line.reshape(-1, 1))
407
+ scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
408
+
409
+ # Mostrar m茅tricas de regresi贸n despu茅s del scatter plot
410
+ regression_metrics = {"Slope": [slope], "Intercept": [intercept], "R2": [r2]}
411
+ reg_df = pd.DataFrame(regression_metrics)
412
+ st.table(reg_df)
413
+
414
+ # --- Fin regresi贸n global ---
415
+
416
  data_table, df_table, source_table = create_table(df_distances)
417
 
418
  real_subset_names = list(df_table.columns[1:])
 
462
  df_table.to_excel(buffer, index=False)
463
  buffer.seek(0)
464
 
465
+ layout = column(fig, scatter_fig, column(real_select, reset_button, data_table))
466
  st.bokeh_chart(layout, use_container_width=True)
467
 
468
  st.download_button(
 
473
  key=f"download_button_excel_{model_name}"
474
  )
475
 
476
+
477
  def main():
478
  config_style()
479
  tabs = st.tabs(["Donut", "Idefics2"])