Estherrr777's picture
Update backend/app/train.py
656af83 verified
raw
history blame
2.96 kB
import json
import os
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
)
import torch
from datasets import Dataset
# -------- Settings --------
MODEL_NAME = "distilbert-base-uncased"
DATA_PATH = "./backend/data/pregnancy_dataset.json"
SAVE_PATH = "./data/best_model"
os.makedirs(SAVE_PATH, exist_ok=True)
# -------- Load and Preprocess Dataset --------
def load_and_prepare_dataset():
with open(DATA_PATH, "r") as f:
data = json.load(f)
# Map risk levels to integer labels
label_map = {"low risk": 0, "medium risk": 1, "high risk": 2}
def preprocess(example):
prompt = (
f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, "
f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, "
f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
f"Predict the Risk Level."
)
# Ensure consistent and safe label mapping
label_str = str(example.get("RiskLevel", "")).lower()
label = label_map.get(label_str, 0)
return {"text": prompt, "label": label}
dataset = Dataset.from_list(data)
return dataset.map(preprocess)
# -------- Tokenization --------
def tokenize_function(example, tokenizer):
tokens = tokenizer(
example["text"],
truncation=True,
padding=True,
max_length=64,
)
tokens["label"] = example["label"]
return tokens
# -------- Main Training Function --------
def train():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
dataset = load_and_prepare_dataset()
# Tokenize dataset
tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False)
# Remove any non-tensor-compatible fields
tokenized_dataset = tokenized_dataset.remove_columns(
[col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]]
)
# Optional sanity check
print("πŸ”Ž Sample tokenized example:", tokenized_dataset[0])
training_args = TrainingArguments(
output_dir=SAVE_PATH,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=50,
logging_steps=10,
save_total_limit=1,
remove_unused_columns=False,
report_to="none",
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(SAVE_PATH)
tokenizer.save_pretrained(SAVE_PATH)
print("βœ… Fine-tuned model saved!")
# -------- Entry Point --------
if __name__ == "__main__":
train()