Spaces:
Sleeping
Sleeping
File size: 4,401 Bytes
d41ddc1 f5cfe82 d41ddc1 5ee91ca d41ddc1 46331c0 d41ddc1 bc2ee7a d41ddc1 5ee91ca d41ddc1 f5cfe82 d41ddc1 46331c0 5ee91ca 8dfdfe9 46331c0 5ee91ca 7d4aa82 46331c0 5ee91ca d41ddc1 cbb1938 d41ddc1 cdcf202 d41ddc1 cdcf202 d41ddc1 f5cfe82 d41ddc1 cdcf202 f5cfe82 d41ddc1 cdcf202 7dc5d22 d41ddc1 7dc5d22 d41ddc1 f5cfe82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
GroundedβSAM Flask API (CPU only)
POST /segment
Body (multipart/form-data):
- image: the house photo
- prompt: text prompt, e.g. "roof sheet"
Query params:
- overlay (bool, default=false): if true, returns a PNG overlay instead
Returns:
- image/png mask (single channel) OR overlay
"""
import io
import os
import argparse
import numpy as np
from PIL import Image
from flask import Flask, request, send_file
from flask_cors import CORS
import logging
import torch
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
import groundingdino.datasets.transforms as T
# βββ Load models once βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = torch.device("cpu")
DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
SAM_CKPT = "weights/sam_vit_h_4b8939.pth"
# Load GroundingDINO model
grounder = load_model(DINO_CONFIG, DINO_CKPT)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
predictor = SamPredictor(sam)
# βββ Flask app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = Flask(__name__)
CORS(app)
app.config["DEBUG"] = True
def segment(image_pil: Image.Image, prompt: str):
# Use the proper image preprocessing for GroundingDINO
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image_transformed, _ = transform(image_pil, None)
image_transformed = image_transformed.to(device)
# Run GroundingDINO to get boxes for the prompt
boxes, _, _ = predict(
model=grounder,
image=image_transformed,
caption=prompt,
box_threshold=0.3,
text_threshold=0.25,
device="cpu"
)
if boxes.size == 0:
raise ValueError("No boxes found for prompt.")
# 2) Largest box β mask via SAM
box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
predictor.set_image(np.array(image_pil))
masks, _, _ = predictor.predict(box=box)
mask = masks[0] # boolean HxW
return mask
@app.route("/segment", methods=["POST"])
def segment_endpoint():
if "image" not in request.files or "prompt" not in request.form:
return {"error": "image file and prompt are required."}, 400
prompt = request.form["prompt"]
image = request.files["image"]
# Check for unsupported file types
if not image.content_type.startswith("image/"):
return {"error": "unsupported file type. Only image files are allowed."}, 415
try:
image_pil = Image.open(image.stream).convert("RGB")
mask = segment(image_pil, prompt)
except ValueError as e:
app.logger.error(f"ValueError: {e}")
return {"error": str(e)}, 422
except Exception as e:
app.logger.exception("Exception in /segment endpoint")
return {"error": str(e)}, 500 # Return actual exception message
overlay = request.args.get("overlay", "false").lower() == "true"
if overlay:
colored = np.array(image_pil).copy()
mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else np.array(mask)
colored[mask_np] = [255, 0, 0] # red overlay
out_img = Image.fromarray(colored)
else:
out_img = Image.fromarray((mask.astype(np.uint8) * 255) if hasattr(mask, "astype") else (np.array(mask, dtype=np.uint8) * 255))
buf = io.BytesIO()
out_img.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png")
# βββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=7860, type=int)
args = parser.parse_args()
app.run(host=args.host, port=args.port, debug=True)
|