Prathamesh1420 commited on
Commit
0ace04f
·
verified ·
1 Parent(s): 66a0dab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -139
app.py CHANGED
@@ -1,135 +1,32 @@
1
-
2
-
3
- import os
4
- import pickle
5
- import torch
6
- import matplotlib.pyplot as plt
7
- from langchain_community.document_loaders import TextLoader
8
- from datasets import load_dataset
9
- from sentence_transformers import SentenceTransformer, util
10
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
11
- from transformers import BertModel, BertTokenizer
12
- from langchain_core.prompts import PromptTemplate
13
  import streamlit as st
 
14
  from PIL import Image
15
  import numpy as np
16
- import tensorflow as tf
 
17
  from tensorflow.keras.preprocessing import image
18
  from tensorflow.keras.layers import GlobalMaxPooling2D
19
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
20
  from sklearn.neighbors import NearestNeighbors
21
  from numpy.linalg import norm
22
-
23
- os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
24
-
25
- class Chatbot:
26
- def __init__(self):
27
- self.load_data()
28
- self.load_models()
29
- self.load_embeddings()
30
- self.load_template()
31
-
32
- def load_data(self):
33
- self.data = load_dataset("ashraq/fashion-product-images-small", split="train")
34
- self.images = self.data["image"]
35
- self.product_frame = self.data.remove_columns("image").to_pandas()
36
- self.product_data = self.product_frame.reset_index(drop=True).to_dict(orient='index')
37
-
38
- def load_template(self):
39
- self.template = """
40
- You are a fashion shopping assistant that wants to convert customers based on the information given.
41
- Describe season and usage given in the context in your interaction with the customer.
42
- Use a bullet list when describing each product.
43
- If user ask general question then answer them accordingly, the question may be like when the store will open, where is your store located.
44
- Context: {context}
45
- User question: {question}
46
- Your response: {response}
47
- """
48
- self.prompt = PromptTemplate.from_template(self.template)
49
-
50
- def load_models(self):
51
- self.model = SentenceTransformer('clip-ViT-B-32')
52
- self.bert_model_name = "bert-base-uncased"
53
- self.bert_model = BertModel.from_pretrained(self.bert_model_name)
54
- self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
55
- self.gpt2_model_name = "gpt2"
56
- self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
57
- self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
58
-
59
- def load_embeddings(self):
60
- if os.path.exists("embeddings_cache.pkl"):
61
- with open("embeddings_cache.pkl", "rb") as f:
62
- embeddings_cache = pickle.load(f)
63
- self.image_embeddings = embeddings_cache["image_embeddings"]
64
- self.text_embeddings = embeddings_cache["text_embeddings"]
65
- else:
66
- self.image_embeddings = self.model.encode([image for image in self.images])
67
- self.text_embeddings = self.model.encode(self.product_frame['productDisplayName'])
68
- embeddings_cache = {"image_embeddings": self.image_embeddings, "text_embeddings": self.text_embeddings}
69
- with open("embeddings_cache.pkl", "wb") as f:
70
- pickle.dump(embeddings_cache, f)
71
-
72
- def create_docs(self, results):
73
- docs = []
74
- for result in results:
75
- pid = result['corpus_id']
76
- score = result['score']
77
- result_string = ''
78
- result_string += "Product Name:" + self.product_data[pid]['productDisplayName'] + \
79
- ';' + "Category:" + self.product_data[pid]['masterCategory'] + \
80
- ';' + "Article Type:" + self.product_data[pid]['articleType'] + \
81
- ';' + "Usage:" + self.product_data[pid]['usage'] + \
82
- ';' + "Season:" + self.product_data[pid]['season'] + \
83
- ';' + "Gender:" + self.product_data[pid]['gender']
84
- # Assuming text is imported from somewhere else
85
- doc = TextLoader(page_content=result_string)
86
- doc.metadata['pid'] = str(pid)
87
- doc.metadata['score'] = score
88
- docs.append(doc)
89
- return docs
90
-
91
- def get_results(self, query, embeddings, top_k=10):
92
- query_embedding = self.model.encode([query])
93
- cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
94
- top_results = torch.topk(cos_scores, k=top_k)
95
- indices = top_results.indices.tolist()
96
- scores = top_results.values.tolist()
97
- results = [{'corpus_id': idx, 'score': score} for idx, score in zip(indices, scores)]
98
- return results
99
-
100
- def display_text_and_images(self, results_text):
101
- for result in results_text:
102
- pid = result['corpus_id']
103
- product_info = self.product_data[pid]
104
- print("Product Name:", product_info['productDisplayName'])
105
- print("Category:", product_info['masterCategory'])
106
- print("Article Type:", product_info['articleType'])
107
- print("Usage:", product_info['usage'])
108
- print("Season:", product_info['season'])
109
- print("Gender:", product_info['gender'])
110
- print("Score:", result['score'])
111
- plt.imshow(self.images[pid])
112
- plt.axis('off')
113
- plt.show()
114
-
115
- @staticmethod
116
- def cos_sim(a, b):
117
- a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
118
- b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
119
- return torch.mm(a_norm.T, b_norm) # Reshape a_norm to (768, 1)
120
-
121
- def generate_response(self, query):
122
- # Process the user query and generate a response
123
- results_text = self.get_results(query, self.text_embeddings)
124
-
125
- # Generate chatbot response
126
- chatbot_response = "This is a placeholder response from the chatbot." # Placeholder, replace with actual response
127
-
128
- # Display recommended products
129
- self.display_text_and_images(results_text)
130
-
131
- # Return both chatbot response and recommended products
132
- return chatbot_response, results_text
133
 
