feat: add TrainerCallback to stream live training logs to UI
Browse files- Implemented GradioLoggerCallback to forward Hugging Face Trainer logs to Gradio
- Replaced pre-loop simulated logging with true per-step feedback
- UI now shows step-by-step progress without freezing or blocking
- train_abuse_model.py +18 -1
train_abuse_model.py
CHANGED
@@ -7,6 +7,7 @@ import os
|
|
7 |
import time
|
8 |
import gradio as gr # β
required for progress bar
|
9 |
from pathlib import Path
|
|
|
10 |
|
11 |
# Python standard + ML packages
|
12 |
import pandas as pd
|
@@ -23,6 +24,7 @@ from huggingface_hub import hf_hub_download
|
|
23 |
# Hugging Face transformers
|
24 |
import transformers
|
25 |
from transformers import (
|
|
|
26 |
AutoTokenizer,
|
27 |
DebertaV2Tokenizer,
|
28 |
BertTokenizer,
|
@@ -66,6 +68,15 @@ logger.info(f"Transformers version: {transformers.__version__}")
|
|
66 |
logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
|
67 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def evaluate_model_with_thresholds(trainer, test_dataset):
|
71 |
"""Run full evaluation with automatic threshold tuning."""
|
@@ -191,6 +202,7 @@ train_texts, val_texts, train_labels, val_labels = train_test_split(
|
|
191 |
model_name = "microsoft/deberta-v3-base"
|
192 |
|
193 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
|
194 |
if os.path.exists("saved_model/"):
|
195 |
yield "β
Trained model found! Skipping training...\n"
|
196 |
for line in evaluate_saved_model():
|
@@ -239,7 +251,8 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
239 |
model=model,
|
240 |
args=training_args,
|
241 |
train_dataset=train_dataset,
|
242 |
-
eval_dataset=val_dataset
|
|
|
243 |
)
|
244 |
|
245 |
logger.info("Training started with %d samples", len(train_dataset))
|
@@ -262,6 +275,10 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
262 |
# Start training!
|
263 |
trainer.train()
|
264 |
|
|
|
|
|
|
|
|
|
265 |
progress(1.0)
|
266 |
yield "β
Progress: 100%\n"
|
267 |
|
|
|
7 |
import time
|
8 |
import gradio as gr # β
required for progress bar
|
9 |
from pathlib import Path
|
10 |
+
import queue
|
11 |
|
12 |
# Python standard + ML packages
|
13 |
import pandas as pd
|
|
|
24 |
# Hugging Face transformers
|
25 |
import transformers
|
26 |
from transformers import (
|
27 |
+
TrainerCallback,
|
28 |
AutoTokenizer,
|
29 |
DebertaV2Tokenizer,
|
30 |
BertTokenizer,
|
|
|
68 |
logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
|
69 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
|
71 |
+
class GradioLoggerCallback(TrainerCallback):
|
72 |
+
def __init__(self, gr_queue):
|
73 |
+
self.gr_queue = gr_queue
|
74 |
+
|
75 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
76 |
+
if logs:
|
77 |
+
msg = f"π Step {state.global_step}: {logs}"
|
78 |
+
logger.info(msg)
|
79 |
+
self.gr_queue.put(msg)
|
80 |
|
81 |
def evaluate_model_with_thresholds(trainer, test_dataset):
|
82 |
"""Run full evaluation with automatic threshold tuning."""
|
|
|
202 |
model_name = "microsoft/deberta-v3-base"
|
203 |
|
204 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
205 |
+
log_queue = queue.Queue()
|
206 |
if os.path.exists("saved_model/"):
|
207 |
yield "β
Trained model found! Skipping training...\n"
|
208 |
for line in evaluate_saved_model():
|
|
|
251 |
model=model,
|
252 |
args=training_args,
|
253 |
train_dataset=train_dataset,
|
254 |
+
eval_dataset=val_dataset,
|
255 |
+
callbacks=[GradioLoggerCallback(log_queue)]
|
256 |
)
|
257 |
|
258 |
logger.info("Training started with %d samples", len(train_dataset))
|
|
|
275 |
# Start training!
|
276 |
trainer.train()
|
277 |
|
278 |
+
# Drain queue to UI
|
279 |
+
while not log_queue.empty():
|
280 |
+
yield log_queue.get()
|
281 |
+
|
282 |
progress(1.0)
|
283 |
yield "β
Progress: 100%\n"
|
284 |
|