Spaces:
Sleeping
Sleeping
Commit
·
71bdcfa
1
Parent(s):
261a578
Add advanced training metrics and confusion matrix logging/visualization
Browse files- app.py +51 -10
- requirements.txt +2 -0
app.py
CHANGED
@@ -17,6 +17,9 @@ from dotenv import load_dotenv
|
|
17 |
from chatbot_utils import ask_receipt_chatbot
|
18 |
import time
|
19 |
from tensorboard.backend.event_processing import event_accumulator
|
|
|
|
|
|
|
20 |
|
21 |
# Configure logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
@@ -204,35 +207,73 @@ def main():
|
|
204 |
|
205 |
if st.button("Start Training"):
|
206 |
epochs = 10
|
|
|
207 |
losses = []
|
|
|
208 |
accuracies = []
|
209 |
progress = st.progress(0)
|
210 |
-
chart = st.line_chart({"Loss": [], "Accuracy": []})
|
|
|
|
|
211 |
|
212 |
for epoch in range(epochs):
|
213 |
# Simulate training
|
214 |
loss = np.exp(-epoch/5) + np.random.rand() * 0.05
|
|
|
215 |
acc = 1 - loss + np.random.rand() * 0.02
|
216 |
losses.append(loss)
|
|
|
217 |
accuracies.append(acc)
|
218 |
-
chart.add_rows({"Loss": [loss], "Accuracy": [acc]})
|
219 |
progress.progress((epoch+1)/epochs)
|
220 |
-
st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Accuracy={acc:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
time.sleep(0.5)
|
222 |
|
|
|
223 |
st.success("Training complete!")
|
224 |
-
st.line_chart({"Loss": losses, "Accuracy": accuracies})
|
225 |
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
227 |
|
|
|
228 |
if os.path.exists(logdir) and os.listdir(logdir):
|
229 |
ea = event_accumulator.EventAccumulator(logdir)
|
230 |
ea.Reload()
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
236 |
else:
|
237 |
st.info("No TensorBoard logs found. Please upload logs to the 'logs' directory.")
|
238 |
|
|
|
17 |
from chatbot_utils import ask_receipt_chatbot
|
18 |
import time
|
19 |
from tensorboard.backend.event_processing import event_accumulator
|
20 |
+
from torch.utils.tensorboard import SummaryWriter
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
23 |
|
24 |
# Configure logging
|
25 |
logging.basicConfig(level=logging.INFO)
|
|
|
207 |
|
208 |
if st.button("Start Training"):
|
209 |
epochs = 10
|
210 |
+
num_classes = 3 # Example: 3 classes for confusion matrix
|
211 |
losses = []
|
212 |
+
val_losses = []
|
213 |
accuracies = []
|
214 |
progress = st.progress(0)
|
215 |
+
chart = st.line_chart({"Loss": [], "Val Loss": [], "Accuracy": []})
|
216 |
+
|
217 |
+
writer = SummaryWriter("logs")
|
218 |
|
219 |
for epoch in range(epochs):
|
220 |
# Simulate training
|
221 |
loss = np.exp(-epoch/5) + np.random.rand() * 0.05
|
222 |
+
val_loss = loss + np.random.rand() * 0.02
|
223 |
acc = 1 - loss + np.random.rand() * 0.02
|
224 |
losses.append(loss)
|
225 |
+
val_losses.append(val_loss)
|
226 |
accuracies.append(acc)
|
227 |
+
chart.add_rows({"Loss": [loss], "Val Loss": [val_loss], "Accuracy": [acc]})
|
228 |
progress.progress((epoch+1)/epochs)
|
229 |
+
st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Loss={val_loss:.4f}, Accuracy={acc:.4f}")
|
230 |
+
|
231 |
+
# Log to TensorBoard
|
232 |
+
writer.add_scalar("loss", loss, epoch)
|
233 |
+
writer.add_scalar("val_loss", val_loss, epoch)
|
234 |
+
writer.add_scalar("accuracy", acc, epoch)
|
235 |
+
|
236 |
+
# Simulate predictions and labels for confusion matrix
|
237 |
+
y_true = np.random.randint(0, num_classes, 100)
|
238 |
+
y_pred = y_true.copy()
|
239 |
+
# Add some noise to predictions
|
240 |
+
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
|
241 |
+
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
|
242 |
+
|
243 |
+
# Plot and log confusion matrix as image
|
244 |
+
fig, ax = plt.subplots()
|
245 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
|
246 |
+
disp.plot(ax=ax)
|
247 |
+
plt.close(fig)
|
248 |
+
writer.add_figure("confusion_matrix", fig, epoch)
|
249 |
+
|
250 |
time.sleep(0.5)
|
251 |
|
252 |
+
writer.close()
|
253 |
st.success("Training complete!")
|
254 |
+
st.line_chart({"Loss": losses, "Val Loss": val_losses, "Accuracy": accuracies})
|
255 |
|
256 |
+
# Show last confusion matrix in Streamlit
|
257 |
+
st.subheader("Confusion Matrix (Last Epoch)")
|
258 |
+
fig, ax = plt.subplots()
|
259 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
|
260 |
+
disp.plot(ax=ax)
|
261 |
+
st.pyplot(fig)
|
262 |
|
263 |
+
logdir = "logs"
|
264 |
if os.path.exists(logdir) and os.listdir(logdir):
|
265 |
ea = event_accumulator.EventAccumulator(logdir)
|
266 |
ea.Reload()
|
267 |
+
scalars = ea.Tags()['scalars']
|
268 |
+
for tag in ['loss', 'val_loss', 'accuracy']:
|
269 |
+
if tag in scalars:
|
270 |
+
values = [s.value for s in ea.Scalars(tag)]
|
271 |
+
st.line_chart({tag: values})
|
272 |
+
# Show confusion matrix images if available
|
273 |
+
if 'confusion_matrix' in ea.Tags()['images']:
|
274 |
+
st.subheader("TensorBoard Confusion Matrices")
|
275 |
+
for img in ea.Images('confusion_matrix'):
|
276 |
+
st.image(img.encoded_image_string)
|
277 |
else:
|
278 |
st.info("No TensorBoard logs found. Please upload logs to the 'logs' directory.")
|
279 |
|
requirements.txt
CHANGED
@@ -31,3 +31,5 @@ pydantic>=2.0.0
|
|
31 |
openai
|
32 |
streamlit
|
33 |
plotly==5.18.0
|
|
|
|
|
|
31 |
openai
|
32 |
streamlit
|
33 |
plotly==5.18.0
|
34 |
+
matplotlib
|
35 |
+
scikit-learn
|