Vela commited on
Commit
0930d33
·
1 Parent(s): 13f7670

create a project using streamlit,clip model,pinecone db

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ config.yaml
2
+ .venv
3
+ logs
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ requests
3
+ Pillow
4
+ pandas
5
+ fastapi[standard]
6
+ pinecone
7
+ transformers
src/app/__pycache__/app.cpython-313.pyc ADDED
Binary file (772 Bytes). View file
 
src/app/__pycache__/homepage.cpython-313.pyc ADDED
Binary file (4.2 kB). View file
 
src/app/app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import homepage
2
+
3
+ search_option = ['Select an option','Search by text', 'Search by image']
4
+
5
+ homepage.setup_page()
6
+
7
+ choosen_option = homepage.get_user_selection(search_option)
8
+ if choosen_option.lower() == 'search by text':
9
+ user_query = homepage.get_search_text_input()
10
+ if user_query:
11
+ homepage.get_images_by_text(user_query)
12
+ elif choosen_option.lower() == 'search by image':
13
+ image_input = homepage.get_search_image_input()
14
+ if image_input:
15
+ homepage.get_images_by_image(image_input)
16
+
src/app/homepage.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
4
+ sys.path.append(src_directory)
5
+ import streamlit as st
6
+ from utils import logger
7
+ from database_pinecone import querry_database
8
+ from model.clip_model import ClipModel
9
+
10
+ clip_model = ClipModel()
11
+ logger = logger.get_logger()
12
+
13
+ PAGE_TITLE = "Look A Like - Image Finder"
14
+ PAGE_LAYOUT = "centered"
15
+ SIDEBAR_TITLE = "Find Similar Images"
16
+
17
+ def setup_page():
18
+ if 'is_page_configured' not in st.session_state:
19
+ st.set_page_config(page_title=PAGE_TITLE, layout=PAGE_LAYOUT)
20
+ st.title(PAGE_TITLE)
21
+ st.sidebar.title(SIDEBAR_TITLE)
22
+ logger.info(f"Page configured with title '{PAGE_TITLE}', layout '{PAGE_LAYOUT}', and sidebar title '{SIDEBAR_TITLE}'")
23
+ st.session_state.is_page_configured = True
24
+ else:
25
+ logger.info("Page configuration already completed. Skipping setup.")
26
+
27
+ def get_user_selection(options):
28
+ selected_option = st.sidebar.selectbox("Select the option", options)
29
+ return selected_option
30
+
31
+ def get_search_image_input():
32
+ uploaded_image = st.sidebar.file_uploader("Upload the image to get similar images", type=['png', 'jpeg'])
33
+ return uploaded_image
34
+
35
+ def get_search_text_input():
36
+ user_search = st.sidebar.text_input("Enter the text to search")
37
+ return user_search
38
+
39
+ def display_images(response):
40
+ if response:
41
+ cols = st.columns(2)
42
+ for i, result in enumerate(response.matches):
43
+ with cols[i % 2]:
44
+ st.image(result.metadata["url"])
45
+
46
+ def write_message(message):
47
+ st.write(message)
48
+
49
+ def get_images_by_text(query):
50
+ embedding = clip_model.get_text_embedding(query)
51
+ response = querry_database.fetch_data(embedding)
52
+ message = f"Showing search results for {query}"
53
+ write_message(message)
54
+ images = display_images(response)
55
+
56
+ def get_images_by_image(query):
57
+ embedding = clip_model.get_uploaded_image_embedding(query)
58
+ response = querry_database.fetch_data(embedding)
59
+ message = f"Showing search results of relevant images"
60
+ write_message(message)
61
+ images = display_images(response)
src/config/__pycache__/config.cpython-313.pyc ADDED
Binary file (1.09 kB). View file
 
