Sa-m's picture
Update app.py
7d155ba
raw
history blame
2.51 kB
"""
Run a rest API exposing the yolov5s object detection model
"""
import io
import torch
from flask import Flask, request
from PIL import Image
from waitress import serve
import subprocess
import argparse
import os
'''
#subprocess.run(["export", "FLASK_APP","=","app.py"])
app = Flask(__name__)
DETECTION_URL = "/v1/detect"
@app.route(DETECTION_URL,methods=["POST"])
def predict():
#model = torch.hub.load('ultralytics/yolov5', 'custom', path='best2.pt', force_reload=True) # force_reload to recache
if not request.method == "POST":
return
if request.files.get("image"):
image_file = request.files["image"]
image_bytes = image_file.read()
img = Image.open(io.BytesIO(image_bytes))
results = model(img, size=640) # reduce size=320 for faster inference
results=results.pandas().xyxy[0].to_json(orient="records")
return f"{results}"
if __name__ == "__main__":
#subprocess.run(["export","FLASK_ENV","=","development"])
app.run(host="0.0.0.0", port=7860) # debug=True causes Restarting with stat
#serve(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best2.pt', force_reload=True) # force_reload to recache
app.run(host="0.0.0.0", port=7860,debug =True) # debug=True causes Restarting with stat
'''
app = Flask(__name__)
@app.route('/')
def index():
'''return '<iframe frameBorder="0" height="100%" src="{}/?__dark-theme={}" width="100%"></iframe>'.format(
os.getenv('INACCEL_URL'),request.args.get('__dark-theme', 'false'))'''
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best2.pt', force_reload=True) # force_reload to recache
if request.files.get("image"):
image_file = request.files["image"]
image_bytes = image_file.read()
img = Image.open(io.BytesIO(image_bytes))
results = model(img, size=640) # reduce size=320 for faster inference
results.imgs # array of original images (as np array) passed to model for inference
results.render() # updates results.imgs with boxes and labels
for img in results.imgs:
buffered = BytesIO()
img_base64 = Image.fromarray(img)
img_base64.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8') # base64 encoded image with results
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)