ImageRetrieval / app.py
leopoldmaillard's picture
init app
17c4cbe
raw
history blame
3.08 kB
from cProfile import label
from turtle import title
import numpy as np
import gradio as gr
import pickle
from skimage import io
from scipy.spatial import distance
# all the images name in a list
images = [line.strip() for line in open("holidays_images.dat","r")]
# all the query image names in a list
query_images = []
for line in open("holidays_images.dat","r"):
imname=line.strip()
imno=int(imname[:-len(".jpg")])
if imno%100==0:
query_images.append(imname)
with open('saved_cnn.pkl', 'rb') as f:
cnn_embeddings = pickle.load(f)
with open('saved_bovw.pkl', 'rb') as f:
bovw_embeddings = pickle.load(f)
with open('saved_naive.pkl', 'rb') as f:
naive_embeddings = pickle.load(f)
def similarity_all(query_image_name, embeddings, metric):
querry_embedding = embeddings[query_image_name]
scores = {image_name : metric(querry_embedding, embeddings[image_name]) for image_name in images}
return scores
def l1_similarity_score(query_embedding, target_embedding):
return np.linalg.norm(query_embedding-target_embedding)
def cosine_similarity_score(query_embedding, target_embedding):
return distance.cosine(np.reshape(query_embedding, -1), np.reshape(target_embedding, -1))
def retrieve(query_image_name, embeddings_type, metric_type):
if embeddings_type == 'MobileNetV2' :
embeddings = cnn_embeddings
elif embeddings_type == 'BoVW' :
embeddings = bovw_embeddings
else :
embeddings = naive_embeddings
if metric_type == 'L1 Norm' :
metric = l1_similarity_score
else :
metric = cosine_similarity_score
scores = similarity_all(query_image_name, embeddings, metric)
top = sorted(scores, key=scores.get)[:11]
return io.imread('smallholidays/'+top[0]), [io.imread('smallholidays/'+img) for img in top[1:]]
input_button = gr.inputs.Dropdown(query_images, label='Choice of the query image')
embeddings_selection = gr.inputs.Radio(['MobileNetV2', 'BoVW', 'Baseline'], label='Embeddings to use')
metric_selection = gr.inputs.Radio(['L1 Norm', 'Cosine'], label='Similarity Metric')
retrieved_images = gr.outputs.Carousel(["image"]*10, label='Ranked retrieved images')
description = "This is a demo of the content-based image retrieval system developed as part of the IR course project, 2022.\n \nThe indexed dataset is [INRIA Holidays](https://lear.inrialpes.fr/~jegou/data.php). \n\nSeveral image embeddings can be used :\n \n-**MobileNetV2** : feature extraction is performed using a MobileNet architecture trained on ImageNet.\n\n-**BoVW (Bag of Visual Words)** : embedding is the BoVW histogram using color histogram as a descriptor.\n\n-**Baseline** : basic descriptor that uses pixel values of the downsized images."
iface = gr.Interface(fn=retrieve,
inputs=[input_button, embeddings_selection, metric_selection],
outputs=[gr.outputs.Image(label='Query Image'), retrieved_images],
title='Image Retrieval on INRIA Holidays',
article=description)
iface.launch()