File size: 4,956 Bytes
afebd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c0e93e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd

import backoff
from functools import lru_cache

cpu_count = multiprocessing.cpu_count()

model = SentenceTransformer("clip-ViT-B-16")


def resize_image(image: Image, size: int = 224) -> Image:
    """Resizes an image retaining the aspect ratio."""
    w, h = image.size
    if w == h:
        image = image.resize((size, size), ANTIALIAS)
        return image
    if w > h:
        height_percent = size / float(h)
        width_size = int(float(w) * float(height_percent))
        image = image.resize((width_size, size), ANTIALIAS)
        return image
    if w < h:
        width_percent = size / float(w)
        height_size = int(float(w) * float(width_percent))
        image = image.resize((size, height_size), ANTIALIAS)
        return image


dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)


def get_nearest_k_examples(input, k):
    query = model.encode(input)
    # faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
    # threshold = 0.95
    # limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
    # images = dataset[indices]
    _, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
    images = retrieved_examples["image"][:k]
    last_modified = retrieved_examples["last_modified_date"]  # [:k]
    crawl_date = retrieved_examples["crawl_date"]  # [:k]
    metadata = [
        f"last_modified {modified}, crawl date:{crawl}"
        for modified, crawl in zip(last_modified, crawl_date)
    ]
    return list(zip(images, metadata))


def return_random_sample(k=27):
    sample = random.sample(range(len(dataset)), k)
    images = dataset[sample]["image"]
    return [resize_image(image).convert("RGB") for image in images]


def predict_subset(model_id, token):
    API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
    headers = {"Authorization": f"Bearer {token}"}

    @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
    def _query(url):
        r = requests.post(API_URL, headers=headers, data=url)
        print(r)
        return r

    @lru_cache(maxsize=1000)
    def query(url):
        response = _query(url)
        try:
            data = response.json()
            argmax = data[0]
            return {"score": argmax["score"], "label": argmax["label"]}
        except Exception:
            return {"score": None, "label": None}

    # dataset2 = copy.deepcopy(dataset)
    # dataset2.drop_index("embedding")
    dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
    sample = random.sample(range(len(dataset)), 10)
    sample = dataset.select(sample)
    print("predicting...")
    predictions = []
    for row in sample:
        url = row["url"]
        predictions.append(query(url))
    gallery = []
    for url, prediction in zip(sample["url"], predictions):
        gallery.append((url, f"{prediction['label'], prediction['score']}"))
    # sample = sample.map(lambda x:  query(x['url']))
    labels = [d["label"] for d in predictions]
    from toolz import frequencies

    df = pd.DataFrame(
        {"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
    )
    return gallery, df


with gr.Blocks() as demo:
    with gr.Tab("Random image gallery"):
        button = gr.Button("Refresh")
        gallery = gr.Gallery().style(grid=9, height="1400")
        button.click(return_random_sample, [], [gallery])
    with gr.Tab("image search"):
        text = gr.Textbox(label="Search for images")
        k = gr.Slider(minimum=3, maximum=18, step=1)
        button = gr.Button("search")
        gallery = gr.Gallery().style(grid=3)
        button.click(get_nearest_k_examples, [text, k], [gallery])
    # with gr.Tab("Export for label studio"):
    #     button = gr.Button("Export")
    #     dataset2 = copy.deepcopy(dataset)
    #     # dataset2 = dataset2.remove_columns('image')
    #     # dataset2 = dataset2.rename_column("url", "image")
    #     csv = dataset2.to_csv("label_studio.csv")
    #     csv_file = gr.File("label_studio.csv")
    #     button.click(dataset.save_to_disk, [], [csv_file])
    with gr.Tab("predict"):
        token = gr.Textbox(label="token", type="password")
        model_id = gr.Textbox(label="model_id")
        button = gr.Button("predict")
        plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
        gallery = gr.Gallery()
        button.click(predict_subset, [model_id, token], [gallery, plot])

demo.launch(enable_queue=True)