File size: 2,958 Bytes
8d218d1
 
 
 
4d5a067
8d218d1
 
c7e776a
8d218d1
 
 
 
 
3cfdfb0
8d218d1
656af83
 
 
 
8d218d1
4d5a067
 
8d218d1
 
 
4d5a067
78a7529
 
4d5a067
 
 
 
 
 
 
8c837b6
 
 
4d5a067
 
 
 
b5dd623
4d5a067
 
f36ba5d
4d5a067
8d218d1
4d5a067
3cfdfb0
8d218d1
8c837b6
f36ba5d
 
8d218d1
8dc105c
8d218d1
4d5a067
8d218d1
4d5a067
8c837b6
 
4507f84
8c837b6
 
 
 
 
 
 
 
8d218d1
 
 
 
4d5a067
8d218d1
 
 
 
 
 
 
4d5a067
 
8d218d1
 
 
 
 
 
 
 
 
 
 
 
 
8dc105c
8d218d1
8dc105c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import os
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import torch
from datasets import Dataset

# -------- Settings --------
MODEL_NAME = "distilbert-base-uncased"
DATA_PATH = "./backend/data/pregnancy_dataset.json"
SAVE_PATH = "./data/best_model"

os.makedirs(SAVE_PATH, exist_ok=True)


# -------- Load and Preprocess Dataset --------
def load_and_prepare_dataset():
    with open(DATA_PATH, "r") as f:
        data = json.load(f)

    # Map risk levels to integer labels
    label_map = {"low risk": 0, "medium risk": 1, "high risk": 2}

    def preprocess(example):
        prompt = (
            f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, "
            f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, "
            f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
            f"Predict the Risk Level."
        )
        # Ensure consistent and safe label mapping
        label_str = str(example.get("RiskLevel", "")).lower()
        label = label_map.get(label_str, 0)
        return {"text": prompt, "label": label}

    dataset = Dataset.from_list(data)
    return dataset.map(preprocess)

# -------- Tokenization --------
def tokenize_function(example, tokenizer):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        padding=True,
        max_length=64,
    )
    tokens["label"] = example["label"]
    return tokens

# -------- Main Training Function --------
def train():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)

    dataset = load_and_prepare_dataset()

    # Tokenize dataset
    tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False)

    # Remove any non-tensor-compatible fields
    tokenized_dataset = tokenized_dataset.remove_columns(
        [col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]]
    )

    # Optional sanity check
    print("🔎 Sample tokenized example:", tokenized_dataset[0])

    training_args = TrainingArguments(
        output_dir=SAVE_PATH,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        save_steps=50,
        logging_steps=10,
        save_total_limit=1,
        remove_unused_columns=False,
        report_to="none",
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()
    trainer.save_model(SAVE_PATH)
    tokenizer.save_pretrained(SAVE_PATH)
    print("✅ Fine-tuned model saved!")

# -------- Entry Point --------
if __name__ == "__main__":
    train()