src/config/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from utils import logger
3
+
4
+ logger = logger.get_logger()
5
+
6
+ def load_config():
7
+ try:
8
+ with open('config.yaml', 'r') as file:
9
+ config_data = yaml.load(file, Loader=yaml.FullLoader)
10
+ logger.info("Successfully loaded the config.")
11
+ return config_data
12
+ except Exception as e:
13
+ logger.error(f"Unexpected error occurred while loading the config: {e}")
14
+ raise Exception(f"Error loading configuration: {e}")
src/data/__pycache__/data_set.cpython-313.pyc ADDED
Binary file (1.38 kB). View file
 
src/data/__pycache__/images.cpython-313.pyc ADDED
Binary file (1.13 kB). View file
 
src/data/__pycache__/request_images.cpython-313.pyc ADDED
Binary file (1.6 kB). View file
 
src/data/data_set.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from utils import logger
3
+
4
+ logger = logger.get_logger()
5
+
6
+ file_name = 'src/data/image_dataset.csv'
7
+
8
+ tsv_file = 'src/data/photos.tsv000'
9
+
10
+ def convert_tsc_to_csv(tsv_file):
11
+ df = pd.read_csv(tsv_file, sep='\t', header=0)
12
+ dataset = df.to_csv(file_name)
13
+ return dataset
14
+
15
+ def get_df(start_index,end_index):
16
+ try:
17
+ logger.info("Loading the dataframe")
18
+ image_df = pd.read_csv(file_name)
19
+ final_df = image_df[['photo_id','photo_image_url']]
20
+ df = final_df[start_index:end_index]
21
+ logger.info("Successfully loaded the data frame")
22
+ return df
23
+ except Exception as e:
24
+ logger.error(f"Unable to load the dataframe {e}")
25
+ raise
26
+
src/data/image_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/data/photos.tsv000 ADDED
The diff for this file is too large to render. See raw diff
 
src/data/request_images.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from utils import logger
4
+
5
+ logger = logger.get_logger()
6
+
7
+ def get_image_url(url):
8
+ try:
9
+ logger.info("Loading image from url to embed")
10
+ res = requests.get(url,stream = True).raw
11
+ img = Image.open(res)
12
+ logger.info("Loaded the image to embed successfully")
13
+ return img
14
+ except Exception as e:
15
+ logger.error(f"Unable to load the image to embed {e}")
16
+
17
+ def convert_image_to_embedding_format(query_image):
18
+ try:
19
+ logger.info("Loading the image to embed")
20
+ image = Image.open(query_image)
21
+ logger.info("Loaded the image to embed successfully")
22
+ return image
23
+ except Exception as e:
24
+ logger.error(f"Unable to load the image to embed {e}")
src/database_pinecone/__pycache__/create_database.cpython-313.pyc ADDED
Binary file (4.87 kB). View file
 
src/database_pinecone/__pycache__/querry_database.cpython-313.pyc ADDED
Binary file (1.31 kB). View file
 
