Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import torch | |
from datasets import Dataset | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
Trainer, | |
TrainingArguments, | |
DataCollatorWithPadding, | |
) | |
from peft import get_peft_model, LoraConfig, TaskType | |
from huggingface_hub import HfApi | |
# === Path Setup === | |
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
DATA_PATH = os.path.join(BASE_DIR, "data", "pregnancy_dataset.json") | |
SAVE_PATH = os.path.join(BASE_DIR, "data", "best_model") | |
os.makedirs(SAVE_PATH, exist_ok=True) | |
HF_REPO_ID = "Estherrr777/Mila-Wellnest-Backend" # Change this if needed | |
# === 1. Load and preprocess dataset === | |
def load_and_prepare_dataset(): | |
with open(DATA_PATH, "r") as f: | |
data = json.load(f) | |
label_map = {"low risk": 0, "mid risk": 1, "high risk": 2} | |
def preprocess(example): | |
label_str = str(example["RiskLevel"]).strip().lower() | |
prompt = ( | |
f"A {example['Age']}-year-old pregnant individual has a systolic " | |
f"blood pressure of {example['SystolicBP']}, diastolic of " | |
f"{example['DiastolicBP']}, blood sugar of {example['BS']}, " | |
f"body temperature of {example['BodyTemp']}, and heart rate of " | |
f"{example['HeartRate']}. What is the pregnancy risk level?" | |
) | |
return {"text": prompt, "label": label_map[label_str]} | |
raw_ds = Dataset.from_list(data) | |
print(f"📦 Loaded dataset: {len(raw_ds)} total samples") | |
return raw_ds.map(preprocess) | |
# === 2. Tokenization === | |
def tokenize_dataset(dataset, tokenizer): | |
print(f"🔠 Tokenizing {len(dataset)} samples...") | |
tokenized = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=128), batched=True) | |
print(f"✅ Tokenized: now has keys {tokenized.column_names}") | |
return tokenized | |
# === 3. Accuracy Metric === | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
preds = np.argmax(logits, axis=1) | |
return {"accuracy": (preds == labels).mean()} | |
# === 4. Upload model files to Hugging Face === | |
def upload_model_to_hf(token, save_path): | |
api = HfApi() | |
print("🚀 Uploading model files to Hugging Face Hub...") | |
for fname in os.listdir(save_path): | |
fpath = os.path.join(save_path, fname) | |
api.upload_file( | |
token=token, | |
path_or_fileobj=fpath, | |
path_in_repo=f".data/best_model/{fname}", | |
repo_id=HF_REPO_ID, | |
repo_type="space", | |
) | |
print("✅ Upload complete.") | |
# === 5. Main training routine === | |
def train(): | |
print("🔧 Loading and splitting dataset...") | |
ds = load_and_prepare_dataset() | |
# First split: 90% train/val, 10% test | |
split1 = ds.train_test_split(test_size=0.1, seed=42) | |
tv, test_ds = split1["train"], split1["test"] | |
# Second split: from the 90%, take ~11.1% as val → ends up ~10% of full | |
split2 = tv.train_test_split(test_size=0.1111, seed=42) | |
train_ds, val_ds = split2["train"], split2["test"] | |
print("📊 Dataset sizes:") | |
print(f" ➤ Train: {len(train_ds)}") | |
print(f" ➤ Val: {len(val_ds)}") | |
print(f" ➤ Test: {len(test_ds)}") | |
total = len(train_ds) + len(val_ds) + len(test_ds) | |
print(f" ➤ Total: {total} (expected ≈ original JSON length)") | |
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased") | |
train_ds = tokenize_dataset(train_ds, tokenizer) | |
val_ds = tokenize_dataset(val_ds, tokenizer) | |
test_ds = tokenize_dataset(test_ds, tokenizer) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"google/mobilebert-uncased", num_labels=3 | |
) | |
peft_config = LoraConfig( | |
task_type=TaskType.SEQ_CLS, | |
inference_mode=False, | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
target_modules=["query", "value"], | |
) | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() | |
training_args = TrainingArguments( | |
output_dir=SAVE_PATH, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
num_train_epochs=4, | |
learning_rate=2e-5, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
weight_decay=0.01, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
save_total_limit=1, | |
report_to="none", # Disable W&B / other trackers | |
) | |
print(f"📦 Batch size (train): {training_args.per_device_train_batch_size}") | |
print(f"📦 Batch size (eval): {training_args.per_device_eval_batch_size}") | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=val_ds, | |
tokenizer=tokenizer, | |
data_collator=DataCollatorWithPadding(tokenizer), | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
print("🧪 Evaluating on test set...") | |
test_result = trainer.evaluate(test_ds) | |
print(f"✅ Final Test Accuracy: {test_result['eval_accuracy']:.4f}") | |
hf_token = os.getenv("R_HF_TOKEN") | |
if hf_token: | |
upload_model_to_hf(hf_token, SAVE_PATH) | |
else: | |
print("⚠️ R_HF_TOKEN not set — skipping Hugging Face upload.") | |
if __name__ == "__main__": | |
train() | |
# Gemma model version | |
# import json | |
# import os | |
# from transformers import ( | |
# AutoTokenizer, | |
# AutoModelForSequenceClassification, | |
# TrainingArguments, | |
# Trainer, | |
# DataCollatorWithPadding, | |
# ) | |
# import torch | |
# from datasets import Dataset | |
# # -------- Settings -------- | |
# MODEL_NAME = "distilbert-base-uncased" | |
# DATA_PATH = "./backend/data/pregnancy_dataset.json" | |
# SAVE_PATH = "./data/best_model" | |
# os.makedirs(SAVE_PATH, exist_ok=True) | |
# # -------- Load and Preprocess Dataset -------- | |
# def load_and_prepare_datI aset(): | |
# with open(DATA_PATH, "r") as f: | |
# data = json.load(f) | |
# # Map risk levels to integer labels | |
# label_map = {"low risk": 0, "medium risk": 1, "high risk": 2} | |
# def preprocess(example): | |
# prompt = ( | |
# f"Age: {example['Age']}, SystolicBP: {example['SystolicBdon’t P']}, " | |
# f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, " | |
# f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. " | |
# f"Predict the Risk Level." | |
# ) | |
# # Ensure consistent and safe label mapping | |
# label_str = str(example.get("RiskLevel", "")).lower() | |
# label = label_map.get(label_str, 0) | |
# return {"text": prompt, "label": label} | |
# dataset = Dataset.from_list(data) | |
# return dataset.map(preprocess) | |
# # -------- Tokenization -------- | |
# def tokenize_function(example, tokenizer): | |
# tokens = tokenizer( | |
# example["text"], | |
# truncation=True, | |
# padding=True, | |
# max_length=64, | |
# ) | |
# tokens["label"] = exdon’t ample["label"] | |
# return tokens | |
# # -------- Main Training Function -------- | |
# def train(): | |
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
# dataset = load_and_prepare_dataset() | |
# # Tokenize dataset | |
# tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False) | |
# # Remove any non-tensor-compatible fields | |
# tokenized_dataset = tokenized_dataset.remove_columns( | |
# [col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]] | |
# ) | |
# # Optional sanity check | |
# print("🔎 Sample tokenized example:", tokenized_dataset[0]) | |
# training_args = TrainingArguments( | |
# output_dir=SAVE_PATH, | |
# num_train_epochs=3, | |
# per_device_train_batch_size=4, | |
# save_steps=50, | |
# logging_steps=10, | |
# save_total_limit=1, | |
# remove_unused_columns=False, | |
# report_to="none", | |
# ) | |
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
# trainer = Trainer( | |
# model=model, | |
# args=training_args, | |
# train_dataset=tokenized_dataset, | |
# tokenizer=tokenizer, | |
# data_collator=data_collator, | |
# ) | |
# trainer.train() | |
# trainer.save_model(SAVE_PATH) | |
# tokenizer.save_pretrained(SAVE_PATH) | |
# print("✅ Fine-tuned model saved!") | |
# # -------- Entry Point -------- | |
# if __name__ == "__main__": | |
# train() | |