Spaces:
Running
Running
Commit
路
bfe8480
1
Parent(s):
ff7de4b
Multiple Principal Components Combinations for Heat Map
Browse files
app.py
CHANGED
@@ -1145,115 +1145,139 @@ def run_model(model_name):
|
|
1145 |
if 'img' not in df_all["real"].columns:
|
1146 |
st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
|
1147 |
else:
|
1148 |
-
# Crear columna 'name'
|
1149 |
df_all["real"]["name"] = df_all["real"]["img"].apply(
|
1150 |
lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
|
1151 |
)
|
1152 |
-
#
|
1153 |
-
|
1154 |
|
1155 |
-
# Extraer
|
1156 |
feature_options = [col for col in df_heat.columns if col != "name"]
|
1157 |
selected_feature = st.selectbox("Select heatmap feature:",
|
1158 |
options=feature_options, key=f"heatmap_{model_name}")
|
1159 |
select_extra_dataset_hm = st.selectbox("Select a dataset:",
|
1160 |
options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
|
1161 |
|
1162 |
-
# Definir
|
1163 |
-
# x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
|
1164 |
-
# y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
|
1165 |
-
|
1166 |
x_min, x_max = -4, 4
|
1167 |
y_min, y_max = -4, 4
|
1168 |
-
|
1169 |
-
|
1170 |
grid_size = 50
|
1171 |
x_bins = np.linspace(x_min, x_max, grid_size + 1)
|
1172 |
y_bins = np.linspace(y_min, y_max, grid_size + 1)
|
1173 |
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
cat_mapping =
|
1188 |
-
df_heatmap[selected_feature]
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
|
1208 |
-
x_range=(x_min, x_max), y_range=(y_min, y_max),
|
1209 |
-
width=600, height=600,
|
1210 |
-
tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS)
|
1211 |
-
heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
|
1212 |
-
dw=x_max - x_min, dh=y_max - y_min,
|
1213 |
-
color_mapper=color_mapper)
|
1214 |
-
|
1215 |
-
color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
|
1216 |
-
if cat_mapping is not None:
|
1217 |
-
ticks = list(range(len(cat_mapping)))
|
1218 |
-
color_bar.ticker = FixedTicker(ticks=ticks)
|
1219 |
-
categories_json = json.dumps(cat_mapping)
|
1220 |
-
color_bar.formatter = FuncTickFormatter(code=f"""
|
1221 |
-
var categories = {categories_json};
|
1222 |
-
var index = Math.round(tick);
|
1223 |
-
if(index >= 0 && index < categories.length) {{
|
1224 |
-
return categories[index];
|
1225 |
-
}} else {{
|
1226 |
-
return "";
|
1227 |
-
}}
|
1228 |
-
""")
|
1229 |
-
heatmap_fig.add_layout(color_bar, 'right')
|
1230 |
-
|
1231 |
-
source_points = ColumnDataSource(data={
|
1232 |
-
'x': df_heatmap['x'],
|
1233 |
-
'y': df_heatmap['y'],
|
1234 |
-
'img': df_heatmap['img'],
|
1235 |
-
'label': df_heatmap['name']
|
1236 |
-
})
|
1237 |
-
invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
|
1238 |
-
|
1239 |
-
if select_extra_dataset_hm != "-":
|
1240 |
-
df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm]
|
1241 |
-
if 'name' not in df_extra.columns:
|
1242 |
-
df_extra["name"] = df_extra["img"].apply(
|
1243 |
-
lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
|
1244 |
)
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
1248 |
-
|
1249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1250 |
})
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1257 |
|
1258 |
def main():
|
1259 |
config_style()
|
|
|
1145 |
if 'img' not in df_all["real"].columns:
|
1146 |
st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
|
1147 |
else:
|
1148 |
+
# Crear columna 'name' en las muestras reales (si a煤n no existe)
|
1149 |
df_all["real"]["name"] = df_all["real"]["img"].apply(
|
1150 |
lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
|
1151 |
)
|
1152 |
+
# Merge de las posiciones reales con el CSV de heatmaps (se usa el merge base)
|
1153 |
+
df_heatmap_base = pd.merge(df_all["real"], df_heat, on="name", how="inner")
|
1154 |
|
1155 |
+
# Extraer opciones de feature (excluyendo 'name')
|
1156 |
feature_options = [col for col in df_heat.columns if col != "name"]
|
1157 |
selected_feature = st.selectbox("Select heatmap feature:",
|
1158 |
options=feature_options, key=f"heatmap_{model_name}")
|
1159 |
select_extra_dataset_hm = st.selectbox("Select a dataset:",
|
1160 |
options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
|
1161 |
|
1162 |
+
# Definir un rango fijo para los ejes (por ejemplo, de -4 a 4) y rejilla
|
|
|
|
|
|
|
1163 |
x_min, x_max = -4, 4
|
1164 |
y_min, y_max = -4, 4
|
|
|
|
|
1165 |
grid_size = 50
|
1166 |
x_bins = np.linspace(x_min, x_max, grid_size + 1)
|
1167 |
y_bins = np.linspace(y_min, y_max, grid_size + 1)
|
1168 |
|
1169 |
+
# Generar heatmaps para cada combinaci贸n de componentes
|
1170 |
+
pairs = list(itertools.combinations(range(N_COMPONENTS), 2))
|
1171 |
+
for (i, j) in pairs:
|
1172 |
+
x_comp = f'PC{i+1}'
|
1173 |
+
y_comp = f'PC{j+1}'
|
1174 |
+
st.markdown(f"### Heatmap: {x_comp} vs {y_comp}")
|
1175 |
+
|
1176 |
+
# Crear un DataFrame de heatmap para la combinaci贸n actual a partir del merge base
|
1177 |
+
df_heatmap = df_heatmap_base.copy()
|
1178 |
+
df_heatmap["x"] = df_heatmap[x_comp]
|
1179 |
+
df_heatmap["y"] = df_heatmap[y_comp]
|
1180 |
+
|
1181 |
+
# Si la feature seleccionada no es num茅rica, convertir a c贸digos y guardar la correspondencia
|
1182 |
+
cat_mapping = None
|
1183 |
+
if df_heatmap[selected_feature].dtype == bool or not pd.api.types.is_numeric_dtype(df_heatmap[selected_feature]):
|
1184 |
+
cat = df_heatmap[selected_feature].astype('category')
|
1185 |
+
cat_mapping = list(cat.cat.categories)
|
1186 |
+
df_heatmap[selected_feature] = cat.cat.codes
|
1187 |
+
|
1188 |
+
# Calcular la estad铆stica binned (por ejemplo, la media) en la rejilla
|
1189 |
+
try:
|
1190 |
+
heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
|
1191 |
+
df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
|
1192 |
+
statistic='mean', bins=[x_bins, y_bins]
|
1193 |
+
)
|
1194 |
+
except TypeError:
|
1195 |
+
cat = df_heatmap[selected_feature].astype('category')
|
1196 |
+
cat_mapping = list(cat.cat.categories)
|
1197 |
+
df_heatmap[selected_feature] = cat.cat.codes
|
1198 |
+
heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
|
1199 |
+
df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
|
1200 |
+
statistic='mean', bins=[x_bins, y_bins]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1201 |
)
|
1202 |
+
# Transponer la matriz para alinear correctamente los ejes
|
1203 |
+
heatmap_data = heat_stat.T
|
1204 |
+
|
1205 |
+
# Definir el color mapper: si la feature es de un modelo (model_options) usar la paleta rojo-verde con rango 0 a 1
|
1206 |
+
if selected_feature in model_options:
|
1207 |
+
color_mapper = LinearColorMapper(
|
1208 |
+
palette=red_green_palette,
|
1209 |
+
low=0,
|
1210 |
+
high=1,
|
1211 |
+
nan_color='rgba(0, 0, 0, 0)'
|
1212 |
+
)
|
1213 |
+
else:
|
1214 |
+
color_mapper = LinearColorMapper(
|
1215 |
+
palette="Viridis256",
|
1216 |
+
low=np.nanmin(heatmap_data),
|
1217 |
+
high=np.nanmax(heatmap_data),
|
1218 |
+
nan_color='rgba(0, 0, 0, 0)'
|
1219 |
+
)
|
1220 |
+
|
1221 |
+
# Crear la figura para el heatmap con la misma escala para x e y
|
1222 |
+
heatmap_fig = figure(title=f"Heatmap de '{selected_feature}' ({x_comp} vs {y_comp})",
|
1223 |
+
x_range=(x_min, x_max), y_range=(y_min, y_max),
|
1224 |
+
width=600, height=600,
|
1225 |
+
tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS,
|
1226 |
+
sizing_mode="fixed")
|
1227 |
+
heatmap_fig.match_aspect = True
|
1228 |
+
|
1229 |
+
heatmap_fig.xaxis.axis_label = x_comp
|
1230 |
+
heatmap_fig.yaxis.axis_label = y_comp
|
1231 |
+
# Dibujar la imagen del heatmap
|
1232 |
+
heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
|
1233 |
+
dw=x_max - x_min, dh=y_max - y_min,
|
1234 |
+
color_mapper=color_mapper)
|
1235 |
+
|
1236 |
+
# Agregar la barra de color
|
1237 |
+
color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
|
1238 |
+
if cat_mapping is not None:
|
1239 |
+
ticks = list(range(len(cat_mapping)))
|
1240 |
+
color_bar.ticker = FixedTicker(ticks=ticks)
|
1241 |
+
categories_json = json.dumps(cat_mapping)
|
1242 |
+
color_bar.formatter = FuncTickFormatter(code=f"""
|
1243 |
+
var categories = {categories_json};
|
1244 |
+
var index = Math.round(tick);
|
1245 |
+
if(index >= 0 && index < categories.length) {{
|
1246 |
+
return categories[index];
|
1247 |
+
}} else {{
|
1248 |
+
return "";
|
1249 |
+
}}
|
1250 |
+
""")
|
1251 |
+
heatmap_fig.add_layout(color_bar, 'right')
|
1252 |
+
|
1253 |
+
# Agregar renderer invisible para tooltips (usando puntos en cada bin)
|
1254 |
+
source_points = ColumnDataSource(data={
|
1255 |
+
'x': df_heatmap['x'],
|
1256 |
+
'y': df_heatmap['y'],
|
1257 |
+
'img': df_heatmap['img'],
|
1258 |
+
'label': df_heatmap['name']
|
1259 |
})
|
1260 |
+
invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
|
1261 |
+
|
1262 |
+
# Si se selecciona un dataset extra, proyectar sus puntos en la combinaci贸n actual
|
1263 |
+
if select_extra_dataset_hm != "-":
|
1264 |
+
df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm].copy()
|
1265 |
+
df_extra["x"] = df_extra[x_comp]
|
1266 |
+
df_extra["y"] = df_extra[y_comp]
|
1267 |
+
if 'name' not in df_extra.columns:
|
1268 |
+
df_extra["name"] = df_extra["img"].apply(lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x)
|
1269 |
+
source_extra_points = ColumnDataSource(data={
|
1270 |
+
'x': df_extra['x'],
|
1271 |
+
'y': df_extra['y'],
|
1272 |
+
'img': df_extra['img'],
|
1273 |
+
'label': df_extra['name']
|
1274 |
+
})
|
1275 |
+
extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
|
1276 |
+
|
1277 |
+
hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
|
1278 |
+
heatmap_fig.add_tools(hover_tool_points)
|
1279 |
+
|
1280 |
+
st.bokeh_chart(heatmap_fig)
|
1281 |
|
1282 |
def main():
|
1283 |
config_style()
|