de-Rodrigo commited on
Commit
bfe8480
1 Parent(s): ff7de4b

Multiple Principal Components Combinations for Heat Map

Browse files
Files changed (1) hide show
  1. app.py +115 -91
app.py CHANGED
@@ -1145,115 +1145,139 @@ def run_model(model_name):
1145
  if 'img' not in df_all["real"].columns:
1146
  st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
1147
  else:
1148
- # Crear columna 'name' a partir del nombre final de la URL de la imagen
1149
  df_all["real"]["name"] = df_all["real"]["img"].apply(
1150
  lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1151
  )
1152
- # Realizar merge de las posiciones reales con el CSV de heatmaps
1153
- df_heatmap = pd.merge(df_all["real"], df_heat, on="name", how="inner")
1154
 
1155
- # Extraer las caracter铆sticas disponibles (excluyendo 'name')
1156
  feature_options = [col for col in df_heat.columns if col != "name"]
1157
  selected_feature = st.selectbox("Select heatmap feature:",
1158
  options=feature_options, key=f"heatmap_{model_name}")
1159
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1160
  options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
1161
 
1162
- # Definir el rango de posiciones (x, y)
1163
- # x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
1164
- # y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
1165
-
1166
  x_min, x_max = -4, 4
1167
  y_min, y_max = -4, 4
1168
-
1169
-
1170
  grid_size = 50
1171
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
1172
  y_bins = np.linspace(y_min, y_max, grid_size + 1)
1173
 
1174
- cat_mapping = None
1175
- if df_heatmap[selected_feature].dtype == bool or not pd.api.types.is_numeric_dtype(df_heatmap[selected_feature]):
1176
- cat = df_heatmap[selected_feature].astype('category')
1177
- cat_mapping = list(cat.cat.categories)
1178
- df_heatmap[selected_feature] = cat.cat.codes
1179
-
1180
- try:
1181
- heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1182
- df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1183
- statistic='mean', bins=[x_bins, y_bins]
1184
- )
1185
- except TypeError:
1186
- cat = df_heatmap[selected_feature].astype('category')
1187
- cat_mapping = list(cat.cat.categories)
1188
- df_heatmap[selected_feature] = cat.cat.codes
1189
- heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1190
- df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1191
- statistic='mean', bins=[x_bins, y_bins]
1192
- )
1193
-
1194
- # Transponer la matriz para alinear correctamente los ejes
1195
- heatmap_data = heat_stat.T
1196
-
1197
- if selected_feature in model_options:
1198
- color_mapper = LinearColorMapper(
1199
- palette=red_green_palette,
1200
- low=0,
1201
- high=1,
1202
- nan_color='rgba(0, 0, 0, 0)'
1203
- )
1204
- else:
1205
- color_mapper = LinearColorMapper(palette="Viridis256", low=np.nanmin(heatmap_data), high=np.nanmax(heatmap_data), nan_color='rgba(0, 0, 0, 0)')
1206
-
1207
- heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
1208
- x_range=(x_min, x_max), y_range=(y_min, y_max),
1209
- width=600, height=600,
1210
- tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS)
1211
- heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
1212
- dw=x_max - x_min, dh=y_max - y_min,
1213
- color_mapper=color_mapper)
1214
-
1215
- color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
1216
- if cat_mapping is not None:
1217
- ticks = list(range(len(cat_mapping)))
1218
- color_bar.ticker = FixedTicker(ticks=ticks)
1219
- categories_json = json.dumps(cat_mapping)
1220
- color_bar.formatter = FuncTickFormatter(code=f"""
1221
- var categories = {categories_json};
1222
- var index = Math.round(tick);
1223
- if(index >= 0 && index < categories.length) {{
1224
- return categories[index];
1225
- }} else {{
1226
- return "";
1227
- }}
1228
- """)
1229
- heatmap_fig.add_layout(color_bar, 'right')
1230
-
1231
- source_points = ColumnDataSource(data={
1232
- 'x': df_heatmap['x'],
1233
- 'y': df_heatmap['y'],
1234
- 'img': df_heatmap['img'],
1235
- 'label': df_heatmap['name']
1236
- })
1237
- invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
1238
-
1239
- if select_extra_dataset_hm != "-":
1240
- df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm]
1241
- if 'name' not in df_extra.columns:
1242
- df_extra["name"] = df_extra["img"].apply(
1243
- lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1244
  )
1245
- source_extra_points = ColumnDataSource(data={
1246
- 'x': df_extra['x'],
1247
- 'y': df_extra['y'],
1248
- 'img': df_extra['img'],
1249
- 'label': df_extra['name']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1250
  })
