IZERE HIRWA Roger commited on
Commit
5ee91ca
Β·
1 Parent(s): 85b6fb8
Files changed (2) hide show
  1. app.py +15 -5
  2. result.png +5 -0
app.py CHANGED
@@ -19,7 +19,7 @@ from flask import Flask, request, send_file
19
  from flask_cors import CORS
20
 
21
  import torch
22
- from groundingdino.util.inference import Model as GroundingModel
23
  from segment_anything import sam_model_registry, SamPredictor
24
 
25
  # ─── Load models once ───────────────────────────────────────────────────────────
@@ -29,7 +29,8 @@ DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
29
  DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
30
  SAM_CKPT = "weights/sam_vit_h_4b8939.pth"
31
 
32
- grounder = GroundingModel(model_config_path=DINO_CONFIG, model_checkpoint_path=DINO_CKPT, device=device)
 
33
  sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
34
  predictor = SamPredictor(sam)
35
 
@@ -38,14 +39,23 @@ app = Flask(__name__)
38
  CORS(app)
39
 
40
  def segment(image_pil: Image.Image, prompt: str):
41
- # 1) Run GroundingDINO to get boxes for the prompt
42
- boxes, _, _ = grounder.predict(image_pil, prompt=prompt, box_threshold=0.3, text_threshold=0.25)
 
 
 
 
 
 
 
 
 
43
  if boxes.size == 0:
44
  raise ValueError("No boxes found for prompt.")
45
 
46
  # 2) Largest box β†’ mask via SAM
47
  box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
48
- predictor.set_image(np.array(image_pil))
49
  masks, _, _ = predictor.predict(box=box)
50
  mask = masks[0] # boolean HxW
51
 
 
19
  from flask_cors import CORS
20
 
21
  import torch
22
+ from groundingdino.util.inference import load_model, predict
23
  from segment_anything import sam_model_registry, SamPredictor
24
 
25
  # ─── Load models once ───────────────────────────────────────────────────────────
 
29
  DINO_CKPT = "weights/groundingdino_swint_ogc.pth"
30
  SAM_CKPT = "weights/sam_vit_h_4b8939.pth"
31
 
32
+ # Load GroundingDINO model
33
+ grounder = load_model(DINO_CONFIG, DINO_CKPT)
34
  sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device)
35
  predictor = SamPredictor(sam)
36
 
 
39
  CORS(app)
40
 
41
  def segment(image_pil: Image.Image, prompt: str):
42
+ # Convert PIL image to numpy array
43
+ image_np = np.array(image_pil)
44
+
45
+ # Run GroundingDINO to get boxes for the prompt
46
+ boxes, _, _ = predict(
47
+ model=grounder,
48
+ image=image_np,
49
+ caption=prompt,
50
+ box_threshold=0.3,
51
+ text_threshold=0.25
52
+ )
53
  if boxes.size == 0:
54
  raise ValueError("No boxes found for prompt.")
55
 
56
  # 2) Largest box β†’ mask via SAM
57
  box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))]
58
+ predictor.set_image(image_np)
59
  masks, _, _ = predictor.predict(box=box)
60
  mask = masks[0] # boolean HxW
61
 
result.png ADDED