Spaces:
Sleeping
Sleeping
Commit
·
bb9f60a
1
Parent(s):
e8dbc8c
Remove post-training TensorBoard log charts
Browse files
app.py
CHANGED
@@ -20,6 +20,8 @@ from tensorboard.backend.event_processing import event_accumulator
|
|
20 |
from torch.utils.tensorboard import SummaryWriter
|
21 |
import matplotlib.pyplot as plt
|
22 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
|
|
|
|
23 |
|
24 |
# Configure logging
|
25 |
logging.basicConfig(level=logging.INFO)
|
@@ -252,11 +254,14 @@ def main():
|
|
252 |
st.success("Training complete!")
|
253 |
|
254 |
# Show last confusion matrix in Streamlit
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
260 |
|
261 |
if __name__ == "__main__":
|
262 |
main()
|
|
|
20 |
from torch.utils.tensorboard import SummaryWriter
|
21 |
import matplotlib.pyplot as plt
|
22 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
23 |
+
import matplotlib
|
24 |
+
matplotlib.use('Agg')
|
25 |
|
26 |
# Configure logging
|
27 |
logging.basicConfig(level=logging.INFO)
|
|
|
254 |
st.success("Training complete!")
|
255 |
|
256 |
# Show last confusion matrix in Streamlit
|
257 |
+
if 'cm' in locals():
|
258 |
+
st.subheader("Confusion Matrix (Last Epoch)")
|
259 |
+
fig, ax = plt.subplots()
|
260 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
|
261 |
+
disp.plot(ax=ax)
|
262 |
+
st.pyplot(fig)
|
263 |
+
else:
|
264 |
+
st.info("Confusion matrix not found.")
|
265 |
|
266 |
if __name__ == "__main__":
|
267 |
main()
|