CZerion's picture
Create train.py
846c17a verified
from transformers import TrainingArguments, Trainer
import torch
from datasets import DatasetDict
from crop_disease_monitor.datasets import load_disease_dataset
from crop_disease_monitor.transforms import build_transforms
from crop_disease_monitor.model.architecture import DiseaseModel
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits.argmax(-1)
acc = (preds == labels).astype(float).mean()
return {"accuracy": acc}
def main():
# Load data
train_ds, val_ds = load_disease_dataset()
# Build transforms
transforms = build_transforms()
def preprocess(batch):
imgs = [transforms(img.convert("RGB")) for img in batch["image"]]
return {"pixel_values": imgs, "labels": batch["label"]}
train_ds = train_ds.with_transform(preprocess)
val_ds = val_ds.with_transform(preprocess)
ds = DatasetDict({"train": train_ds, "validation": val_ds})
# Instantiate model
num_classes = len(train_ds.features["label"].names)
model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes)
# Training arguments
args = TrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
)
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=ds["train"],
eval_dataset=ds["validation"],
compute_metrics=compute_metrics,
tokenizer=None # not needed for image tasks
)
# Train
trainer.train()
if __name__ == "__main__":
main()