import json import os from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, ) import torch from datasets import Dataset # -------- Settings -------- MODEL_NAME = "distilbert-base-uncased" DATA_PATH = "./backend/data/pregnancy_dataset.json" SAVE_PATH = "./data/best_model" os.makedirs(SAVE_PATH, exist_ok=True) # -------- Load and Preprocess Dataset -------- def load_and_prepare_dataset(): with open(DATA_PATH, "r") as f: data = json.load(f) # Map risk levels to integer labels label_map = {"low risk": 0, "medium risk": 1, "high risk": 2} def preprocess(example): prompt = ( f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, " f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, " f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. " f"Predict the Risk Level." ) # Ensure consistent and safe label mapping label_str = str(example.get("RiskLevel", "")).lower() label = label_map.get(label_str, 0) return {"text": prompt, "label": label} dataset = Dataset.from_list(data) return dataset.map(preprocess) # -------- Tokenization -------- def tokenize_function(example, tokenizer): tokens = tokenizer( example["text"], truncation=True, padding=True, max_length=64, ) tokens["label"] = example["label"] return tokens # -------- Main Training Function -------- def train(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) dataset = load_and_prepare_dataset() # Tokenize dataset tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False) # Remove any non-tensor-compatible fields tokenized_dataset = tokenized_dataset.remove_columns( [col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]] ) # Optional sanity check print("🔎 Sample tokenized example:", tokenized_dataset[0]) training_args = TrainingArguments( output_dir=SAVE_PATH, num_train_epochs=3, per_device_train_batch_size=4, save_steps=50, logging_steps=10, save_total_limit=1, remove_unused_columns=False, report_to="none", ) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, ) trainer.train() trainer.save_model(SAVE_PATH) tokenizer.save_pretrained(SAVE_PATH) print("✅ Fine-tuned model saved!") # -------- Entry Point -------- if __name__ == "__main__": train()