Spaces:
Sleeping
Sleeping
""" | |
GroundedβSAM Flask API (CPU only) | |
POST /segment | |
Body (multipart/form-data): | |
- image: the house photo | |
- prompt: text prompt, e.g. "roof sheet" | |
Query params: | |
- overlay (bool, default=false): if true, returns a PNG overlay instead | |
Returns: | |
- image/png mask (single channel) OR overlay | |
""" | |
import io | |
import os | |
import argparse | |
import numpy as np | |
from PIL import Image | |
from flask import Flask, request, send_file | |
from flask_cors import CORS | |
import logging | |
import torch | |
from groundingdino.util.inference import load_model, predict | |
from segment_anything import sam_model_registry, SamPredictor | |
import groundingdino.datasets.transforms as T | |
# βββ Load models once βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
device = torch.device("cpu") | |
DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
DINO_CKPT = "weights/groundingdino_swint_ogc.pth" | |
SAM_CKPT = "weights/sam_vit_h_4b8939.pth" | |
# Load GroundingDINO model | |
grounder = load_model(DINO_CONFIG, DINO_CKPT) | |
sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device) | |
predictor = SamPredictor(sam) | |
# βββ Flask app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
app = Flask(__name__) | |
CORS(app) | |
app.config["DEBUG"] = True | |
def segment(image_pil: Image.Image, prompt: str): | |
# Use the proper image preprocessing for GroundingDINO | |
transform = T.Compose([ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
image_transformed, _ = transform(image_pil, None) | |
image_transformed = image_transformed.to(device) | |
# Run GroundingDINO to get boxes for the prompt | |
boxes, _, _ = predict( | |
model=grounder, | |
image=image_transformed, | |
caption=prompt, | |
box_threshold=0.3, | |
text_threshold=0.25, | |
device="cpu" | |
) | |
if boxes.size == 0: | |
raise ValueError("No boxes found for prompt.") | |
# 2) Largest box β mask via SAM | |
box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))] | |
box_np = box.cpu().numpy() if hasattr(box, "cpu") else np.array(box) | |
predictor.set_image(np.array(image_pil)) | |
masks, _, _ = predictor.predict(box=box_np) | |
mask = masks[0] # boolean HxW | |
return mask | |
def segment_endpoint(): | |
if "image" not in request.files or "prompt" not in request.form: | |
return {"error": "image file and prompt are required."}, 400 | |
prompt = request.form["prompt"] | |
image = request.files["image"] | |
# Check for unsupported file types | |
if not image.content_type.startswith("image/"): | |
return {"error": "unsupported file type. Only image files are allowed."}, 415 | |
try: | |
image_pil = Image.open(image.stream).convert("RGB") | |
mask = segment(image_pil, prompt) | |
except ValueError as e: | |
app.logger.error(f"ValueError: {e}") | |
return {"error": str(e)}, 422 | |
except Exception as e: | |
app.logger.exception("Exception in /segment endpoint") | |
return {"error": str(e)}, 500 # Return actual exception message | |
overlay = request.args.get("overlay", "false").lower() == "true" | |
if overlay: | |
colored = np.array(image_pil).copy() | |
mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else np.array(mask) | |
colored[mask_np] = [255, 0, 0] # red overlay | |
out_img = Image.fromarray(colored) | |
else: | |
out_img = Image.fromarray((mask.astype(np.uint8) * 255) if hasattr(mask, "astype") else (np.array(mask, dtype=np.uint8) * 255)) | |
buf = io.BytesIO() | |
out_img.save(buf, format="PNG") | |
buf.seek(0) | |
return send_file(buf, mimetype="image/png") | |
# βββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", default="127.0.0.1") | |
parser.add_argument("--port", default=7860, type=int) | |
args = parser.parse_args() | |
app.run(host=args.host, port=args.port, debug=True) | |