Spaces:
Runtime error
Runtime error
File size: 4,956 Bytes
afebd23 4c0e93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd
import backoff
from functools import lru_cache
cpu_count = multiprocessing.cpu_count()
model = SentenceTransformer("clip-ViT-B-16")
def resize_image(image: Image, size: int = 224) -> Image:
"""Resizes an image retaining the aspect ratio."""
w, h = image.size
if w == h:
image = image.resize((size, size), ANTIALIAS)
return image
if w > h:
height_percent = size / float(h)
width_size = int(float(w) * float(height_percent))
image = image.resize((width_size, size), ANTIALIAS)
return image
if w < h:
width_percent = size / float(w)
height_size = int(float(w) * float(width_percent))
image = image.resize((size, height_size), ANTIALIAS)
return image
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)
def get_nearest_k_examples(input, k):
query = model.encode(input)
# faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
# threshold = 0.95
# limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
# images = dataset[indices]
_, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
images = retrieved_examples["image"][:k]
last_modified = retrieved_examples["last_modified_date"] # [:k]
crawl_date = retrieved_examples["crawl_date"] # [:k]
metadata = [
f"last_modified {modified}, crawl date:{crawl}"
for modified, crawl in zip(last_modified, crawl_date)
]
return list(zip(images, metadata))
def return_random_sample(k=27):
sample = random.sample(range(len(dataset)), k)
images = dataset[sample]["image"]
return [resize_image(image).convert("RGB") for image in images]
def predict_subset(model_id, token):
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
headers = {"Authorization": f"Bearer {token}"}
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
def _query(url):
r = requests.post(API_URL, headers=headers, data=url)
print(r)
return r
@lru_cache(maxsize=1000)
def query(url):
response = _query(url)
try:
data = response.json()
argmax = data[0]
return {"score": argmax["score"], "label": argmax["label"]}
except Exception:
return {"score": None, "label": None}
# dataset2 = copy.deepcopy(dataset)
# dataset2.drop_index("embedding")
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
sample = random.sample(range(len(dataset)), 10)
sample = dataset.select(sample)
print("predicting...")
predictions = []
for row in sample:
url = row["url"]
predictions.append(query(url))
gallery = []
for url, prediction in zip(sample["url"], predictions):
gallery.append((url, f"{prediction['label'], prediction['score']}"))
# sample = sample.map(lambda x: query(x['url']))
labels = [d["label"] for d in predictions]
from toolz import frequencies
df = pd.DataFrame(
{"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
)
return gallery, df
with gr.Blocks() as demo:
with gr.Tab("Random image gallery"):
button = gr.Button("Refresh")
gallery = gr.Gallery().style(grid=9, height="1400")
button.click(return_random_sample, [], [gallery])
with gr.Tab("image search"):
text = gr.Textbox(label="Search for images")
k = gr.Slider(minimum=3, maximum=18, step=1)
button = gr.Button("search")
gallery = gr.Gallery().style(grid=3)
button.click(get_nearest_k_examples, [text, k], [gallery])
# with gr.Tab("Export for label studio"):
# button = gr.Button("Export")
# dataset2 = copy.deepcopy(dataset)
# # dataset2 = dataset2.remove_columns('image')
# # dataset2 = dataset2.rename_column("url", "image")
# csv = dataset2.to_csv("label_studio.csv")
# csv_file = gr.File("label_studio.csv")
# button.click(dataset.save_to_disk, [], [csv_file])
with gr.Tab("predict"):
token = gr.Textbox(label="token", type="password")
model_id = gr.Textbox(label="model_id")
button = gr.Button("predict")
plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
gallery = gr.Gallery()
button.click(predict_subset, [model_id, token], [gallery, plot])
demo.launch(enable_queue=True)
|