Chanlefe commited on
Commit
c66d381
·
verified ·
1 Parent(s): 21bab57

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +40 -0
train.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
2
+ from datasets import load_dataset
3
+ import torch
4
+
5
+ # Load dataset from the 'dataset' folder
6
+ dataset = load_dataset("imagefolder", data_dir="dataset", split="train", label_column="label")
7
+
8
+ # Load model and processor
9
+ model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex", num_labels=2)
10
+ processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
11
+
12
+ # Preprocess the dataset
13
+ def transform(example):
14
+ inputs = processor(images=example["image"], return_tensors="pt")
15
+ inputs["label"] = example["label"]
16
+ return inputs
17
+
18
+ dataset = dataset.map(transform, batched=True)
19
+
20
+ # Training setup
21
+ training_args = TrainingArguments(
22
+ output_dir="./siglip2-meme-classifier",
23
+ per_device_train_batch_size=8,
24
+ num_train_epochs=3,
25
+ save_steps=100,
26
+ logging_dir="./logs",
27
+ )
28
+
29
+ trainer = Trainer(
30
+ model=model,
31
+ args=training_args,
32
+ train_dataset=dataset,
33
+ )
34
+
35
+ # Start training
36
+ trainer.train()
37
+
38
+ # Save the fine-tuned model and processor
39
+ model.save_pretrained("model")
40
+ processor.save_pretrained("model")