Estherrr777's picture
Update backend/app/train.py
4d05e89 verified
import os
import json
import numpy as np
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
DataCollatorWithPadding,
)
from peft import get_peft_model, LoraConfig, TaskType
from huggingface_hub import HfApi
# === Path Setup ===
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
DATA_PATH = os.path.join(BASE_DIR, "data", "pregnancy_dataset.json")
SAVE_PATH = os.path.join(BASE_DIR, "data", "best_model")
os.makedirs(SAVE_PATH, exist_ok=True)
HF_REPO_ID = "Estherrr777/Mila-Wellnest-Backend" # Change this if needed
# === 1. Load and preprocess dataset ===
def load_and_prepare_dataset():
with open(DATA_PATH, "r") as f:
data = json.load(f)
label_map = {"low risk": 0, "mid risk": 1, "high risk": 2}
def preprocess(example):
label_str = str(example["RiskLevel"]).strip().lower()
prompt = (
f"A {example['Age']}-year-old pregnant individual has a systolic "
f"blood pressure of {example['SystolicBP']}, diastolic of "
f"{example['DiastolicBP']}, blood sugar of {example['BS']}, "
f"body temperature of {example['BodyTemp']}, and heart rate of "
f"{example['HeartRate']}. What is the pregnancy risk level?"
)
return {"text": prompt, "label": label_map[label_str]}
raw_ds = Dataset.from_list(data)
print(f"📦 Loaded dataset: {len(raw_ds)} total samples")
return raw_ds.map(preprocess)
# === 2. Tokenization ===
def tokenize_dataset(dataset, tokenizer):
print(f"🔠 Tokenizing {len(dataset)} samples...")
tokenized = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=128), batched=True)
print(f"✅ Tokenized: now has keys {tokenized.column_names}")
return tokenized
# === 3. Accuracy Metric ===
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=1)
return {"accuracy": (preds == labels).mean()}
# === 4. Upload model files to Hugging Face ===
def upload_model_to_hf(token, save_path):
api = HfApi()
print("🚀 Uploading model files to Hugging Face Hub...")
for fname in os.listdir(save_path):
fpath = os.path.join(save_path, fname)
api.upload_file(
token=token,
path_or_fileobj=fpath,
path_in_repo=f".data/best_model/{fname}",
repo_id=HF_REPO_ID,
repo_type="space",
)
print("✅ Upload complete.")
# === 5. Main training routine ===
def train():
print("🔧 Loading and splitting dataset...")
ds = load_and_prepare_dataset()
# First split: 90% train/val, 10% test
split1 = ds.train_test_split(test_size=0.1, seed=42)
tv, test_ds = split1["train"], split1["test"]
# Second split: from the 90%, take ~11.1% as val → ends up ~10% of full
split2 = tv.train_test_split(test_size=0.1111, seed=42)
train_ds, val_ds = split2["train"], split2["test"]
print("📊 Dataset sizes:")
print(f" ➤ Train: {len(train_ds)}")
print(f" ➤ Val: {len(val_ds)}")
print(f" ➤ Test: {len(test_ds)}")
total = len(train_ds) + len(val_ds) + len(test_ds)
print(f" ➤ Total: {total} (expected ≈ original JSON length)")
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
train_ds = tokenize_dataset(train_ds, tokenizer)
val_ds = tokenize_dataset(val_ds, tokenizer)
test_ds = tokenize_dataset(test_ds, tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(
"google/mobilebert-uncased", num_labels=3
)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["query", "value"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
training_args = TrainingArguments(
output_dir=SAVE_PATH,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=4,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
save_total_limit=1,
report_to="none", # Disable W&B / other trackers
)
print(f"📦 Batch size (train): {training_args.per_device_train_batch_size}")
print(f"📦 Batch size (eval): {training_args.per_device_eval_batch_size}")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics,
)
trainer.train()
print("🧪 Evaluating on test set...")
test_result = trainer.evaluate(test_ds)
print(f"✅ Final Test Accuracy: {test_result['eval_accuracy']:.4f}")
hf_token = os.getenv("R_HF_TOKEN")
if hf_token:
upload_model_to_hf(hf_token, SAVE_PATH)
else:
print("⚠️ R_HF_TOKEN not set — skipping Hugging Face upload.")
if __name__ == "__main__":
train()
# Gemma model version
# import json
# import os
# from transformers import (
# AutoTokenizer,
# AutoModelForSequenceClassification,
# TrainingArguments,
# Trainer,
# DataCollatorWithPadding,
# )
# import torch
# from datasets import Dataset
# # -------- Settings --------
# MODEL_NAME = "distilbert-base-uncased"
# DATA_PATH = "./backend/data/pregnancy_dataset.json"
# SAVE_PATH = "./data/best_model"
# os.makedirs(SAVE_PATH, exist_ok=True)
# # -------- Load and Preprocess Dataset --------
# def load_and_prepare_datI aset():
# with open(DATA_PATH, "r") as f:
# data = json.load(f)
# # Map risk levels to integer labels
# label_map = {"low risk": 0, "medium risk": 1, "high risk": 2}
# def preprocess(example):
# prompt = (
# f"Age: {example['Age']}, SystolicBP: {example['SystolicBdon’t P']}, "
# f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, "
# f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. "
# f"Predict the Risk Level."
# )
# # Ensure consistent and safe label mapping
# label_str = str(example.get("RiskLevel", "")).lower()
# label = label_map.get(label_str, 0)
# return {"text": prompt, "label": label}
# dataset = Dataset.from_list(data)
# return dataset.map(preprocess)
# # -------- Tokenization --------
# def tokenize_function(example, tokenizer):
# tokens = tokenizer(
# example["text"],
# truncation=True,
# padding=True,
# max_length=64,
# )
# tokens["label"] = exdon’t ample["label"]
# return tokens
# # -------- Main Training Function --------
# def train():
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
# dataset = load_and_prepare_dataset()
# # Tokenize dataset
# tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False)
# # Remove any non-tensor-compatible fields
# tokenized_dataset = tokenized_dataset.remove_columns(
# [col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]]
# )
# # Optional sanity check
# print("🔎 Sample tokenized example:", tokenized_dataset[0])
# training_args = TrainingArguments(
# output_dir=SAVE_PATH,
# num_train_epochs=3,
# per_device_train_batch_size=4,
# save_steps=50,
# logging_steps=10,
# save_total_limit=1,
# remove_unused_columns=False,
# report_to="none",
# )
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=tokenized_dataset,
# tokenizer=tokenizer,
# data_collator=data_collator,
# )
# trainer.train()
# trainer.save_model(SAVE_PATH)
# tokenizer.save_pretrained(SAVE_PATH)
# print("✅ Fine-tuned model saved!")
# # -------- Entry Point --------
# if __name__ == "__main__":
# train()