File size: 4,088 Bytes
d41ddc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ee91ca
d41ddc1
 
 
 
 
bc2ee7a
d41ddc1
 
 
5ee91ca
 
d41ddc1
 
 
 
 
 
 
 
cbb1938
 
 
 
 
5ee91ca
 
 
8dfdfe9
cbb1938
5ee91ca
 
7d4aa82
 
5ee91ca
d41ddc1
 
 
 
 
cbb1938
d41ddc1
 
 
 
 
 
 
 
 
 
 
cdcf202
 
 
 
 
d41ddc1
 
cdcf202
 
d41ddc1
 
cdcf202
 
d41ddc1
 
 
cdcf202
d41ddc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4aa82
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Grounded‑SAM Flask API (CPU only)

POST /segment
Body (multipart/form-data):
  - image: the house photo
  - prompt: text prompt, e.g. "roof sheet"
Query params:
  - overlay (bool, default=false): if true, returns a PNG overlay instead
Returns:
  - image/png mask (single channel)  OR  overlay
"""
import io
import os
import argparse
import numpy as np
from PIL import Image
from flask import Flask, request, send_file
from flask_cors import CORS

import torch
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor

# ─── Load models once ───────────────────────────────────────────────────────────
device = torch.device("cpu")

DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
SAM_CKPT  = "weights/sam_vit_h_4b8939.pth"

# Load GroundingDINO model
grounder = load_model(DINO_CONFIG, DINO_CKPT)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
predictor = SamPredictor(sam)

# ─── Flask app ──────────────────────────────────────────────────────────────────
app = Flask(__name__)
CORS(app)

def segment(image_pil: Image.Image, prompt: str):
    # Convert PIL image to numpy array and normalize
    image_np = np.array(image_pil).astype(np.float32) / 255.0  # Normalize to [0, 1]

    # Convert numpy array to torch tensor
    image_tensor = torch.tensor(image_np).permute(2, 0, 1).unsqueeze(0).to(device)  # Convert to CHW format

    # Run GroundingDINO to get boxes for the prompt
    boxes, _, _ = predict(
        model=grounder,
        image=image_tensor,  # Pass normalized tensor
        caption=prompt,
        box_threshold=0.3,
        text_threshold=0.25,
        device="cpu"  # Explicitly set device to CPU
    )
    if boxes.size == 0:
        raise ValueError("No boxes found for prompt.")

    # 2) Largest box β†’ mask via SAM
    box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
    predictor.set_image(np.array(image_pil))
    masks, _, _ = predictor.predict(box=box)
    mask = masks[0]  # boolean HxW

    return mask

@app.route("/segment", methods=["POST"])
def segment_endpoint():
    if "image" not in request.files or "prompt" not in request.form:
        return {"error": "image file and prompt are required."}, 400

    prompt = request.form["prompt"]
    image = request.files["image"]

    # Check for unsupported file types
    if not image.content_type.startswith("image/"):
        return {"error": "unsupported file type. Only image files are allowed."}, 415

    try:
        image_pil = Image.open(image.stream).convert("RGB")
        mask = segment(image_pil, prompt)
    except ValueError as e:
        return {"error": str(e)}, 422
    except Exception as e:
        return {"error": "not supported"}, 500

    overlay = request.args.get("overlay", "false").lower() == "true"
    if overlay:
        colored = np.array(image_pil).copy()
        colored[mask] = [255, 0, 0]  # red overlay
        out_img = Image.fromarray(colored)
    else:
        out_img = Image.fromarray((mask * 255).astype(np.uint8))

    buf = io.BytesIO()
    out_img.save(buf, format="PNG")
    buf.seek(0)
    return send_file(buf, mimetype="image/png")

# ─── CLI ────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", default=7860, type=int)
    args = parser.parse_args()
    app.run(host=args.host, port=args.port)
    args = parser.parse_args()
    app.run(host=args.host, port=args.port)