dheena
commited on
Commit
·
e8ee9c0
1
Parent(s):
7c5bfda
initial commit
Browse files- requirements.txt +18 -1
- src/model.py +65 -0
- src/segmentation.py +193 -0
- src/streamlit_app.py +37 -38
requirements.txt
CHANGED
@@ -1,3 +1,20 @@
|
|
|
|
1 |
altair
|
2 |
pandas
|
3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Existing packages
|
2 |
altair
|
3 |
pandas
|
4 |
+
streamlit
|
5 |
+
|
6 |
+
# New packages based on imports
|
7 |
+
faiss-cpu # or faiss-gpu if you're using GPU
|
8 |
+
torch # PyTorch
|
9 |
+
ftfy # often required by CLIP
|
10 |
+
git+https://github.com/openai/CLIP.git # for CLIP from OpenAI
|
11 |
+
openai # OpenAI API client
|
12 |
+
numpy
|
13 |
+
Pillow # PIL
|
14 |
+
fastapi
|
15 |
+
segmentation-models-pytorch # Assuming "segmentation" is a model lib, adjust if needed
|
16 |
+
|
17 |
+
# Additional NLP/ML utilities
|
18 |
+
opencv-python # cv2
|
19 |
+
requests
|
20 |
+
transformers # HuggingFace Transformers
|
src/model.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import torch
|
3 |
+
import clip
|
4 |
+
from openai import OpenAI
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from fastapi import FastAPI
|
8 |
+
from typing import List
|
9 |
+
import segmentation
|
10 |
+
|
11 |
+
client = OpenAI()
|
12 |
+
device = "cpu"
|
13 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
14 |
+
|
15 |
+
def get_image_features(image: Image.Image) -> np.ndarray:
|
16 |
+
"""Extract CLIP features from an image."""
|
17 |
+
image_input = preprocess(image).unsqueeze(0).to(device)
|
18 |
+
with torch.no_grad():
|
19 |
+
image_features = model.encode_image(image_input).float()
|
20 |
+
return image_features.cpu().numpy()
|
21 |
+
|
22 |
+
# FAISS setup
|
23 |
+
index = faiss.IndexFlatIP(512)
|
24 |
+
meta_data_store = []
|
25 |
+
|
26 |
+
def save_image_in_index(image_features: np.ndarray, metadata: dict):
|
27 |
+
"""Normalize features and add to index."""
|
28 |
+
faiss.normalize_L2(image_features)
|
29 |
+
index.add(image_features)
|
30 |
+
meta_data_store.append(metadata)
|
31 |
+
|
32 |
+
def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
|
33 |
+
"""Get feature embedding for a query image."""
|
34 |
+
search_image, search_detections = segmentation.grounded_segmentation(image=image_url, labels=labels)
|
35 |
+
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
|
36 |
+
|
37 |
+
# Convert to valid RGB
|
38 |
+
if cropped_image.dtype != np.uint8:
|
39 |
+
cropped_image = (cropped_image * 255).astype(np.uint8)
|
40 |
+
if cropped_image.ndim == 2:
|
41 |
+
cropped_image = np.stack([cropped_image] * 3, axis=-1)
|
42 |
+
|
43 |
+
pil_image = Image.fromarray(cropped_image)
|
44 |
+
return pil_image
|
45 |
+
|
46 |
+
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
|
47 |
+
"""Find top-k similar images from the index."""
|
48 |
+
processed_image = process_image_embedding(image_url)
|
49 |
+
image_search_embedding = get_image_features(processed_image)
|
50 |
+
faiss.normalize_L2(image_search_embedding)
|
51 |
+
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
|
52 |
+
|
53 |
+
results = []
|
54 |
+
for i, dist in zip(indices[0], distances[0]):
|
55 |
+
if i < len(meta_data_store):
|
56 |
+
results.append({
|
57 |
+
'metadata': meta_data_store[i],
|
58 |
+
'score': float(dist)
|
59 |
+
})
|
60 |
+
return results
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
src/segmentation.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, List, Dict, Optional, Union, Tuple
|
4 |
+
import os
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
import requests
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
12 |
+
|
13 |
+
# In[2]:
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class BoundingBox:
|
18 |
+
xmin: int
|
19 |
+
ymin: int
|
20 |
+
xmax: int
|
21 |
+
ymax: int
|
22 |
+
|
23 |
+
@property
|
24 |
+
def xyxy(self) -> List[float]:
|
25 |
+
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class DetectionResult:
|
29 |
+
score: float
|
30 |
+
label: str
|
31 |
+
box: BoundingBox
|
32 |
+
mask: Optional[np.array] = None
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
|
36 |
+
return cls(score=detection_dict['score'],
|
37 |
+
label=detection_dict['label'],
|
38 |
+
box=BoundingBox(xmin=detection_dict['box']['xmin'],
|
39 |
+
ymin=detection_dict['box']['ymin'],
|
40 |
+
xmax=detection_dict['box']['xmax'],
|
41 |
+
ymax=detection_dict['box']['ymax']))
|
42 |
+
|
43 |
+
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
44 |
+
# Find contours in the binary mask
|
45 |
+
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
46 |
+
|
47 |
+
# Find the contour with the largest area
|
48 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
49 |
+
|
50 |
+
# Extract the vertices of the contour
|
51 |
+
polygon = largest_contour.reshape(-1, 2).tolist()
|
52 |
+
|
53 |
+
return polygon
|
54 |
+
|
55 |
+
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
56 |
+
"""
|
57 |
+
Convert a polygon to a segmentation mask.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
61 |
+
- image_shape (tuple): Shape of the image (height, width) for the mask.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
- np.ndarray: Segmentation mask with the polygon filled.
|
65 |
+
"""
|
66 |
+
# Create an empty mask
|
67 |
+
mask = np.zeros(image_shape, dtype=np.uint8)
|
68 |
+
|
69 |
+
# Convert polygon to an array of points
|
70 |
+
pts = np.array(polygon, dtype=np.int32)
|
71 |
+
|
72 |
+
# Fill the polygon with white color (255)
|
73 |
+
cv2.fillPoly(mask, [pts], color=(255,))
|
74 |
+
|
75 |
+
return mask
|
76 |
+
|
77 |
+
def load_image(image_str: str) -> Image.Image:
|
78 |
+
if image_str.startswith("http"):
|
79 |
+
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
|
80 |
+
else:
|
81 |
+
image = Image.open(image_str).convert("RGB")
|
82 |
+
|
83 |
+
return image
|
84 |
+
|
85 |
+
def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
|
86 |
+
boxes = []
|
87 |
+
for result in results:
|
88 |
+
xyxy = result.box.xyxy
|
89 |
+
boxes.append(xyxy)
|
90 |
+
|
91 |
+
return [boxes]
|
92 |
+
|
93 |
+
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
94 |
+
masks = masks.cpu().float()
|
95 |
+
masks = masks.permute(0, 2, 3, 1)
|
96 |
+
masks = masks.mean(axis=-1)
|
97 |
+
masks = (masks > 0).int()
|
98 |
+
masks = masks.numpy().astype(np.uint8)
|
99 |
+
masks = list(masks)
|
100 |
+
|
101 |
+
if polygon_refinement:
|
102 |
+
for idx, mask in enumerate(masks):
|
103 |
+
shape = mask.shape
|
104 |
+
polygon = mask_to_polygon(mask)
|
105 |
+
mask = polygon_to_mask(polygon, shape)
|
106 |
+
masks[idx] = mask
|
107 |
+
|
108 |
+
return masks
|
109 |
+
|
110 |
+
|
111 |
+
# In[6]:
|
112 |
+
|
113 |
+
|
114 |
+
def detect(
|
115 |
+
image: Image.Image,
|
116 |
+
labels: List[str],
|
117 |
+
threshold: float = 0.3,
|
118 |
+
detector_id: Optional[str] = None
|
119 |
+
) -> List[Dict[str, Any]]:
|
120 |
+
"""
|
121 |
+
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
|
122 |
+
"""
|
123 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
124 |
+
detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
|
125 |
+
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
126 |
+
|
127 |
+
labels = [label if label.endswith(".") else label+"." for label in labels]
|
128 |
+
|
129 |
+
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
130 |
+
results = [DetectionResult.from_dict(result) for result in results]
|
131 |
+
|
132 |
+
return results
|
133 |
+
|
134 |
+
def segment(
|
135 |
+
image: Image.Image,
|
136 |
+
detection_results: List[Dict[str, Any]],
|
137 |
+
polygon_refinement: bool = False,
|
138 |
+
segmenter_id: Optional[str] = None
|
139 |
+
) -> List[DetectionResult]:
|
140 |
+
"""
|
141 |
+
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
|
142 |
+
"""
|
143 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
144 |
+
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
|
145 |
+
|
146 |
+
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
147 |
+
processor = AutoProcessor.from_pretrained(segmenter_id)
|
148 |
+
|
149 |
+
boxes = get_boxes(detection_results)
|
150 |
+
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
151 |
+
|
152 |
+
outputs = segmentator(**inputs)
|
153 |
+
masks = processor.post_process_masks(
|
154 |
+
masks=outputs.pred_masks,
|
155 |
+
original_sizes=inputs.original_sizes,
|
156 |
+
reshaped_input_sizes=inputs.reshaped_input_sizes
|
157 |
+
)[0]
|
158 |
+
|
159 |
+
masks = refine_masks(masks, polygon_refinement)
|
160 |
+
|
161 |
+
for detection_result, mask in zip(detection_results, masks):
|
162 |
+
detection_result.mask = mask
|
163 |
+
|
164 |
+
return detection_results
|
165 |
+
|
166 |
+
def grounded_segmentation(
|
167 |
+
image: Union[Image.Image, str],
|
168 |
+
labels: List[str],
|
169 |
+
threshold: float = 0.3,
|
170 |
+
polygon_refinement: bool = False,
|
171 |
+
detector_id: Optional[str] = None,
|
172 |
+
segmenter_id: Optional[str] = None
|
173 |
+
) -> Tuple[np.ndarray, List[DetectionResult]]:
|
174 |
+
if isinstance(image, str):
|
175 |
+
image = load_image(image)
|
176 |
+
|
177 |
+
detections = detect(image, labels, threshold, detector_id)
|
178 |
+
detections = segment(image, detections, polygon_refinement, segmenter_id)
|
179 |
+
|
180 |
+
return image, detections
|
181 |
+
|
182 |
+
|
183 |
+
# In[7]:
|
184 |
+
|
185 |
+
|
186 |
+
# save clipped images
|
187 |
+
def cut_image(image, mask, box):
|
188 |
+
ny_image = np.array(image)
|
189 |
+
cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255)
|
190 |
+
x0, y0, x1, y1 = map(int, box.xyxy)
|
191 |
+
cropped = cut[y0:y1, x0:x1]
|
192 |
+
cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
|
193 |
+
return cropped_bgr
|
src/streamlit_app.py
CHANGED
@@ -1,40 +1,39 @@
|
|
1 |
-
import altair as alt
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import streamlit as st
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import model
|
4 |
|
5 |
+
|
6 |
+
image_url_input = st.text_input("Enter the image URL:")
|
7 |
+
k_value_input = st.number_input("Enter k_value:", min_value=1, value=5)
|
8 |
+
if st.button("Get Results"):
|
9 |
+
results = model.get_top_k_results(image_url_input, int(k_value_input))
|
10 |
+
st.json({"results": [{"metadata": r["metadata"], "score": r["score"]} for r in results]})
|
11 |
+
|
12 |
+
|
13 |
+
if 'metadata_inputs' not in st.session_state:
|
14 |
+
st.session_state['metadata_inputs'] = {}
|
15 |
+
|
16 |
+
uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
17 |
+
|
18 |
+
if uploaded_files:
|
19 |
+
for uploaded_file in uploaded_files:
|
20 |
+
file_key = uploaded_file.name
|
21 |
+
|
22 |
+
|
23 |
+
image = Image.open(uploaded_file)
|
24 |
+
|
25 |
+
st.session_state['metadata_inputs'][file_key] = st.text_input(
|
26 |
+
f"Metadata for {uploaded_file.name}",
|
27 |
+
value=st.session_state['metadata_inputs'].get(file_key, ""),
|
28 |
+
key=f"metadata_{file_key}"
|
29 |
+
)
|
30 |
+
|
31 |
+
if st.button("Upload Images"):
|
32 |
+
for uploaded_file in uploaded_files:
|
33 |
+
metadata = st.session_state['metadata_inputs'][uploaded_file.name]
|
34 |
+
if metadata:
|
35 |
+
image = Image.open(uploaded_file)
|
36 |
+
cropped_image = model.process_image_embedding(image)
|
37 |
+
feature = model.get_image_features(cropped_image)
|
38 |
+
model.save_image_in_index(feature, metadata)
|
39 |
+
st.success("Images uploaded successfully.")
|