Ivan Shelonik
upd: load type
543f7b1
raw
history blame
4.33 kB
import os
import time
import numpy as np
from pathlib import Path
os.environ['TRANSFORMERS_CACHE'] = str(Path('./artifacts/').absolute())
# Disable tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
from flask import Flask, jsonify, request
load_type = 'remote_hub_from_pretrained'
"""
local;
remote_hub_download; - /cache error even using TRANSFORMERS_CACHE & cache_dir to local folder
remote_hub_from_pretrained; - /cache error even using TRANSFORMERS_CACHE & cache_dir to local folder
remote_hub_pipeline; - needs config.json and this is not easy to grasp how to do it with custom models
https://discuss.huggingface.co/t/how-to-create-a-config-json-after-saving-a-model/10459/4
"""
REPO_ID = "1vash/mnist_demo_model"
# Load the saved model into memory
if load_type == 'local':
model = keras.models.load_model('artifacts/models/mnist_model.h5')
elif load_type == 'remote_hub_download':
from huggingface_hub import hf_hub_download
model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb"))
elif load_type == 'remote_hub_from_pretrained':
# https://huggingface.co/docs/hub/keras
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras(REPO_ID, cache_dir='./artifacts/')
elif load_type == 'remote_hub_pipeline':
from transformers import pipeline
classifier = pipeline("image-classification", model=REPO_ID)
else:
pass
# Initialize the Flask application
app = Flask(__name__)
# API route for prediction
@app.route('/predict', methods=['POST'])
def predict():
"""
Predicts the class label of an input image.
Request format:
{
"image": [[pixel_values_gray]]
}
Response format:
{
"prediction": predicted_label,
"ml-latency-ms": latency_in_milliseconds
(Measures time only for ML operations preprocessing with predict)
}
"""
start_time = time.time()
# Get the image data from the request
image_data = request.get_json()['image']
# Preprocess the image
processed_image = preprocess_image(image_data)
# Make a prediction, verbose=0 to disable progress bar in logs
prediction = model.predict(processed_image, verbose=0)
# Get the predicted class label
predicted_label = np.argmax(prediction)
# Calculate latency in milliseconds
latency_ms = (time.time() - start_time) * 1000
# Return the prediction result and latency as JSON response
response = {'prediction': int(predicted_label),
'ml-latency-ms': round(latency_ms, 4)}
# dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary
# flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/
# The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json.
return jsonify(response)
# Helper function to preprocess the image
def preprocess_image(image_data):
"""Preprocess image for Model Inference
:param image_data: Raw image
:return: image: Preprocessed Image
"""
# Resize the image to match the input shape of the model
image = np.array(image_data).reshape(1, 28, 28)
# Normalize the pixel values
image = image.astype('float32') / 255.0
return image
# API route for health check
@app.route('/health', methods=['GET'])
def health():
"""
Health check API to ensure the application is running.
Returns "OK" if the application is healthy.
Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health"
"""
return 'OK'
# API route for version
@app.route('/version', methods=['GET'])
def version():
"""
Returns the version of the application.
Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version"
"""
return '1.0'
@app.route("/")
def hello_world():
return "<p>Hello, Team!</p>"
# Start the Flask application
if __name__ == '__main__':
app.run()
##################
# Flask API usages:
# 1. Just a wrapper over OpenAI API
# 2. You can use Chain calls of OpenAI API
# 3. Using your own ML model in combination with openAPI functionality
# 4. ...
##################