Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -65,27 +65,32 @@ def compute_f1_score(cm):
|
|
| 65 |
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
| 66 |
# Compute the average F1-score
|
| 67 |
avg_f1 = compute_f1_score(cm)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
plt.
|
| 74 |
-
plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
|
| 75 |
-
plt.title(full_title, color='white', pad=20) # Add padding to prevent title clipping, white text for dark mode
|
| 76 |
-
plt.colorbar()
|
| 77 |
|
| 78 |
-
|
| 79 |
-
plt.xticks(
|
| 80 |
-
plt.yticks(
|
| 81 |
-
|
| 82 |
-
plt.
|
| 83 |
-
plt.
|
| 84 |
-
plt.
|
| 85 |
-
|
| 86 |
-
# Save the plot
|
| 87 |
-
plt.savefig(save_path,
|
| 88 |
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def compute_average_confusion_matrix(folder):
|
| 91 |
confusion_matrices = []
|
|
@@ -158,7 +163,7 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
|
|
| 158 |
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 159 |
|
| 160 |
# Add F1-score to the title
|
| 161 |
-
plt.title(f"{title}
|
| 162 |
|
| 163 |
# Customize tick labels for dark mode
|
| 164 |
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
|
|
|
| 65 |
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
| 66 |
# Compute the average F1-score
|
| 67 |
avg_f1 = compute_f1_score(cm)
|
| 68 |
+
|
| 69 |
+
# Set dark mode styling
|
| 70 |
+
plt.style.use('dark_background')
|
| 71 |
+
plt.figure(figsize=(5, 5))
|
| 72 |
|
| 73 |
+
# Plot the confusion matrix with a dark-mode compatible colormap
|
| 74 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 75 |
|
| 76 |
+
# Add F1-score to the title
|
| 77 |
+
plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Customize tick labels for dark mode
|
| 80 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
| 81 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
| 82 |
+
|
| 83 |
+
plt.ylabel('True label', color="white", fontsize=12)
|
| 84 |
+
plt.xlabel('Predicted label', color="white", fontsize=12)
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
|
| 87 |
+
# Save the plot as an image
|
| 88 |
+
plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
|
| 89 |
plt.close()
|
| 90 |
+
|
| 91 |
+
# Return the saved image
|
| 92 |
+
return Image.open(save_path)
|
| 93 |
+
|
| 94 |
|
| 95 |
def compute_average_confusion_matrix(folder):
|
| 96 |
confusion_matrices = []
|
|
|
|
| 163 |
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 164 |
|
| 165 |
# Add F1-score to the title
|
| 166 |
+
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
|
| 167 |
|
| 168 |
# Customize tick labels for dark mode
|
| 169 |
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|