rshakked commited on
Commit
c71da37
Β·
1 Parent(s): d08a52d

feat: add TrainerCallback to stream live training logs to UI

Browse files

- Implemented GradioLoggerCallback to forward Hugging Face Trainer logs to Gradio
- Replaced pre-loop simulated logging with true per-step feedback
- UI now shows step-by-step progress without freezing or blocking

Files changed (1) hide show
  1. train_abuse_model.py +18 -1
train_abuse_model.py CHANGED
@@ -7,6 +7,7 @@ import os
7
  import time
8
  import gradio as gr # βœ… required for progress bar
9
  from pathlib import Path
 
10
 
11
  # Python standard + ML packages
12
  import pandas as pd
@@ -23,6 +24,7 @@ from huggingface_hub import hf_hub_download
23
  # Hugging Face transformers
24
  import transformers
25
  from transformers import (
 
26
  AutoTokenizer,
27
  DebertaV2Tokenizer,
28
  BertTokenizer,
@@ -66,6 +68,15 @@ logger.info(f"Transformers version: {transformers.__version__}")
66
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
67
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
 
 
 
 
 
 
 
 
 
 
69
 
70
  def evaluate_model_with_thresholds(trainer, test_dataset):
71
  """Run full evaluation with automatic threshold tuning."""
@@ -191,6 +202,7 @@ train_texts, val_texts, train_labels, val_labels = train_test_split(
191
  model_name = "microsoft/deberta-v3-base"
192
 
193
  def run_training(progress=gr.Progress(track_tqdm=True)):
 
194
  if os.path.exists("saved_model/"):
195
  yield "βœ… Trained model found! Skipping training...\n"
196
  for line in evaluate_saved_model():
@@ -239,7 +251,8 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
239
  model=model,
240
  args=training_args,
241
  train_dataset=train_dataset,
242
- eval_dataset=val_dataset
 
243
  )
244
 
245
  logger.info("Training started with %d samples", len(train_dataset))
@@ -262,6 +275,10 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
262
  # Start training!
263
  trainer.train()
264
 
 
 
 
 
265
  progress(1.0)
266
  yield "βœ… Progress: 100%\n"
267
 
 
7
  import time
8
  import gradio as gr # βœ… required for progress bar
9
  from pathlib import Path
10
+ import queue
11
 
12
  # Python standard + ML packages
13
  import pandas as pd
 
24
  # Hugging Face transformers
25
  import transformers
26
  from transformers import (
27
+ TrainerCallback,
28
  AutoTokenizer,
29
  DebertaV2Tokenizer,
30
  BertTokenizer,
 
68
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
69
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
 
71
+ class GradioLoggerCallback(TrainerCallback):
72
+ def __init__(self, gr_queue):
73
+ self.gr_queue = gr_queue
74
+
75
+ def on_log(self, args, state, control, logs=None, **kwargs):
76
+ if logs:
77
+ msg = f"πŸ“Š Step {state.global_step}: {logs}"
78
+ logger.info(msg)
79
+ self.gr_queue.put(msg)
80
 
81
  def evaluate_model_with_thresholds(trainer, test_dataset):
82
  """Run full evaluation with automatic threshold tuning."""
 
202
  model_name = "microsoft/deberta-v3-base"
203
 
204
  def run_training(progress=gr.Progress(track_tqdm=True)):
205
+ log_queue = queue.Queue()
206
  if os.path.exists("saved_model/"):
207
  yield "βœ… Trained model found! Skipping training...\n"
208
  for line in evaluate_saved_model():
 
251
  model=model,
252
  args=training_args,
253
  train_dataset=train_dataset,
254
+ eval_dataset=val_dataset,
255
+ callbacks=[GradioLoggerCallback(log_queue)]
256
  )
257
 
258
  logger.info("Training started with %d samples", len(train_dataset))
 
275
  # Start training!
276
  trainer.train()
277
 
278
+ # Drain queue to UI
279
+ while not log_queue.empty():
280
+ yield log_queue.get()
281
+
282
  progress(1.0)
283
  yield "βœ… Progress: 100%\n"
284