1251
- extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
1252
-
1253
- hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1254
- heatmap_fig.add_tools(hover_tool_points)
1255
-
1256
- st.bokeh_chart(heatmap_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1257
 
1258
  def main():
1259
  config_style()
 
1145
  if 'img' not in df_all["real"].columns:
1146
  st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
1147
  else:
1148
+ # Crear columna 'name' en las muestras reales (si a煤n no existe)
1149
  df_all["real"]["name"] = df_all["real"]["img"].apply(
1150
  lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1151
  )
1152
+ # Merge de las posiciones reales con el CSV de heatmaps (se usa el merge base)
1153
+ df_heatmap_base = pd.merge(df_all["real"], df_heat, on="name", how="inner")
1154
 
1155
+ # Extraer opciones de feature (excluyendo 'name')
1156
  feature_options = [col for col in df_heat.columns if col != "name"]
1157
  selected_feature = st.selectbox("Select heatmap feature:",
1158
  options=feature_options, key=f"heatmap_{model_name}")
1159
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1160
  options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
1161
 
1162
+ # Definir un rango fijo para los ejes (por ejemplo, de -4 a 4) y rejilla
 
 
 
1163
  x_min, x_max = -4, 4
1164
  y_min, y_max = -4, 4
 
 
1165
  grid_size = 50
1166
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
1167
  y_bins = np.linspace(y_min, y_max, grid_size + 1)
1168
 
1169
+ # Generar heatmaps para cada combinaci贸n de componentes
1170
+ pairs = list(itertools.combinations(range(N_COMPONENTS), 2))
1171
+ for (i, j) in pairs:
1172
+ x_comp = f'PC{i+1}'
1173
+ y_comp = f'PC{j+1}'
1174
+ st.markdown(f"### Heatmap: {x_comp} vs {y_comp}")
1175
+
1176
+ # Crear un DataFrame de heatmap para la combinaci贸n actual a partir del merge base
1177
+ df_heatmap = df_heatmap_base.copy()
1178
+ df_heatmap["x"] = df_heatmap[x_comp]
1179
+ df_heatmap["y"] = df_heatmap[y_comp]
1180
+
1181
+ # Si la feature seleccionada no es num茅rica, convertir a c贸digos y guardar la correspondencia
1182
+ cat_mapping = None
1183
+ if df_heatmap[selected_feature].dtype == bool or not pd.api.types.is_numeric_dtype(df_heatmap[selected_feature]):
1184
+ cat = df_heatmap[selected_feature].astype('category')
1185
+ cat_mapping = list(cat.cat.categories)
1186
+ df_heatmap[selected_feature] = cat.cat.codes
1187
+
1188
+ # Calcular la estad铆stica binned (por ejemplo, la media) en la rejilla
1189
+ try:
1190
+ heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1191
+ df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1192
+ statistic='mean', bins=[x_bins, y_bins]
1193
+ )
1194
+ except TypeError:
1195
+ cat = df_heatmap[selected_feature].astype('category')
1196
+ cat_mapping = list(cat.cat.categories)
1197
+ df_heatmap[selected_feature] = cat.cat.codes
1198
+ heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1199
+ df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1200
+ statistic='mean', bins=[x_bins, y_bins]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1201
  )
1202
+ # Transponer la matriz para alinear correctamente los ejes
1203
+ heatmap_data = heat_stat.T
1204
+
1205
+ # Definir el color mapper: si la feature es de un modelo (model_options) usar la paleta rojo-verde con rango 0 a 1
1206
+ if selected_feature in model_options:
1207
+ color_mapper = LinearColorMapper(
1208
+ palette=red_green_palette,
1209
+ low=0,
1210
+ high=1,
1211
+ nan_color='rgba(0, 0, 0, 0)'
1212
+ )
1213
+ else:
1214
+ color_mapper = LinearColorMapper(
1215
+ palette="Viridis256",
1216
+ low=np.nanmin(heatmap_data),
1217
+ high=np.nanmax(heatmap_data),
1218
+ nan_color='rgba(0, 0, 0, 0)'
1219
+ )
1220
+
1221
+ # Crear la figura para el heatmap con la misma escala para x e y
1222
+ heatmap_fig = figure(title=f"Heatmap de '{selected_feature}' ({x_comp} vs {y_comp})",
1223
+ x_range=(x_min, x_max), y_range=(y_min, y_max),
1224
+ width=600, height=600,
1225
+ tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS,
1226
+ sizing_mode="fixed")
1227
+ heatmap_fig.match_aspect = True
1228
+
1229
+ heatmap_fig.xaxis.axis_label = x_comp
1230
+ heatmap_fig.yaxis.axis_label = y_comp
1231
+ # Dibujar la imagen del heatmap
1232
+ heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
1233
+ dw=x_max - x_min, dh=y_max - y_min,
1234
+ color_mapper=color_mapper)
1235
+
1236
+ # Agregar la barra de color
1237
+ color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
1238
+ if cat_mapping is not None:
1239
+ ticks = list(range(len(cat_mapping)))
1240
+ color_bar.ticker = FixedTicker(ticks=ticks)
1241
+ categories_json = json.dumps(cat_mapping)
1242
+ color_bar.formatter = FuncTickFormatter(code=f"""
1243
+ var categories = {categories_json};
1244
+ var index = Math.round(tick);
1245
+ if(index >= 0 && index < categories.length) {{
1246
+ return categories[index];
1247
+ }} else {{
1248
+ return "";
1249
+ }}
1250
+ """)
1251
+ heatmap_fig.add_layout(color_bar, 'right')
1252
+
1253
+ # Agregar renderer invisible para tooltips (usando puntos en cada bin)
1254
+ source_points = ColumnDataSource(data={
1255
+ 'x': df_heatmap['x'],
1256
+ 'y': df_heatmap['y'],
1257
+ 'img': df_heatmap['img'],
1258
+ 'label': df_heatmap['name']
1259
  })
1260
+ invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
1261
+
1262
+ # Si se selecciona un dataset extra, proyectar sus puntos en la combinaci贸n actual
1263
+ if select_extra_dataset_hm != "-":
1264
+ df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm].copy()
1265
+ df_extra["x"] = df_extra[x_comp]
1266
+ df_extra["y"] = df_extra[y_comp]
1267
+ if 'name' not in df_extra.columns:
1268
+ df_extra["name"] = df_extra["img"].apply(lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x)
1269
+ source_extra_points = ColumnDataSource(data={
1270
+ 'x': df_extra['x'],
1271
+ 'y': df_extra['y'],
1272
+ 'img': df_extra['img'],
1273
+ 'label': df_extra['name']
1274
+ })
1275
+ extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
1276
+
1277
+ hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1278
+ heatmap_fig.add_tools(hover_tool_points)
1279
+
1280
+ st.bokeh_chart(heatmap_fig)
1281
 
1282
  def main():
1283
  config_style()