Spaces:
Sleeping
Sleeping
File size: 4,265 Bytes
3a664f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
Trains a PyTorch image classification model using device-agnostic code.
"""
import torch
import torchvision
from torch import nn
import data_setup
import engine
import model_builder
import utils
# Setup hyperparameters
NUM_EPOCHS = [7, 4, 3] # [feature Extraction, Fine Tuning Part 1, Fine Tuning Part 2]
BATCH_SIZE = 32
LEARNING_RATE = [0.001, 0.0001, 0.00001] # [feature Extraction, Fine Tuning Part 1, Fine Tuning Part 2]
# Setup target device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create transforms
data_transform = torchvision.models.EfficientNet_B2_Weights.DEFAULT.transforms()
# ------------------------------------------ DataLoaders ----------------------------------------------------------------------#
# Create DataLoaders with help from data_setup.py
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
transform=data_transform,
batch_size=BATCH_SIZE)
print("dataloaders created")
# ------------------------------------------ Model ----------------------------------------------------------------------------#
# Create model with help from model_builder.py
model = model_builder.model_build(device=device)
print("model created")
# ------------------------------------------ Feature Extraction ---------------------------------------------------------------#
# Setting all parameters to not-trainable
for params in model.parameters():
params.requires_grad = False
# Changing Classification layer
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=len(class_names)))
# model.classifier
# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE[0])
# Start training with help from engine.py
feature_extraction_results=engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=NUM_EPOCHS[0],
device=device)
print(feature_extraction_results)
# ------------------------------------------ Fine Tuning Part 1 ---------------------------------------------------------------#
# Setting models upper layer un froze
for params in model.features[5:].parameters():
params.requires_grad = True
for m in model.modules(): # Making the BatchNorm2d froze
if isinstance(m, nn.BatchNorm2d):
m.track_running_stats = False
m.eval()
# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE[1])
# Start training with help from engine.py
fine_tuning_p1_results=engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=NUM_EPOCHS[1],
device=device)
print(fine_tuning_p1_results)
# ------------------------------------------ Fine Tuning Part 2 ---------------------------------------------------------------#
# Setting models upper layer un froze
for params in model.features.parameters():
params.requires_grad = True
for m in model.modules(): # Making the BatchNorm2d froze
if isinstance(m, nn.BatchNorm2d):
m.track_running_stats = False
m.eval()
# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE[2])
# Start training with help from engine.py
fine_tuning_p2_results=engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=NUM_EPOCHS[2],
device=device)
print(fine_tuning_p2_results)
# ------------------------------------------ Save model -----------------------------------------------------------------------#
# Save the model with help from utils.py
utils.save_model(model=model,
target_dir="models",
model_name="Image_Classification_EfficientNet_B2.pth")
|