Spaces:
Runtime error
Runtime error
Upload 11 files
Browse files- .gitattributes +1 -34
- README.md +46 -12
- app.py +54 -0
- chatbot.py +123 -0
- embeddings.pkl +3 -0
- embeddings_cache.pkl +3 -0
- filenames.pkl +3 -0
- main.py +80 -0
- main_1.py +228 -0
- requirements.txt +13 -0
- test.py +48 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.psd filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,12 +1,46 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gen-AI-based-fashion-product-recommendation-system
|
| 2 |
+
Fashion Recommender System with Chatbot Integration This project implements a Fashion Recommender System using state-of-the-art machine learning models and a chatbot interface. The system recommends fashion products based on uploaded images and responds to user queries through a chat interface.
|
| 3 |
+
|
| 4 |
+
Technologies Used:
|
| 5 |
+
|
| 6 |
+
Streamlit: For building the user interface and dashboard.
|
| 7 |
+
|
| 8 |
+
TensorFlow: Utilized for image feature extraction and recommendation.
|
| 9 |
+
|
| 10 |
+
Sentence Transformers: Used for text embeddings and similarity calculations.
|
| 11 |
+
|
| 12 |
+
Hugging Face Transformers: Integrated for natural language processing tasks.
|
| 13 |
+
|
| 14 |
+
MySQL Connector: For user authentication and database operations.
|
| 15 |
+
|
| 16 |
+
PIL (Python Imaging Library): Used for image processing tasks.
|
| 17 |
+
|
| 18 |
+
scikit-learn: Employed for nearest neighbor search for recommendation.
|
| 19 |
+
|
| 20 |
+
PyTorch: Utilized for text and image embeddings.
|
| 21 |
+
|
| 22 |
+
Matplotlib: Used for displaying images and visualizations.
|
| 23 |
+
|
| 24 |
+
Key Features:
|
| 25 |
+
|
| 26 |
+
Login and Registration: Users can authenticate themselves before accessing the system.
|
| 27 |
+
|
| 28 |
+
Image Upload and Recommendation: Users can upload images of fashion products, and the system provides recommendations based on image similarity.
|
| 29 |
+
|
| 30 |
+
Chatbot Integration: Includes a chatbot interface that responds to user queries and provides additional product information.
|
| 31 |
+
|
| 32 |
+
Dashboard: Provides a user-friendly dashboard for navigation and interaction with the system.
|
| 33 |
+
|
| 34 |
+
Usage:
|
| 35 |
+
|
| 36 |
+
Clone the repository.
|
| 37 |
+
|
| 38 |
+
Install the required dependencies using pip install -r requirements.txt.
|
| 39 |
+
|
| 40 |
+
Run the application using streamlit run main.py.
|
| 41 |
+
|
| 42 |
+
Access the application through the provided URL.
|
| 43 |
+
|
| 44 |
+
Note:
|
| 45 |
+
|
| 46 |
+
This project is a prototype and may require additional enhancements for production use. Contributions and feedback are welcome!
|
app.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow
|
| 2 |
+
from tensorflow.keras.preprocessing import image
|
| 3 |
+
from tensorflow.keras.layers import GlobalMaxPooling2D
|
| 4 |
+
from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
|
| 5 |
+
import numpy as np
|
| 6 |
+
from numpy.linalg import norm
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import pickle
|
| 10 |
+
|
| 11 |
+
# Load pre-trained ResNet50 model
|
| 12 |
+
|
| 13 |
+
model = ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
|
| 14 |
+
model.trainable = False
|
| 15 |
+
|
| 16 |
+
# Create a Sequential model with ResNet50 as base and GlobalMaxPooling2D layer
|
| 17 |
+
|
| 18 |
+
model = tensorflow.keras.Sequential([
|
| 19 |
+
model,
|
| 20 |
+
GlobalMaxPooling2D()
|
| 21 |
+
])
|
| 22 |
+
|
| 23 |
+
#print(model.summary())
|
| 24 |
+
|
| 25 |
+
# Function to extract features from an image using the model
|
| 26 |
+
def extract_features(img_path,model):
|
| 27 |
+
img = image.load_img(img_path,target_size=(224,224))
|
| 28 |
+
img_array = image.img_to_array(img)
|
| 29 |
+
expanded_img_array = np.expand_dims(img_array, axis=0)
|
| 30 |
+
preprocessed_img = preprocess_input(expanded_img_array)
|
| 31 |
+
|
| 32 |
+
# Extract features using the model and normalize the result
|
| 33 |
+
|
| 34 |
+
result = model.predict(preprocessed_img).flatten()
|
| 35 |
+
normalized_result = result / norm(result)
|
| 36 |
+
|
| 37 |
+
return normalized_result
|
| 38 |
+
|
| 39 |
+
# List all filenames in the 'images' directory
|
| 40 |
+
|
| 41 |
+
filenames = []
|
| 42 |
+
|
| 43 |
+
for file in os.listdir('images'):
|
| 44 |
+
filenames.append(os.path.join('images',file))
|
| 45 |
+
|
| 46 |
+
feature_list = [] # List to store extracted features
|
| 47 |
+
|
| 48 |
+
# Loop through each image file, extract features, and append to the feature_list
|
| 49 |
+
|
| 50 |
+
for file in tqdm(filenames):
|
| 51 |
+
feature_list.append(extract_features(file,model))
|
| 52 |
+
# Save the extracted features and filenames to pickle files
|
| 53 |
+
pickle.dump(feature_list,open('embeddings.pkl','wb')) # Save features
|
| 54 |
+
pickle.dump(filenames,open('filenames.pkl','wb')) # save filenames
|
chatbot.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import torch
|
| 4 |
+
import pickle
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from langchain_community.document_loaders import TextLoader
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from sentence_transformers import SentenceTransformer, util
|
| 9 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 10 |
+
from transformers import BertModel, BertTokenizer
|
| 11 |
+
from langchain_core.prompts import PromptTemplate
|
| 12 |
+
|
| 13 |
+
os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
|
| 14 |
+
|
| 15 |
+
class Chatbot:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.load_data()
|
| 18 |
+
self.load_models()
|
| 19 |
+
self.load_embeddings()
|
| 20 |
+
self.load_template()
|
| 21 |
+
|
| 22 |
+
def load_data(self):
|
| 23 |
+
self.data = load_dataset("ashraq/fashion-product-images-small", split="train")
|
| 24 |
+
self.images = self.data["image"]
|
| 25 |
+
self.product_frame = self.data.remove_columns("image").to_pandas()
|
| 26 |
+
self.product_data = self.product_frame.reset_index(drop=True).to_dict(orient='index')
|
| 27 |
+
|
| 28 |
+
def load_template(self):
|
| 29 |
+
self.template = """
|
| 30 |
+
You are a fashion shopping assistant that wants to convert customers based on the information given.
|
| 31 |
+
Describe season and usage given in the context in your interaction with the customer.
|
| 32 |
+
Use a bullet list when describing each product.
|
| 33 |
+
If user ask general question then answer them accordingly, the question may be like when the store will open, where is your store located.
|
| 34 |
+
Context: {context}
|
| 35 |
+
User question: {question}
|
| 36 |
+
Your response: {response}
|
| 37 |
+
"""
|
| 38 |
+
self.prompt = PromptTemplate.from_template(self.template)
|
| 39 |
+
|
| 40 |
+
def load_models(self):
|
| 41 |
+
self.model = SentenceTransformer('clip-ViT-B-32')
|
| 42 |
+
self.bert_model_name = "bert-base-uncased"
|
| 43 |
+
self.bert_model = BertModel.from_pretrained(self.bert_model_name)
|
| 44 |
+
self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
|
| 45 |
+
self.gpt2_model_name = "gpt2"
|
| 46 |
+
self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
|
| 47 |
+
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
|
| 48 |
+
|
| 49 |
+
def load_embeddings(self):
|
| 50 |
+
if os.path.exists("embeddings_cache.pkl"):
|
| 51 |
+
with open("embeddings_cache.pkl", "rb") as f:
|
| 52 |
+
embeddings_cache = pickle.load(f)
|
| 53 |
+
self.image_embeddings = embeddings_cache["image_embeddings"]
|
| 54 |
+
self.text_embeddings = embeddings_cache["text_embeddings"]
|
| 55 |
+
else:
|
| 56 |
+
self.image_embeddings = self.model.encode([image for image in self.images])
|
| 57 |
+
self.text_embeddings = self.model.encode(self.product_frame['productDisplayName'])
|
| 58 |
+
embeddings_cache = {"image_embeddings": self.image_embeddings, "text_embeddings": self.text_embeddings}
|
| 59 |
+
with open("embeddings_cache.pkl", "wb") as f:
|
| 60 |
+
pickle.dump(embeddings_cache, f)
|
| 61 |
+
|
| 62 |
+
def create_docs(self, results):
|
| 63 |
+
docs = []
|
| 64 |
+
for result in results:
|
| 65 |
+
pid = result['corpus_id']
|
| 66 |
+
score = result['score']
|
| 67 |
+
result_string = ''
|
| 68 |
+
result_string += "Product Name:" + self.product_data[pid]['productDisplayName'] + \
|
| 69 |
+
';' + "Category:" + self.product_data[pid]['masterCategory'] + \
|
| 70 |
+
';' + "Article Type:" + self.product_data[pid]['articleType'] + \
|
| 71 |
+
';' + "Usage:" + self.product_data[pid]['usage'] + \
|
| 72 |
+
';' + "Season:" + self.product_data[pid]['season'] + \
|
| 73 |
+
';' + "Gender:" + self.product_data[pid]['gender']
|
| 74 |
+
# Assuming text is imported from somewhere else
|
| 75 |
+
doc = text(page_content=result_string)
|
| 76 |
+
doc.metadata['pid'] = str(pid)
|
| 77 |
+
doc.metadata['score'] = score
|
| 78 |
+
docs.append(doc)
|
| 79 |
+
return docs
|
| 80 |
+
|
| 81 |
+
def get_results(self, query, embeddings, top_k=10):
|
| 82 |
+
query_embedding = self.model.encode([query])
|
| 83 |
+
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
|
| 84 |
+
top_results = torch.topk(cos_scores, k=top_k)
|
| 85 |
+
indices = top_results.indices.tolist()
|
| 86 |
+
scores = top_results.values.tolist()
|
| 87 |
+
results = [{'corpus_id': idx, 'score': score} for idx, score in zip(indices, scores)]
|
| 88 |
+
return results
|
| 89 |
+
|
| 90 |
+
def display_text_and_images(self, results_text):
|
| 91 |
+
for result in results_text:
|
| 92 |
+
pid = result['corpus_id']
|
| 93 |
+
product_info = self.product_data[pid]
|
| 94 |
+
print("Product Name:", product_info['productDisplayName'])
|
| 95 |
+
print("Category:", product_info['masterCategory'])
|
| 96 |
+
print("Article Type:", product_info['articleType'])
|
| 97 |
+
print("Usage:", product_info['usage'])
|
| 98 |
+
print("Season:", product_info['season'])
|
| 99 |
+
print("Gender:", product_info['gender'])
|
| 100 |
+
print("Score:", result['score'])
|
| 101 |
+
plt.imshow(self.images[pid])
|
| 102 |
+
plt.axis('off')
|
| 103 |
+
plt.show()
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def cos_sim(a, b):
|
| 107 |
+
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
|
| 108 |
+
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
|
| 109 |
+
return torch.mm(a_norm.T, b_norm) # Reshape a_norm to (768, 1)
|
| 110 |
+
|
| 111 |
+
def generate_response(self, query):
|
| 112 |
+
# Process the user query and generate a response
|
| 113 |
+
results_text = self.get_results(query, self.text_embeddings)
|
| 114 |
+
|
| 115 |
+
# Generate chatbot response
|
| 116 |
+
chatbot_response = "This is a placeholder response from the chatbot." # Placeholder, replace with actual response
|
| 117 |
+
|
| 118 |
+
# Display recommended products
|
| 119 |
+
self.display_text_and_images(results_text)
|
| 120 |
+
|
| 121 |
+
# Return both chatbot response and recommended products
|
| 122 |
+
return chatbot_response,results_text
|
| 123 |
+
|
embeddings.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88399617409c087450736ddfa9c64f99365cce80055cdcd2075f8dac76fe2a3e
|
| 3 |
+
size 134
|
embeddings_cache.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f3c0e69a8744b7f05c185f17a4d546e41e94982c2d72b9bee915fe174a2ecd1
|
| 3 |
+
size 134
|
filenames.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d0dfe2467c266620c47ca39138c3d4ad1c3d68e365ab007f44cba5d083b9fde
|
| 3 |
+
size 131
|
main.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import altair as alt
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
import tensorflow
|
| 8 |
+
from tensorflow.keras.preprocessing import image
|
| 9 |
+
from tensorflow.keras.layers import GlobalMaxPooling2D
|
| 10 |
+
from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
|
| 11 |
+
from sklearn.neighbors import NearestNeighbors
|
| 12 |
+
from numpy.linalg import norm
|
| 13 |
+
|
| 14 |
+
feature_list = np.array(pickle.load(open('embeddings.pkl','rb')))
|
| 15 |
+
filenames = pickle.load(open('filenames.pkl','rb'))
|
| 16 |
+
|
| 17 |
+
model = ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
|
| 18 |
+
model.trainable = False
|
| 19 |
+
|
| 20 |
+
model = tensorflow.keras.Sequential([
|
| 21 |
+
model,
|
| 22 |
+
GlobalMaxPooling2D()
|
| 23 |
+
])
|
| 24 |
+
|
| 25 |
+
st.title('Fashion Recommender System')
|
| 26 |
+
|
| 27 |
+
def save_uploaded_file(uploaded_file):
|
| 28 |
+
try:
|
| 29 |
+
with open(os.path.join('uploads',uploaded_file.name),'wb') as f:
|
| 30 |
+
f.write(uploaded_file.getbuffer())
|
| 31 |
+
return 1
|
| 32 |
+
except:
|
| 33 |
+
return 0
|
| 34 |
+
|
| 35 |
+
def feature_extraction(img_path,model):
|
| 36 |
+
img = image.load_img(img_path, target_size=(224, 224))
|
| 37 |
+
img_array = image.img_to_array(img)
|
| 38 |
+
expanded_img_array = np.expand_dims(img_array, axis=0)
|
| 39 |
+
preprocessed_img = preprocess_input(expanded_img_array)
|
| 40 |
+
result = model.predict(preprocessed_img).flatten()
|
| 41 |
+
normalized_result = result / norm(result)
|
| 42 |
+
|
| 43 |
+
return normalized_result
|
| 44 |
+
|
| 45 |
+
def recommend(features,feature_list):
|
| 46 |
+
neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
|
| 47 |
+
neighbors.fit(feature_list)
|
| 48 |
+
|
| 49 |
+
distances, indices = neighbors.kneighbors([features])
|
| 50 |
+
|
| 51 |
+
return indices
|
| 52 |
+
|
| 53 |
+
# steps
|
| 54 |
+
# file upload -> save
|
| 55 |
+
uploaded_file = st.file_uploader("Choose an image")
|
| 56 |
+
if uploaded_file is not None:
|
| 57 |
+
if save_uploaded_file(uploaded_file):
|
| 58 |
+
# display the file
|
| 59 |
+
display_image = Image.open(uploaded_file)
|
| 60 |
+
st.image(display_image)
|
| 61 |
+
# feature extract
|
| 62 |
+
features = feature_extraction(os.path.join("uploads",uploaded_file.name),model)
|
| 63 |
+
#st.text(features)
|
| 64 |
+
# recommendention
|
| 65 |
+
indices = recommend(features,feature_list)
|
| 66 |
+
# show
|
| 67 |
+
col1,col2,col3,col4,col5 = st.beta_columns(5)
|
| 68 |
+
|
| 69 |
+
with col1:
|
| 70 |
+
st.image(filenames[indices[0][0]])
|
| 71 |
+
with col2:
|
| 72 |
+
st.image(filenames[indices[0][1]])
|
| 73 |
+
with col3:
|
| 74 |
+
st.image(filenames[indices[0][2]])
|
| 75 |
+
with col4:
|
| 76 |
+
st.image(filenames[indices[0][3]])
|
| 77 |
+
with col5:
|
| 78 |
+
st.image(filenames[indices[0][4]])
|
| 79 |
+
else:
|
| 80 |
+
st.header("Some error occured in file upload")
|
main_1.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import mysql.connector
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
import tensorflow
|
| 8 |
+
from tensorflow.keras.preprocessing import image
|
| 9 |
+
from tensorflow.keras.layers import GlobalMaxPooling2D
|
| 10 |
+
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
| 11 |
+
from sklearn.neighbors import NearestNeighbors
|
| 12 |
+
from numpy.linalg import norm
|
| 13 |
+
from chatbot import Chatbot # Assuming you have a chatbot module
|
| 14 |
+
|
| 15 |
+
# Function to authenticate user credentials from the database
|
| 16 |
+
def authenticate_user(username, password):
|
| 17 |
+
try:
|
| 18 |
+
# Connect to the MySQL database
|
| 19 |
+
connection = mysql.connector.connect(
|
| 20 |
+
host="sql6.freemysqlhosting.net",
|
| 21 |
+
database="sql6697353",
|
| 22 |
+
user="sql6697353",
|
| 23 |
+
password="wvfSQJbmMs"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Create a cursor object to execute SQL queries
|
| 27 |
+
cursor = connection.cursor()
|
| 28 |
+
|
| 29 |
+
# Query to check if the username and password match
|
| 30 |
+
query = "SELECT * FROM users WHERE username = %s AND password = %s"
|
| 31 |
+
cursor.execute(query, (username, password))
|
| 32 |
+
user = cursor.fetchone() # Fetch the first row
|
| 33 |
+
|
| 34 |
+
# If user exists, return True (authentication successful)
|
| 35 |
+
if user:
|
| 36 |
+
return True
|
| 37 |
+
else:
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
except mysql.connector.Error as error:
|
| 41 |
+
st.error(f"Error: {error}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
finally:
|
| 45 |
+
# Close the cursor and database connection
|
| 46 |
+
if 'connection' in locals() and connection.is_connected():
|
| 47 |
+
cursor.close()
|
| 48 |
+
connection.close()
|
| 49 |
+
|
| 50 |
+
# Define your login function
|
| 51 |
+
def login():
|
| 52 |
+
st.header("Login Page")
|
| 53 |
+
# Add your login code here
|
| 54 |
+
username = st.text_input("Username")
|
| 55 |
+
password = st.text_input("Password", type="password")
|
| 56 |
+
if st.button("Login"):
|
| 57 |
+
# Authenticate user with provided credentials
|
| 58 |
+
if authenticate_user(username, password):
|
| 59 |
+
st.session_state.logged_in = True
|
| 60 |
+
query_params = st.experimental_get_query_params()
|
| 61 |
+
query_params["page"] = ["Dashboard"]
|
| 62 |
+
st.experimental_set_query_params(**query_params)
|
| 63 |
+
# Redirect to dashboard
|
| 64 |
+
st.success("Login successful!")
|
| 65 |
+
else:
|
| 66 |
+
st.error("Invalid username or password")
|
| 67 |
+
|
| 68 |
+
# Define your registration function
|
| 69 |
+
def register():
|
| 70 |
+
st.header("Registration Page")
|
| 71 |
+
# Add your registration code here
|
| 72 |
+
username = st.text_input("Username")
|
| 73 |
+
password = st.text_input("Password", type="password")
|
| 74 |
+
email = st.text_input("Email")
|
| 75 |
+
|
| 76 |
+
if st.button("Register"):
|
| 77 |
+
# Validate input
|
| 78 |
+
if not username or not password or not email:
|
| 79 |
+
st.error("Please enter username, password, and email.")
|
| 80 |
+
else:
|
| 81 |
+
try:
|
| 82 |
+
# Connect to the MySQL database
|
| 83 |
+
connection = mysql.connector.connect(
|
| 84 |
+
host="sql6.freemysqlhosting.net",
|
| 85 |
+
database="sql6697353",
|
| 86 |
+
user="sql6697353",
|
| 87 |
+
password="wvfSQJbmMs"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Create a cursor object to execute SQL queries
|
| 91 |
+
cursor = connection.cursor()
|
| 92 |
+
|
| 93 |
+
# Check if username or email already exists
|
| 94 |
+
cursor.execute("SELECT * FROM users WHERE username = %s OR email = %s", (username, email))
|
| 95 |
+
result = cursor.fetchone()
|
| 96 |
+
if result:
|
| 97 |
+
st.error("Username or email already exists. Please choose different ones.")
|
| 98 |
+
else:
|
| 99 |
+
# Insert new user into database
|
| 100 |
+
cursor.execute("INSERT INTO users (username, password, email) VALUES (%s, %s, %s)", (username, password, email))
|
| 101 |
+
connection.commit()
|
| 102 |
+
st.success("Registration successful. You can now log in.")
|
| 103 |
+
except mysql.connector.Error as error:
|
| 104 |
+
st.error(f"Error: {error}")
|
| 105 |
+
finally:
|
| 106 |
+
# Close the cursor and database connection
|
| 107 |
+
if 'connection' in locals() and connection.is_connected():
|
| 108 |
+
cursor.close()
|
| 109 |
+
connection.close()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Define function for feature extraction
|
| 113 |
+
def feature_extraction(img_path, model):
|
| 114 |
+
img = image.load_img(img_path, target_size=(224, 224))
|
| 115 |
+
img_array = image.img_to_array(img)
|
| 116 |
+
expanded_img_array = np.expand_dims(img_array, axis=0)
|
| 117 |
+
preprocessed_img = preprocess_input(expanded_img_array)
|
| 118 |
+
result = model.predict(preprocessed_img).flatten()
|
| 119 |
+
normalized_result = result / norm(result)
|
| 120 |
+
return normalized_result
|
| 121 |
+
|
| 122 |
+
# Define function for recommendation
|
| 123 |
+
def recommend(features, feature_list):
|
| 124 |
+
neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
|
| 125 |
+
neighbors.fit(feature_list)
|
| 126 |
+
distances, indices = neighbors.kneighbors([features])
|
| 127 |
+
return indices
|
| 128 |
+
|
| 129 |
+
# Function to save uploaded file
|
| 130 |
+
def save_uploaded_file(uploaded_file):
|
| 131 |
+
try:
|
| 132 |
+
with open(os.path.join('uploads', uploaded_file.name), 'wb') as f:
|
| 133 |
+
f.write(uploaded_file.getbuffer())
|
| 134 |
+
return True
|
| 135 |
+
except:
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
# Function to show dashboard content
|
| 139 |
+
def show_dashboard():
|
| 140 |
+
st.header("Fashion Recommender System")
|
| 141 |
+
chatbot = Chatbot()
|
| 142 |
+
# Load ResNet model for image feature extraction
|
| 143 |
+
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
| 144 |
+
model.trainable = False
|
| 145 |
+
model = tensorflow.keras.Sequential([
|
| 146 |
+
model,
|
| 147 |
+
GlobalMaxPooling2D()
|
| 148 |
+
])
|
| 149 |
+
|
| 150 |
+
feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
|
| 151 |
+
filenames = pickle.load(open('filenames.pkl', 'rb'))
|
| 152 |
+
|
| 153 |
+
# File upload section
|
| 154 |
+
uploaded_file = st.file_uploader("Choose an image")
|
| 155 |
+
if uploaded_file is not None:
|
| 156 |
+
if save_uploaded_file(uploaded_file):
|
| 157 |
+
# Display the uploaded image
|
| 158 |
+
display_image = Image.open(uploaded_file)
|
| 159 |
+
st.image(display_image)
|
| 160 |
+
|
| 161 |
+
# Feature extraction
|
| 162 |
+
features = feature_extraction(os.path.join("uploads", uploaded_file.name), model)
|
| 163 |
+
|
| 164 |
+
# Recommendation
|
| 165 |
+
indices = recommend(features, feature_list)
|
| 166 |
+
|
| 167 |
+
# Display recommended products
|
| 168 |
+
col1, col2, col3, col4, col5 = st.beta_columns(5)
|
| 169 |
+
with col1:
|
| 170 |
+
st.image(filenames[indices[0][0]])
|
| 171 |
+
with col2:
|
| 172 |
+
st.image(filenames[indices[0][1]])
|
| 173 |
+
with col3:
|
| 174 |
+
st.image(filenames[indices[0][2]])
|
| 175 |
+
with col4:
|
| 176 |
+
st.image(filenames[indices[0][3]])
|
| 177 |
+
with col5:
|
| 178 |
+
st.image(filenames[indices[0][4]])
|
| 179 |
+
|
| 180 |
+
else:
|
| 181 |
+
st.header("Some error occurred in file upload")
|
| 182 |
+
|
| 183 |
+
# Chatbot section
|
| 184 |
+
user_question = st.text_input("Ask a question:")
|
| 185 |
+
if user_question:
|
| 186 |
+
bot_response, recommended_products = chatbot.generate_response(user_question)
|
| 187 |
+
st.write("Chatbot:", bot_response)
|
| 188 |
+
|
| 189 |
+
# Display recommended products
|
| 190 |
+
for result in recommended_products:
|
| 191 |
+
pid = result['corpus_id']
|
| 192 |
+
product_info = chatbot.product_data[pid]
|
| 193 |
+
st.write("Product Name:", product_info['productDisplayName'])
|
| 194 |
+
st.write("Category:", product_info['masterCategory'])
|
| 195 |
+
st.write("Article Type:", product_info['articleType'])
|
| 196 |
+
st.write("Usage:", product_info['usage'])
|
| 197 |
+
st.write("Season:", product_info['season'])
|
| 198 |
+
st.write("Gender:", product_info['gender'])
|
| 199 |
+
st.image(chatbot.images[pid])
|
| 200 |
+
|
| 201 |
+
# Main Streamlit app
|
| 202 |
+
def main():
|
| 203 |
+
# Give title to the app
|
| 204 |
+
st.title("Fashion Recommender System")
|
| 205 |
+
|
| 206 |
+
# Sidebar navigation
|
| 207 |
+
page = st.sidebar.radio("Navigation", ["Login", "Register", "Dashboard"])
|
| 208 |
+
|
| 209 |
+
# Check if user is logged in
|
| 210 |
+
if "logged_in" not in st.session_state:
|
| 211 |
+
st.session_state.logged_in = False
|
| 212 |
+
|
| 213 |
+
# Show login page if user is not logged in
|
| 214 |
+
if not st.session_state.logged_in:
|
| 215 |
+
if page == "Login":
|
| 216 |
+
login()
|
| 217 |
+
|
| 218 |
+
# Show registration page if selected in the sidebar
|
| 219 |
+
if page == "Register":
|
| 220 |
+
register()
|
| 221 |
+
|
| 222 |
+
# Show dashboard if selected in the sidebar and user is logged in
|
| 223 |
+
if page == "Dashboard" and st.session_state.logged_in:
|
| 224 |
+
show_dashboard()
|
| 225 |
+
|
| 226 |
+
# Run the main app
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
torch
|
| 3 |
+
matplotlib
|
| 4 |
+
langchain-core # Might need adjustment based on availability
|
| 5 |
+
datasets
|
| 6 |
+
sentence-transformers
|
| 7 |
+
transformers
|
| 8 |
+
|
| 9 |
+
# Additional libraries based on your code (if applicable)
|
| 10 |
+
mysql-connector-python # If using MySQL database
|
| 11 |
+
Pillow # If using PIL for image processing
|
| 12 |
+
scikit-learn # If using scikit-learn (used for sklearn in your code)
|
| 13 |
+
langchain_community
|
test.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import tensorflow
|
| 3 |
+
import numpy as np
|
| 4 |
+
from numpy.linalg import norm
|
| 5 |
+
from tensorflow.keras.preprocessing import image
|
| 6 |
+
from tensorflow.keras.layers import GlobalMaxPooling2D
|
| 7 |
+
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
| 8 |
+
from sklearn.neighbors import NearestNeighbors
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
# Load the precomputed feature vectors and filenames from pickle files
|
| 12 |
+
feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
|
| 13 |
+
filenames = pickle.load(open('filenames.pkl', 'rb'))
|
| 14 |
+
|
| 15 |
+
# Load ResNet50 model without the top layer for feature extraction
|
| 16 |
+
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
| 17 |
+
model.trainable = False
|
| 18 |
+
|
| 19 |
+
# Create a Sequential model with ResNet50 and GlobalMaxPooling2D layers
|
| 20 |
+
model = tensorflow.keras.Sequential([
|
| 21 |
+
model,
|
| 22 |
+
GlobalMaxPooling2D()
|
| 23 |
+
])
|
| 24 |
+
|
| 25 |
+
# Load and preprocess the query image
|
| 26 |
+
img = image.load_img('sample/khade.jpg', target_size=(224, 224))
|
| 27 |
+
img_array = image.img_to_array(img)
|
| 28 |
+
expanded_img_array = np.expand_dims(img_array, axis=0)
|
| 29 |
+
preprocessed_img = preprocess_input(expanded_img_array)
|
| 30 |
+
|
| 31 |
+
# Extract features from the query image and normalize
|
| 32 |
+
result = model.predict(preprocessed_img).flatten()
|
| 33 |
+
normalized_result = result / norm(result)
|
| 34 |
+
|
| 35 |
+
# Initialize NearestNeighbors model and fit with the feature vectors
|
| 36 |
+
neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
|
| 37 |
+
neighbors.fit(feature_list)
|
| 38 |
+
|
| 39 |
+
# Find the nearest neighbors (excluding itself)
|
| 40 |
+
distances, indices = neighbors.kneighbors([normalized_result])
|
| 41 |
+
|
| 42 |
+
print(indices) # Print the indices of nearest neighbors
|
| 43 |
+
|
| 44 |
+
# Display the nearest neighbor images
|
| 45 |
+
for file in indices[0][0:5]:
|
| 46 |
+
temp_img = cv2.imread(filenames[file])
|
| 47 |
+
cv2.imshow('output', cv2.resize(temp_img, (312, 312)))
|
| 48 |
+
cv2.waitKey(0)
|