wi-lab commited on
Commit
761563c
·
verified ·
1 Parent(s): bc49788

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -330,21 +330,32 @@ def plot_confusion_matrix(y_true, y_pred, title):
330
  # Calculate F1 Score
331
  f1 = f1_score(y_true, y_pred, average='weighted')
332
 
333
- plt.style.use('dark_background')
 
 
 
 
 
 
 
 
 
 
 
334
  plt.figure(figsize=(5, 5))
335
 
336
  # Plot the confusion matrix with a dark-mode compatible colormap
337
- sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
338
 
339
  # Add F1-score to the title
340
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
341
 
342
  # Customize tick labels for dark mode
343
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
344
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
345
 
346
- plt.ylabel('True label', color="white", fontsize=12)
347
- plt.xlabel('Predicted label', color="white", fontsize=12)
348
  plt.tight_layout()
349
 
350
  # Save the plot as an image
 
330
  # Calculate F1 Score
331
  f1 = f1_score(y_true, y_pred, average='weighted')
332
 
333
+ #plt.style.use('dark_background')
334
+
335
+ # Set styling based on light or dark mode
336
+ if light_mode:
337
+ plt.style.use('default') # Light mode styling
338
+ text_color = 'black'
339
+ cmap = 'Blues' # Light-mode-friendly colormap
340
+ else:
341
+ plt.style.use('dark_background') # Dark mode styling
342
+ text_color = 'white'
343
+ cmap = 'magma' # Dark-mode-friendly colormap
344
+
345
  plt.figure(figsize=(5, 5))
346
 
347
  # Plot the confusion matrix with a dark-mode compatible colormap
348
+ sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
349
 
350
  # Add F1-score to the title
351
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
352
 
353
  # Customize tick labels for dark mode
354
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
355
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
356
 
357
+ plt.ylabel('True label', color=text_color, fontsize=12)
358
+ plt.xlabel('Predicted label', color=text_color, fontsize=12)
359
  plt.tight_layout()
360
 
361
  # Save the plot as an image