dheena commited on
Commit
e8ee9c0
·
1 Parent(s): 7c5bfda

initial commit

Browse files
Files changed (4) hide show
  1. requirements.txt +18 -1
  2. src/model.py +65 -0
  3. src/segmentation.py +193 -0
  4. src/streamlit_app.py +37 -38
requirements.txt CHANGED
@@ -1,3 +1,20 @@
 
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Existing packages
2
  altair
3
  pandas
4
+ streamlit
5
+
6
+ # New packages based on imports
7
+ faiss-cpu # or faiss-gpu if you're using GPU
8
+ torch # PyTorch
9
+ ftfy # often required by CLIP
10
+ git+https://github.com/openai/CLIP.git # for CLIP from OpenAI
11
+ openai # OpenAI API client
12
+ numpy
13
+ Pillow # PIL
14
+ fastapi
15
+ segmentation-models-pytorch # Assuming "segmentation" is a model lib, adjust if needed
16
+
17
+ # Additional NLP/ML utilities
18
+ opencv-python # cv2
19
+ requests
20
+ transformers # HuggingFace Transformers
src/model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import torch
3
+ import clip
4
+ from openai import OpenAI
5
+ import numpy as np
6
+ from PIL import Image
7
+ from fastapi import FastAPI
8
+ from typing import List
9
+ import segmentation
10
+
11
+ client = OpenAI()
12
+ device = "cpu"
13
+ model, preprocess = clip.load("ViT-B/32", device=device)
14
+
15
+ def get_image_features(image: Image.Image) -> np.ndarray:
16
+ """Extract CLIP features from an image."""
17
+ image_input = preprocess(image).unsqueeze(0).to(device)
18
+ with torch.no_grad():
19
+ image_features = model.encode_image(image_input).float()
20
+ return image_features.cpu().numpy()
21
+
22
+ # FAISS setup
23
+ index = faiss.IndexFlatIP(512)
24
+ meta_data_store = []
25
+
26
+ def save_image_in_index(image_features: np.ndarray, metadata: dict):
27
+ """Normalize features and add to index."""
28
+ faiss.normalize_L2(image_features)
29
+ index.add(image_features)
30
+ meta_data_store.append(metadata)
31
+
32
+ def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
33
+ """Get feature embedding for a query image."""
34
+ search_image, search_detections = segmentation.grounded_segmentation(image=image_url, labels=labels)
35
+ cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
36
+
37
+ # Convert to valid RGB
38
+ if cropped_image.dtype != np.uint8:
39
+ cropped_image = (cropped_image * 255).astype(np.uint8)
40
+ if cropped_image.ndim == 2:
41
+ cropped_image = np.stack([cropped_image] * 3, axis=-1)
42
+
43
+ pil_image = Image.fromarray(cropped_image)
44
+ return pil_image
45
+
46
+ def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
47
+ """Find top-k similar images from the index."""
48
+ processed_image = process_image_embedding(image_url)
49
+ image_search_embedding = get_image_features(processed_image)
50
+ faiss.normalize_L2(image_search_embedding)
51
+ distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
52
+
53
+ results = []
54
+ for i, dist in zip(indices[0], distances[0]):
55
+ if i < len(meta_data_store):
56
+ results.append({
57
+ 'metadata': meta_data_store[i],
58
+ 'score': float(dist)
59
+ })
60
+ return results
61
+
62
+
63
+
64
+
65
+
src/segmentation.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass
3
+ from typing import Any, List, Dict, Optional, Union, Tuple
4
+ import os
5
+
6
+ import cv2
7
+ import torch
8
+ import requests
9
+ import numpy as np
10
+ from PIL import Image
11
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
12
+
13
+ # In[2]:
14
+
15
+
16
+ @dataclass
17
+ class BoundingBox:
18
+ xmin: int
19
+ ymin: int
20
+ xmax: int
21
+ ymax: int
22
+
23
+ @property
24
+ def xyxy(self) -> List[float]:
25
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
26
+
27
+ @dataclass
28
+ class DetectionResult:
29
+ score: float
30
+ label: str
31
+ box: BoundingBox
32
+ mask: Optional[np.array] = None
33
+
34
+ @classmethod
35
+ def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
36
+ return cls(score=detection_dict['score'],
37
+ label=detection_dict['label'],
38
+ box=BoundingBox(xmin=detection_dict['box']['xmin'],
39
+ ymin=detection_dict['box']['ymin'],
40
+ xmax=detection_dict['box']['xmax'],
41
+ ymax=detection_dict['box']['ymax']))
42
+
43
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
44
+ # Find contours in the binary mask
45
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
46
+
47
+ # Find the contour with the largest area
48
+ largest_contour = max(contours, key=cv2.contourArea)
49
+
50
+ # Extract the vertices of the contour
51
+ polygon = largest_contour.reshape(-1, 2).tolist()
52
+
53
+ return polygon
54
+
55
+ def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
56
+ """
57
+ Convert a polygon to a segmentation mask.
58
+
59
+ Args:
60
+ - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
61
+ - image_shape (tuple): Shape of the image (height, width) for the mask.
62
+
63
+ Returns:
64
+ - np.ndarray: Segmentation mask with the polygon filled.
65
+ """
66
+ # Create an empty mask
67
+ mask = np.zeros(image_shape, dtype=np.uint8)
68
+
69
+ # Convert polygon to an array of points
70
+ pts = np.array(polygon, dtype=np.int32)
71
+
72
+ # Fill the polygon with white color (255)
73
+ cv2.fillPoly(mask, [pts], color=(255,))
74
+
75
+ return mask
76
+
77
+ def load_image(image_str: str) -> Image.Image:
78
+ if image_str.startswith("http"):
79
+ image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
80
+ else:
81
+ image = Image.open(image_str).convert("RGB")
82
+
83
+ return image
84
+
85
+ def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
86
+ boxes = []
87
+ for result in results:
88
+ xyxy = result.box.xyxy
89
+ boxes.append(xyxy)
90
+
91
+ return [boxes]
92
+
93
+ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
94
+ masks = masks.cpu().float()
95
+ masks = masks.permute(0, 2, 3, 1)
96
+ masks = masks.mean(axis=-1)
97
+ masks = (masks > 0).int()
98
+ masks = masks.numpy().astype(np.uint8)
99
+ masks = list(masks)
100
+
101
+ if polygon_refinement:
102
+ for idx, mask in enumerate(masks):
103
+ shape = mask.shape
104
+ polygon = mask_to_polygon(mask)
105
+ mask = polygon_to_mask(polygon, shape)
106
+ masks[idx] = mask
107
+
108
+ return masks
109
+
110
+
111
+ # In[6]:
112
+
113
+
114
+ def detect(
115
+ image: Image.Image,
116
+ labels: List[str],
117
+ threshold: float = 0.3,
118
+ detector_id: Optional[str] = None
119
+ ) -> List[Dict[str, Any]]:
120
+ """
121
+ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
122
+ """
123
+ device = "cuda" if torch.cuda.is_available() else "cpu"
124
+ detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
125
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
126
+
127
+ labels = [label if label.endswith(".") else label+"." for label in labels]
128
+
129
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
130
+ results = [DetectionResult.from_dict(result) for result in results]
131
+
132
+ return results
133
+
134
+ def segment(
135
+ image: Image.Image,
136
+ detection_results: List[Dict[str, Any]],
137
+ polygon_refinement: bool = False,
138
+ segmenter_id: Optional[str] = None
139
+ ) -> List[DetectionResult]:
140
+ """
141
+ Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
142
+ """
143
+ device = "cuda" if torch.cuda.is_available() else "cpu"
144
+ segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
145
+
146
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
147
+ processor = AutoProcessor.from_pretrained(segmenter_id)
148
+
149
+ boxes = get_boxes(detection_results)
150
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
151
+
152
+ outputs = segmentator(**inputs)
153
+ masks = processor.post_process_masks(
154
+ masks=outputs.pred_masks,
155
+ original_sizes=inputs.original_sizes,
156
+ reshaped_input_sizes=inputs.reshaped_input_sizes
157
+ )[0]
158
+
159
+ masks = refine_masks(masks, polygon_refinement)
160
+
161
+ for detection_result, mask in zip(detection_results, masks):
162
+ detection_result.mask = mask
163
+
164
+ return detection_results
165
+
166
+ def grounded_segmentation(
167
+ image: Union[Image.Image, str],
168
+ labels: List[str],
169
+ threshold: float = 0.3,
170
+ polygon_refinement: bool = False,
171
+ detector_id: Optional[str] = None,
172
+ segmenter_id: Optional[str] = None
173
+ ) -> Tuple[np.ndarray, List[DetectionResult]]:
174
+ if isinstance(image, str):
175
+ image = load_image(image)
176
+
177
+ detections = detect(image, labels, threshold, detector_id)
178
+ detections = segment(image, detections, polygon_refinement, segmenter_id)
179
+
180
+ return image, detections
181
+
182
+
183
+ # In[7]:
184
+
185
+
186
+ # save clipped images
187
+ def cut_image(image, mask, box):
188
+ ny_image = np.array(image)
189
+ cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255)
190
+ x0, y0, x1, y1 = map(int, box.xyxy)
191
+ cropped = cut[y0:y1, x0:x1]
192
+ cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
193
+ return cropped_bgr
src/streamlit_app.py CHANGED
@@ -1,40 +1,39 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import model
4
 
