File size: 5,468 Bytes
afebd23
 
 
 
 
 
 
 
 
d5321e7
afebd23
 
de33a84
a776cb5
d5321e7
 
 
afebd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a776cb5
de33a84
 
 
 
afebd23
a776cb5
afebd23
de33a84
 
a776cb5
 
 
 
 
 
 
afebd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c0e93e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd
import os
import backoff
from functools import lru_cache
from huggingface_hub import list_models, ModelFilter

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"


cpu_count = multiprocessing.cpu_count()

model = SentenceTransformer("clip-ViT-B-16")


def resize_image(image: Image, size: int = 224) -> Image:
    """Resizes an image retaining the aspect ratio."""
    w, h = image.size
    if w == h:
        image = image.resize((size, size), ANTIALIAS)
        return image
    if w > h:
        height_percent = size / float(h)
        width_size = int(float(w) * float(height_percent))
        image = image.resize((width_size, size), ANTIALIAS)
        return image
    if w < h:
        width_percent = size / float(w)
        height_size = int(float(w) * float(width_percent))
        image = image.resize((size, height_size), ANTIALIAS)
        return image


dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)


def get_nearest_k_examples(input, k):
    query = model.encode(input)
    # faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
    # threshold = 0.95
    # limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
    # images = dataset[indices]
    _, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
    images = retrieved_examples["image"][:k]
    last_modified = retrieved_examples["last_modified_date"]  # [:k]
    crawl_date = retrieved_examples["crawl_date"]  # [:k]
    metadata = [
        f"last_modified {modified}, crawl date:{crawl}"
        for modified, crawl in zip(last_modified, crawl_date)
    ]
    return list(zip(images, metadata))


def return_random_sample(k=27):
    sample = random.sample(range(len(dataset)), k)
    images = dataset[sample]["image"]
    return [resize_image(image).convert("RGB") for image in images]


@lru_cache()
def get_valid_hub_image_classification_model_ids():
    models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
    return {model.id for model in models}


def predict_subset(model_id, token):
    valid_model_ids = get_valid_hub_image_classification_model_ids()
    if model_id not in valid_model_ids:
        raise gr.Error(
            f"model_id {model_id} is not a valid image classification model id"
        )

    API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
    headers = {"Authorization": f"Bearer {token}"}

    @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
    def _query(url):
        r = requests.post(API_URL, headers=headers, data=url)
        return r

    @lru_cache(maxsize=1000)
    def query(url):
        response = _query(url)
        try:
            data = response.json()
            argmax = data[0]
            return {"score": argmax["score"], "label": argmax["label"]}
        except Exception:
            return {"score": None, "label": None}

    # dataset2 = copy.deepcopy(dataset)
    # dataset2.drop_index("embedding")
    dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
    sample = random.sample(range(len(dataset)), 10)
    sample = dataset.select(sample)
    print("predicting...")
    predictions = []
    for row in sample:
        url = row["url"]
        predictions.append(query(url))
    gallery = []
    for url, prediction in zip(sample["url"], predictions):
        gallery.append((url, f"{prediction['label'], prediction['score']}"))
    # sample = sample.map(lambda x:  query(x['url']))
    labels = [d["label"] for d in predictions]
    from toolz import frequencies

    df = pd.DataFrame(
        {"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
    )
    return gallery, df


with gr.Blocks() as demo:
    with gr.Tab("Random image gallery"):
        button = gr.Button("Refresh")
        gallery = gr.Gallery().style(grid=9, height="1400")
        button.click(return_random_sample, [], [gallery])
    with gr.Tab("image search"):
        text = gr.Textbox(label="Search for images")
        k = gr.Slider(minimum=3, maximum=18, step=1)
        button = gr.Button("search")
        gallery = gr.Gallery().style(grid=3)
        button.click(get_nearest_k_examples, [text, k], [gallery])
    # with gr.Tab("Export for label studio"):
    #     button = gr.Button("Export")
    #     dataset2 = copy.deepcopy(dataset)
    #     # dataset2 = dataset2.remove_columns('image')
    #     # dataset2 = dataset2.rename_column("url", "image")
    #     csv = dataset2.to_csv("label_studio.csv")
    #     csv_file = gr.File("label_studio.csv")
    #     button.click(dataset.save_to_disk, [], [csv_file])
    with gr.Tab("predict"):
        token = gr.Textbox(label="token", type="password")
        model_id = gr.Textbox(label="model_id")
        button = gr.Button("predict")
        plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
        gallery = gr.Gallery()
        button.click(predict_subset, [model_id, token], [gallery, plot])

demo.launch(enable_queue=True)