from flask import Flask, request, jsonify, render_template from flask_cors import CORS import numpy as np import cv2 import base64 from src.core import process_inpaint, s_image import os app = Flask(__name__) CORS(app) # Add the API URL API_URL = os.environ.get("API_URL", "https://walidadebayo-image-eraser-api.hf.space/") # endpoint for health checks @app.route('/', methods=['GET']) def health_check(): # For browser requests, return HTML documentation if request.headers.get('Accept', '').find('text/html') != -1: return render_template('index.html', api_url=API_URL) # For API health checks, return JSON return jsonify({"status": "API is running"}) @app.route('/api/inpaint', methods=['POST']) def inpaint(): # Get data from request data = request.json image_data = data.get('image') mask_data = data.get('mask') # Convert base64 to numpy arrays image = base64_to_image(image_data) mask = base64_to_image(mask_data) # Process the image result = process_inpaint(image, mask) # Convert back to base64 result_base64 = image_to_base64(result) return jsonify({'result': result_base64}) @app.route('/api/seam-carve', methods=['POST']) def seam_carve(): # Get data from request data = request.json image_data = data.get('image') mask_data = data.get('mask') vs = int(data.get('vs', 0)) # vertical seams hs = int(data.get('hs', 0)) # horizontal seams mode = data.get('mode', 'resize') # resize or remove # Convert base64 to numpy arrays image = base64_to_image(image_data) mask = base64_to_image(mask_data) # Process the image result = s_image(image, mask, vs, hs, mode) # Convert back to base64 result_base64 = image_to_base64(result) return jsonify({'result': result_base64}) def base64_to_image(base64_str): img_bytes = base64.b64decode(base64_str.split(',')[1]) img_array = np.frombuffer(img_bytes, np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) return img def image_to_base64(image): # Convert float types to uint8 if image.dtype == np.float64 or image.dtype == np.float32: # Normalize to 0-255 range and convert to uint8 image = np.clip(image * 255 if image.max() <= 1.0 else image, 0, 255).astype(np.uint8) # Convert to BGR if it's RGB if len(image.shape) > 2 and image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) _, buffer = cv2.imencode('.png', image) img_bytes = base64.b64encode(buffer).decode('utf-8') return f"data:image/png;base64,{img_bytes}" if __name__ == '__main__': port = int(os.environ.get("PORT", 7860)) app.run(debug=True, host='0.0.0.0', port=port)