CZerion commited on
Commit
873e161
·
verified ·
1 Parent(s): cf62d40

Update transforms.py

Browse files
Files changed (1) hide show
  1. transforms.py +10 -9
transforms.py CHANGED
@@ -1,12 +1,13 @@
1
  from transformers import AutoImageProcessor
2
  import torchvision.transforms as T
3
 
4
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
5
- augment = T.Compose([
6
- T.RandomResizedCrop(224),
7
- T.RandomHorizontalFlip(),
8
- T.ColorJitter(0.2,0.2,0.2,0.1),
9
- T.RandomRotation(20),
10
- T.ToTensor(),
11
- T.Normalize(mean=processor.image_mean, std=processor.image_std)
12
- ])
 
 
1
  from transformers import AutoImageProcessor
2
  import torchvision.transforms as T
3
 
4
+ def build_transforms(backbone_model="google/vit-base-patch16-224-in21k"):
5
+ processor = AutoImageProcessor.from_pretrained(backbone_model)
6
+ return T.Compose([
7
+ T.RandomResizedCrop(224),
8
+ T.RandomHorizontalFlip(),
9
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
10
+ T.RandomRotation(20),
11
+ T.ToTensor(),
12
+ T.Normalize(mean=processor.image_mean, std=processor.image_std)
13
+ ])