ksvmuralidhar commited on
Commit
4b1ee17
·
verified ·
1 Parent(s): e7260b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from matplotlib import rcParams
4
+ import matplotlib.pyplot as plt
5
+ from tensorflow.keras.models import load_model, Model
6
+ from tensorflow.keras.utils import load_img, save_img, img_to_array
7
+ from tensorflow.keras.applications.vgg19 import preprocess_input
8
+ from tensorflow.keras.layers import GlobalAveragePooling2D
9
+ from pymilvus import connections, Collection, utility
10
+ from requests import get
11
+ from shutil import rmtree
12
+ import streamlit as st
13
+ import zipfile
14
+
15
+ # unzip vegetable images
16
+ with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref:
17
+ zip_ref.extractall('.')
18
+
19
+
20
+ class ImageVectorizer:
21
+ '''
22
+ Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
23
+ '''
24
+
25
+ def __init__(self):
26
+ self.__model = self.get_model()
27
+
28
+ @staticmethod
29
+ def get_model():
30
+ model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
31
+ top = model.get_layer('block5_pool').output
32
+ top = GlobalAveragePooling2D()(top)
33
+ model = Model(inputs=model.input, outputs=top)
34
+ return model
35
+
36
+ def vectorize(self, img_path: str):
37
+ model = self.__model
38
+ test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
39
+ test_image = img_to_array(test_image)
40
+ test_image = preprocess_input(test_image)
41
+ test_image = np.array([test_image])
42
+ return model(test_image).numpy()[0]
43
+
44
+
45
+ def get_milvus_collection():
46
+ uri = os.environ.get("URI")
47
+ token = os.environ.get("TOKEN")
48
+ connections.connect("default", uri=uri, token=token)
49
+ print(f"Connected to DB")
50
+ collection_name = os.environ.get("COLLECTION_NAME")
51
+ collection = Collection(name=collection_name)
52
+ collection.load()
53
+ return collection
54
+
55
+
56
+ def plot_images(input_image_path: str, similar_img_paths: list):
57
+ # plotting similar images
58
+ rows = 5 # rows in subplots
59
+ cols = 3 # columns in subplots
60
+ fig, ax = plt.subplots(rows, cols, figsize=(12, 20))
61
+ r = 0
62
+ c = 0
63
+ for i in range(rows*cols):
64
+ sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224))
65
+ ax[r,c].axis("off")
66
+ ax[r,c].imshow(sim_image)
67
+ c += 1
68
+ if c == cols:
69
+ c = 0
70
+ r += 1
71
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
72
+
73
+ # display input image
74
+ rcParams.update({'figure.autolayout': True})
75
+ input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224))
76
+ st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True)
77
+ st.image(input_image)
78
+
79
+ st.write(' \n')
80
+
81
+ # display similar images
82
+ st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True)
83
+ st.pyplot(fig)
84
+
85
+
86
+ def find_similar_images(img_path: str, top_n: int=15):
87
+ search_params = {"metric_type": "L2"}
88
+ search_vec = vectorizer.vectorize(img_path)
89
+ result = collection.search([search_vec],
90
+ anns_field='image_vector', # annotation field specified in the schema definition
91
+ param=search_params,
92
+ limit=top_n,
93
+ guarantee_timestamp=1,
94
+ output_fields=['image_path']) # which fields to return in output
95
+
96
+ output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]}
97
+ plot_images(output_dict['input_image_path'], output_dict['similar_image_paths'])
98
+
99
+
100
+ def delete_file(path_: str):
101
+ if os.path.exists(path_):
102
+ rmtree(path=path_, ignore_errors=True)
103
+
104
+
105
+ def process_input_image(img_url):
106
+ upload_file_path = os.path.join('.', 'uploads')
107
+ os.makedirs(upload_file_path, exist_ok=True)
108
+ upload_filename = "input.jpg"
109
+ upload_file_path = os.path.join(upload_file_path, upload_filename)
110
+ headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
111
+ r = get(img_url, headers=headers)
112
+ with open(upload_file_path, "wb") as file:
113
+ file.write(r.content)
114
+ return upload_file_path
115
+
116
+
117
+ vectorizer = ImageVectorizer()
118
+ collection = get_milvus_collection()
119
+
120
+
121
+ def main():
122
+ try:
123
+ st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True)
124
+ desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd,
125
+ Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber,
126
+ Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
127
+ </p>
128
+ <p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera.
129
+ Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p>
130
+ '''
131
+ st.markdown(desc, unsafe_allow_html=True)
132
+ img_url = st.text_input("Paste the image URL of a vegetable:", "")
133
+ if img_url:
134
+ img_path = process_input_image(img_url)
135
+ find_similar_images(img_path, 15)
136
+ delete_file(os.path.dirname(img_path))
137
+ except Exception as e:
138
+ st.error(f'An unexpected error occured: \n{e}')
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()