File size: 4,401 Bytes
d41ddc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5cfe82
d41ddc1
 
5ee91ca
d41ddc1
46331c0
d41ddc1
 
 
 
bc2ee7a
d41ddc1
 
 
5ee91ca
 
d41ddc1
 
 
 
 
 
f5cfe82
d41ddc1
 
46331c0
 
 
 
 
 
 
 
 
5ee91ca
 
 
8dfdfe9
46331c0
5ee91ca
 
7d4aa82
46331c0
5ee91ca
d41ddc1
 
 
 
 
cbb1938
d41ddc1
 
 
 
 
 
 
 
 
 
 
cdcf202
 
 
 
 
d41ddc1
 
cdcf202
 
d41ddc1
f5cfe82
d41ddc1
cdcf202
f5cfe82
 
d41ddc1
 
 
cdcf202
7dc5d22
 
d41ddc1
 
7dc5d22
d41ddc1
 
 
 
 
 
 
 
 
 
 
 
f5cfe82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Grounded‑SAM Flask API (CPU only)

POST /segment
Body (multipart/form-data):
  - image: the house photo
  - prompt: text prompt, e.g. "roof sheet"
Query params:
  - overlay (bool, default=false): if true, returns a PNG overlay instead
Returns:
  - image/png mask (single channel)  OR  overlay
"""
import io
import os
import argparse
import numpy as np
from PIL import Image
from flask import Flask, request, send_file
from flask_cors import CORS
import logging

import torch
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
import groundingdino.datasets.transforms as T

# ─── Load models once ───────────────────────────────────────────────────────────
device = torch.device("cpu")

DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
SAM_CKPT  = "weights/sam_vit_h_4b8939.pth"

# Load GroundingDINO model
grounder = load_model(DINO_CONFIG, DINO_CKPT)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
predictor = SamPredictor(sam)

# ─── Flask app ──────────────────────────────────────────────────────────────────
app = Flask(__name__)
CORS(app)
app.config["DEBUG"] = True

def segment(image_pil: Image.Image, prompt: str):
    # Use the proper image preprocessing for GroundingDINO
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    image_transformed, _ = transform(image_pil, None)
    image_transformed = image_transformed.to(device)

    # Run GroundingDINO to get boxes for the prompt
    boxes, _, _ = predict(
        model=grounder,
        image=image_transformed,
        caption=prompt,
        box_threshold=0.3,
        text_threshold=0.25,
        device="cpu"
    )
    if boxes.size == 0:
        raise ValueError("No boxes found for prompt.")

    # 2) Largest box β†’ mask via SAM
    box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
    predictor.set_image(np.array(image_pil))
    masks, _, _ = predictor.predict(box=box)
    mask = masks[0]  # boolean HxW

    return mask

@app.route("/segment", methods=["POST"])
def segment_endpoint():
    if "image" not in request.files or "prompt" not in request.form:
        return {"error": "image file and prompt are required."}, 400

    prompt = request.form["prompt"]
    image = request.files["image"]

    # Check for unsupported file types
    if not image.content_type.startswith("image/"):
        return {"error": "unsupported file type. Only image files are allowed."}, 415

    try:
        image_pil = Image.open(image.stream).convert("RGB")
        mask = segment(image_pil, prompt)
    except ValueError as e:
        app.logger.error(f"ValueError: {e}")
        return {"error": str(e)}, 422
    except Exception as e:
        app.logger.exception("Exception in /segment endpoint")
        return {"error": str(e)}, 500  # Return actual exception message

    overlay = request.args.get("overlay", "false").lower() == "true"
    if overlay:
        colored = np.array(image_pil).copy()
        mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else np.array(mask)
        colored[mask_np] = [255, 0, 0]  # red overlay
        out_img = Image.fromarray(colored)
    else:
        out_img = Image.fromarray((mask.astype(np.uint8) * 255) if hasattr(mask, "astype") else (np.array(mask, dtype=np.uint8) * 255))

    buf = io.BytesIO()
    out_img.save(buf, format="PNG")
    buf.seek(0)
    return send_file(buf, mimetype="image/png")

# ─── CLI ────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", default=7860, type=int)
    args = parser.parse_args()
    app.run(host=args.host, port=args.port, debug=True)