Spaces:
Runtime error
Runtime error
import os | |
import time | |
import numpy as np | |
import requests | |
import matplotlib.pyplot as plt | |
# Disable tensorflow warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
from keras.datasets import mnist | |
from typing import List | |
# Set random seed for reproducibility | |
np.random.seed(50) | |
# Number of images taken from test dataset to make prediction | |
N_IMAGES = 9 | |
def get_image_prediction(image: List): | |
"""Get Model prediction for a given image | |
:param | |
image: List | |
Grayscale Image | |
:return: Json | |
HTTP Response format: | |
{ | |
"prediction": predicted_label, | |
"ml-latency-ms": latency_in_milliseconds | |
(Measures time only for ML operations preprocessing with predict) | |
} | |
""" | |
# Making prediction request API | |
response = requests.post(url='http://127.0.0.1:5000/predict', json={'image': image}) | |
# Parse the response JSON | |
return response.json() | |
# Load the dataset from keras.datasets | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
# Select N-th (N_IMAGES) random indices from x_test | |
indices = np.random.choice(len(x_test), N_IMAGES, replace=False) | |
# Get the images and labels based on the selected indices | |
images, labels, predictions = x_test[indices], y_test[indices], [] | |
# Iterate over each image, invoke prediction API and save results to predictions array | |
for i in range(N_IMAGES): | |
# Send the POST request to the Flask server | |
start_time = time.time() | |
model_response = get_image_prediction(images[i].tolist()) | |
print('Model Response:', model_response) | |
print('Total Measured Time (ms):', round((time.time() - start_time) * 1000, 3)) | |
# Save prediction data into predictions array | |
predictions.append(model_response['prediction']) | |
def plot_images_and_results_plot(images_, labels_, predictions_): | |
"""Plotting the images with their labels and predictions | |
""" | |
fig, axes = plt.subplots(N_IMAGES, 1, figsize=(6, 10)) | |
for i in range(N_IMAGES): | |
axes[i].imshow(images_[i], cmap='gray') | |
axes[i].axis('off') | |
axes[i].set_title("Label/Prediction: {}/{}".format(labels_[i], predictions_[i])) | |
plt.tight_layout() | |
plt.show() | |
plot_images_and_results_plot(images, labels, predictions) |