Estherrr777 commited on
Commit
8c837b6
·
verified ·
1 Parent(s): 8dc105c

Update backend/app/train.py

Browse files
Files changed (1) hide show
  1. backend/app/train.py +14 -3
backend/app/train.py CHANGED
@@ -30,7 +30,9 @@ def load_and_prepare_dataset():
30
  f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
31
  f"Predict the Risk Level."
32
  )
33
- label = label_map.get(example["RiskLevel"].lower(), 0) # Default to 0 if unknown
 
 
34
  return {"text": prompt, "label": label}
35
 
36
  dataset = Dataset.from_list(data)
@@ -44,7 +46,7 @@ def tokenize_function(example, tokenizer):
44
  padding=True,
45
  max_length=256,
46
  )
47
- tokens["label"] = example["label"] # ✅ Keep label after tokenization
48
  return tokens
49
 
50
  # -------- Main Training Function --------
@@ -53,8 +55,17 @@ def train():
53
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
54
 
55
  dataset = load_and_prepare_dataset()
 
 
56
  tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False)
57
- tokenized_dataset = tokenized_dataset.remove_columns(["text"])
 
 
 
 
 
 
 
58
 
59
  training_args = TrainingArguments(
60
  output_dir=SAVE_PATH,
 
30
  f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
31
  f"Predict the Risk Level."
32
  )
33
+ # Ensure consistent and safe label mapping
34
+ label_str = str(example.get("RiskLevel", "")).lower()
35
+ label = label_map.get(label_str, 0)
36
  return {"text": prompt, "label": label}
37
 
38
  dataset = Dataset.from_list(data)
 
46
  padding=True,
47
  max_length=256,
48
  )
49
+ tokens["label"] = example["label"]
50
  return tokens
51
 
52
  # -------- Main Training Function --------
 
55
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
56
 
57
  dataset = load_and_prepare_dataset()
58
+
59
+ # Tokenize dataset
60
  tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False)
61
+
62
+ # Remove any non-tensor-compatible fields
63
+ tokenized_dataset = tokenized_dataset.remove_columns(
64
+ [col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]]
65
+ )
66
+
67
+ # Optional sanity check
68
+ print("🔎 Sample tokenized example:", tokenized_dataset[0])
69
 
70
  training_args = TrainingArguments(
71
  output_dir=SAVE_PATH,