Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template | |
from flask_cors import CORS | |
import numpy as np | |
import cv2 | |
import base64 | |
from src.core import process_inpaint, s_image | |
import os | |
app = Flask(__name__) | |
CORS(app) | |
# Add the API URL | |
API_URL = os.environ.get("API_URL", "https://walidadebayo-image-eraser-api.hf.space/") | |
# endpoint for health checks | |
def health_check(): | |
# For browser requests, return HTML documentation | |
if request.headers.get('Accept', '').find('text/html') != -1: | |
return render_template('index.html', api_url=API_URL) | |
# For API health checks, return JSON | |
return jsonify({"status": "API is running"}) | |
def inpaint(): | |
# Get data from request | |
data = request.json | |
image_data = data.get('image') | |
mask_data = data.get('mask') | |
# Convert base64 to numpy arrays | |
image = base64_to_image(image_data) | |
mask = base64_to_image(mask_data) | |
# Process the image | |
result = process_inpaint(image, mask) | |
# Convert back to base64 | |
result_base64 = image_to_base64(result) | |
return jsonify({'result': result_base64}) | |
def seam_carve(): | |
# Get data from request | |
data = request.json | |
image_data = data.get('image') | |
mask_data = data.get('mask') | |
vs = int(data.get('vs', 0)) # vertical seams | |
hs = int(data.get('hs', 0)) # horizontal seams | |
mode = data.get('mode', 'resize') # resize or remove | |
# Convert base64 to numpy arrays | |
image = base64_to_image(image_data) | |
mask = base64_to_image(mask_data) | |
# Process the image | |
result = s_image(image, mask, vs, hs, mode) | |
# Convert back to base64 | |
result_base64 = image_to_base64(result) | |
return jsonify({'result': result_base64}) | |
def base64_to_image(base64_str): | |
img_bytes = base64.b64decode(base64_str.split(',')[1]) | |
img_array = np.frombuffer(img_bytes, np.uint8) | |
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) | |
return img | |
def image_to_base64(image): | |
# Convert float types to uint8 | |
if image.dtype == np.float64 or image.dtype == np.float32: | |
# Normalize to 0-255 range and convert to uint8 | |
image = np.clip(image * 255 if image.max() <= 1.0 else image, 0, 255).astype(np.uint8) | |
# Convert to BGR if it's RGB | |
if len(image.shape) > 2 and image.shape[2] == 3: | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
_, buffer = cv2.imencode('.png', image) | |
img_bytes = base64.b64encode(buffer).decode('utf-8') | |
return f"data:image/png;base64,{img_bytes}" | |
if __name__ == '__main__': | |
port = int(os.environ.get("PORT", 7860)) | |
app.run(debug=True, host='0.0.0.0', port=port) | |