src/database_pinecone/create_database.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
4
+ sys.path.append(src_directory)
5
+ from pinecone import Pinecone, ServerlessSpec
6
+ import time
7
+ from model.clip_model import ClipModel
8
+ from data import request_images
9
+ from data import data_set
10
+ from config import config
11
+ from utils import logger
12
+
13
+ config = config.load_config()
14
+ logger = logger.get_logger()
15
+
16
+ clip_model = ClipModel()
17
+
18
+ def create_index(pinecone, index_name):
19
+ pinecone.create_index(
20
+ name=index_name,
21
+ dimension=512,
22
+ metric="cosine",
23
+ spec=ServerlessSpec(
24
+ cloud="aws",
25
+ region="us-east-1"
26
+ )
27
+ )
28
+
29
+ def wait_till_index_loaded(pinecone, index_name):
30
+ while True:
31
+ index = pinecone.describe_index(index_name)
32
+ if index.status.get("ready", False):
33
+ index = pinecone.Index(index_name)
34
+ logger.info(f"Index '{index_name}' is ready and is now accessible.")
35
+ return index
36
+ else:
37
+ logger.debug(f"Index '{index_name}' is not ready yet. Checking again in 1 second.")
38
+ time.sleep(1)
39
+
40
+ def get_index():
41
+ try:
42
+ pincone_api_key = config['pinecone_db']['pincone_api_key']
43
+ pc = Pinecone(api_key=pincone_api_key)
44
+ index = None
45
+ index_name = "imagesearch"
46
+ logger.info(f"Checking if the index '{index_name}' exists...")
47
+ if not pc.has_index(index_name):
48
+ logger.info(f"Index '{index_name}' does not exist. Creating a new index...")
49
+ create_index(pc,index_name)
50
+ logger.info(f"Index '{index_name}' creation initiated. Waiting for it to be ready...")
51
+ index = wait_till_index_loaded(index_name,pc)
52
+ else:
53
+ index = pc.Index(index_name)
54
+ logger.info(f"Index '{index_name}' already exists. Returning the existing index.")
55
+ return index
56
+ except Exception as e:
57
+ logger.info(f"Error occurred while getting or creating the Pinecone index: {str(e)}", exc_info=True)
58
+ return index
59
+
60
+ def upsert_data(index,embeddings,id,url):
61
+ try :
62
+ logger.info("Started to upsert the data")
63
+ index.upsert(
64
+ vectors=[{
65
+ "id": id,
66
+ "values": embeddings,
67
+ "metadata": {
68
+ "url": url,
69
+ "photo_id": id
70
+ }
71
+ }],
72
+ namespace="image-search-dataset",
73
+ )
74
+ logger.info(f"Successfully upserted the data in database")
75
+ except Exception as e:
76
+ logger.info(f"Unable to upsert the data {e}")
77
+ raise
78
+
79
+ def add_data_to_database(df):
80
+ try:
81
+ index = get_index()
82
+ logger.info("Starting to add the embeddings to the database")
83
+ for _, data in df.iterrows():
84
+ url = data['photo_image_url']
85
+ id = data['photo_id']
86
+ embeddings = clip_model.get_image_embedding(url)
87
+ upsert_data(index,embeddings,id,url)
88
+ logger.info("Added embeddings to the database successfully")
89
+ except Exception as e:
90
+ logger.info("Unable to add the data. Error : {e}")
91
+
92
+
93
+ # df = data_set.get_df(8000,8500)
94
+ # add_data_to_database(df)
src/database_pinecone/querry_database.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
4
+ sys.path.append(src_directory)
5
+ from utils import logger
6
+ import streamlit as st
7
+ from model.clip_model import ClipModel
8
+ from database_pinecone import create_database
9
+
10
+ clip_model = ClipModel()
11
+ logger = logger.get_logger()
12
+
13
+ index = create_database.get_index()
14
+ namespace = 'image-search-dataset'
15
+
16
+ def fetch_data(embedding):
17
+ try:
18
+ response = index.query(
19
+ top_k=10,
20
+ vector=embedding,
21
+ namespace=namespace,
22
+ include_metadata=True)
23
+ return response
24
+ except Exception as e:
25
+ raise
src/model/__pycache__/clip_model.cpython-313.pyc ADDED
Binary file (5.91 kB). View file
 
