fix: pass tokenizer explicitly to AbuseDataset and safeguard evaluation step
Browse files- train_abuse_model.py +8 -8
train_abuse_model.py
CHANGED
@@ -54,7 +54,7 @@ logger.info("PyTorch version:", torch.__version__)
|
|
54 |
# Custom Dataset class
|
55 |
|
56 |
class AbuseDataset(Dataset):
|
57 |
-
def __init__(self, texts, labels):
|
58 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|
59 |
self.labels = labels
|
60 |
|
@@ -223,10 +223,9 @@ def run_training():
|
|
223 |
param.requires_grad = False
|
224 |
|
225 |
|
226 |
-
train_dataset = AbuseDataset(train_texts, train_labels)
|
227 |
-
val_dataset = AbuseDataset(val_texts, val_labels)
|
228 |
-
test_dataset = AbuseDataset(test_texts, test_labels)
|
229 |
-
|
230 |
|
231 |
# TrainingArguments for HuggingFace Trainer (logging, saving)
|
232 |
training_args = TrainingArguments(
|
@@ -270,9 +269,10 @@ def run_training():
|
|
270 |
|
271 |
# Evaluation
|
272 |
try:
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
276 |
except Exception as e:
|
277 |
logger.exception(f"Evaluation failed: {e}")
|
278 |
log_buffer.seek(0)
|
|
|
54 |
# Custom Dataset class
|
55 |
|
56 |
class AbuseDataset(Dataset):
|
57 |
+
def __init__(self, texts, labels, tokenizer):
|
58 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|
59 |
self.labels = labels
|
60 |
|
|
|
223 |
param.requires_grad = False
|
224 |
|
225 |
|
226 |
+
train_dataset = AbuseDataset(train_texts, train_labels,tokenizer)
|
227 |
+
val_dataset = AbuseDataset(val_texts, val_labels,tokenizer)
|
228 |
+
test_dataset = AbuseDataset(test_texts, test_labels,tokenizer)
|
|
|
229 |
|
230 |
# TrainingArguments for HuggingFace Trainer (logging, saving)
|
231 |
training_args = TrainingArguments(
|
|
|
269 |
|
270 |
# Evaluation
|
271 |
try:
|
272 |
+
if 'trainer' in locals():
|
273 |
+
label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
|
274 |
+
evaluate_model_with_thresholds(trainer, test_dataset)
|
275 |
+
logger.info("Evaluation completed")
|
276 |
except Exception as e:
|
277 |
logger.exception(f"Evaluation failed: {e}")
|
278 |
log_buffer.seek(0)
|