Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
-
import os
|
2 |
-
import pickle
|
3 |
-
import numpy as np
|
4 |
-
import pandas as pd
|
5 |
import streamlit as st
|
|
|
6 |
from PIL import Image
|
7 |
-
import
|
|
|
|
|
8 |
from tensorflow.keras.preprocessing import image
|
9 |
from tensorflow.keras.layers import GlobalMaxPooling2D
|
10 |
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
-
from
|
13 |
-
import
|
14 |
|
15 |
# Define function for feature extraction
|
16 |
def feature_extraction(img_path, model):
|
@@ -19,7 +18,7 @@ def feature_extraction(img_path, model):
|
|
19 |
expanded_img_array = np.expand_dims(img_array, axis=0)
|
20 |
preprocessed_img = preprocess_input(expanded_img_array)
|
21 |
result = model.predict(preprocessed_img).flatten()
|
22 |
-
normalized_result = result /
|
23 |
return normalized_result
|
24 |
|
25 |
# Define function for recommendation
|
@@ -40,41 +39,19 @@ def save_uploaded_file(uploaded_file):
|
|
40 |
with open(file_path, 'wb') as f:
|
41 |
f.write(uploaded_file.getbuffer())
|
42 |
st.success(f"File saved to {file_path}")
|
43 |
-
return
|
44 |
except Exception as e:
|
45 |
st.error(f"Error saving file: {e}")
|
46 |
-
return
|
47 |
-
|
48 |
-
# Function to load image data from dataset
|
49 |
-
def load_image_data():
|
50 |
-
dataset = load_dataset("ashraq/fashion-product-images-small", split="train")
|
51 |
-
images = dataset["image"]
|
52 |
-
product_frame = dataset.remove_columns("image").to_pandas()
|
53 |
-
product_data = product_frame.reset_index(drop=True).to_dict(orient='index')
|
54 |
-
return images, product_data
|
55 |
-
|
56 |
-
# Function to show similar images
|
57 |
-
def display_similar_images(indices, filenames, images):
|
58 |
-
col1, col2, col3, col4, col5 = st.columns(5)
|
59 |
-
columns = [col1, col2, col3, col4, col5]
|
60 |
-
|
61 |
-
for col, idx in zip(columns, indices[0]):
|
62 |
-
try:
|
63 |
-
img = images[idx]
|
64 |
-
with col:
|
65 |
-
st.image(img)
|
66 |
-
except Exception as e:
|
67 |
-
st.error(f"Error displaying image {idx}: {e}")
|
68 |
|
69 |
# Function to show dashboard content
|
70 |
def show_dashboard():
|
71 |
st.header("Fashion Recommender System")
|
72 |
-
|
73 |
-
# Load
|
74 |
-
images, product_data = load_image_data()
|
75 |
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
76 |
model.trainable = False
|
77 |
-
model =
|
78 |
model,
|
79 |
GlobalMaxPooling2D()
|
80 |
])
|
@@ -89,32 +66,35 @@ def show_dashboard():
|
|
89 |
# File upload section
|
90 |
uploaded_file = st.file_uploader("Choose an image")
|
91 |
if uploaded_file is not None:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
st.
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
# Chatbot section
|
114 |
user_question = st.text_input("Ask a question:")
|
115 |
if user_question:
|
116 |
-
from chatbot import Chatbot
|
117 |
-
chatbot = Chatbot()
|
118 |
bot_response, recommended_products = chatbot.generate_response(user_question)
|
119 |
st.write("Chatbot:", bot_response)
|
120 |
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import os
|
3 |
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import pickle
|
6 |
+
import tensorflow
|
7 |
from tensorflow.keras.preprocessing import image
|
8 |
from tensorflow.keras.layers import GlobalMaxPooling2D
|
9 |
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
10 |
from sklearn.neighbors import NearestNeighbors
|
11 |
+
from numpy.linalg import norm
|
12 |
+
from chatbot import Chatbot # Assuming you have a chatbot module
|
13 |
|
14 |
# Define function for feature extraction
|
15 |
def feature_extraction(img_path, model):
|
|
|
18 |
expanded_img_array = np.expand_dims(img_array, axis=0)
|
19 |
preprocessed_img = preprocess_input(expanded_img_array)
|
20 |
result = model.predict(preprocessed_img).flatten()
|
21 |
+
normalized_result = result / norm(result)
|
22 |
return normalized_result
|
23 |
|
24 |
# Define function for recommendation
|
|
|
39 |
with open(file_path, 'wb') as f:
|
40 |
f.write(uploaded_file.getbuffer())
|
41 |
st.success(f"File saved to {file_path}")
|
42 |
+
return True
|
43 |
except Exception as e:
|
44 |
st.error(f"Error saving file: {e}")
|
45 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# Function to show dashboard content
|
48 |
def show_dashboard():
|
49 |
st.header("Fashion Recommender System")
|
50 |
+
chatbot = Chatbot()
|
51 |
+
# Load ResNet model for image feature extraction
|
|
|
52 |
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
53 |
model.trainable = False
|
54 |
+
model = tensorflow.keras.Sequential([
|
55 |
model,
|
56 |
GlobalMaxPooling2D()
|
57 |
])
|
|
|
66 |
# File upload section
|
67 |
uploaded_file = st.file_uploader("Choose an image")
|
68 |
if uploaded_file is not None:
|
69 |
+
if save_uploaded_file(uploaded_file):
|
70 |
+
# Display the uploaded image
|
71 |
+
display_image = Image.open(uploaded_file)
|
72 |
+
st.image(display_image)
|
73 |
+
|
74 |
+
# Feature extraction
|
75 |
+
features = feature_extraction(os.path.join("uploads", uploaded_file.name), model)
|
76 |
+
|
77 |
+
# Recommendation
|
78 |
+
indices = recommend(features, feature_list)
|
79 |
+
|
80 |
+
# Display recommended products
|
81 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
82 |
+
with col1:
|
83 |
+
st.image(filenames[indices[0][0]])
|
84 |
+
with col2:
|
85 |
+
st.image(filenames[indices[0][1]])
|
86 |
+
with col3:
|
87 |
+
st.image(filenames[indices[0][2]])
|
88 |
+
with col4:
|
89 |
+
st.image(filenames[indices[0][3]])
|
90 |
+
with col5:
|
91 |
+
st.image(filenames[indices[0][4]])
|
92 |
+
else:
|
93 |
+
st.error("Some error occurred in file upload")
|
94 |
|
95 |
# Chatbot section
|
96 |
user_question = st.text_input("Ask a question:")
|
97 |
if user_question:
|
|
|
|
|
98 |
bot_response, recommended_products = chatbot.generate_response(user_question)
|
99 |
st.write("Chatbot:", bot_response)
|
100 |
|