rshakked commited on
Commit
ac6514b
Β·
1 Parent(s): c71da37

feat: add live training log streaming using TrainerCallback and background thread

Browse files

- Added GradioLoggerCallback to capture trainer log events
- Ran trainer.train() in a background thread to avoid UI blocking
- Streamed logs from queue during training to Gradio UI using yield
- Replaced simulated progress loop with real-time progress and log updates
- Fixes issue where UI showed only progress bar and froze at 99%

Files changed (1) hide show
  1. train_abuse_model.py +33 -19
train_abuse_model.py CHANGED
@@ -1,6 +1,6 @@
1
  # # Install core packages
2
  # !pip install -U transformers datasets accelerate
3
-
4
  import logging
5
  import io
6
  import os
@@ -256,28 +256,42 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
256
  )
257
 
258
  logger.info("Training started with %d samples", len(train_dataset))
259
- yield "πŸ”„ Training in progress...\n"
260
-
261
- total_steps = len(train_dataset) * training_args.num_train_epochs // training_args.per_device_train_batch_size
262
- intervals = max(total_steps // 20, 1)
263
-
264
- for i in range(0, total_steps, intervals):
265
- time.sleep(0.5)
266
- percent = int(100 * i / total_steps)
267
- progress(percent / 100)
268
- yield f"⏳ Progress: {percent}%\n"
269
- # # This checks if any tensor is on GPU too early.
270
- # logger.info("πŸ§ͺ Sample device check from train_dataset:")
271
- # sample = train_dataset[0]
272
- # for k, v in sample.items():
273
- # logger.info(f"{k}: {v.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  # Start training!
276
  trainer.train()
277
 
278
- # Drain queue to UI
279
- while not log_queue.empty():
280
- yield log_queue.get()
281
 
282
  progress(1.0)
283
  yield "βœ… Progress: 100%\n"
 
1
  # # Install core packages
2
  # !pip install -U transformers datasets accelerate
3
+ import threading
4
  import logging
5
  import io
6
  import os
 
256
  )
257
 
258
  logger.info("Training started with %d samples", len(train_dataset))
259
+ yield "πŸ”„ Training started...\n"
260
+
261
+ progress(0.01)
262
+
263
+ # Run training in background thread
264
+ trainer_training = [True]
265
+
266
+ def background_train():
267
+ trainer.train()
268
+ trainer_training[0] = False # Mark as done
269
+
270
+ train_thread = threading.Thread(target=background_train)
271
+ train_thread.start()
272
+
273
+ # Drain log queue live while training runs
274
+ percent = 0
275
+ while train_thread.is_alive() or not log_queue.empty():
276
+ while not log_queue.empty():
277
+ log_msg = log_queue.get()
278
+ yield log_msg
279
+ # Optional: update progress bar slowly toward 1.0
280
+ if percent < 98:
281
+ percent += 1
282
+ progress(percent / 100)
283
+ time.sleep(1)
284
+
285
+ progress(1.0)
286
+ yield "βœ… Progress: 100%\n"
287
+
288
 
289
  # Start training!
290
  trainer.train()
291
 
292
+ # # Drain queue to UI
293
+ # while not log_queue.empty():
294
+ # yield log_queue.get()
295
 
296
  progress(1.0)
297
  yield "βœ… Progress: 100%\n"