CZerion commited on
Commit
846c17a
·
verified ·
1 Parent(s): 49a5910

Create train.py

Browse files
Files changed (1) hide show
  1. model/train.py +59 -0
model/train.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrainingArguments, Trainer
2
+ import torch
3
+ from datasets import DatasetDict
4
+ from crop_disease_monitor.datasets import load_disease_dataset
5
+ from crop_disease_monitor.transforms import build_transforms
6
+ from crop_disease_monitor.model.architecture import DiseaseModel
7
+
8
+ def compute_metrics(eval_pred):
9
+ logits, labels = eval_pred
10
+ preds = logits.argmax(-1)
11
+ acc = (preds == labels).astype(float).mean()
12
+ return {"accuracy": acc}
13
+
14
+
15
+ def main():
16
+ # Load data
17
+ train_ds, val_ds = load_disease_dataset()
18
+
19
+ # Build transforms
20
+ transforms = build_transforms()
21
+ def preprocess(batch):
22
+ imgs = [transforms(img.convert("RGB")) for img in batch["image"]]
23
+ return {"pixel_values": imgs, "labels": batch["label"]}
24
+
25
+ train_ds = train_ds.with_transform(preprocess)
26
+ val_ds = val_ds.with_transform(preprocess)
27
+
28
+ ds = DatasetDict({"train": train_ds, "validation": val_ds})
29
+
30
+ # Instantiate model
31
+ num_classes = len(train_ds.features["label"].names)
32
+ model = DiseaseModel("google/vit-base-patch16-224-in21k", num_classes)
33
+
34
+ # Training arguments
35
+ args = TrainingArguments(
36
+ output_dir="./outputs",
37
+ per_device_train_batch_size=16,
38
+ per_device_eval_batch_size=16,
39
+ num_train_epochs=10,
40
+ evaluation_strategy="epoch",
41
+ save_strategy="epoch",
42
+ learning_rate=3e-5,
43
+ )
44
+
45
+ # Trainer
46
+ trainer = Trainer(
47
+ model=model,
48
+ args=args,
49
+ train_dataset=ds["train"],
50
+ eval_dataset=ds["validation"],
51
+ compute_metrics=compute_metrics,
52
+ tokenizer=None # not needed for image tasks
53
+ )
54
+
55
+ # Train
56
+ trainer.train()
57
+
58
+ if __name__ == "__main__":
59
+ main()