flask-docker / api_server.py
ning8429's picture
Update api_server.py
105430f verified
raw
history blame
5.42 kB
import os
import time
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
# Disable tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
from flask import Flask, jsonify, request, render_template
import torch
load_type = 'local'
MODEL_NAME = "yolo11_detect_best_241018_1.pt"
MODEL_DIR = "./artifacts/models"
#REPO_ID = "1vash/mnist_demo_model"
# Load the saved YOLO model into memory
if load_type == 'local':
# 本地模型路徑
model_path = f'{MODEL_DIR}/{MODEL_NAME}'
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
# 使用 torch 來載入 YOLO 模型
model = torch.load(model_path)
model.eval() # 設定模型為推理模式
elif load_type == 'remote_hub_download':
from huggingface_hub import hf_hub_download
# 從 Hugging Face Hub 下載模型
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME)
model = torch.load(model_path)
model.eval()
elif load_type == 'remote_hub_from_pretrained':
# 使用 Hugging Face Hub 預訓練的模型方式下載
os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute())
from huggingface_hub import from_pretrained
model = from_pretrained(REPO_ID, filename=MODEL_NAME, cache_dir=MODEL_DIR)
model.eval()
else:
raise AssertionError('No load type is specified!')
# Initialize the Flask application
app = Flask(__name__)
# API route for prediction(YOLO)
@app.route('/predict', methods=['POST'])
def predict():
"""
Predicts the class label of an input image.
Request format:
{
"image": [[pixel_values_gray]]
}
Response format:
{
"label": predicted_label,
"pred_proba" prediction class probability
"ml-latency-ms": latency_in_milliseconds
(Measures time only for ML operations preprocessing with predict)
}
"""
if 'image' not in request.files:
# Handle if no file is selected
return 'No file selected'
start_time = time.time()
file = request.files['image']
# Get pixels out of file
image_data = Image.open(file)
# # Check image shape
# if image_data.size != (28, 28):
# return "Invalid image shape. Expected (28, 28), take from 'demo images' folder."
# Preprocess the image
processed_image = preprocess_image(image_data)
# Make a prediction using YOLO
results = model(processed_image)
# Process the YOLO output
detections = []
for det in results.xyxy[0]: # Assuming results are in xyxy format (xmin, ymin, xmax, ymax, confidence, class)
x_min, y_min, x_max, y_max, confidence, class_idx = det
width = x_max - x_min
height = y_max - y_min
detection = {
"label": int(class_idx),
"confidence": float(confidence),
"bbox": [float(x_min), float(y_min), float(width), float(height)]
}
detections.append(detection)
# Calculate latency in milliseconds
latency_ms = (time.time() - start_time) * 1000
# Return the detection results and latency as JSON response
response = {
'detections': detections,
'ml-latency-ms': round(latency_ms, 4)
}
# dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary
# flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/
# The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json.
return jsonify(response)
# Helper function to preprocess the image
def preprocess_image(image_data):
"""Preprocess image for YOLO Model Inference
:param image_data: Raw image (PIL.Image)
:return: image: Preprocessed Image (Tensor)
"""
# Define the YOLO input size (example 640x640, you can modify this based on your model)
input_size = (640, 640)
# Define transformation: Resize the image, convert to Tensor, and normalize pixel values
transform = transforms.Compose([
transforms.Resize(input_size), # Resize to YOLO input size
transforms.ToTensor(), # Convert image to PyTorch Tensor (通道數、影像高度和寬度)
transforms.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]) # Normalization (if needed)
])
# Apply transformations to the image
image = transform(image_data)
# Add batch dimension (1, C, H, W) since YOLO expects a batch
image = image.unsqueeze(0)
return image
# API route for health check
@app.route('/health', methods=['GET'])
def health():
"""
Health check API to ensure the application is running.
Returns "OK" if the application is healthy.
Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health"
"""
return 'OK'
# API route for version
@app.route('/version', methods=['GET'])
def version():
"""
Returns the version of the application.
Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version"
"""
return '1.0'
@app.route("/")
def hello_world():
return render_template("index.html")
# return "<p>Hello, Team!</p>"
# Start the Flask application
if __name__ == '__main__':
app.run(debug=True)