Spaces:
Runtime error
Runtime error
File size: 5,468 Bytes
afebd23 d5321e7 afebd23 de33a84 a776cb5 d5321e7 afebd23 a776cb5 de33a84 afebd23 a776cb5 afebd23 de33a84 a776cb5 afebd23 4c0e93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd
import os
import backoff
from functools import lru_cache
from huggingface_hub import list_models, ModelFilter
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
cpu_count = multiprocessing.cpu_count()
model = SentenceTransformer("clip-ViT-B-16")
def resize_image(image: Image, size: int = 224) -> Image:
"""Resizes an image retaining the aspect ratio."""
w, h = image.size
if w == h:
image = image.resize((size, size), ANTIALIAS)
return image
if w > h:
height_percent = size / float(h)
width_size = int(float(w) * float(height_percent))
image = image.resize((width_size, size), ANTIALIAS)
return image
if w < h:
width_percent = size / float(w)
height_size = int(float(w) * float(width_percent))
image = image.resize((size, height_size), ANTIALIAS)
return image
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)
def get_nearest_k_examples(input, k):
query = model.encode(input)
# faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
# threshold = 0.95
# limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
# images = dataset[indices]
_, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
images = retrieved_examples["image"][:k]
last_modified = retrieved_examples["last_modified_date"] # [:k]
crawl_date = retrieved_examples["crawl_date"] # [:k]
metadata = [
f"last_modified {modified}, crawl date:{crawl}"
for modified, crawl in zip(last_modified, crawl_date)
]
return list(zip(images, metadata))
def return_random_sample(k=27):
sample = random.sample(range(len(dataset)), k)
images = dataset[sample]["image"]
return [resize_image(image).convert("RGB") for image in images]
@lru_cache()
def get_valid_hub_image_classification_model_ids():
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
return {model.id for model in models}
def predict_subset(model_id, token):
valid_model_ids = get_valid_hub_image_classification_model_ids()
if model_id not in valid_model_ids:
raise gr.Error(
f"model_id {model_id} is not a valid image classification model id"
)
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
headers = {"Authorization": f"Bearer {token}"}
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
def _query(url):
r = requests.post(API_URL, headers=headers, data=url)
return r
@lru_cache(maxsize=1000)
def query(url):
response = _query(url)
try:
data = response.json()
argmax = data[0]
return {"score": argmax["score"], "label": argmax["label"]}
except Exception:
return {"score": None, "label": None}
# dataset2 = copy.deepcopy(dataset)
# dataset2.drop_index("embedding")
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
sample = random.sample(range(len(dataset)), 10)
sample = dataset.select(sample)
print("predicting...")
predictions = []
for row in sample:
url = row["url"]
predictions.append(query(url))
gallery = []
for url, prediction in zip(sample["url"], predictions):
gallery.append((url, f"{prediction['label'], prediction['score']}"))
# sample = sample.map(lambda x: query(x['url']))
labels = [d["label"] for d in predictions]
from toolz import frequencies
df = pd.DataFrame(
{"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
)
return gallery, df
with gr.Blocks() as demo:
with gr.Tab("Random image gallery"):
button = gr.Button("Refresh")
gallery = gr.Gallery().style(grid=9, height="1400")
button.click(return_random_sample, [], [gallery])
with gr.Tab("image search"):
text = gr.Textbox(label="Search for images")
k = gr.Slider(minimum=3, maximum=18, step=1)
button = gr.Button("search")
gallery = gr.Gallery().style(grid=3)
button.click(get_nearest_k_examples, [text, k], [gallery])
# with gr.Tab("Export for label studio"):
# button = gr.Button("Export")
# dataset2 = copy.deepcopy(dataset)
# # dataset2 = dataset2.remove_columns('image')
# # dataset2 = dataset2.rename_column("url", "image")
# csv = dataset2.to_csv("label_studio.csv")
# csv_file = gr.File("label_studio.csv")
# button.click(dataset.save_to_disk, [], [csv_file])
with gr.Tab("predict"):
token = gr.Textbox(label="token", type="password")
model_id = gr.Textbox(label="model_id")
button = gr.Button("predict")
plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
gallery = gr.Gallery()
button.click(predict_subset, [model_id, token], [gallery, plot])
demo.launch(enable_queue=True)
|