Ivan Shelonik
first commit
df0d440
raw
history blame
2.27 kB
import os
import time
import numpy as np
import requests
import matplotlib.pyplot as plt
# Disable tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from keras.datasets import mnist
from typing import List
# Set random seed for reproducibility
np.random.seed(50)
# Number of images taken from test dataset to make prediction
N_IMAGES = 9
def get_image_prediction(image: List):
"""Get Model prediction for a given image
:param
image: List
Grayscale Image
:return: Json
HTTP Response format:
{
"prediction": predicted_label,
"ml-latency-ms": latency_in_milliseconds
(Measures time only for ML operations preprocessing with predict)
}
"""
# Making prediction request API
response = requests.post(url='http://127.0.0.1:5000/predict', json={'image': image})
# Parse the response JSON
return response.json()
# Load the dataset from keras.datasets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Select N-th (N_IMAGES) random indices from x_test
indices = np.random.choice(len(x_test), N_IMAGES, replace=False)
# Get the images and labels based on the selected indices
images, labels, predictions = x_test[indices], y_test[indices], []
# Iterate over each image, invoke prediction API and save results to predictions array
for i in range(N_IMAGES):
# Send the POST request to the Flask server
start_time = time.time()
model_response = get_image_prediction(images[i].tolist())
print('Model Response:', model_response)
print('Total Measured Time (ms):', round((time.time() - start_time) * 1000, 3))
# Save prediction data into predictions array
predictions.append(model_response['prediction'])
def plot_images_and_results_plot(images_, labels_, predictions_):
"""Plotting the images with their labels and predictions
"""
fig, axes = plt.subplots(N_IMAGES, 1, figsize=(6, 10))
for i in range(N_IMAGES):
axes[i].imshow(images_[i], cmap='gray')
axes[i].axis('off')
axes[i].set_title("Label/Prediction: {}/{}".format(labels_[i], predictions_[i]))
plt.tight_layout()
plt.show()
plot_images_and_results_plot(images, labels, predictions)