""" Grounded‑SAM Flask API (CPU only) POST /segment Body (multipart/form-data): - image: the house photo - prompt: text prompt, e.g. "roof sheet" Query params: - overlay (bool, default=false): if true, returns a PNG overlay instead Returns: - image/png mask (single channel) OR overlay """ import io import os import argparse import numpy as np from PIL import Image from flask import Flask, request, send_file from flask_cors import CORS import logging import torch from groundingdino.util.inference import load_model, predict from segment_anything import sam_model_registry, SamPredictor import groundingdino.datasets.transforms as T # ─── Load models once ─────────────────────────────────────────────────────────── device = torch.device("cpu") DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" DINO_CKPT = "weights/groundingdino_swint_ogc.pth" SAM_CKPT = "weights/sam_vit_h_4b8939.pth" # Load GroundingDINO model grounder = load_model(DINO_CONFIG, DINO_CKPT) sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device) predictor = SamPredictor(sam) # ─── Flask app ────────────────────────────────────────────────────────────────── app = Flask(__name__) CORS(app) app.config["DEBUG"] = True def segment(image_pil: Image.Image, prompt: str): # Use the proper image preprocessing for GroundingDINO transform = T.Compose([ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image_transformed, _ = transform(image_pil, None) image_transformed = image_transformed.to(device) # Run GroundingDINO to get boxes for the prompt boxes, _, _ = predict( model=grounder, image=image_transformed, caption=prompt, box_threshold=0.3, text_threshold=0.25, device="cpu" ) if boxes.size == 0: raise ValueError("No boxes found for prompt.") # 2) Largest box → mask via SAM box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))] box_np = box.cpu().numpy() if hasattr(box, "cpu") else np.array(box) predictor.set_image(np.array(image_pil)) masks, _, _ = predictor.predict(box=box_np) mask = masks[0] # boolean HxW return mask @app.route("/segment", methods=["POST"]) def segment_endpoint(): if "image" not in request.files or "prompt" not in request.form: return {"error": "image file and prompt are required."}, 400 prompt = request.form["prompt"] image = request.files["image"] # Check for unsupported file types if not image.content_type.startswith("image/"): return {"error": "unsupported file type. Only image files are allowed."}, 415 try: image_pil = Image.open(image.stream).convert("RGB") mask = segment(image_pil, prompt) except ValueError as e: app.logger.error(f"ValueError: {e}") return {"error": str(e)}, 422 except Exception as e: app.logger.exception("Exception in /segment endpoint") return {"error": str(e)}, 500 # Return actual exception message overlay = request.args.get("overlay", "false").lower() == "true" if overlay: colored = np.array(image_pil).copy() mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else np.array(mask) colored[mask_np] = [255, 0, 0] # red overlay out_img = Image.fromarray(colored) else: out_img = Image.fromarray((mask.astype(np.uint8) * 255) if hasattr(mask, "astype") else (np.array(mask, dtype=np.uint8) * 255)) buf = io.BytesIO() out_img.save(buf, format="PNG") buf.seek(0) return send_file(buf, mimetype="image/png") # ─── CLI ──────────────────────────────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=7860, type=int) args = parser.parse_args() app.run(host=args.host, port=args.port, debug=True)