Spaces:
Runtime error
Runtime error
Update backend/app/train.py
Browse files- backend/app/train.py +5 -2
backend/app/train.py
CHANGED
@@ -5,7 +5,7 @@ from transformers import (
|
|
5 |
AutoModelForSequenceClassification,
|
6 |
TrainingArguments,
|
7 |
Trainer,
|
8 |
-
|
9 |
)
|
10 |
import torch
|
11 |
from datasets import Dataset
|
@@ -38,12 +38,15 @@ def load_and_prepare_dataset():
|
|
38 |
|
39 |
# -------- Tokenization --------
|
40 |
def tokenize_function(example, tokenizer):
|
41 |
-
|
42 |
example["text"],
|
43 |
truncation=True,
|
44 |
padding=True,
|
45 |
max_length=256,
|
46 |
)
|
|
|
|
|
|
|
47 |
|
48 |
# -------- Main Training Function --------
|
49 |
def train():
|
|
|
5 |
AutoModelForSequenceClassification,
|
6 |
TrainingArguments,
|
7 |
Trainer,
|
8 |
+
Dad taCollatorWithPadding,
|
9 |
)
|
10 |
import torch
|
11 |
from datasets import Dataset
|
|
|
38 |
|
39 |
# -------- Tokenization --------
|
40 |
def tokenize_function(example, tokenizer):
|
41 |
+
tokens = tokenizer(
|
42 |
example["text"],
|
43 |
truncation=True,
|
44 |
padding=True,
|
45 |
max_length=256,
|
46 |
)
|
47 |
+
tokens["label"] = example["label"] # ✅ Keep label after tokenization
|
48 |
+
return tokens
|
49 |
+
|
50 |
|
51 |
# -------- Main Training Function --------
|
52 |
def train():
|