chandini2595 commited on
Commit
71bdcfa
·
1 Parent(s): 261a578

Add advanced training metrics and confusion matrix logging/visualization

Browse files
Files changed (2) hide show
  1. app.py +51 -10
  2. requirements.txt +2 -0
app.py CHANGED
@@ -17,6 +17,9 @@ from dotenv import load_dotenv
17
  from chatbot_utils import ask_receipt_chatbot
18
  import time
19
  from tensorboard.backend.event_processing import event_accumulator
 
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(level=logging.INFO)
@@ -204,35 +207,73 @@ def main():
204
 
205
  if st.button("Start Training"):
206
  epochs = 10
 
207
  losses = []
 
208
  accuracies = []
209
  progress = st.progress(0)
210
- chart = st.line_chart({"Loss": [], "Accuracy": []})
 
 
211
 
212
  for epoch in range(epochs):
213
  # Simulate training
214
  loss = np.exp(-epoch/5) + np.random.rand() * 0.05
 
215
  acc = 1 - loss + np.random.rand() * 0.02
216
  losses.append(loss)
 
217
  accuracies.append(acc)
218
- chart.add_rows({"Loss": [loss], "Accuracy": [acc]})
219
  progress.progress((epoch+1)/epochs)
220
- st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Accuracy={acc:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  time.sleep(0.5)
222
 
 
223
  st.success("Training complete!")
224
- st.line_chart({"Loss": losses, "Accuracy": accuracies})
225
 
226
- logdir = "logs" # or the actual path to your logs
 
 
 
 
 
227
 
 
228
  if os.path.exists(logdir) and os.listdir(logdir):
229
  ea = event_accumulator.EventAccumulator(logdir)
230
  ea.Reload()
231
- if 'loss' in ea.Tags()['scalars']:
232
- losses = [s.value for s in ea.Scalars('loss')]
233
- st.line_chart(losses)
234
- else:
235
- st.info("No 'loss' scalar found in TensorBoard logs.")
 
 
 
 
 
236
  else:
237
  st.info("No TensorBoard logs found. Please upload logs to the 'logs' directory.")
238
 
 
17
  from chatbot_utils import ask_receipt_chatbot
18
  import time
19
  from tensorboard.backend.event_processing import event_accumulator
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ import matplotlib.pyplot as plt
22
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
 
207
 
208
  if st.button("Start Training"):
209
  epochs = 10
210
+ num_classes = 3 # Example: 3 classes for confusion matrix
211
  losses = []
212
+ val_losses = []
213
  accuracies = []
214
  progress = st.progress(0)
215
+ chart = st.line_chart({"Loss": [], "Val Loss": [], "Accuracy": []})
216
+
217
+ writer = SummaryWriter("logs")
218
 
219
  for epoch in range(epochs):
220
  # Simulate training
221
  loss = np.exp(-epoch/5) + np.random.rand() * 0.05
222
+ val_loss = loss + np.random.rand() * 0.02
223
  acc = 1 - loss + np.random.rand() * 0.02
224
  losses.append(loss)
225
+ val_losses.append(val_loss)
226
  accuracies.append(acc)
227
+ chart.add_rows({"Loss": [loss], "Val Loss": [val_loss], "Accuracy": [acc]})
228
  progress.progress((epoch+1)/epochs)
229
+ st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Loss={val_loss:.4f}, Accuracy={acc:.4f}")
230
+
231
+ # Log to TensorBoard
232
+ writer.add_scalar("loss", loss, epoch)
233
+ writer.add_scalar("val_loss", val_loss, epoch)
234
+ writer.add_scalar("accuracy", acc, epoch)
235
+
236
+ # Simulate predictions and labels for confusion matrix
237
+ y_true = np.random.randint(0, num_classes, 100)
238
+ y_pred = y_true.copy()
239
+ # Add some noise to predictions
240
+ y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10)
241
+ cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
242
+
243
+ # Plot and log confusion matrix as image
244
+ fig, ax = plt.subplots()
245
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
246
+ disp.plot(ax=ax)
247
+ plt.close(fig)
248
+ writer.add_figure("confusion_matrix", fig, epoch)
249
+
250
  time.sleep(0.5)
251
 
252
+ writer.close()
253
  st.success("Training complete!")
254
+ st.line_chart({"Loss": losses, "Val Loss": val_losses, "Accuracy": accuracies})
255
 
256
+ # Show last confusion matrix in Streamlit
257
+ st.subheader("Confusion Matrix (Last Epoch)")
258
+ fig, ax = plt.subplots()
259
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
260
+ disp.plot(ax=ax)
261
+ st.pyplot(fig)
262
 
263
+ logdir = "logs"
264
  if os.path.exists(logdir) and os.listdir(logdir):
265
  ea = event_accumulator.EventAccumulator(logdir)
266
  ea.Reload()
267
+ scalars = ea.Tags()['scalars']
268
+ for tag in ['loss', 'val_loss', 'accuracy']:
269
+ if tag in scalars:
270
+ values = [s.value for s in ea.Scalars(tag)]
271
+ st.line_chart({tag: values})
272
+ # Show confusion matrix images if available
273
+ if 'confusion_matrix' in ea.Tags()['images']:
274
+ st.subheader("TensorBoard Confusion Matrices")
275
+ for img in ea.Images('confusion_matrix'):
276
+ st.image(img.encoded_image_string)
277
  else:
278
  st.info("No TensorBoard logs found. Please upload logs to the 'logs' directory.")
279
 
requirements.txt CHANGED
@@ -31,3 +31,5 @@ pydantic>=2.0.0
31
  openai
32
  streamlit
33
  plotly==5.18.0
 
 
 
31
  openai
32
  streamlit
33
  plotly==5.18.0
34
+ matplotlib
35
+ scikit-learn