gg / app.py
IZERE HIRWA Roger
u
21e15c5
"""
Grounded‑SAM Flask API (CPU only)
POST /segment
Body (multipart/form-data):
- image: the house photo
- prompt: text prompt, e.g. "roof sheet"
Query params:
- overlay (bool, default=false): if true, returns a PNG overlay instead
Returns:
- image/png mask (single channel) OR overlay
"""
import io
import os
import argparse
import numpy as np
from PIL import Image
from flask import Flask, request, send_file
from flask_cors import CORS
import logging
import torch
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
import groundingdino.datasets.transforms as T
# ─── Load models once ───────────────────────────────────────────────────────────
device = torch.device("cpu")
DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
SAM_CKPT = "weights/sam_vit_h_4b8939.pth"
# Load GroundingDINO model
grounder = load_model(DINO_CONFIG, DINO_CKPT)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
predictor = SamPredictor(sam)
# ─── Flask app ──────────────────────────────────────────────────────────────────
app = Flask(__name__)
CORS(app)
app.config["DEBUG"] = True
def segment(image_pil: Image.Image, prompt: str):
# Use the proper image preprocessing for GroundingDINO
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image_transformed, _ = transform(image_pil, None)
image_transformed = image_transformed.to(device)
# Run GroundingDINO to get boxes for the prompt
boxes, _, _ = predict(
model=grounder,
image=image_transformed,
caption=prompt,
box_threshold=0.3,
text_threshold=0.25,
device="cpu"
)
if boxes.size == 0:
raise ValueError("No boxes found for prompt.")
# 2) Largest box β†’ mask via SAM
box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
box_np = box.cpu().numpy() if hasattr(box, "cpu") else np.array(box)
predictor.set_image(np.array(image_pil))
masks, _, _ = predictor.predict(box=box_np)
mask = masks[0] # boolean HxW
return mask
@app.route("/segment", methods=["POST"])
def segment_endpoint():
if "image" not in request.files or "prompt" not in request.form:
return {"error": "image file and prompt are required."}, 400
prompt = request.form["prompt"]
image = request.files["image"]
# Check for unsupported file types
if not image.content_type.startswith("image/"):
return {"error": "unsupported file type. Only image files are allowed."}, 415
try:
image_pil = Image.open(image.stream).convert("RGB")
mask = segment(image_pil, prompt)
except ValueError as e:
app.logger.error(f"ValueError: {e}")
return {"error": str(e)}, 422
except Exception as e:
app.logger.exception("Exception in /segment endpoint")
return {"error": str(e)}, 500 # Return actual exception message
overlay = request.args.get("overlay", "false").lower() == "true"
if overlay:
colored = np.array(image_pil).copy()
mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else np.array(mask)
colored[mask_np] = [255, 0, 0] # red overlay
out_img = Image.fromarray(colored)
else:
out_img = Image.fromarray((mask.astype(np.uint8) * 255) if hasattr(mask, "astype") else (np.array(mask, dtype=np.uint8) * 255))
buf = io.BytesIO()
out_img.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png")
# ─── CLI ────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=7860, type=int)
args = parser.parse_args()
app.run(host=args.host, port=args.port, debug=True)