134
  # Function to save uploaded file
135
  def save_uploaded_file(uploaded_file):
@@ -147,13 +44,13 @@ def show_dashboard():
147
  # Load ResNet model for image feature extraction
148
  model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
149
  model.trainable = False
150
- model = tf.keras.Sequential([
151
  model,
152
  GlobalMaxPooling2D()
153
  ])
154
 
155
  feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
156
- # filenames = pickle.load(open('filenames.pkl', 'rb')) # No longer needed
157
 
158
  # File upload section
159
  uploaded_file = st.file_uploader("Choose an image")
@@ -169,19 +66,18 @@ def show_dashboard():
169
  # Recommendation
170
  indices = recommend(features, feature_list)
171
 
172
- # Display recommended products using the dataset images
173
- st.write("Recommended Products:")
174
- cols = st.columns(5)
175
- for i, idx in enumerate(indices[0][:5]):
176
- with cols[i]:
177
- st.image(chatbot.images[idx])
178
- product_info = chatbot.product_data[idx]
179
- st.write("Product Name:", product_info['productDisplayName'])
180
- st.write("Category:", product_info['masterCategory'])
181
- st.write("Article Type:", product_info['articleType'])
182
- st.write("Usage:", product_info['usage'])
183
- st.write("Season:", product_info['season'])
184
- st.write("Gender:", product_info['gender'])
185
 
186
  else:
187
  st.header("Some error occurred in file upload")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
  from PIL import Image
4
  import numpy as np
5
+ import pickle
6
+ import tensorflow
7
  from tensorflow.keras.preprocessing import image
8
  from tensorflow.keras.layers import GlobalMaxPooling2D
9
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
10
  from sklearn.neighbors import NearestNeighbors
11
  from numpy.linalg import norm
12
+ from chatbot import Chatbot # Assuming you have a chatbot module
13
+
14
+ # Define function for feature extraction
15
+ def feature_extraction(img_path, model):
16
+ img = image.load_img(img_path, target_size=(224, 224))
17
+ img_array = image.img_to_array(img)
18
+ expanded_img_array = np.expand_dims(img_array, axis=0)
19
+ preprocessed_img = preprocess_input(expanded_img_array)
20
+ result = model.predict(preprocessed_img).flatten()
21
+ normalized_result = result / norm(result)
22
+ return normalized_result
23
+
24
+ # Define function for recommendation
25
+ def recommend(features, feature_list):
26
+ neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
27
+ neighbors.fit(feature_list)
28
+ distances, indices = neighbors.kneighbors([features])
29
+ return indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Function to save uploaded file
32
  def save_uploaded_file(uploaded_file):
 
44
  # Load ResNet model for image feature extraction
45
  model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
46
  model.trainable = False
47
+ model = tensorflow.keras.Sequential([
48
  model,
49
  GlobalMaxPooling2D()
50
  ])
51
 
52
  feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
53
+ filenames = pickle.load(open('filenames.pkl', 'rb'))
54
 
55
  # File upload section
56
  uploaded_file = st.file_uploader("Choose an image")
 
66
  # Recommendation
67
  indices = recommend(features, feature_list)
68
 
69
+ # Display recommended products
70
+ col1, col2, col3, col4, col5 = st.columns(5)
71
+ with col1:
72
+ st.image(filenames[indices[0][0]])
73
+ with col2:
74
+ st.image(filenames[indices[0][1]])
75
+ with col3:
76
+ st.image(filenames[indices[0][2]])
77
+ with col4:
78
+ st.image(filenames[indices[0][3]])
79
+ with col5:
80
+ st.image(filenames[indices[0][4]])
 
81
 
82
  else:
83
  st.header("Some error occurred in file upload")