Prathamesh1420 commited on
Commit
682642b
·
verified ·
1 Parent(s): e8f45ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -57
app.py CHANGED
@@ -1,16 +1,15 @@
1
- import os
2
- import pickle
3
- import numpy as np
4
- import pandas as pd
5
  import streamlit as st
 
6
  from PIL import Image
7
- import tensorflow as tf
 
 
8
  from tensorflow.keras.preprocessing import image
9
  from tensorflow.keras.layers import GlobalMaxPooling2D
10
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
11
  from sklearn.neighbors import NearestNeighbors
12
- from datasets import load_dataset
13
- import zipfile
14
 
15
  # Define function for feature extraction
16
  def feature_extraction(img_path, model):
@@ -19,7 +18,7 @@ def feature_extraction(img_path, model):
19
  expanded_img_array = np.expand_dims(img_array, axis=0)
20
  preprocessed_img = preprocess_input(expanded_img_array)
21
  result = model.predict(preprocessed_img).flatten()
22
- normalized_result = result / np.linalg.norm(result)
23
  return normalized_result
24
 
25
  # Define function for recommendation
@@ -40,41 +39,19 @@ def save_uploaded_file(uploaded_file):
40
  with open(file_path, 'wb') as f:
41
  f.write(uploaded_file.getbuffer())
42
  st.success(f"File saved to {file_path}")
43
- return file_path
44
  except Exception as e:
45
  st.error(f"Error saving file: {e}")
46
- return None
47
-
48
- # Function to load image data from dataset
49
- def load_image_data():
50
- dataset = load_dataset("ashraq/fashion-product-images-small", split="train")
51
- images = dataset["image"]
52
- product_frame = dataset.remove_columns("image").to_pandas()
53
- product_data = product_frame.reset_index(drop=True).to_dict(orient='index')
54
- return images, product_data
55
-
56
- # Function to show similar images
57
- def display_similar_images(indices, filenames, images):
58
- col1, col2, col3, col4, col5 = st.columns(5)
59
- columns = [col1, col2, col3, col4, col5]
60
-
61
- for col, idx in zip(columns, indices[0]):
62
- try:
63
- img = images[idx]
64
- with col:
65
- st.image(img)
66
- except Exception as e:
67
- st.error(f"Error displaying image {idx}: {e}")
68
 
69
  # Function to show dashboard content
70
  def show_dashboard():
71
  st.header("Fashion Recommender System")
72
-
73
- # Load the dataset and models
74
- images, product_data = load_image_data()
75
  model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
76
  model.trainable = False
77
- model = tf.keras.Sequential([
78
  model,
79
  GlobalMaxPooling2D()
80
  ])
@@ -89,32 +66,35 @@ def show_dashboard():
89
  # File upload section
90
  uploaded_file = st.file_uploader("Choose an image")
91
  if uploaded_file is not None:
92
- file_path = save_uploaded_file(uploaded_file)
93
- if file_path:
94
- try:
95
- display_image = Image.open(file_path)
96
- st.image(display_image)
97
- except Exception as e:
98
- st.error(f"Error displaying uploaded image: {e}")
99
-
100
- try:
101
- features = feature_extraction(file_path, model)
102
- except Exception as e:
103
- st.error(f"Error extracting features: {e}")
104
- return
105
-
106
- try:
107
- indices = recommend(features, feature_list)
108
- display_similar_images(indices, filenames, images)
109
- except Exception as e:
110
- st.error(f"Error in recommendation: {e}")
111
- return
 
 
 
 
 
112
 
113
  # Chatbot section
114
  user_question = st.text_input("Ask a question:")
115
  if user_question:
116
- from chatbot import Chatbot
117
- chatbot = Chatbot()
118
  bot_response, recommended_products = chatbot.generate_response(user_question)
119
  st.write("Chatbot:", bot_response)
120
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
  from PIL import Image
4
+ import numpy as np
5
+ import pickle
6
+ import tensorflow
7
  from tensorflow.keras.preprocessing import image
8
  from tensorflow.keras.layers import GlobalMaxPooling2D
9
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
10
  from sklearn.neighbors import NearestNeighbors
11
+ from numpy.linalg import norm
12
+ from chatbot import Chatbot # Assuming you have a chatbot module
13
 
14
  # Define function for feature extraction
15
  def feature_extraction(img_path, model):
 
18
  expanded_img_array = np.expand_dims(img_array, axis=0)
19
  preprocessed_img = preprocess_input(expanded_img_array)
20
  result = model.predict(preprocessed_img).flatten()
21
+ normalized_result = result / norm(result)
22
  return normalized_result
23
 
24
  # Define function for recommendation
 
39
  with open(file_path, 'wb') as f:
40
  f.write(uploaded_file.getbuffer())
41
  st.success(f"File saved to {file_path}")
42
+ return True
43
  except Exception as e:
44
  st.error(f"Error saving file: {e}")
45
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Function to show dashboard content
48
  def show_dashboard():
49
  st.header("Fashion Recommender System")
50
+ chatbot = Chatbot()
51
+ # Load ResNet model for image feature extraction
 
52
  model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
53
  model.trainable = False
54
+ model = tensorflow.keras.Sequential([
55
  model,
56
  GlobalMaxPooling2D()
57
  ])
 
66
  # File upload section
67
  uploaded_file = st.file_uploader("Choose an image")
68
  if uploaded_file is not None:
69
+ if save_uploaded_file(uploaded_file):
70
+ # Display the uploaded image
71
+ display_image = Image.open(uploaded_file)
72
+ st.image(display_image)
73
+
74
+ # Feature extraction
75
+ features = feature_extraction(os.path.join("uploads", uploaded_file.name), model)
76
+
77
+ # Recommendation
78
+ indices = recommend(features, feature_list)
79
+
80
+ # Display recommended products
81
+ col1, col2, col3, col4, col5 = st.columns(5)
82
+ with col1:
83
+ st.image(filenames[indices[0][0]])
84
+ with col2:
85
+ st.image(filenames[indices[0][1]])
86
+ with col3:
87
+ st.image(filenames[indices[0][2]])
88
+ with col4:
89
+ st.image(filenames[indices[0][3]])
90
+ with col5:
91
+ st.image(filenames[indices[0][4]])
92
+ else:
93
+ st.error("Some error occurred in file upload")
94
 
95
  # Chatbot section
96
  user_question = st.text_input("Ask a question:")
97
  if user_question:
 
 
98
  bot_response, recommended_products = chatbot.generate_response(user_question)
99
  st.write("Chatbot:", bot_response)
100