de-Rodrigo commited on
Commit
ff7de4b
·
1 Parent(s): ce4ccf6

Improve Data Visualization

Browse files
Files changed (1) hide show
  1. app.py +31 -6
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
6
  from bokeh.layouts import column
7
- from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE, trustworthiness
10
  from sklearn.metrics import pairwise_distances
@@ -14,6 +14,8 @@ from sklearn.linear_model import LinearRegression
14
  from scipy.stats import binned_statistic_2d
15
  import json
16
  import itertools
 
 
17
 
18
 
19
  N_COMPONENTS = 3
@@ -1041,6 +1043,17 @@ def run_model(model_name):
1041
 
1042
  # -------------------------------------------------------------------------
1043
  # 4. Cálculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
 
 
 
 
 
 
 
 
 
 
 
1044
  real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
1045
  df_distances_new = compute_cluster_distances_synthetic_individual(
1046
  df_all["synthetic"],
@@ -1144,11 +1157,15 @@ def run_model(model_name):
1144
  selected_feature = st.selectbox("Select heatmap feature:",
1145
  options=feature_options, key=f"heatmap_{model_name}")
1146
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1147
- options=["-", "es-digital-line-degradation-seq", "es-digital-seq", "es-digital-rotation-degradation-seq", "es-digital-zoom-degradation-seq", "es-render-seq"], key=f"heatmap_extra_dataset_{model_name}")
1148
 
1149
  # Definir el rango de posiciones (x, y)
1150
- x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
1151
- y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
 
 
 
 
1152
 
1153
  grid_size = 50
1154
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
@@ -1177,7 +1194,15 @@ def run_model(model_name):
1177
  # Transponer la matriz para alinear correctamente los ejes
1178
  heatmap_data = heat_stat.T
1179
 
1180
- color_mapper = LinearColorMapper(palette="Viridis256", low=np.nanmin(heatmap_data), high=np.nanmax(heatmap_data), nan_color='rgba(0, 0, 0, 0)')
 
 
 
 
 
 
 
 
1181
 
1182
  heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
1183
  x_range=(x_min, x_max), y_range=(y_min, y_max),
@@ -1223,7 +1248,7 @@ def run_model(model_name):
1223
  'img': df_extra['img'],
1224
  'label': df_extra['name']
1225
  })
1226
- extra_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="red")
1227
 
1228
  hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1229
  heatmap_fig.add_tools(hover_tool_points)
 
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
6
  from bokeh.layouts import column
7
+ from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9, RdYlGn11, linear_palette
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE, trustworthiness
10
  from sklearn.metrics import pairwise_distances
 
14
  from scipy.stats import binned_statistic_2d
15
  import json
16
  import itertools
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib.colors as mcolors
19
 
20
 
21
  N_COMPONENTS = 3
 
1043
 
1044
  # -------------------------------------------------------------------------
1045
  # 4. Cálculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
1046
+ model_options = ["es-digital-paragraph-degradation-seq", "es-digital-line-degradation-seq", "es-digital-seq", "es-digital-rotation-degradation-seq", "es-digital-zoom-degradation-seq", "es-render-seq"]
1047
+ model_options_with_default = [""]
1048
+ model_options_with_default.extend(model_options)
1049
+
1050
+
1051
+ # Genera una paleta de 256 colores basada en RdYlGn11
1052
+ cmap = plt.get_cmap("RdYlGn")
1053
+ red_green_palette = [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, 256)]
1054
+
1055
+
1056
+
1057
  real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
1058
  df_distances_new = compute_cluster_distances_synthetic_individual(
1059
  df_all["synthetic"],
 
1157
  selected_feature = st.selectbox("Select heatmap feature:",
1158
  options=feature_options, key=f"heatmap_{model_name}")
1159
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1160
+ options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
1161
 
1162
  # Definir el rango de posiciones (x, y)
1163
+ # x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
1164
+ # y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
1165
+
1166
+ x_min, x_max = -4, 4
1167
+ y_min, y_max = -4, 4
1168
+
1169
 
1170
  grid_size = 50
1171
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
 
1194
  # Transponer la matriz para alinear correctamente los ejes
1195
  heatmap_data = heat_stat.T
1196
 
1197
+ if selected_feature in model_options:
1198
+ color_mapper = LinearColorMapper(
1199
+ palette=red_green_palette,
1200
+ low=0,
1201
+ high=1,
1202
+ nan_color='rgba(0, 0, 0, 0)'
1203
+ )
1204
+ else:
1205
+ color_mapper = LinearColorMapper(palette="Viridis256", low=np.nanmin(heatmap_data), high=np.nanmax(heatmap_data), nan_color='rgba(0, 0, 0, 0)')
1206
 
1207
  heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
1208
  x_range=(x_min, x_max), y_range=(y_min, y_max),
 
1248
  'img': df_extra['img'],
1249
  'label': df_extra['name']
1250
  })
1251
+ extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
1252
 
1253
  hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1254
  heatmap_fig.add_tools(hover_tool_points)