Spaces:
Running
Running
Duplicate from sayakpaul/fetch-similar-images
Browse filesCo-authored-by: Sayak Paul <[email protected]>
- .gitattributes +34 -0
- 0.png +0 -0
- 1.png +0 -0
- 2.png +0 -0
- README.md +14 -0
- app.py +73 -0
- lsh.pickle +3 -0
- requirements.txt +5 -0
- similarity_utils.py +175 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
0.png
ADDED
![]() |
1.png
ADDED
![]() |
2.png
ADDED
![]() |
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Fetch Similar Beans 🪴
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.12.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: sayakpaul/fetch-similar-images
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Thanks to Freddy Boulton (https://github.com/freddyaboulton) for helping with this.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from datasets import load_dataset
|
10 |
+
from transformers import AutoModel
|
11 |
+
|
12 |
+
# `LSH` and `Table` imports are necessary in order for the
|
13 |
+
# `lsh.pickle` file to load successfully.
|
14 |
+
from similarity_utils import LSH, BuildLSHTable, Table
|
15 |
+
|
16 |
+
seed = 42
|
17 |
+
|
18 |
+
# Only runs once when the script is first run.
|
19 |
+
with open("lsh.pickle", "rb") as handle:
|
20 |
+
loaded_lsh = pickle.load(handle)
|
21 |
+
|
22 |
+
# Load model for computing embeddings.
|
23 |
+
model_ckpt = "nateraw/vit-base-beans"
|
24 |
+
model = AutoModel.from_pretrained(model_ckpt)
|
25 |
+
lsh_builder = BuildLSHTable(model)
|
26 |
+
lsh_builder.lsh = loaded_lsh
|
27 |
+
|
28 |
+
# Candidate images.
|
29 |
+
dataset = load_dataset("beans")
|
30 |
+
candidate_dataset = dataset["train"].shuffle(seed=seed)
|
31 |
+
|
32 |
+
|
33 |
+
def query(image, top_k):
|
34 |
+
results = lsh_builder.query(image)
|
35 |
+
|
36 |
+
# Should be a list of string file paths for gr.Gallery to work
|
37 |
+
images = []
|
38 |
+
# List of labels for each image in the gallery
|
39 |
+
labels = []
|
40 |
+
|
41 |
+
candidates = []
|
42 |
+
|
43 |
+
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
|
44 |
+
if idx == top_k:
|
45 |
+
break
|
46 |
+
image_id, label = r.split("_")[0], r.split("_")[1]
|
47 |
+
candidates.append(candidate_dataset[int(image_id)]["image"])
|
48 |
+
labels.append(f"Label: {label}")
|
49 |
+
|
50 |
+
for i, candidate in enumerate(candidates):
|
51 |
+
filename = f"similar_{i}.png"
|
52 |
+
candidate.save(filename)
|
53 |
+
images.append(filename)
|
54 |
+
|
55 |
+
# The gallery component can be a list of tuples, where the first element is a path to a file
|
56 |
+
# and the second element is an optional caption for that image
|
57 |
+
return list(zip(images, labels))
|
58 |
+
|
59 |
+
|
60 |
+
title = "Fetch Similar Beans 🪴"
|
61 |
+
description = "This Space demos an image similarity system. You can refer to [this notebook](TODO) to know the details of the system. You can pick any image from the available samples below. On the right hand side, you'll find the similar images returned by the system. The example images have been named with their corresponding integer class labels for easier identification. The fetched images will also have their integer labels tagged so that you can validate the correctness of the results."
|
62 |
+
|
63 |
+
# You can set the type of gr.Image to be PIL, numpy or str (filepath)
|
64 |
+
# Not sure what the best for this demo is.
|
65 |
+
gr.Interface(
|
66 |
+
query,
|
67 |
+
inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
|
68 |
+
outputs=gr.Gallery().style(grid=[3], height="auto"),
|
69 |
+
# Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
|
70 |
+
title=title,
|
71 |
+
description=description,
|
72 |
+
examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]],
|
73 |
+
).launch()
|
lsh.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caa1727832f2279a4026b03b9f17638ff4a4deffa0a28586e74db59332dce732
|
3 |
+
size 136667
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.25.1
|
2 |
+
datasets==2.7.1
|
3 |
+
numpy==1.21.6
|
4 |
+
torch==1.12.1
|
5 |
+
torchvision
|
similarity_utils.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from transformers import AutoFeatureExtractor, AutoModel
|
10 |
+
|
11 |
+
seed = 42
|
12 |
+
hash_size = 8
|
13 |
+
hidden_dim = 768 # ViT-base
|
14 |
+
np.random.seed(seed)
|
15 |
+
|
16 |
+
|
17 |
+
# Device.
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
# Load model for computing embeddings..
|
21 |
+
model_ckpt = "nateraw/vit-base-beans"
|
22 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
|
23 |
+
|
24 |
+
# Data transformation chain.
|
25 |
+
transformation_chain = T.Compose(
|
26 |
+
[
|
27 |
+
# We first resize the input image to 256x256 and then we take center crop.
|
28 |
+
T.Resize(int((256 / 224) * extractor.size["height"])),
|
29 |
+
T.CenterCrop(extractor.size["height"]),
|
30 |
+
T.ToTensor(),
|
31 |
+
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
|
32 |
+
]
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# Define random vectors to project with.
|
37 |
+
random_vectors = np.random.randn(hash_size, hidden_dim).T
|
38 |
+
|
39 |
+
|
40 |
+
def hash_func(embedding, random_vectors=random_vectors):
|
41 |
+
"""Randomly projects the embeddings and then computes bit-wise hashes."""
|
42 |
+
if not isinstance(embedding, np.ndarray):
|
43 |
+
embedding = np.array(embedding)
|
44 |
+
if len(embedding.shape) < 2:
|
45 |
+
embedding = np.expand_dims(embedding, 0)
|
46 |
+
|
47 |
+
# Random projection.
|
48 |
+
bools = np.dot(embedding, random_vectors) > 0
|
49 |
+
return [bool2int(bool_vec) for bool_vec in bools]
|
50 |
+
|
51 |
+
|
52 |
+
def bool2int(x):
|
53 |
+
y = 0
|
54 |
+
for i, j in enumerate(x):
|
55 |
+
if j:
|
56 |
+
y += 1 << i
|
57 |
+
return y
|
58 |
+
|
59 |
+
|
60 |
+
def compute_hash(model: Union[torch.nn.Module, str]):
|
61 |
+
"""Computes hash on a given dataset."""
|
62 |
+
device = model.device
|
63 |
+
|
64 |
+
def pp(example_batch):
|
65 |
+
# Prepare the input images for the model.
|
66 |
+
image_batch = example_batch["image"]
|
67 |
+
image_batch_transformed = torch.stack(
|
68 |
+
[transformation_chain(image) for image in image_batch]
|
69 |
+
)
|
70 |
+
new_batch = {"pixel_values": image_batch_transformed.to(device)}
|
71 |
+
|
72 |
+
# Compute embeddings and pool them i.e., take the representations from the [CLS]
|
73 |
+
# token.
|
74 |
+
with torch.no_grad():
|
75 |
+
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu().numpy()
|
76 |
+
|
77 |
+
# Compute hashes for the batch of images.
|
78 |
+
hashes = [hash_func(embeddings[i]) for i in range(len(embeddings))]
|
79 |
+
example_batch["hashes"] = hashes
|
80 |
+
return example_batch
|
81 |
+
|
82 |
+
return pp
|
83 |
+
|
84 |
+
|
85 |
+
class Table:
|
86 |
+
def __init__(self, hash_size: int):
|
87 |
+
self.table = {}
|
88 |
+
self.hash_size = hash_size
|
89 |
+
|
90 |
+
def add(self, id: int, hashes: List[int], label: int):
|
91 |
+
# Create a unique indentifier.
|
92 |
+
entry = {"id_label": str(id) + "_" + str(label)}
|
93 |
+
|
94 |
+
# Add the hash values to the current table.
|
95 |
+
for h in hashes:
|
96 |
+
if h in self.table:
|
97 |
+
self.table[h].append(entry)
|
98 |
+
else:
|
99 |
+
self.table[h] = [entry]
|
100 |
+
|
101 |
+
def query(self, hashes: List[int]):
|
102 |
+
results = []
|
103 |
+
|
104 |
+
# Loop over the query hashes and determine if they exist in
|
105 |
+
# the current table.
|
106 |
+
for h in hashes:
|
107 |
+
if h in self.table:
|
108 |
+
results.extend(self.table[h])
|
109 |
+
return results
|
110 |
+
|
111 |
+
|
112 |
+
class LSH:
|
113 |
+
def __init__(self, hash_size, num_tables):
|
114 |
+
self.num_tables = num_tables
|
115 |
+
self.tables = []
|
116 |
+
for i in range(self.num_tables):
|
117 |
+
self.tables.append(Table(hash_size))
|
118 |
+
|
119 |
+
def add(self, id: int, hash: List[int], label: int):
|
120 |
+
for table in self.tables:
|
121 |
+
table.add(id, hash, label)
|
122 |
+
|
123 |
+
def query(self, hashes: List[int]):
|
124 |
+
results = []
|
125 |
+
for table in self.tables:
|
126 |
+
results.extend(table.query(hashes))
|
127 |
+
return results
|
128 |
+
|
129 |
+
|
130 |
+
class BuildLSHTable:
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
model: Union[torch.nn.Module, None],
|
134 |
+
batch_size: int = 48,
|
135 |
+
hash_size: int = hash_size,
|
136 |
+
dim: int = hidden_dim,
|
137 |
+
num_tables: int = 10,
|
138 |
+
):
|
139 |
+
self.hash_size = hash_size
|
140 |
+
self.dim = dim
|
141 |
+
self.num_tables = num_tables
|
142 |
+
self.lsh = LSH(self.hash_size, self.num_tables)
|
143 |
+
|
144 |
+
self.batch_size = batch_size
|
145 |
+
self.hash_fn = compute_hash(model.to(device))
|
146 |
+
|
147 |
+
def build(self, ds: datasets.DatasetDict):
|
148 |
+
dataset_hashed = ds.map(self.hash_fn, batched=True, batch_size=self.batch_size)
|
149 |
+
|
150 |
+
for id in tqdm(range(len(dataset_hashed))):
|
151 |
+
hash, label = dataset_hashed[id]["hashes"], dataset_hashed[id]["labels"]
|
152 |
+
self.lsh.add(id, hash, label)
|
153 |
+
|
154 |
+
def query(self, image, verbose=True):
|
155 |
+
if isinstance(image, str):
|
156 |
+
image = Image.open(image).convert("RGB")
|
157 |
+
|
158 |
+
# Compute the hashes of the query image and fetch the results.
|
159 |
+
example_batch = dict(image=[image])
|
160 |
+
hashes = self.hash_fn(example_batch)["hashes"][0]
|
161 |
+
|
162 |
+
results = self.lsh.query(hashes)
|
163 |
+
if verbose:
|
164 |
+
print("Matches:", len(results))
|
165 |
+
|
166 |
+
# Calculate Jaccard index to quantify the similarity.
|
167 |
+
counts = {}
|
168 |
+
for r in results:
|
169 |
+
if r["id_label"] in counts:
|
170 |
+
counts[r["id_label"]] += 1
|
171 |
+
else:
|
172 |
+
counts[r["id_label"]] = 1
|
173 |
+
for k in counts:
|
174 |
+
counts[k] = float(counts[k]) / self.dim
|
175 |
+
return counts
|