Spaces:
Sleeping
Sleeping
Commit
·
9e80388
1
Parent(s):
cbc6443
Show only the last confusion matrix image after training
Browse files
app.py
CHANGED
@@ -241,18 +241,18 @@ def main():
|
|
241 |
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
|
242 |
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
|
243 |
|
244 |
-
#
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
250 |
|
251 |
time.sleep(0.5)
|
252 |
|
253 |
writer.close()
|
254 |
st.success("Training complete!")
|
255 |
-
st.line_chart({"Loss": losses, "Val Loss": val_losses, "Accuracy": accuracies})
|
256 |
|
257 |
# Wait a moment to ensure logs are written
|
258 |
time.sleep(1)
|
|
|
241 |
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
|
242 |
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
|
243 |
|
244 |
+
# Only log confusion matrix in the last epoch
|
245 |
+
if epoch == epochs - 1:
|
246 |
+
fig, ax = plt.subplots()
|
247 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
|
248 |
+
disp.plot(ax=ax)
|
249 |
+
plt.close(fig)
|
250 |
+
writer.add_figure("confusion_matrix", fig, epoch)
|
251 |
|
252 |
time.sleep(0.5)
|
253 |
|
254 |
writer.close()
|
255 |
st.success("Training complete!")
|
|
|
256 |
|
257 |
# Wait a moment to ensure logs are written
|
258 |
time.sleep(1)
|