rshakked commited on
Commit
5c16708
·
1 Parent(s): 2032430

fix: pass tokenizer explicitly to AbuseDataset and safeguard evaluation step

Browse files
Files changed (1) hide show
  1. train_abuse_model.py +8 -8
train_abuse_model.py CHANGED
@@ -54,7 +54,7 @@ logger.info("PyTorch version:", torch.__version__)
54
  # Custom Dataset class
55
 
56
  class AbuseDataset(Dataset):
57
- def __init__(self, texts, labels):
58
  self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
59
  self.labels = labels
60
 
@@ -223,10 +223,9 @@ def run_training():
223
  param.requires_grad = False
224
 
225
 
226
- train_dataset = AbuseDataset(train_texts, train_labels)
227
- val_dataset = AbuseDataset(val_texts, val_labels)
228
- test_dataset = AbuseDataset(test_texts, test_labels)
229
-
230
 
231
  # TrainingArguments for HuggingFace Trainer (logging, saving)
232
  training_args = TrainingArguments(
@@ -270,9 +269,10 @@ def run_training():
270
 
271
  # Evaluation
272
  try:
273
- label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
274
- evaluate_model_with_thresholds(trainer, test_dataset)
275
- logger.info("Evaluation completed")
 
276
  except Exception as e:
277
  logger.exception(f"Evaluation failed: {e}")
278
  log_buffer.seek(0)
 
54
  # Custom Dataset class
55
 
56
  class AbuseDataset(Dataset):
57
+ def __init__(self, texts, labels, tokenizer):
58
  self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
59
  self.labels = labels
60
 
 
223
  param.requires_grad = False
224
 
225
 
226
+ train_dataset = AbuseDataset(train_texts, train_labels,tokenizer)
227
+ val_dataset = AbuseDataset(val_texts, val_labels,tokenizer)
228
+ test_dataset = AbuseDataset(test_texts, test_labels,tokenizer)
 
229
 
230
  # TrainingArguments for HuggingFace Trainer (logging, saving)
231
  training_args = TrainingArguments(
 
269
 
270
  # Evaluation
271
  try:
272
+ if 'trainer' in locals():
273
+ label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
274
+ evaluate_model_with_thresholds(trainer, test_dataset)
275
+ logger.info("Evaluation completed")
276
  except Exception as e:
277
  logger.exception(f"Evaluation failed: {e}")
278
  log_buffer.seek(0)