FoodVision / Recipe /prediction.py
rajatsingh0702's picture
foodvision
3a664f3
raw
history blame
2.8 kB
"""
Utility functions to make predictions.
"""
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from typing import List, Tuple
from PIL import Image
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Predict on a target image with a target model
def pred_and_plot_image(
model: torch.nn.Module,
class_names: List[str],
image_path: str,
image_size: Tuple[int, int] = (288, 288),
transform: torchvision.transforms = None,
device: torch.device = device):
"""Predicts on a target image with a target model.
Args:
model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image.
class_names (List[str]): A list of target classes to map predictions to.
image_path (str): Filepath to target image to predict on.
image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224).
transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization.
device (torch.device, optional): Target device to perform prediction on. Defaults to device.
"""
# Open image
img = Image.open(image_path)
# Create transformation for image (if one doesn't exist)
if transform is not None:
image_transform = transform
else:
image_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
### Predict on image ###
# Make sure the model is on the target device
model.to(device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels,
# height, width])
transformed_image = image_transform(img).unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(transformed_image.to(device))
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# Plot image with predicted label and probability
plt.figure()
plt.imshow(img)
plt.title(
f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
)
plt.axis(False)
plt.show()