Spaces:
Running
Running
Create train.py
Browse files- model/train.py +59 -0
model/train.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TrainingArguments, Trainer
|
2 |
+
import torch
|
3 |
+
from datasets import DatasetDict
|
4 |
+
from crop_disease_monitor.datasets import load_disease_dataset
|
5 |
+
from crop_disease_monitor.transforms import build_transforms
|
6 |
+
from crop_disease_monitor.model.architecture import DiseaseModel
|
7 |
+
|
8 |
+
def compute_metrics(eval_pred):
|
9 |
+
logits, labels = eval_pred
|
10 |
+
preds = logits.argmax(-1)
|
11 |
+
acc = (preds == labels).astype(float).mean()
|
12 |
+
return {"accuracy": acc}
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# Load data
|
17 |
+
train_ds, val_ds = load_disease_dataset()
|
18 |
+
|
19 |
+
# Build transforms
|
20 |
+
transforms = build_transforms()
|
21 |
+
def preprocess(batch):
|
22 |
+
imgs = [transforms(img.convert("RGB")) for img in batch["image"]]
|
23 |
+
return {"pixel_values": imgs, "labels": batch["label"]}
|
24 |
+
|
25 |
+
train_ds = train_ds.with_transform(preprocess)
|
26 |
+
val_ds = val_ds.with_transform(preprocess)
|
27 |
+
|
28 |
+
ds = DatasetDict({"train": train_ds, "validation": val_ds})
|
29 |
+
|
30 |
+
# Instantiate model
|
31 |
+
num_classes = len(train_ds.features["label"].names)
|
32 |
+
model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes)
|
33 |
+
|
34 |
+
# Training arguments
|
35 |
+
args = TrainingArguments(
|
36 |
+
output_dir="./outputs",
|
37 |
+
per_device_train_batch_size=16,
|
38 |
+
per_device_eval_batch_size=16,
|
39 |
+
num_train_epochs=10,
|
40 |
+
evaluation_strategy="epoch",
|
41 |
+
save_strategy="epoch",
|
42 |
+
learning_rate=3e-5,
|
43 |
+
)
|
44 |
+
|
45 |
+
# Trainer
|
46 |
+
trainer = Trainer(
|
47 |
+
model=model,
|
48 |
+
args=args,
|
49 |
+
train_dataset=ds["train"],
|
50 |
+
eval_dataset=ds["validation"],
|
51 |
+
compute_metrics=compute_metrics,
|
52 |
+
tokenizer=None # not needed for image tasks
|
53 |
+
)
|
54 |
+
|
55 |
+
# Train
|
56 |
+
trainer.train()
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
main()
|