src/model/clip_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import AutoProcessor, CLIPModel, AutoTokenizer
4
+ src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
5
+ sys.path.append(src_directory)
6
+ from data import request_images
7
+ from utils import logger
8
+
9
+ logger = logger.get_logger()
10
+
11
+ class ClipModel:
12
+ _models = {}
13
+
14
+ def __init__(self, model_name: str = "openai/clip-vit-base-patch32", tokenizer_name: str = "openai/clip-vit-large-patch14"):
15
+ self.model_name = model_name
16
+ self.tokenizer_name = tokenizer_name
17
+
18
+ if model_name not in ClipModel._models:
19
+ ClipModel._models[model_name] = self.load_models()
20
+
21
+ def load_models(self):
22
+ try:
23
+ logger.info(f"Loading the models: {self.model_name}")
24
+ model = CLIPModel.from_pretrained(self.model_name)
25
+ processor = AutoProcessor.from_pretrained(self.model_name)
26
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
27
+ return {
28
+ 'model': model,
29
+ 'processor': processor,
30
+ 'tokenizer': tokenizer
31
+ }
32
+ except Exception as e:
33
+ logger.error(f"Unable to load the model {e}")
34
+ raise
35
+
36
+ def get_text_embedding(self, text: str):
37
+ try:
38
+ logger.info(f"Getting embedding for the text: {text}")
39
+ inputs = self._models[self.model_name]['tokenizer']([text], padding=True, return_tensors="pt")
40
+ text_features = self._models[self.model_name]['model'].get_text_features(**inputs)
41
+ text_embedding = text_features.detach().numpy().flatten().tolist()
42
+ logger.info("Text embedding successfully retrieved.")
43
+ return text_embedding
44
+ except Exception as e:
45
+ logger.error(f"Error while getting embedding for text: {e}")
46
+ raise
47
+
48
+ def get_image_embedding(self, image):
49
+ try:
50
+ logger.info(f"Getting embedding for the image")
51
+ image = request_images.get_image_url(image)
52
+ inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt")
53
+ image_features = self._models[self.model_name]['model'].get_image_features(**inputs)
54
+ embeddings = image_features.detach().cpu().numpy().flatten().tolist()
55
+ logger.info("Image embedding successfully retrieved.")
56
+ return embeddings
57
+ except Exception as e:
58
+ logger.error(f"Error while getting embedding for image: {e}")
59
+ raise
60
+
61
+ def get_uploaded_image_embedding(self, image):
62
+ try:
63
+ logger.info(f"Getting embedding for the image")
64
+ image = request_images.convert_image_to_embedding_format(image)
65
+ inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt")
66
+ image_features = self._models[self.model_name]['model'].get_image_features(**inputs)
67
+ embeddings = image_features.detach().cpu().numpy().flatten().tolist()
68
+ logger.info("Image embedding successfully retrieved.")
69
+ return embeddings
70
+ except Exception as e:
71
+ logger.error(f"Error while getting embedding for image: {e}")
72
+ raise
73
+
74
+ if __name__ == "__main__":
75
+ try:
76
+ logger.info("Starting the initialization of the ClipModel class...")
77
+ clip_model = ClipModel()
78
+ logger.info("ClipModel class initialized successfully.")
79
+ except Exception as e:
80
+ logger.error(f"Error during ClipModel initialization: {str(e)}")
src/utils/__pycache__/logger.cpython-313.pyc ADDED
Binary file (1.8 kB). View file
 
src/utils/logger.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging.handlers import RotatingFileHandler
3
+ import os
4
+
5
+ log_file = 'image_search.log'
6
+ log_dir = 'src/logs'
7
+ log_level=logging.INFO
8
+
9
+ def get_logger( ):
10
+
11
+ if not os.path.exists(log_dir):
12
+ os.makedirs(log_dir)
13
+
14
+ log_file_path = os.path.join(log_dir, log_file)
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ if not logger.hasHandlers():
19
+ logger.setLevel(log_level)
20
+
21
+ console_handler = logging.StreamHandler()
22
+ console_handler.setLevel(logging.DEBUG)
23
+
24
+ file_handler = RotatingFileHandler(log_file_path, maxBytes=5*1024*1024, backupCount=3)
25
+ file_handler.setLevel(logging.INFO)
26
+
27
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
28
+ formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M')
29
+ console_handler.setFormatter(formatter)
30
+ file_handler.setFormatter(formatter)
31
+
32
+ logger.addHandler(console_handler)
33
+ logger.addHandler(file_handler)
34
+
35
+ return logger