Prathamesh1420 commited on
Commit
66c277d
·
verified ·
1 Parent(s): 60bee17

Upload 11 files

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -34
  2. README.md +46 -12
  3. app.py +54 -0
  4. chatbot.py +123 -0
  5. embeddings.pkl +3 -0
  6. embeddings_cache.pkl +3 -0
  7. filenames.pkl +3 -0
  8. main.py +80 -0
  9. main_1.py +228 -0
  10. requirements.txt +13 -0
  11. test.py +48 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
+ *.psd filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,46 @@
1
- ---
2
- title: Recommendation System
3
- emoji: ⚡
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.36.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gen-AI-based-fashion-product-recommendation-system
2
+ Fashion Recommender System with Chatbot Integration This project implements a Fashion Recommender System using state-of-the-art machine learning models and a chatbot interface. The system recommends fashion products based on uploaded images and responds to user queries through a chat interface.
3
+
4
+ Technologies Used:
5
+
6
+ Streamlit: For building the user interface and dashboard.
7
+
8
+ TensorFlow: Utilized for image feature extraction and recommendation.
9
+
10
+ Sentence Transformers: Used for text embeddings and similarity calculations.
11
+
12
+ Hugging Face Transformers: Integrated for natural language processing tasks.
13
+
14
+ MySQL Connector: For user authentication and database operations.
15
+
16
+ PIL (Python Imaging Library): Used for image processing tasks.
17
+
18
+ scikit-learn: Employed for nearest neighbor search for recommendation.
19
+
20
+ PyTorch: Utilized for text and image embeddings.
21
+
22
+ Matplotlib: Used for displaying images and visualizations.
23
+
24
+ Key Features:
25
+
26
+ Login and Registration: Users can authenticate themselves before accessing the system.
27
+
28
+ Image Upload and Recommendation: Users can upload images of fashion products, and the system provides recommendations based on image similarity.
29
+
30
+ Chatbot Integration: Includes a chatbot interface that responds to user queries and provides additional product information.
31
+
32
+ Dashboard: Provides a user-friendly dashboard for navigation and interaction with the system.
33
+
34
+ Usage:
35
+
36
+ Clone the repository.
37
+
38
+ Install the required dependencies using pip install -r requirements.txt.
39
+
40
+ Run the application using streamlit run main.py.
41
+
42
+ Access the application through the provided URL.
43
+
44
+ Note:
45
+
46
+ This project is a prototype and may require additional enhancements for production use. Contributions and feedback are welcome!
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow
2
+ from tensorflow.keras.preprocessing import image
3
+ from tensorflow.keras.layers import GlobalMaxPooling2D
4
+ from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
5
+ import numpy as np
6
+ from numpy.linalg import norm
7
+ import os
8
+ from tqdm import tqdm
9
+ import pickle
10
+
11
+ # Load pre-trained ResNet50 model
12
+
13
+ model = ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
14
+ model.trainable = False
15
+
16
+ # Create a Sequential model with ResNet50 as base and GlobalMaxPooling2D layer
17
+
18
+ model = tensorflow.keras.Sequential([
19
+ model,
20
+ GlobalMaxPooling2D()
21
+ ])
22
+
23
+ #print(model.summary())
24
+
25
+ # Function to extract features from an image using the model
26
+ def extract_features(img_path,model):
27
+ img = image.load_img(img_path,target_size=(224,224))
28
+ img_array = image.img_to_array(img)
29
+ expanded_img_array = np.expand_dims(img_array, axis=0)
30
+ preprocessed_img = preprocess_input(expanded_img_array)
31
+
32
+ # Extract features using the model and normalize the result
33
+
34
+ result = model.predict(preprocessed_img).flatten()
35
+ normalized_result = result / norm(result)
36
+
37
+ return normalized_result
38
+
39
+ # List all filenames in the 'images' directory
40
+
41
+ filenames = []
42
+
43
+ for file in os.listdir('images'):
44
+ filenames.append(os.path.join('images',file))
45
+
46
+ feature_list = [] # List to store extracted features
47
+
48
+ # Loop through each image file, extract features, and append to the feature_list
49
+
50
+ for file in tqdm(filenames):
51
+ feature_list.append(extract_features(file,model))
52
+ # Save the extracted features and filenames to pickle files
53
+ pickle.dump(feature_list,open('embeddings.pkl','wb')) # Save features
54
+ pickle.dump(filenames,open('filenames.pkl','wb')) # save filenames
chatbot.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import pickle
5
+ import matplotlib.pyplot as plt
6
+ from langchain_community.document_loaders import TextLoader
7
+ from datasets import load_dataset
8
+ from sentence_transformers import SentenceTransformer, util
9
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
10
+ from transformers import BertModel, BertTokenizer
11
+ from langchain_core.prompts import PromptTemplate
12
+
13
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
14
+
15
+ class Chatbot:
16
+ def __init__(self):
17
+ self.load_data()
18
+ self.load_models()
19
+ self.load_embeddings()
20
+ self.load_template()
21
+
22
+ def load_data(self):
23
+ self.data = load_dataset("ashraq/fashion-product-images-small", split="train")
24
+ self.images = self.data["image"]
25
+ self.product_frame = self.data.remove_columns("image").to_pandas()
26
+ self.product_data = self.product_frame.reset_index(drop=True).to_dict(orient='index')
27
+
28
+ def load_template(self):
29
+ self.template = """
30
+ You are a fashion shopping assistant that wants to convert customers based on the information given.
31
+ Describe season and usage given in the context in your interaction with the customer.
32
+ Use a bullet list when describing each product.
33
+ If user ask general question then answer them accordingly, the question may be like when the store will open, where is your store located.
34
+ Context: {context}
35
+ User question: {question}
36
+ Your response: {response}
37
+ """
38
+ self.prompt = PromptTemplate.from_template(self.template)
39
+
40
+ def load_models(self):
41
+ self.model = SentenceTransformer('clip-ViT-B-32')
42
+ self.bert_model_name = "bert-base-uncased"
43
+ self.bert_model = BertModel.from_pretrained(self.bert_model_name)
44
+ self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
45
+ self.gpt2_model_name = "gpt2"
46
+ self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
47
+ self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
48
+
49
+ def load_embeddings(self):
50
+ if os.path.exists("embeddings_cache.pkl"):
51
+ with open("embeddings_cache.pkl", "rb") as f:
52
+ embeddings_cache = pickle.load(f)
53
+ self.image_embeddings = embeddings_cache["image_embeddings"]
54
+ self.text_embeddings = embeddings_cache["text_embeddings"]
55
+ else:
56
+ self.image_embeddings = self.model.encode([image for image in self.images])
57
+ self.text_embeddings = self.model.encode(self.product_frame['productDisplayName'])
58
+ embeddings_cache = {"image_embeddings": self.image_embeddings, "text_embeddings": self.text_embeddings}
59
+ with open("embeddings_cache.pkl", "wb") as f:
60
+ pickle.dump(embeddings_cache, f)
61
+
62
+ def create_docs(self, results):
63
+ docs = []
64
+ for result in results:
65
+ pid = result['corpus_id']
66
+ score = result['score']
67
+ result_string = ''
68
+ result_string += "Product Name:" + self.product_data[pid]['productDisplayName'] + \
69
+ ';' + "Category:" + self.product_data[pid]['masterCategory'] + \
70
+ ';' + "Article Type:" + self.product_data[pid]['articleType'] + \
71
+ ';' + "Usage:" + self.product_data[pid]['usage'] + \
72
+ ';' + "Season:" + self.product_data[pid]['season'] + \
73
+ ';' + "Gender:" + self.product_data[pid]['gender']
74
+ # Assuming text is imported from somewhere else
75
+ doc = text(page_content=result_string)
76
+ doc.metadata['pid'] = str(pid)
77
+ doc.metadata['score'] = score
78
+ docs.append(doc)
79
+ return docs
80
+
81
+ def get_results(self, query, embeddings, top_k=10):
82
+ query_embedding = self.model.encode([query])
83
+ cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
84
+ top_results = torch.topk(cos_scores, k=top_k)
85
+ indices = top_results.indices.tolist()
86
+ scores = top_results.values.tolist()
87
+ results = [{'corpus_id': idx, 'score': score} for idx, score in zip(indices, scores)]
88
+ return results
89
+
90
+ def display_text_and_images(self, results_text):
91
+ for result in results_text:
92
+ pid = result['corpus_id']
93
+ product_info = self.product_data[pid]
94
+ print("Product Name:", product_info['productDisplayName'])
95
+ print("Category:", product_info['masterCategory'])
96
+ print("Article Type:", product_info['articleType'])
97
+ print("Usage:", product_info['usage'])
98
+ print("Season:", product_info['season'])
99
+ print("Gender:", product_info['gender'])
100
+ print("Score:", result['score'])
101
+ plt.imshow(self.images[pid])
102
+ plt.axis('off')
103
+ plt.show()
104
+
105
+ @staticmethod
106
+ def cos_sim(a, b):
107
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
108
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
109
+ return torch.mm(a_norm.T, b_norm) # Reshape a_norm to (768, 1)
110
+
111
+ def generate_response(self, query):
112
+ # Process the user query and generate a response
113
+ results_text = self.get_results(query, self.text_embeddings)
114
+
115
+ # Generate chatbot response
116
+ chatbot_response = "This is a placeholder response from the chatbot." # Placeholder, replace with actual response
117
+
118
+ # Display recommended products
119
+ self.display_text_and_images(results_text)
120
+
121
+ # Return both chatbot response and recommended products
122
+ return chatbot_response,results_text
123
+
embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88399617409c087450736ddfa9c64f99365cce80055cdcd2075f8dac76fe2a3e
3
+ size 134
embeddings_cache.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f3c0e69a8744b7f05c185f17a4d546e41e94982c2d72b9bee915fe174a2ecd1
3
+ size 134
filenames.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d0dfe2467c266620c47ca39138c3d4ad1c3d68e365ab007f44cba5d083b9fde
3
+ size 131
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import altair as alt
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import pickle
7
+ import tensorflow
8
+ from tensorflow.keras.preprocessing import image
9
+ from tensorflow.keras.layers import GlobalMaxPooling2D
10
+ from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
11
+ from sklearn.neighbors import NearestNeighbors
12
+ from numpy.linalg import norm
13
+
14
+ feature_list = np.array(pickle.load(open('embeddings.pkl','rb')))
15
+ filenames = pickle.load(open('filenames.pkl','rb'))
16
+
17
+ model = ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
18
+ model.trainable = False
19
+
20
+ model = tensorflow.keras.Sequential([
21
+ model,
22
+ GlobalMaxPooling2D()
23
+ ])
24
+
25
+ st.title('Fashion Recommender System')
26
+
27
+ def save_uploaded_file(uploaded_file):
28
+ try:
29
+ with open(os.path.join('uploads',uploaded_file.name),'wb') as f:
30
+ f.write(uploaded_file.getbuffer())
31
+ return 1
32
+ except:
33
+ return 0
34
+
35
+ def feature_extraction(img_path,model):
36
+ img = image.load_img(img_path, target_size=(224, 224))
37
+ img_array = image.img_to_array(img)
38
+ expanded_img_array = np.expand_dims(img_array, axis=0)
39
+ preprocessed_img = preprocess_input(expanded_img_array)
40
+ result = model.predict(preprocessed_img).flatten()
41
+ normalized_result = result / norm(result)
42
+
43
+ return normalized_result
44
+
45
+ def recommend(features,feature_list):
46
+ neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
47
+ neighbors.fit(feature_list)
48
+
49
+ distances, indices = neighbors.kneighbors([features])
50
+
51
+ return indices
52
+
53
+ # steps
54
+ # file upload -> save
55
+ uploaded_file = st.file_uploader("Choose an image")
56
+ if uploaded_file is not None:
57
+ if save_uploaded_file(uploaded_file):
58
+ # display the file
59
+ display_image = Image.open(uploaded_file)
60
+ st.image(display_image)
61
+ # feature extract
62
+ features = feature_extraction(os.path.join("uploads",uploaded_file.name),model)
63
+ #st.text(features)
64
+ # recommendention
65
+ indices = recommend(features,feature_list)
66
+ # show
67
+ col1,col2,col3,col4,col5 = st.beta_columns(5)
68
+
69
+ with col1:
70
+ st.image(filenames[indices[0][0]])
71
+ with col2:
72
+ st.image(filenames[indices[0][1]])
73
+ with col3:
74
+ st.image(filenames[indices[0][2]])
75
+ with col4:
76
+ st.image(filenames[indices[0][3]])
77
+ with col5:
78
+ st.image(filenames[indices[0][4]])
79
+ else:
80
+ st.header("Some error occured in file upload")
main_1.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import mysql.connector
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import pickle
7
+ import tensorflow
8
+ from tensorflow.keras.preprocessing import image
9
+ from tensorflow.keras.layers import GlobalMaxPooling2D
10
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
11
+ from sklearn.neighbors import NearestNeighbors
12
+ from numpy.linalg import norm
13
+ from chatbot import Chatbot # Assuming you have a chatbot module
14
+
15
+ # Function to authenticate user credentials from the database
16
+ def authenticate_user(username, password):
17
+ try:
18
+ # Connect to the MySQL database
19
+ connection = mysql.connector.connect(
20
+ host="sql6.freemysqlhosting.net",
21
+ database="sql6697353",
22
+ user="sql6697353",
23
+ password="wvfSQJbmMs"
24
+ )
25
+
26
+ # Create a cursor object to execute SQL queries
27
+ cursor = connection.cursor()
28
+
29
+ # Query to check if the username and password match
30
+ query = "SELECT * FROM users WHERE username = %s AND password = %s"
31
+ cursor.execute(query, (username, password))
32
+ user = cursor.fetchone() # Fetch the first row
33
+
34
+ # If user exists, return True (authentication successful)
35
+ if user:
36
+ return True
37
+ else:
38
+ return False
39
+
40
+ except mysql.connector.Error as error:
41
+ st.error(f"Error: {error}")
42
+ return False
43
+
44
+ finally:
45
+ # Close the cursor and database connection
46
+ if 'connection' in locals() and connection.is_connected():
47
+ cursor.close()
48
+ connection.close()
49
+
50
+ # Define your login function
51
+ def login():
52
+ st.header("Login Page")
53
+ # Add your login code here
54
+ username = st.text_input("Username")
55
+ password = st.text_input("Password", type="password")
56
+ if st.button("Login"):
57
+ # Authenticate user with provided credentials
58
+ if authenticate_user(username, password):
59
+ st.session_state.logged_in = True
60
+ query_params = st.experimental_get_query_params()
61
+ query_params["page"] = ["Dashboard"]
62
+ st.experimental_set_query_params(**query_params)
63
+ # Redirect to dashboard
64
+ st.success("Login successful!")
65
+ else:
66
+ st.error("Invalid username or password")
67
+
68
+ # Define your registration function
69
+ def register():
70
+ st.header("Registration Page")
71
+ # Add your registration code here
72
+ username = st.text_input("Username")
73
+ password = st.text_input("Password", type="password")
74
+ email = st.text_input("Email")
75
+
76
+ if st.button("Register"):
77
+ # Validate input
78
+ if not username or not password or not email:
79
+ st.error("Please enter username, password, and email.")
80
+ else:
81
+ try:
82
+ # Connect to the MySQL database
83
+ connection = mysql.connector.connect(
84
+ host="sql6.freemysqlhosting.net",
85
+ database="sql6697353",
86
+ user="sql6697353",
87
+ password="wvfSQJbmMs"
88
+ )
89
+
90
+ # Create a cursor object to execute SQL queries
91
+ cursor = connection.cursor()
92
+
93
+ # Check if username or email already exists
94
+ cursor.execute("SELECT * FROM users WHERE username = %s OR email = %s", (username, email))
95
+ result = cursor.fetchone()
96
+ if result:
97
+ st.error("Username or email already exists. Please choose different ones.")
98
+ else:
99
+ # Insert new user into database
100
+ cursor.execute("INSERT INTO users (username, password, email) VALUES (%s, %s, %s)", (username, password, email))
101
+ connection.commit()
102
+ st.success("Registration successful. You can now log in.")
103
+ except mysql.connector.Error as error:
104
+ st.error(f"Error: {error}")
105
+ finally:
106
+ # Close the cursor and database connection
107
+ if 'connection' in locals() and connection.is_connected():
108
+ cursor.close()
109
+ connection.close()
110
+
111
+
112
+ # Define function for feature extraction
113
+ def feature_extraction(img_path, model):
114
+ img = image.load_img(img_path, target_size=(224, 224))
115
+ img_array = image.img_to_array(img)
116
+ expanded_img_array = np.expand_dims(img_array, axis=0)
117
+ preprocessed_img = preprocess_input(expanded_img_array)
118
+ result = model.predict(preprocessed_img).flatten()
119
+ normalized_result = result / norm(result)
120
+ return normalized_result
121
+
122
+ # Define function for recommendation
123
+ def recommend(features, feature_list):
124
+ neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
125
+ neighbors.fit(feature_list)
126
+ distances, indices = neighbors.kneighbors([features])
127
+ return indices
128
+
129
+ # Function to save uploaded file
130
+ def save_uploaded_file(uploaded_file):
131
+ try:
132
+ with open(os.path.join('uploads', uploaded_file.name), 'wb') as f:
133
+ f.write(uploaded_file.getbuffer())
134
+ return True
135
+ except:
136
+ return False
137
+
138
+ # Function to show dashboard content
139
+ def show_dashboard():
140
+ st.header("Fashion Recommender System")
141
+ chatbot = Chatbot()
142
+ # Load ResNet model for image feature extraction
143
+ model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
144
+ model.trainable = False
145
+ model = tensorflow.keras.Sequential([
146
+ model,
147
+ GlobalMaxPooling2D()
148
+ ])
149
+
150
+ feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
151
+ filenames = pickle.load(open('filenames.pkl', 'rb'))
152
+
153
+ # File upload section
154
+ uploaded_file = st.file_uploader("Choose an image")
155
+ if uploaded_file is not None:
156
+ if save_uploaded_file(uploaded_file):
157
+ # Display the uploaded image
158
+ display_image = Image.open(uploaded_file)
159
+ st.image(display_image)
160
+
161
+ # Feature extraction
162
+ features = feature_extraction(os.path.join("uploads", uploaded_file.name), model)
163
+
164
+ # Recommendation
165
+ indices = recommend(features, feature_list)
166
+
167
+ # Display recommended products
168
+ col1, col2, col3, col4, col5 = st.beta_columns(5)
169
+ with col1:
170
+ st.image(filenames[indices[0][0]])
171
+ with col2:
172
+ st.image(filenames[indices[0][1]])
173
+ with col3:
174
+ st.image(filenames[indices[0][2]])
175
+ with col4:
176
+ st.image(filenames[indices[0][3]])
177
+ with col5:
178
+ st.image(filenames[indices[0][4]])
179
+
180
+ else:
181
+ st.header("Some error occurred in file upload")
182
+
183
+ # Chatbot section
184
+ user_question = st.text_input("Ask a question:")
185
+ if user_question:
186
+ bot_response, recommended_products = chatbot.generate_response(user_question)
187
+ st.write("Chatbot:", bot_response)
188
+
189
+ # Display recommended products
190
+ for result in recommended_products:
191
+ pid = result['corpus_id']
192
+ product_info = chatbot.product_data[pid]
193
+ st.write("Product Name:", product_info['productDisplayName'])
194
+ st.write("Category:", product_info['masterCategory'])
195
+ st.write("Article Type:", product_info['articleType'])
196
+ st.write("Usage:", product_info['usage'])
197
+ st.write("Season:", product_info['season'])
198
+ st.write("Gender:", product_info['gender'])
199
+ st.image(chatbot.images[pid])
200
+
201
+ # Main Streamlit app
202
+ def main():
203
+ # Give title to the app
204
+ st.title("Fashion Recommender System")
205
+
206
+ # Sidebar navigation
207
+ page = st.sidebar.radio("Navigation", ["Login", "Register", "Dashboard"])
208
+
209
+ # Check if user is logged in
210
+ if "logged_in" not in st.session_state:
211
+ st.session_state.logged_in = False
212
+
213
+ # Show login page if user is not logged in
214
+ if not st.session_state.logged_in:
215
+ if page == "Login":
216
+ login()
217
+
218
+ # Show registration page if selected in the sidebar
219
+ if page == "Register":
220
+ register()
221
+
222
+ # Show dashboard if selected in the sidebar and user is logged in
223
+ if page == "Dashboard" and st.session_state.logged_in:
224
+ show_dashboard()
225
+
226
+ # Run the main app
227
+ if __name__ == "__main__":
228
+ main()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ matplotlib
4
+ langchain-core # Might need adjustment based on availability
5
+ datasets
6
+ sentence-transformers
7
+ transformers
8
+
9
+ # Additional libraries based on your code (if applicable)
10
+ mysql-connector-python # If using MySQL database
11
+ Pillow # If using PIL for image processing
12
+ scikit-learn # If using scikit-learn (used for sklearn in your code)
13
+ langchain_community
test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import tensorflow
3
+ import numpy as np
4
+ from numpy.linalg import norm
5
+ from tensorflow.keras.preprocessing import image
6
+ from tensorflow.keras.layers import GlobalMaxPooling2D
7
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
+ from sklearn.neighbors import NearestNeighbors
9
+ import cv2
10
+
11
+ # Load the precomputed feature vectors and filenames from pickle files
12
+ feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
13
+ filenames = pickle.load(open('filenames.pkl', 'rb'))
14
+
15
+ # Load ResNet50 model without the top layer for feature extraction
16
+ model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
17
+ model.trainable = False
18
+
19
+ # Create a Sequential model with ResNet50 and GlobalMaxPooling2D layers
20
+ model = tensorflow.keras.Sequential([
21
+ model,
22
+ GlobalMaxPooling2D()
23
+ ])
24
+
25
+ # Load and preprocess the query image
26
+ img = image.load_img('sample/khade.jpg', target_size=(224, 224))
27
+ img_array = image.img_to_array(img)
28
+ expanded_img_array = np.expand_dims(img_array, axis=0)
29
+ preprocessed_img = preprocess_input(expanded_img_array)
30
+
31
+ # Extract features from the query image and normalize
32
+ result = model.predict(preprocessed_img).flatten()
33
+ normalized_result = result / norm(result)
34
+
35
+ # Initialize NearestNeighbors model and fit with the feature vectors
36
+ neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
37
+ neighbors.fit(feature_list)
38
+
39
+ # Find the nearest neighbors (excluding itself)
40
+ distances, indices = neighbors.kneighbors([normalized_result])
41
+
42
+ print(indices) # Print the indices of nearest neighbors
43
+
44
+ # Display the nearest neighbor images
45
+ for file in indices[0][0:5]:
46
+ temp_img = cv2.imread(filenames[file])
47
+ cv2.imshow('output', cv2.resize(temp_img, (312, 312)))
48
+ cv2.waitKey(0)