5
+
6
+ image_url_input = st.text_input("Enter the image URL:")
7
+ k_value_input = st.number_input("Enter k_value:", min_value=1, value=5)
8
+ if st.button("Get Results"):
9
+ results = model.get_top_k_results(image_url_input, int(k_value_input))
10
+ st.json({"results": [{"metadata": r["metadata"], "score": r["score"]} for r in results]})
11
+
12
+
13
+ if 'metadata_inputs' not in st.session_state:
14
+ st.session_state['metadata_inputs'] = {}
15
+
16
+ uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
17
+
18
+ if uploaded_files:
19
+ for uploaded_file in uploaded_files:
20
+ file_key = uploaded_file.name
21
+
22
+
23
+ image = Image.open(uploaded_file)
24
+
25
+ st.session_state['metadata_inputs'][file_key] = st.text_input(
26
+ f"Metadata for {uploaded_file.name}",
27
+ value=st.session_state['metadata_inputs'].get(file_key, ""),
28
+ key=f"metadata_{file_key}"
29
+ )
30
+
31
+ if st.button("Upload Images"):
32
+ for uploaded_file in uploaded_files:
33
+ metadata = st.session_state['metadata_inputs'][uploaded_file.name]
34
+ if metadata:
35
+ image = Image.open(uploaded_file)
36
+ cropped_image = model.process_image_embedding(image)
37
+ feature = model.get_image_features(cropped_image)
38
+ model.save_image_in_index(feature, metadata)
39
+ st.success("Images uploaded successfully.")