Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,7 +34,7 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
| 34 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 35 |
if embeddings_cm is not None:
|
| 36 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 37 |
-
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix
|
| 38 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 39 |
else:
|
| 40 |
embeddings_img = None
|
|
@@ -71,7 +71,7 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
|
| 71 |
plt.figure(figsize=(5, 5))
|
| 72 |
|
| 73 |
# Plot the confusion matrix with a dark-mode compatible colormap
|
| 74 |
-
sns.heatmap(cm, cmap="magma", cbar=
|
| 75 |
|
| 76 |
# Add F1-score to the title
|
| 77 |
plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
|
|
@@ -316,34 +316,31 @@ def plot_confusion_matrix(y_true, y_pred, title):
|
|
| 316 |
|
| 317 |
# Calculate F1 Score
|
| 318 |
f1 = f1_score(y_true, y_pred, average='weighted')
|
| 319 |
-
|
| 320 |
-
|
| 321 |
plt.figure(figsize=(5, 5))
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
plt.
|
| 339 |
-
|
| 340 |
-
# Save the plot
|
| 341 |
-
plt.savefig(f"{title}.png", facecolor='black') # Set background to black for dark mode
|
| 342 |
plt.close()
|
| 343 |
|
|
|
|
| 344 |
return Image.open(f"{title}.png")
|
| 345 |
-
|
| 346 |
-
|
| 347 |
def identical_train_test_split(output_emb, output_raw, labels, percentage):
|
| 348 |
N = output_emb.shape[0] # Get the total number of samples
|
| 349 |
|
|
|
|
| 34 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 35 |
if embeddings_cm is not None:
|
| 36 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 37 |
+
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
|
| 38 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 39 |
else:
|
| 40 |
embeddings_img = None
|
|
|
|
| 71 |
plt.figure(figsize=(5, 5))
|
| 72 |
|
| 73 |
# Plot the confusion matrix with a dark-mode compatible colormap
|
| 74 |
+
sns.heatmap(cm, cmap="magma", cbar=True, linecolor='white')
|
| 75 |
|
| 76 |
# Add F1-score to the title
|
| 77 |
plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
|
|
|
|
| 316 |
|
| 317 |
# Calculate F1 Score
|
| 318 |
f1 = f1_score(y_true, y_pred, average='weighted')
|
| 319 |
+
|
| 320 |
+
plt.style.use('dark_background')
|
| 321 |
plt.figure(figsize=(5, 5))
|
| 322 |
+
|
| 323 |
+
# Plot the confusion matrix with a dark-mode compatible colormap
|
| 324 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 325 |
+
|
| 326 |
+
# Add F1-score to the title
|
| 327 |
+
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
|
| 328 |
+
|
| 329 |
+
# Customize tick labels for dark mode
|
| 330 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
| 331 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
| 332 |
+
|
| 333 |
+
plt.ylabel('True label', color="white", fontsize=12)
|
| 334 |
+
plt.xlabel('Predicted label', color="white", fontsize=12)
|
| 335 |
+
plt.tight_layout()
|
| 336 |
+
|
| 337 |
+
# Save the plot as an image
|
| 338 |
+
plt.savefig(f"{title}.png", transparent=True) # Use transparent to blend with the dark mode website
|
|
|
|
|
|
|
|
|
|
| 339 |
plt.close()
|
| 340 |
|
| 341 |
+
# Return the saved image
|
| 342 |
return Image.open(f"{title}.png")
|
| 343 |
+
|
|
|
|
| 344 |
def identical_train_test_split(output_emb, output_raw, labels, percentage):
|
| 345 |
N = output_emb.shape[0] # Get the total number of samples
|
| 346 |
|