feat: add live training log streaming using TrainerCallback and background thread
Browse files- Added GradioLoggerCallback to capture trainer log events
- Ran trainer.train() in a background thread to avoid UI blocking
- Streamed logs from queue during training to Gradio UI using yield
- Replaced simulated progress loop with real-time progress and log updates
- Fixes issue where UI showed only progress bar and froze at 99%
- train_abuse_model.py +33 -19
train_abuse_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# # Install core packages
|
2 |
# !pip install -U transformers datasets accelerate
|
3 |
-
|
4 |
import logging
|
5 |
import io
|
6 |
import os
|
@@ -256,28 +256,42 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
256 |
)
|
257 |
|
258 |
logger.info("Training started with %d samples", len(train_dataset))
|
259 |
-
yield "π Training
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
# Start training!
|
276 |
trainer.train()
|
277 |
|
278 |
-
# Drain queue to UI
|
279 |
-
while not log_queue.empty():
|
280 |
-
|
281 |
|
282 |
progress(1.0)
|
283 |
yield "β
Progress: 100%\n"
|
|
|
1 |
# # Install core packages
|
2 |
# !pip install -U transformers datasets accelerate
|
3 |
+
import threading
|
4 |
import logging
|
5 |
import io
|
6 |
import os
|
|
|
256 |
)
|
257 |
|
258 |
logger.info("Training started with %d samples", len(train_dataset))
|
259 |
+
yield "π Training started...\n"
|
260 |
+
|
261 |
+
progress(0.01)
|
262 |
+
|
263 |
+
# Run training in background thread
|
264 |
+
trainer_training = [True]
|
265 |
+
|
266 |
+
def background_train():
|
267 |
+
trainer.train()
|
268 |
+
trainer_training[0] = False # Mark as done
|
269 |
+
|
270 |
+
train_thread = threading.Thread(target=background_train)
|
271 |
+
train_thread.start()
|
272 |
+
|
273 |
+
# Drain log queue live while training runs
|
274 |
+
percent = 0
|
275 |
+
while train_thread.is_alive() or not log_queue.empty():
|
276 |
+
while not log_queue.empty():
|
277 |
+
log_msg = log_queue.get()
|
278 |
+
yield log_msg
|
279 |
+
# Optional: update progress bar slowly toward 1.0
|
280 |
+
if percent < 98:
|
281 |
+
percent += 1
|
282 |
+
progress(percent / 100)
|
283 |
+
time.sleep(1)
|
284 |
+
|
285 |
+
progress(1.0)
|
286 |
+
yield "β
Progress: 100%\n"
|
287 |
+
|
288 |
|
289 |
# Start training!
|
290 |
trainer.train()
|
291 |
|
292 |
+
# # Drain queue to UI
|
293 |
+
# while not log_queue.empty():
|
294 |
+
# yield log_queue.get()
|
295 |
|
296 |
progress(1.0)
|
297 |
yield "β
Progress: 100%\n"
|