Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
|
|
26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
27 |
plot_confusion_matrix_beamPred(raw_cm,
|
28 |
classes=np.arange(raw_cm.shape[0]),
|
29 |
-
title=f"
|
30 |
save_path=raw_cm_path,
|
31 |
theme=theme)
|
32 |
raw_img = Image.open(raw_cm_path)
|
@@ -39,7 +39,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
|
|
39 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
40 |
plot_confusion_matrix_beamPred(embeddings_cm,
|
41 |
classes=np.arange(embeddings_cm.shape[0]),
|
42 |
-
title=f"
|
43 |
save_path=embeddings_cm_path,
|
44 |
theme=theme)
|
45 |
embeddings_img = Image.open(embeddings_cm_path)
|
@@ -191,14 +191,14 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path, light_mode=F
|
|
191 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
192 |
|
193 |
# Add F1-score to the title
|
194 |
-
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=
|
195 |
|
196 |
# Customize tick labels for light/dark mode
|
197 |
-
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
198 |
-
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
199 |
|
200 |
-
plt.ylabel('True label', color=text_color, fontsize=
|
201 |
-
plt.xlabel('Predicted label', color=text_color, fontsize=
|
202 |
plt.tight_layout()
|
203 |
|
204 |
# Save the plot as an image
|
@@ -220,14 +220,14 @@ def display_confusion_matrices_los(percentage):
|
|
220 |
raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
|
221 |
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
222 |
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
223 |
-
f"
|
224 |
raw_cm_img_path)
|
225 |
|
226 |
# Process embeddings confusion matrix
|
227 |
embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
|
228 |
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
229 |
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
230 |
-
f"
|
231 |
embeddings_cm_img_path)
|
232 |
|
233 |
return raw_img, embeddings_img
|
@@ -362,14 +362,14 @@ def plot_confusion_matrix(y_true, y_pred, title, light_mode=False):
|
|
362 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
363 |
|
364 |
# Add F1-score to the title
|
365 |
-
plt.title(f"{title}\
|
366 |
|
367 |
# Customize tick labels for dark mode
|
368 |
-
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
369 |
-
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
370 |
|
371 |
-
plt.ylabel('True label', color=text_color, fontsize=
|
372 |
-
plt.xlabel('Predicted label', color=text_color, fontsize=
|
373 |
plt.tight_layout()
|
374 |
|
375 |
# Save the plot as an image
|
|
|
26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
27 |
plot_confusion_matrix_beamPred(raw_cm,
|
28 |
classes=np.arange(raw_cm.shape[0]),
|
29 |
+
title=f"Confusion Matrix (Raw Channels)\n{data_percentage}% data, {task_complexity} beams",
|
30 |
save_path=raw_cm_path,
|
31 |
theme=theme)
|
32 |
raw_img = Image.open(raw_cm_path)
|
|
|
39 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
40 |
plot_confusion_matrix_beamPred(embeddings_cm,
|
41 |
classes=np.arange(embeddings_cm.shape[0]),
|
42 |
+
title=f"Confusion Matrix (LWM Embeddings)\n{data_percentage}% data, {task_complexity} beams",
|
43 |
save_path=embeddings_cm_path,
|
44 |
theme=theme)
|
45 |
embeddings_img = Image.open(embeddings_cm_path)
|
|
|
191 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
192 |
|
193 |
# Add F1-score to the title
|
194 |
+
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=24)
|
195 |
|
196 |
# Customize tick labels for light/dark mode
|
197 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
198 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
199 |
|
200 |
+
plt.ylabel('True label', color=text_color, fontsize=18)
|
201 |
+
plt.xlabel('Predicted label', color=text_color, fontsize=18)
|
202 |
plt.tight_layout()
|
203 |
|
204 |
# Save the plot as an image
|
|
|
220 |
raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
|
221 |
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
222 |
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
223 |
+
f"Confusion Matrix (Raw Channels)\n{percentage:.1f}% data",
|
224 |
raw_cm_img_path)
|
225 |
|
226 |
# Process embeddings confusion matrix
|
227 |
embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
|
228 |
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
229 |
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
230 |
+
f"Confusion Matrix (LWM Embeddings)\n{percentage:.1f}% data",
|
231 |
embeddings_cm_img_path)
|
232 |
|
233 |
return raw_img, embeddings_img
|
|
|
362 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
363 |
|
364 |
# Add F1-score to the title
|
365 |
+
plt.title(f"{title}\nF1 Score: {f1:.3f}", color=text_color, fontsize=23)
|
366 |
|
367 |
# Customize tick labels for dark mode
|
368 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
369 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
370 |
|
371 |
+
plt.ylabel('True label', color=text_color, fontsize=18)
|
372 |
+
plt.xlabel('Predicted label', color=text_color, fontsize=18)
|
373 |
plt.tight_layout()
|
374 |
|
375 |
# Save the plot as an image
|