valegro commited on
Commit
ed309ba
·
verified ·
1 Parent(s): cf5fa6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -93
app.py CHANGED
@@ -1,96 +1,54 @@
1
- import streamlit as st
2
- import numpy as np
3
- import torch
4
  from PIL import Image
5
- import cv2
6
- import matplotlib.pyplot as plt
7
- from huggingface_hub import hf_hub_download
8
- from segment_anything import SamPredictor, sam_model_registry
9
- from groundingdino.util.inference import load_model, predict, annotate
10
-
11
- # Titolo dell'app
12
- st.title("🔍 Riconoscimento Zero-Shot con GroundingDINO + SAM")
13
-
14
- # Configurazione dei modelli da Hugging Face Hub
15
- @st.cache_resource
16
- def load_sam():
17
- checkpoint = hf_hub_download(
18
- repo_id="SegmentAnything/sam_vit_b",
19
- filename="sam_vit_b_01ec64.pth"
20
- )
21
- model = sam_model_registry["vit_b"](checkpoint=checkpoint)
22
- return SamPredictor(model.to("cuda" if torch.cuda.is_available() else "cpu"))
23
-
24
- @st.cache_resource
25
- def load_grounding_dino():
26
- config_path = hf_hub_download(
27
- repo_id="IDEA-Research/grounding-dino-tiny",
28
- filename="GroundingDINO_SwinT_OGC.py"
29
- )
30
- checkpoint_path = hf_hub_download(
31
- repo_id="IDEA-Research/grounding-dino-tiny",
32
- filename="groundingdino_tiny.pth"
33
- )
34
- model = load_model(config_path, checkpoint_path)
35
- return model
36
-
37
- sam = load_sam()
38
- grounding_dino = load_grounding_dino()
39
-
40
- # Caricamento immagine da parte dell'utente
41
- uploaded_image = st.file_uploader("📷 Carica un'immagine", type=['jpg', 'jpeg', 'png'])
42
-
43
- prompt = st.text_input("📝 Inserisci le classi da riconoscere (separate da virgola)",
44
- value="lamiera, foro circolare, vite, bullone, scanalatura")
45
-
46
- if uploaded_image is not None:
47
- image = Image.open(uploaded_image).convert("RGB")
48
- img_array = np.array(image)
49
-
50
- st.image(image, caption="Immagine caricata", use_column_width=True)
51
-
52
- if st.button("▶️ Avvia riconoscimento"):
53
- # GroundingDINO prediction
54
- boxes, logits, phrases = predict(
55
- model=grounding_dino,
56
- image=img_array,
57
- caption=prompt,
58
- box_threshold=0.3,
59
- text_threshold=0.25,
60
- device="cuda" if torch.cuda.is_available() else "cpu"
61
- )
62
-
63
- annotated_frame = annotate(image_source=img_array, boxes=boxes, logits=logits, phrases=phrases)
64
-
65
- st.subheader("Risultato GroundingDINO")
66
- st.image(annotated_frame, caption="Annotazione GroundingDINO")
67
-
68
- # SAM segmentation
69
- sam.set_image(img_array)
70
- H, W, _ = img_array.shape
71
- boxes_scaled = boxes * torch.tensor([W, H, W, H], device=boxes.device)
72
- boxes_scaled = boxes_scaled.cpu().numpy()
73
-
74
- masks, scores, _ = sam.predict_torch(
75
- point_coords=None,
76
- point_labels=None,
77
- boxes=torch.tensor(boxes_scaled, device=sam.device),
78
- multimask_output=False,
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Visualizza maschere segmentate
82
- st.subheader("Risultato Segment Anything (SAM)")
83
- plt.figure(figsize=(10, 10))
84
- plt.imshow(img_array)
85
- for mask in masks:
86
- mask_np = mask[0].cpu().numpy()
87
- plt.contour(mask_np, colors='red', linewidths=1.5)
88
- plt.axis('off')
89
-
90
- st.pyplot(plt.gcf())
91
- plt.close()
92
-
93
- # Tabella risultati
94
- st.subheader("🔖 Tabella Risultati")
95
- result_data = [{"Classe": phrase, "Confidenza": round(logit.item(), 2)} for phrase, logit in zip(phrases, logits)]
96
- st.table(result_data)
 
1
+ import gradio as gr, numpy as np
2
+ from utils import SAM, GD
3
+ from groundingdino.util.utils import clean_text
4
  from PIL import Image
5
+ import cv2, torch
6
+
7
+ def pipeline(image, prompt):
8
+ # 1. segmenta con SAM
9
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
10
+ SAM.set_image(img_cv)
11
+ masks, _, _ = SAM.predict(box=None, point_coords=None, point_labels=None, multimask_output=False)
12
+
13
+ annotated = image.copy()
14
+ boxes, labels, feats = [], [], []
15
+
16
+ for m in masks:
17
+ coords = np.argwhere(m)
18
+ y1, x1 = coords.min(0)
19
+ y2, x2 = coords.max(0)
20
+ box = np.array([x1, y1, x2, y2])
21
+ boxes.append(box)
22
+
23
+ if boxes:
24
+ # 2. grounding DINO zero‑shot
25
+ dino_out = GD.predict_with_caption(
26
+ image=np.array(image),
27
+ captions=[prompt] * len(boxes),
28
+ boxes=np.vstack(boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
+ for box, text in zip(dino_out["boxes"], dino_out["captions"]):
31
+ x1,y1,x2,y2 = map(int, box)
32
+ cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,0,0), 2)
33
+ cv2.putText(annotated, clean_text(text), (x1, y1-6),
34
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,0,0), 2)
35
+
36
+ return Image.fromarray(annotated)
37
+
38
+ demo = gr.Interface(
39
+ fn=pipeline,
40
+ inputs=[
41
+ gr.Image(type="pil"),
42
+ gr.Textbox(value="lamiera, foro circolare, vite, bullone, scanalatura")
43
+ ],
44
+ outputs=gr.Image(type="pil"),
45
+ title="Zero‑Shot Mechanical Part Finder",
46
+ description=(
47
+ "Carica una foto di componenti meccanici a fine vita e scrivi le etichette "
48
+ "che vuoi cercare (separate da virgole). Il sistema segmenta con SAM e fa "
49
+ "grounding zero‑shot con GroundingDINO."
50
+ )
51
+ )
52
 
53
+ if __name__ == "__main__":
54
+ demo.launch()