Spaces:
Runtime error
Runtime error
import json | |
import os | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
TrainingArguments, | |
Trainer, | |
DataCollatorWithPadding, | |
) | |
import torch | |
from datasets import Dataset | |
# -------- Settings -------- | |
MODEL_NAME = "distilbert-base-uncased" | |
DATA_PATH = "./backend/data/pregnancy_dataset.json" | |
SAVE_PATH = "./data/best_model" | |
os.makedirs(SAVE_PATH, exist_ok=True) | |
# -------- Load and Preprocess Dataset -------- | |
def load_and_prepare_dataset(): | |
with open(DATA_PATH, "r") as f: | |
data = json.load(f) | |
# Map risk levels to integer labels | |
label_map = {"low risk": 0, "medium risk": 1, "high risk": 2} | |
def preprocess(example): | |
prompt = ( | |
f"Age: {example['Age']}, SystolicBP: {example['SystolicBP']}, " | |
f"DiastolicBP: {example['DiastolicBP']}, BS: {example['BS']}, " | |
f"BodyTemp: {example['BodyTemp']}, HeartRate: {example['HeartRate']}. " | |
f"Predict the Risk Level." | |
) | |
# Ensure consistent and safe label mapping | |
label_str = str(example.get("RiskLevel", "")).lower() | |
label = label_map.get(label_str, 0) | |
return {"text": prompt, "label": label} | |
dataset = Dataset.from_list(data) | |
return dataset.map(preprocess) | |
# -------- Tokenization -------- | |
def tokenize_function(example, tokenizer): | |
tokens = tokenizer( | |
example["text"], | |
truncation=True, | |
padding=True, | |
max_length=64, | |
) | |
tokens["label"] = example["label"] | |
return tokens | |
# -------- Main Training Function -------- | |
def train(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
dataset = load_and_prepare_dataset() | |
# Tokenize dataset | |
tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=False) | |
# Remove any non-tensor-compatible fields | |
tokenized_dataset = tokenized_dataset.remove_columns( | |
[col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask", "label"]] | |
) | |
# Optional sanity check | |
print("π Sample tokenized example:", tokenized_dataset[0]) | |
training_args = TrainingArguments( | |
output_dir=SAVE_PATH, | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
save_steps=50, | |
logging_steps=10, | |
save_total_limit=1, | |
remove_unused_columns=False, | |
report_to="none", | |
) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset, | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
trainer.train() | |
trainer.save_model(SAVE_PATH) | |
tokenizer.save_pretrained(SAVE_PATH) | |
print("β Fine-tuned model saved!") | |
# -------- Entry Point -------- | |
if __name__ == "__main__": | |
train() |