chandini2595 commited on
Commit
9e80388
·
1 Parent(s): cbc6443

Show only the last confusion matrix image after training

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -241,18 +241,18 @@ def main():
241
  y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
242
  cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
243
 
244
- # Plot and log confusion matrix as image
245
- fig, ax = plt.subplots()
246
- disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
247
- disp.plot(ax=ax)
248
- plt.close(fig)
249
- writer.add_figure("confusion_matrix", fig, epoch)
 
250
 
251
  time.sleep(0.5)
252
 
253
  writer.close()
254
  st.success("Training complete!")
255
- st.line_chart({"Loss": losses, "Val Loss": val_losses, "Accuracy": accuracies})
256
 
257
  # Wait a moment to ensure logs are written
258
  time.sleep(1)
 
241
  y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
242
  cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
243
 
244
+ # Only log confusion matrix in the last epoch
245
+ if epoch == epochs - 1:
246
+ fig, ax = plt.subplots()
247
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
248
+ disp.plot(ax=ax)
249
+ plt.close(fig)
250
+ writer.add_figure("confusion_matrix", fig, epoch)
251
 
252
  time.sleep(0.5)
253
 
254
  writer.close()
255
  st.success("Training complete!")
 
256
 
257
  # Wait a moment to ensure logs are written
258
  time.sleep(1)