Spaces:
Sleeping
Sleeping
""" | |
Utility functions to make predictions. | |
""" | |
import torch | |
import torchvision | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from typing import List, Tuple | |
from PIL import Image | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Predict on a target image with a target model | |
def pred_and_plot_image( | |
model: torch.nn.Module, | |
class_names: List[str], | |
image_path: str, | |
image_size: Tuple[int, int] = (288, 288), | |
transform: torchvision.transforms = None, | |
device: torch.device = device): | |
"""Predicts on a target image with a target model. | |
Args: | |
model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image. | |
class_names (List[str]): A list of target classes to map predictions to. | |
image_path (str): Filepath to target image to predict on. | |
image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224). | |
transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization. | |
device (torch.device, optional): Target device to perform prediction on. Defaults to device. | |
""" | |
# Open image | |
img = Image.open(image_path) | |
# Create transformation for image (if one doesn't exist) | |
if transform is not None: | |
image_transform = transform | |
else: | |
image_transform = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
### Predict on image ### | |
# Make sure the model is on the target device | |
model.to(device) | |
# Turn on model evaluation mode and inference mode | |
model.eval() | |
with torch.inference_mode(): | |
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, | |
# height, width]) | |
transformed_image = image_transform(img).unsqueeze(dim=0) | |
# Make a prediction on image with an extra dimension and send it to the target device | |
target_image_pred = model(transformed_image.to(device)) | |
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification) | |
target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
# Convert prediction probabilities -> prediction labels | |
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1) | |
# Plot image with predicted label and probability | |
plt.figure() | |
plt.imshow(img) | |
plt.title( | |
f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}" | |
) | |
plt.axis(False) | |
plt.show() | |