Estherrr777 commited on
Commit
4d5a067
·
verified ·
1 Parent(s): 4d723eb

Update backend/app/train.py

Browse files
Files changed (1) hide show
  1. backend/app/train.py +29 -41
backend/app/train.py CHANGED
@@ -2,10 +2,10 @@ import json
2
  import os
3
  from transformers import (
4
  AutoTokenizer,
5
- AutoModelForCausalLM,
6
  TrainingArguments,
7
  Trainer,
8
- DataCollatorForLanguageModeling,
9
  )
10
  import torch
11
  from datasets import Dataset
@@ -15,62 +15,48 @@ MODEL_NAME = "google/gemma-1.1-2b-it"
15
  DATA_PATH = "./backend/data/pregnancy_dataset.json"
16
  SAVE_PATH = "./backend/app/checkpoints"
17
 
18
- # -------- Load Dataset --------
19
- def load_dataset():
20
  with open(DATA_PATH, "r") as f:
21
  data = json.load(f)
22
- dataset = Dataset.from_list(data)
23
- return dataset
24
 
25
- # -------- Tokenization --------
26
- def tokenize_function(example, tokenizer):
27
- prompt = (
28
- f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, "
29
- f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, "
30
- f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
31
- f"Predict the Risk Level."
32
- )
33
- completion = example["RiskLevel"]
34
-
35
- # Map string labels to integer class indices
36
  label_map = {"low risk": 0, "medium risk": 1, "high risk": 2}
37
- label = label_map.get(completion.lower(), -1) # -1 as fallback
38
 
39
- full_prompt = f"<start_of_turn> {prompt} <end_of_turn>\n{completion} <end_of_turn>"
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- tokenized = tokenizer(
42
- full_prompt,
 
 
43
  truncation=True,
44
- padding="max_length",
45
  max_length=256,
46
  )
47
- tokenized["labels"] = label # ✅ Use numerical label
48
-
49
- return tokenized
50
 
51
  # -------- Main Training Function --------
52
  def train():
53
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
54
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
55
-
56
- # Optional: Freeze all except final layers
57
- for param in model.base_model.parameters():
58
- param.requires_grad = False
59
- for param in model.lm_head.parameters():
60
- param.requires_grad = True
61
 
62
- # Load and tokenize
63
- raw_dataset = load_dataset()
64
- tokenized_dataset = raw_dataset.map(lambda x: tokenize_function(x, tokenizer))
65
 
66
- # Data collator
67
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
68
-
69
- # Training arguments
70
  training_args = TrainingArguments(
71
  output_dir=SAVE_PATH,
72
  num_train_epochs=3,
73
- per_device_train_batch_size=2,
74
  save_steps=50,
75
  logging_steps=10,
76
  save_total_limit=1,
@@ -78,7 +64,8 @@ def train():
78
  report_to="none",
79
  )
80
 
81
- # Trainer
 
82
  trainer = Trainer(
83
  model=model,
84
  args=training_args,
@@ -94,3 +81,4 @@ def train():
94
 
95
  if __name__ == "__main__":
96
  train()
 
 
2
  import os
3
  from transformers import (
4
  AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
  TrainingArguments,
7
  Trainer,
8
+ DataCollatorWithPadding,
9
  )
10
  import torch
11
  from datasets import Dataset
 
15
  DATA_PATH = "./backend/data/pregnancy_dataset.json"
16
  SAVE_PATH = "./backend/app/checkpoints"
17
 
18
+ # -------- Load and Preprocess Dataset --------
19
+ def load_and_prepare_dataset():
20
  with open(DATA_PATH, "r") as f:
21
  data = json.load(f)
 
 
22
 
23
+ # Map risk levels to integer labels
 
 
 
 
 
 
 
 
 
 
24
  label_map = {"low risk": 0, "medium risk": 1, "high risk": 2}
 
25
 
26
+ def preprocess(example):
27
+ prompt = (
28
+ f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, "
29
+ f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, "
30
+ f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
31
+ f"Predict the Risk Level."
32
+ )
33
+ label = label_map.get(example["RiskLevel"].lower(), 0) # Default to 0 if unknown
34
+ return {"text": prompt, "label": label}
35
+
36
+ dataset = Dataset.from_list(data)
37
+ return dataset.map(preprocess)
38
 
39
+ # -------- Tokenization --------
40
+ def tokenize_function(example, tokenizer):
41
+ return tokenizer(
42
+ example["text"],
43
  truncation=True,
44
+ padding=True,
45
  max_length=256,
46
  )
 
 
 
47
 
48
  # -------- Main Training Function --------
49
  def train():
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
51
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
 
 
 
 
 
 
52
 
53
+ dataset = load_and_prepare_dataset()
54
+ tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
 
55
 
 
 
 
 
56
  training_args = TrainingArguments(
57
  output_dir=SAVE_PATH,
58
  num_train_epochs=3,
59
+ per_device_train_batch_size=4,
60
  save_steps=50,
61
  logging_steps=10,
62
  save_total_limit=1,
 
64
  report_to="none",
65
  )
66
 
67
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
68
+
69
  trainer = Trainer(
70
  model=model,
71
  args=training_args,
 
81
 
82
  if __name__ == "__main__":
83
  train()
84
+