wi-lab commited on
Commit
ff9221d
·
verified ·
1 Parent(s): f938ef0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -26,7 +26,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
26
  raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
27
  plot_confusion_matrix_beamPred(raw_cm,
28
  classes=np.arange(raw_cm.shape[0]),
29
- title=f"Raw Confusion Matrix\n{data_percentage}% data, {task_complexity} beams",
30
  save_path=raw_cm_path,
31
  theme=theme)
32
  raw_img = Image.open(raw_cm_path)
@@ -39,7 +39,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
39
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
40
  plot_confusion_matrix_beamPred(embeddings_cm,
41
  classes=np.arange(embeddings_cm.shape[0]),
42
- title=f"Embeddings Confusion Matrix\n{data_percentage}% data, {task_complexity} beams",
43
  save_path=embeddings_cm_path,
44
  theme=theme)
45
  embeddings_img = Image.open(embeddings_cm_path)
@@ -191,14 +191,14 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path, light_mode=F
191
  sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
192
 
193
  # Add F1-score to the title
194
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
195
 
196
  # Customize tick labels for light/dark mode
197
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
198
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
199
 
200
- plt.ylabel('True label', color=text_color, fontsize=12)
201
- plt.xlabel('Predicted label', color=text_color, fontsize=12)
202
  plt.tight_layout()
203
 
204
  # Save the plot as an image
@@ -220,14 +220,14 @@ def display_confusion_matrices_los(percentage):
220
  raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
221
  raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
222
  raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
223
- f"Raw Confusion Matrix ({percentage:.1f}% data)",
224
  raw_cm_img_path)
225
 
226
  # Process embeddings confusion matrix
227
  embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
228
  embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
229
  embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
230
- f"Embeddings Confusion Matrix ({percentage:.1f}% data)",
231
  embeddings_cm_img_path)
232
 
233
  return raw_img, embeddings_img
@@ -362,14 +362,14 @@ def plot_confusion_matrix(y_true, y_pred, title, light_mode=False):
362
  sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
363
 
364
  # Add F1-score to the title
365
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
366
 
367
  # Customize tick labels for dark mode
368
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
369
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
370
 
371
- plt.ylabel('True label', color=text_color, fontsize=12)
372
- plt.xlabel('Predicted label', color=text_color, fontsize=12)
373
  plt.tight_layout()
374
 
375
  # Save the plot as an image
 
26
  raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
27
  plot_confusion_matrix_beamPred(raw_cm,
28
  classes=np.arange(raw_cm.shape[0]),
29
+ title=f"Confusion Matrix (Raw Channels)\n{data_percentage}% data, {task_complexity} beams",
30
  save_path=raw_cm_path,
31
  theme=theme)
32
  raw_img = Image.open(raw_cm_path)
 
39
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
40
  plot_confusion_matrix_beamPred(embeddings_cm,
41
  classes=np.arange(embeddings_cm.shape[0]),
42
+ title=f"Confusion Matrix (LWM Embeddings)\n{data_percentage}% data, {task_complexity} beams",
43
  save_path=embeddings_cm_path,
44
  theme=theme)
45
  embeddings_img = Image.open(embeddings_cm_path)
 
191
  sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
192
 
193
  # Add F1-score to the title
194
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=24)
195
 
196
  # Customize tick labels for light/dark mode
197
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
198
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
199
 
200
+ plt.ylabel('True label', color=text_color, fontsize=18)
201
+ plt.xlabel('Predicted label', color=text_color, fontsize=18)
202
  plt.tight_layout()
203
 
204
  # Save the plot as an image
 
220
  raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
221
  raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
222
  raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
223
+ f"Confusion Matrix (Raw Channels)\n{percentage:.1f}% data",
224
  raw_cm_img_path)
225
 
226
  # Process embeddings confusion matrix
227
  embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
228
  embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
229
  embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
230
+ f"Confusion Matrix (LWM Embeddings)\n{percentage:.1f}% data",
231
  embeddings_cm_img_path)
232
 
233
  return raw_img, embeddings_img
 
362
  sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
363
 
364
  # Add F1-score to the title
365
+ plt.title(f"{title}\nF1 Score: {f1:.3f}", color=text_color, fontsize=23)
366
 
367
  # Customize tick labels for dark mode
368
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
369
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
370
 
371
+ plt.ylabel('True label', color=text_color, fontsize=18)
372
+ plt.xlabel('Predicted label', color=text_color, fontsize=18)
373
  plt.tight_layout()
374
 
375
  # Save the plot as an image