Spaces:
Running
Running
from transformers import TrainingArguments, Trainer | |
import torch | |
from datasets import DatasetDict | |
from crop_disease_monitor.datasets import load_disease_dataset | |
from crop_disease_monitor.transforms import build_transforms | |
from crop_disease_monitor.model.architecture import DiseaseModel | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
preds = logits.argmax(-1) | |
acc = (preds == labels).astype(float).mean() | |
return {"accuracy": acc} | |
def main(): | |
# Load data | |
train_ds, val_ds = load_disease_dataset() | |
# Build transforms | |
transforms = build_transforms() | |
def preprocess(batch): | |
imgs = [transforms(img.convert("RGB")) for img in batch["image"]] | |
return {"pixel_values": imgs, "labels": batch["label"]} | |
train_ds = train_ds.with_transform(preprocess) | |
val_ds = val_ds.with_transform(preprocess) | |
ds = DatasetDict({"train": train_ds, "validation": val_ds}) | |
# Instantiate model | |
num_classes = len(train_ds.features["label"].names) | |
model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes) | |
# Training arguments | |
args = TrainingArguments( | |
output_dir="./outputs", | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=10, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=3e-5, | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=ds["train"], | |
eval_dataset=ds["validation"], | |
compute_metrics=compute_metrics, | |
tokenizer=None # not needed for image tasks | |
) | |
# Train | |
trainer.train() | |
if __name__ == "__main__": | |
main() |