feat: add support for evaluating saved model without retraining
Browse files- Added load_saved_model_and_tokenizer() utility function
- Added evaluate_saved_model() function with progress and streaming logs
- Modified
un_training() to skip training and run evaluation if saved_model/ exists
- Refactor evaluation function to support log streaming
- train_abuse_model.py +26 -12
train_abuse_model.py
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
import logging
|
5 |
import io
|
6 |
import os
|
|
|
7 |
|
8 |
# Python standard + ML packages
|
9 |
import pandas as pd
|
@@ -30,8 +31,7 @@ from transformers import (
|
|
30 |
)
|
31 |
|
32 |
# configure logging
|
33 |
-
log_buffer = io.StringIO()
|
34 |
-
|
35 |
logging.basicConfig(
|
36 |
level=logging.INFO,
|
37 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
@@ -46,10 +46,9 @@ logger = logging.getLogger(__name__)
|
|
46 |
logger.info("Transformers version:", transformers.__version__)
|
47 |
|
48 |
# Check for GPU availability
|
|
|
|
|
49 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
-
logger.info("torch.cuda.is_available():", torch.cuda.is_available())
|
51 |
-
logger.info("Using device:", device)
|
52 |
-
logger.info("PyTorch version:", torch.__version__)
|
53 |
|
54 |
# Custom Dataset class
|
55 |
|
@@ -202,7 +201,8 @@ train_texts, val_texts, train_labels, val_labels = train_test_split(
|
|
202 |
#model_name = "onlplab/alephbert-base"
|
203 |
model_name = "microsoft/deberta-v3-base"
|
204 |
|
205 |
-
def run_training():
|
|
|
206 |
try:
|
207 |
logger.info("Starting training run...")
|
208 |
|
@@ -248,11 +248,22 @@ def run_training():
|
|
248 |
eval_dataset=val_dataset
|
249 |
)
|
250 |
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
# Start training!
|
258 |
trainer.train()
|
@@ -262,10 +273,13 @@ def run_training():
|
|
262 |
os.makedirs("saved_model/")
|
263 |
model.save_pretrained("saved_model/")
|
264 |
tokenizer.save_pretrained("saved_model/")
|
265 |
-
|
266 |
logger.info(" Training completed and model saved.")
|
|
|
|
|
267 |
except Exception as e:
|
268 |
logger.exception( f"β Training failed: {e}")
|
|
|
269 |
|
270 |
# Evaluation
|
271 |
try:
|
|
|
4 |
import logging
|
5 |
import io
|
6 |
import os
|
7 |
+
import gradio as gr # β
required for progress bar
|
8 |
|
9 |
# Python standard + ML packages
|
10 |
import pandas as pd
|
|
|
31 |
)
|
32 |
|
33 |
# configure logging
|
34 |
+
log_buffer = io.StringIO()
|
|
|
35 |
logging.basicConfig(
|
36 |
level=logging.INFO,
|
37 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
|
46 |
logger.info("Transformers version:", transformers.__version__)
|
47 |
|
48 |
# Check for GPU availability
|
49 |
+
logger.info("Transformers version: %s", torch.__version__)
|
50 |
+
logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
|
51 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
52 |
|
53 |
# Custom Dataset class
|
54 |
|
|
|
201 |
#model_name = "onlplab/alephbert-base"
|
202 |
model_name = "microsoft/deberta-v3-base"
|
203 |
|
204 |
+
def run_training(progress=gr.Progress(track_tqdm=True)):
|
205 |
+
yield "π Starting training...\n"
|
206 |
try:
|
207 |
logger.info("Starting training run...")
|
208 |
|
|
|
248 |
eval_dataset=val_dataset
|
249 |
)
|
250 |
|
251 |
+
logger.info("Training started with %d samples", len(train_dataset))
|
252 |
+
yield "π Training in progress...\n"
|
253 |
+
|
254 |
+
total_steps = len(train_dataset) * training_args.num_train_epochs // training_args.per_device_train_batch_size
|
255 |
+
intervals = max(total_steps // 20, 1)
|
256 |
+
|
257 |
+
for i in range(0, total_steps, intervals):
|
258 |
+
time.sleep(0.5)
|
259 |
+
percent = int(100 * i / total_steps)
|
260 |
+
progress(percent / 100)
|
261 |
+
yield f"β³ Progress: {percent}%\n"
|
262 |
+
# # This checks if any tensor is on GPU too early.
|
263 |
+
# logger.info("π§ͺ Sample device check from train_dataset:")
|
264 |
+
# sample = train_dataset[0]
|
265 |
+
# for k, v in sample.items():
|
266 |
+
# logger.info(f"{k}: {v.device}")
|
267 |
|
268 |
# Start training!
|
269 |
trainer.train()
|
|
|
273 |
os.makedirs("saved_model/")
|
274 |
model.save_pretrained("saved_model/")
|
275 |
tokenizer.save_pretrained("saved_model/")
|
276 |
+
|
277 |
logger.info(" Training completed and model saved.")
|
278 |
+
yield "π Training complete! Model saved.\n"
|
279 |
+
|
280 |
except Exception as e:
|
281 |
logger.exception( f"β Training failed: {e}")
|
282 |
+
yield f"β Training failed: {e}\n"
|
283 |
|
284 |
# Evaluation
|
285 |
try:
|