File size: 3,107 Bytes
17c4cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed481cf
17c4cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
ed481cf
 
17c4cbe
 
 
 
 
 
 
 
 
459cd3a
ed481cf
44485b6
17c4cbe
459cd3a
17c4cbe
 
 
459cd3a
17c4cbe
cad54b2
d121302
44485b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from cProfile import label
from turtle import title
import numpy as np
import gradio as gr
import pickle
from skimage import io
from scipy.spatial import distance

# all the images name in a list
images = [line.strip() for line in open("holidays_images.dat","r")]

# all the query image names in a list
query_images = []
for line in open("holidays_images.dat","r"): 
    imname=line.strip() 
    imno=int(imname[:-len(".jpg")])
    if imno%100==0:
        query_images.append(imname)

with open('saved_cnn.pkl', 'rb') as f:
    cnn_embeddings = pickle.load(f)

with open('saved_bovw.pkl', 'rb') as f:
    bovw_embeddings = pickle.load(f)

with open('saved_naive.pkl', 'rb') as f:
    naive_embeddings = pickle.load(f)


def similarity_all(query_image_name, embeddings, metric):
    querry_embedding = embeddings[query_image_name]
    scores = {image_name : metric(querry_embedding, embeddings[image_name]) for image_name in images}
    return scores

def euclidean_similarity_score(query_embedding, target_embedding):    
    return np.linalg.norm(query_embedding-target_embedding)

def cosine_similarity_score(query_embedding, target_embedding):
    return distance.cosine(np.reshape(query_embedding, -1), np.reshape(target_embedding, -1))

def retrieve(query_image_name, embeddings_type, metric_type):

    if embeddings_type == 'MobileNetV2' :
        embeddings = cnn_embeddings
    elif embeddings_type == 'BoVW' :
        embeddings = bovw_embeddings
    else :
        embeddings = naive_embeddings

    if metric_type == 'Euclidean' :
        metric = euclidean_similarity_score
    else :
        metric = cosine_similarity_score
    
    scores = similarity_all(query_image_name, embeddings, metric)
    top = sorted(scores, key=scores.get)[:11]

    return io.imread('smallholidays/'+top[0]), [io.imread('smallholidays/'+img) for img in top[1:]]

input_button = gr.inputs.Dropdown(query_images, label='Choice of the query image')
embeddings_selection = gr.inputs.Radio(['MobileNetV2', 'BoVW', 'Baseline'], label='Embeddings')
metric_selection = gr.inputs.Radio(['Euclidean', 'Cosine'], label='Similarity Metric')
retrieved_images = gr.outputs.Carousel(["image"], label='Retrieved images')

description = "This is a demo of the content-based image retrieval system developed as part of the IR course project, 2022. The indexed dataset is [INRIA Holidays](https://lear.inrialpes.fr/~jegou/data.php). \n\nSeveral image embeddings can be used :\n \n-**MobileNetV2** : feature extraction is performed using a MobileNet architecture trained on ImageNet.\n\n-**BoVW (Bag of Visual Words)** : embedding is the BoVW histogram using color histogram as a descriptor.\n\n-**Baseline** : basic descriptor that uses pixel values of the downsized images."

iface = gr.Interface(fn=retrieve, 
                     inputs=[input_button, embeddings_selection, metric_selection], 
                     outputs=[gr.outputs.Image(label='Query image'), retrieved_images],
                     title='Image Retrieval on INRIA Holidays',
                     description=description)
                     
iface.launch()