File size: 3,335 Bytes
df0d440
 
 
a7f0ae3
df0d440
 
 
 
 
 
 
a87ccd3
a7f0ae3
df0d440
a7f0ae3
ce989ee
 
a7f0ae3
 
 
df0d440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce989ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import time
import numpy as np
from huggingface_hub import hf_hub_download

# Disable tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow import keras
from flask import Flask, jsonify, request

LOCAL = False

# Load the saved model into memory
if LOCAL is True:
    model = keras.models.load_model('artifacts/models/mnist_model.h5')
else:
    REPO_ID = "1vash/mnist-model"
    FILENAME = "mnist_model.h5"
    model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename=FILENAME))

# Initialize the Flask application
app = Flask(__name__)


# API route for prediction
@app.route('/predict', methods=['POST'])
def predict():
    """
    Predicts the class label of an input image.

    Request format:
    {
        "image": [[pixel_values_gray]]
    }

    Response format:
    {
        "prediction": predicted_label,
        "ml-latency-ms": latency_in_milliseconds
            (Measures time only for ML operations preprocessing with predict)
    }
    """
    start_time = time.time()

    # Get the image data from the request
    image_data = request.get_json()['image']

    # Preprocess the image
    processed_image = preprocess_image(image_data)

    # Make a prediction, verbose=0 to disable progress bar in logs
    prediction = model.predict(processed_image, verbose=0)

    # Get the predicted class label
    predicted_label = np.argmax(prediction)

    # Calculate latency in milliseconds
    latency_ms = (time.time() - start_time) * 1000

    # Return the prediction result and latency as JSON response
    response = {'prediction': int(predicted_label),
                'ml-latency-ms': round(latency_ms, 4)}

    # dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary
    # flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/
    # The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json.
    return jsonify(response)


# Helper function to preprocess the image
def preprocess_image(image_data):
    """Preprocess image for Model Inference

    :param image_data: Raw image
    :return: image: Preprocessed Image
    """
    # Resize the image to match the input shape of the model
    image = np.array(image_data).reshape(1, 28, 28)

    # Normalize the pixel values
    image = image.astype('float32') / 255.0

    return image


# API route for health check
@app.route('/health', methods=['GET'])
def health():
    """
    Health check API to ensure the application is running.
    Returns "OK" if the application is healthy.
    Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health"
    """
    return 'OK'


# API route for version
@app.route('/version', methods=['GET'])
def version():
    """
    Returns the version of the application.
    Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version"
    """
    return '1.0'


# Start the Flask application
if __name__ == '__main__':
    app.run()


##################
# Flask API usages:
# 1. Just a wrapper over OpenAI API
# 2. You can use Chain calls of OpenAI API
# 3. Using your own ML model in combination with openAPI functionality
# 4. ...
##################