FoodVision / Recipe /engine.py
rajatsingh0702's picture
foodvision
3a664f3
"""
Contains functions for training and testing a PyTorch model.
"""
import torch
from tqdm.auto import tqdm
def train_step(model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device):
"""Trains a PyTorch model for a single epoch.
Turns a target PyTorch model to training mode and then
runs through all the required training steps (forward
pass, loss calculation, optimizer step).
Args:
model: A PyTorch model to be trained.
train_dataloader: A DataLoader instance for the model to be trained on.
loss_fn: A PyTorch loss function to minimize.
optimizer: A PyTorch optimizer to help minimize the loss function.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A tuple of training loss and training accuracy metrics.
In the form (train_loss, train_accuracy). For example:
(0.1112, 0.8743)
"""
# setting the model to train mode
total_loss, total_acc = 0, 0
for i, (X, y) in enumerate(train_dataloader):
# to device
X, y = X.to(device), y.to(device)
# forward pass
y_logits = model(X)
# loss calculate
loss = loss_fn(y_logits, y)
# zero the optimizer
optimizer.zero_grad()
# loss backward
loss.backward()
# optimizer step
optimizer.step()
total_acc += (torch.argmax(y_logits, dim=1) == y).sum() / len(y)
total_loss += loss
total_loss /= len(train_dataloader)
total_acc = total_acc / len(train_dataloader)
return total_loss, total_acc
def test_step(model: torch.nn.Module,
test_dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
device: torch.device):
"""Tests a PyTorch model for a single epoch.
Turns a target PyTorch model to "eval" mode and then performs
a forward pass on a testing dataset.
Args:
model: A PyTorch model to be tested.
test_dataloader: A DataLoader instance for the model to be tested on.
loss_fn: A PyTorch loss function to calculate loss on the test data.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A tuple of testing loss and testing accuracy metrics.
In the form (test_loss, test_accuracy). For example:
(0.0223, 0.8985)
"""
# eval model
with torch.inference_mode():
test_loss, test_acc = 0, 0
for i, (X, y) in enumerate(test_dataloader):
X, y = X.to(device), y.to(device)
y_logits = model(X)
test_acc += (torch.argmax(y_logits, dim=1) == y).sum() / len(y)
test_loss += loss_fn(y_logits, y)
test_loss /= len(test_dataloader)
test_acc = test_acc / len(test_dataloader)
return test_loss, test_acc
def train(model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
test_dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
epochs: int,
device: torch.device):
"""Trains and tests a PyTorch model.
Passes a target PyTorch models through train_step() and test_step()
functions for a number of epochs, training and testing the model
in the same epoch loop.
Calculates, prints and stores evaluation metrics throughout.
Args:
model: A PyTorch model to be trained and tested.
train_dataloader: A DataLoader instance for the model to be trained on.
test_dataloader: A DataLoader instance for the model to be tested on.
optimizer: A PyTorch optimizer to help minimize the loss function.
loss_fn: A PyTorch loss function to calculate loss on both datasets.
epochs: An integer indicating how many epochs to train for.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A dictionary of training and testing loss as well as training and
testing accuracy metrics. Each metric has a value in a list for
each epoch.
In the form: {train_loss: [...],
train_acc: [...],
test_loss: [...],
test_acc: [...]}
For example if training for epochs=2:
{train_loss: [2.0616, 1.0537],
train_acc: [0.3945, 0.3945],
test_loss: [1.2641, 1.5706],
test_acc: [0.3400, 0.2973]}
Remark : Can un-comment the below lines of code to also store the model weights
"""
results = {"train_loss": [],
"train_acc": [],
"test_loss": [],
"test_acc": []}
# from pathlib import Path
# save_path=Path("models")
# save_path.mkdir(parents=True,exist_ok=True)
model.to(device)
for epoch in tqdm(range(epochs)):
model.train()
loss_train, acc_train = train_step(model,
train_dataloader=train_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
device=device)
results["train_loss"].append(loss_train)
results["train_acc"].append(acc_train)
model.eval()
loss_test, acc_test = test_step(model,
test_dataloader,
loss_fn,
device)
results["test_loss"].append(loss_test)
results["test_acc"].append(acc_test)
print(
f"epoch : {epoch + 1} | train_loss : {loss_train} | train_acc : {acc_train} | test_loss : {loss_test} | test_acc : {acc_test}")
# model_path=save_path/f"model_efficient_b2_weights_{epoch}.pth"
# torch.save(obj=model.state_dict(),f=model_path)
return results