from transformers import TrainingArguments, Trainer import torch from datasets import DatasetDict from crop_disease_monitor.datasets import load_disease_dataset from crop_disease_monitor.transforms import build_transforms from crop_disease_monitor.model.architecture import DiseaseModel def compute_metrics(eval_pred): logits, labels = eval_pred preds = logits.argmax(-1) acc = (preds == labels).astype(float).mean() return {"accuracy": acc} def main(): # Load data train_ds, val_ds = load_disease_dataset() # Build transforms transforms = build_transforms() def preprocess(batch): imgs = [transforms(img.convert("RGB")) for img in batch["image"]] return {"pixel_values": imgs, "labels": batch["label"]} train_ds = train_ds.with_transform(preprocess) val_ds = val_ds.with_transform(preprocess) ds = DatasetDict({"train": train_ds, "validation": val_ds}) # Instantiate model num_classes = len(train_ds.features["label"].names) model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes) # Training arguments args = TrainingArguments( output_dir="./outputs", per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=10, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=3e-5, ) # Trainer trainer = Trainer( model=model, args=args, train_dataset=ds["train"], eval_dataset=ds["validation"], compute_metrics=compute_metrics, tokenizer=None # not needed for image tasks ) # Train trainer.train() if __name__ == "__main__": main()