Update vit_model_test.py
Browse files- vit_model_test.py +2 -1
vit_model_test.py
CHANGED
|
@@ -29,13 +29,14 @@ if __name__ == "__main__":
|
|
| 29 |
# Check for GPU availability
|
| 30 |
device = torch.device('cuda')
|
| 31 |
|
|
|
|
| 32 |
# Load the pre-trained ViT model and move it to GPU
|
| 33 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
| 34 |
|
| 35 |
|
| 36 |
|
| 37 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
| 38 |
-
#
|
| 39 |
preprocess = transforms.Compose([
|
| 40 |
transforms.Resize((224, 224)),
|
| 41 |
transforms.ToTensor()
|
|
|
|
| 29 |
# Check for GPU availability
|
| 30 |
device = torch.device('cuda')
|
| 31 |
|
| 32 |
+
#this code runs only with nvidia gpu
|
| 33 |
# Load the pre-trained ViT model and move it to GPU
|
| 34 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
| 35 |
|
| 36 |
|
| 37 |
|
| 38 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
| 39 |
+
# resize image and make it a tensor (add dimension)
|
| 40 |
preprocess = transforms.Compose([
|
| 41 |
transforms.Resize((224, 224)),
|
| 42 |
transforms.ToTensor()
|