Spaces:
Running
Running
File size: 1,703 Bytes
846c17a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import TrainingArguments, Trainer
import torch
from datasets import DatasetDict
from crop_disease_monitor.datasets import load_disease_dataset
from crop_disease_monitor.transforms import build_transforms
from crop_disease_monitor.model.architecture import DiseaseModel
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits.argmax(-1)
acc = (preds == labels).astype(float).mean()
return {"accuracy": acc}
def main():
# Load data
train_ds, val_ds = load_disease_dataset()
# Build transforms
transforms = build_transforms()
def preprocess(batch):
imgs = [transforms(img.convert("RGB")) for img in batch["image"]]
return {"pixel_values": imgs, "labels": batch["label"]}
train_ds = train_ds.with_transform(preprocess)
val_ds = val_ds.with_transform(preprocess)
ds = DatasetDict({"train": train_ds, "validation": val_ds})
# Instantiate model
num_classes = len(train_ds.features["label"].names)
model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes)
# Training arguments
args = TrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
)
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=ds["train"],
eval_dataset=ds["validation"],
compute_metrics=compute_metrics,
tokenizer=None # not needed for image tasks
)
# Train
trainer.train()
if __name__ == "__main__":
main() |