walidadebayo's picture
Normalize float types to uint8 in image_to_base64 function
7c3493b
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import numpy as np
import cv2
import base64
from src.core import process_inpaint, s_image
import os
app = Flask(__name__)
CORS(app)
# Add the API URL
API_URL = os.environ.get("API_URL", "https://walidadebayo-image-eraser-api.hf.space/")
# endpoint for health checks
@app.route('/', methods=['GET'])
def health_check():
# For browser requests, return HTML documentation
if request.headers.get('Accept', '').find('text/html') != -1:
return render_template('index.html', api_url=API_URL)
# For API health checks, return JSON
return jsonify({"status": "API is running"})
@app.route('/api/inpaint', methods=['POST'])
def inpaint():
# Get data from request
data = request.json
image_data = data.get('image')
mask_data = data.get('mask')
# Convert base64 to numpy arrays
image = base64_to_image(image_data)
mask = base64_to_image(mask_data)
# Process the image
result = process_inpaint(image, mask)
# Convert back to base64
result_base64 = image_to_base64(result)
return jsonify({'result': result_base64})
@app.route('/api/seam-carve', methods=['POST'])
def seam_carve():
# Get data from request
data = request.json
image_data = data.get('image')
mask_data = data.get('mask')
vs = int(data.get('vs', 0)) # vertical seams
hs = int(data.get('hs', 0)) # horizontal seams
mode = data.get('mode', 'resize') # resize or remove
# Convert base64 to numpy arrays
image = base64_to_image(image_data)
mask = base64_to_image(mask_data)
# Process the image
result = s_image(image, mask, vs, hs, mode)
# Convert back to base64
result_base64 = image_to_base64(result)
return jsonify({'result': result_base64})
def base64_to_image(base64_str):
img_bytes = base64.b64decode(base64_str.split(',')[1])
img_array = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
return img
def image_to_base64(image):
# Convert float types to uint8
if image.dtype == np.float64 or image.dtype == np.float32:
# Normalize to 0-255 range and convert to uint8
image = np.clip(image * 255 if image.max() <= 1.0 else image, 0, 255).astype(np.uint8)
# Convert to BGR if it's RGB
if len(image.shape) > 2 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
_, buffer = cv2.imencode('.png', image)
img_bytes = base64.b64encode(buffer).decode('utf-8')
return f"data:image/png;base64,{img_bytes}"
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(debug=True, host='0.0.0.0', port=port)