rshakked commited on
Commit
d08a52d
Β·
1 Parent(s): cf43376

fix: stream logs to Gradio UI using yield instead of logger only

Browse files

- Added yield statements for all important messages during training and evaluation
- Replaced yield evaluate_saved_model() with for-loop to stream output
- Removed ineffective return log_buffer.read() which had no effect in Gradio

Files changed (1) hide show
  1. train_abuse_model.py +11 -6
train_abuse_model.py CHANGED
@@ -60,10 +60,9 @@ logger = logging.getLogger(__name__)
60
 
61
 
62
  # Check versions
63
- logger.info("Transformers version:", transformers.__version__)
64
 
65
  # Check for GPU availability
66
- logger.info("Transformers version: %s", torch.__version__)
67
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
 
@@ -194,7 +193,8 @@ model_name = "microsoft/deberta-v3-base"
194
  def run_training(progress=gr.Progress(track_tqdm=True)):
195
  if os.path.exists("saved_model/"):
196
  yield "βœ… Trained model found! Skipping training...\n"
197
- yield evaluate_saved_model()
 
198
  return
199
  yield "πŸš€ Starting training...\n"
200
  try:
@@ -262,6 +262,9 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
262
  # Start training!
263
  trainer.train()
264
 
 
 
 
265
  # Save the model and tokenizer
266
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
267
  model.save_pretrained(MODEL_DIR)
@@ -277,12 +280,14 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
277
  # Evaluation
278
  try:
279
  if 'trainer' in locals():
280
- evaluate_model_with_thresholds(trainer, test_dataset)
 
281
  logger.info("Evaluation completed")
 
 
282
  except Exception as e:
283
  logger.exception(f"Evaluation failed: {e}")
284
- log_buffer.seek(0)
285
- return log_buffer.read()
286
 
287
  def push_model_to_hub():
288
  try:
 
60
 
61
 
62
  # Check versions
63
+ logger.info(f"Transformers version: {transformers.__version__}")
64
 
65
  # Check for GPU availability
 
66
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
67
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
 
 
193
  def run_training(progress=gr.Progress(track_tqdm=True)):
194
  if os.path.exists("saved_model/"):
195
  yield "βœ… Trained model found! Skipping training...\n"
196
+ for line in evaluate_saved_model():
197
+ yield line
198
  return
199
  yield "πŸš€ Starting training...\n"
200
  try:
 
262
  # Start training!
263
  trainer.train()
264
 
265
+ progress(1.0)
266
+ yield "βœ… Progress: 100%\n"
267
+
268
  # Save the model and tokenizer
269
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
270
  model.save_pretrained(MODEL_DIR)
 
280
  # Evaluation
281
  try:
282
  if 'trainer' in locals():
283
+ for line in evaluate_model_with_thresholds(trainer, test_dataset):
284
+ yield line
285
  logger.info("Evaluation completed")
286
+ logger.info("Evaluation completed")
287
+ yield "πŸ“ˆ Evaluation completed\n"
288
  except Exception as e:
289
  logger.exception(f"Evaluation failed: {e}")
290
+ return
 
291
 
292
  def push_model_to_hub():
293
  try: